In [49]:
from random import randint
from sympy import isprime

class XTR:
    def __init__(self, P, Q, generate = False):
        if generate:
            self.p, self.q = self.generate_params(P, Q)
        else:
            self.p = P
            self.q = Q
        self.GF_p2 = self.create_gf_p2()

    def generate_params(self, P, Q):
        while True:
            q = self.generate_prime(Q)
            p = self.generate_p(q, P)
            if p is not None:
                return p, q
                
    def generate_prime(self, num_bits):
        while True:
            candidate = randint(2**(num_bits-1), 2**num_bits - 1)
            if candidate % 12 == 7 and self.is_prime(candidate):
                return candidate

    def generate_p(self, q, num_bits):
        # Алгоритм 3.14
        r1, r2 = self.find_roots(q)
        while True:
            i = randint(1, 2)
            k = randint(1, 2**(num_bits - q.bit_length()))
            p = (r1 + k * q if i == 1 else r2 + k * q)
            if self.is_prime(p) and p % 3 == 2:
                s = (p * p - p + 1) // q
                if s > 0 and self.is_prime(s) and s.bit_length() >= q.bit_length():
                    return p
            if self.is_prime(p) and p % 3 == 2 and (p * p - p + 1) % (3 * q) == 0: # Алгоритм 3.13
                return p
        return None

    def find_roots(self, q):
        # Алгоритм 3.12
        for r in range(2, q):
            if (r * r - r + 1) % q == 0:
                return r, q - r
        raise ValueError("Корни не найдены")

    @staticmethod
    def is_prime(p):
        return isprime(p)

    @staticmethod
    def ExtendedGCD(a, b):
        m11, m12 = 1, 0
        m21, m22 = 0, 1
        while b:
            q = a // b
            a, b = b, a % b
            m11, m12 = m12, m11 - m12*q
            m21, m22 = m22, m21 - m22*q
        return (m11, m21)

    @staticmethod
    def inverse(a, m):
        u, v = XTR.ExtendedGCD(a,m)
        d = u*a + v*m
        if d > 1:
            return f"Обратного элемента к {a} не существует"
        else:
            if u < 0:
                u += m
            return u

    def create_gf_p2(self):
        p = self.p
        if p % 3 != 2 or not self.is_prime(p):
            raise ValueError("p должно быть простым и сравнимым с 2 по модулю 3")
        
        class GF_p2:
            p = self.p
            def __init__(self, x1, x2):
                self.x1 = x1 % p
                self.x2 = x2 % p
            
            def __repr__(self):
                return f"{self.x1}*α + {self.x2}*α^2"
            
            def __eq__(self, other):
                return self.x1 == other.x1 and self.x2 == other.x2
            
            def __add__(self, other):
                if isinstance(other, int):
                    other = GF_p2(other, 0)
                return GF_p2((self.x1 + other.x1) % p, (self.x2 + other.x2) % p)
            
            def __sub__(self, other):
                if isinstance(other, int):
                    other = GF_p2(other, 0)
                return GF_p2((self.x1 - other.x1) % p, (self.x2 - other.x2) % p)

            def __neg__(self):
                return GF_p2((-self.x1) % p, (-self.x2) % p)
            
            def __mul__(self, other):
                if isinstance(other, int):
                    other = GF_p2(other, 0)
                y1 = (self.x1 * other.x1) % p
                y2 = (self.x2 * other.x2) % p
                y12 = (self.x1 * other.x2) % p
                y21 = (self.x2 * other.x1) % p
                
                return GF_p2((y2 - y12 - y21) % p, (y1 - y12 - y21) % p)

            def __rmul__(self, other):
                if isinstance(other, int):
                    other = GF_p2(other, 0)
                return self.__mul__(other)
            
            def __pow__(self, power):
                if power == 0:
                    return GF_p2(1, 0)
                elif power == 1:
                    return self
                elif power == 2:
                    return self.square()
                elif power < 0:
                    norm = (self.x1**2 + self.x1 * self.x2 + self.x2**2) % p
                    if norm == 0:
                        raise ZeroDivisionError(f"Деление на 0 в  GF({p}^2)")
                    inv_norm = XTR.inverse(norm, p)
                    conj = self.conjugate()
                    return conj * inv_norm ** (-power)
                else:
                    result = GF_p2(1, 0)
                    base = self
                    while power > 0:
                        if power % 2 == 1:
                            result *= base
                        base = base.square()
                        power //= 2
                    return result

            
            def conjugate(self):
                return GF_p2(self.x2, self.x1)
            
            def square(self):
                return GF_p2((self.x2 * (self.x2 - 2*self.x1)) % p, 
                             (self.x1 * (self.x1 - 2*self.x2)) % p)

        return GF_p2

    @staticmethod
    def st(n):
        s = 0
        t = n
        while t % 2 == 0:
            s += 1
            t = t // 2
        return s, t

    def jacobi(self, a, n):
        if n < 0 or not n % 2:
            raise ValueError("n должно быть нечетным положительным целым числом") 
        j = 1
        if n == 1:
            return j
        if a < 0:
            a = -a
            if n % 4 == 3:
                j = -j
        while n > 1:
            if a == 0:
                return 0
            s, t = self.st(a)
            if (s % 2 == 1) & (n % 8 in [3, 5]):
                j = -j
            if 3 == n % 4 == t % 4:
                j = -j
            a = n % t
            n = t
        return j

    def shanks(self, a, p):
        if not self.jacobi(a, p) == 1:
            raise ValueError("a должно быть квадратичным вычетом")
        s, t = self.st(p - 1)
    
        n = randint(2, p - 2) 
        while self.jacobi(n, p) == 1:
            n = randint(2, p - 2)
        b = pow(n, t, p)
        r = pow(a, (t + 1) // 2, p)
        d = 0
        f = pow(a, t, p)
        b2 = b
        for i in range(1, s):
            b2 = b2 * b2 % p
            if not pow(f, 2 ** (s - 1 - i), p) == 1:
                d += 2 ** i
                f = f * b2 % p
        return r * pow(b, d // 2, p) % p
    
    @staticmethod
    def gcd(a, b):
        while b != 0:
            a, b = b, a % b
        return a
    
    def is_quadratic_residue(self, a):
        p = self.p
        return pow(a, (p - 1) // 2, p) == 1
    
    def sqrt_mod(self, a):
        p = self.p
        if p % 4 == 3:
            return pow(a, (p + 1) // 4, p)
        else:
            return self.shanks(a, p)
    
    def random_element_GF_p2(self):
        p = self.p
        return self.GF_p2(randint(0, p - 1), randint(0, p - 1))
    

    def is_irreducible(self, c):
        # Алгоритм 3.33
        p = self.p
        GF_p2 = self.GF_p2
        f0 = (GF_p2(-27, 0) + 9 * c**(p + 1) - 2 * (c ** 3)) * GF_p2(pow(27, -1, p), 0)
        f1 = (c ** p - c**2 * GF_p2(pow(3, -1, p), 0)) * GF_p2(pow(3, -1, p), 0)
        delta = (f0 ** 2 + 4 * (f1 ** 3)).x1

        if not self.is_quadratic_residue(delta):
            return False
            
        r1 = (-f0 + GF_p2(self.sqrt_mod(delta), 0)) * GF_p2(pow(2, -1, p), 0)
        y = self.compute_c_n(r1, (p + 1) // 3)[1]
    
        return y != y**p
    
    def find_irreducible_polynomial(self):
        p = self.p
        while True:
            c = self.random_element_GF_p2()
            if c != self.GF_p2(0, 0) and self.is_irreducible(c):
                return c

    def compute_c_n(self, c, n):
        # Алгоритм 2.35
        p = self.p
        if n < 0:
            c_neg_n = self.compute_c_n(c, -n)
            c_neg_n = (c_neg_n[2]**p, c_neg_n[1]**p, c_neg_n[0]**p)
            return c_neg_n
        elif n == 0:
            c_p = c**p 
            return (c_p, self.GF_p2(3, 0), c)
        elif n == 1:
            c_p = c**p
            c_2 = c.square() - 2 * c_p
            return (self.GF_p2(3, 0), c, c_2)
        elif n == 2:
            S_1 = self.compute_c_n(c, 1)
            c_p = c**p
            c_3 = c * S_1[2] - c_p * S_1[1] + S_1[0]
            return (S_1[1], S_1[2], c_3)

        elif n == 3:
            S_2 = self.compute_c_n(c, 2)
            c_p = c**p
            c_4 = c * S_2[2] - c_p * S_2[1] + S_2[0]
            return (S_2[1], S_2[2], c_4)
        else:
            if n % 2 == 1:
                m = n
            else:
                m = n - 1
            k = 1
            S_k = self.compute_c_n(c, 3)
            
            c_2k, c_2k_1, c_2k_2 = S_k
            
            m_bits = bin((m-1)//2)[2:]
            r = len(m_bits)
            if r != 1:
                for j in range(r):
                    if j == 0:
                        continue
                    if eval(m_bits[j]) == 0:
                        c_4k = c_2k**2 - 2 * (c_2k**p)
                        c_4k_1 = c_2k * c_2k_1 - (c_2k_1**p)*(c**p) + c_2k_2**p
                        c_4k_2 = c_2k_1**2 - 2 * (c_2k_1**p)
                        S_k = (c_4k, c_4k_1, c_4k_2)
                        k = 2*k
                    else:
                        c_4k_2 = c_2k_1**2 - 2 * (c_2k_1**p)
                        c_4k_3 = c_2k_2 * c_2k_1 - c * (c_2k_1**p) + c_2k**p    
                        c_4k_4 = c_2k_2**2 - 2 * (c_2k_2**p)
                        S_k = (c_4k_2, c_4k_3, c_4k_4)
                        k = 2 * k + 1
            if m == n:
                return S_k
            else:
                S_k = (S_k[1], S_k[2], c * S_k[2] - (c**p) * S_k[1] + S_k[0])
                return S_k

    def find_trace_g(self):
        p = self.p
        q = self.q
        n = (p**2 - p + 1) // q
        while True:
            c = self.find_irreducible_polynomial()
            d = self.compute_c_n(c, n)[1]  
            if d != self.GF_p2(3, 0):
                return d
    
    
    def find_trace_g_fast(self):
        p = self.p
        q = self.q
        if p % 9 == 8:
            return self.find_trace_g()
        while True:
            a = randint(0, p - 1)
            c = (pow(a, 6, p) - pow(a, 3, p) + 1) % p
            if c != 0:
                c = (-c * pow(3, p - 2, p)) % p
                c = self.GF_p2(c, 0)
                c_q = self.compute_c_n(c, q)[1]
                if c_q != self.GF_p2(3, 0):
                    return c_q
   

In [47]:
def xtr_dh_key_agreement(p, q, Tr_g):
    a = randint(2, q - 3)
    Sa_Tr_g = xtr.compute_c_n(Tr_g, a)
    Tr_g_a = Sa_Tr_g[1]

    b = randint(2, q - 3)
    Sb_Tr_g = xtr.compute_c_n(Tr_g, b)
    Tr_g_b = Sb_Tr_g[1]

    Sa_Tr_g_b = xtr.compute_c_n(Tr_g_b, a)
    Tr_g_ab_alice = Sa_Tr_g_b[1]

    Sb_Tr_g_a = xtr.compute_c_n(Tr_g_a, b)
    Tr_g_ab_bob = Sb_Tr_g_a[1]
    
    if Tr_g_ab_alice == Tr_g_ab_bob:
        return Tr_g_ab_alice
    else:
        return xtr_dh_key_agreement(p, q, Tr_g)

xtr = XTR(15,10, generate = True)
Tr_g = xtr.find_trace_g_fast()
print("Shared XTR public key data:")
print("p =", xtr.p)
print("q =", xtr.q)
print("Tr(g) =", Tr_g)

shared_key = xtr_dh_key_agreement(xtr.p, xtr.q, Tr_g)
print("Shared key:", shared_key)

Shared XTR public key data:
p = 27749
q = 859
Tr(g) = 26310*α + 7987*α^2
Shared key: 15730*α + 10792*α^2


In [46]:
class XTRElGamal:
    def __init__(self, xtr, Tr_g):
        self.xtr = xtr
        self.p = xtr.p
        self.q = xtr.q
        self.Tr_g = Tr_g

    def generate_keys(self):
        k = randint(2, self.q - 3) 
        public_key = self.xtr.compute_c_n(self.Tr_g, k)[1]
        private_key = k
        return public_key, private_key

    def encrypt(self, message, public_key):
        Tr_gk = public_key
        b = randint(2, self.q - 3) 

        Sb_Tr_g = self.xtr.compute_c_n(self.Tr_g, b)
        Sb_Tr_gk = self.xtr.compute_c_n(Tr_gk, b)
        
        K = Sb_Tr_gk[1]

        E = self.symmetric_encrypt(message, K)

        return (Sb_Tr_g[1], E)

    def decrypt(self, ciphertext, private_key):
        Tr_gb, E = ciphertext
        k = private_key

        Sk_Tr_gb = self.xtr.compute_c_n(Tr_gb, k)

        K = Sk_Tr_gb[1]

        message = self.symmetric_decrypt(E, K)
        return message

    def symmetric_encrypt(self, message, key):
        encrypted = bytes(c ^ key.x1 for c in message.encode())
        return encrypted

    def symmetric_decrypt(self, ciphertext, key):
        decrypted = bytes(c ^ key.x1 for c in ciphertext).decode()
        return decrypted


p = 113
q = 4219
xtr = XTR(p,q)
Tr_g = xtr.find_trace_g_fast()
print("p =", xtr.p)
print("q =", xtr.q)
print("Tr(g) =", Tr_g)

xtr_elgamal = XTRElGamal(xtr, Tr_g)

public_key, private_key = xtr_elgamal.generate_keys()

message = "Hello, KM!"

ciphertext = xtr_elgamal.encrypt(message, public_key)

decrypted_message = xtr_elgamal.decrypt(ciphertext, private_key)

print("Original Message:", message)
print("Encrypted Ciphertext:", ciphertext)
print("Decrypted Message:", decrypted_message)

p = 113
q = 4219
Tr(g) = 12*α + 6*α^2
Original Message: Hello, KM!
Encrypted Ciphertext: (81*α + 98*α^2, b'Zw~~}>2Y_3')
Decrypted Message: Hello, KM!
