In [1]:
from functools import *
import random

In [2]:
class AnshelAnshelGoldfeld:
    
    def __init__(self, n, keyLength):
        self.n = n
        self.Bn = BraidGroup(n)
        self.keyLength = keyLength
        
        index_public_key = list(range(-1*(self.keyLength), self.keyLength+1))
        index_public_key.remove(0)
        self.func = random.choices(index_public_key, k = random.randint(5,10))
        
        self.generatePublicKey()
        self.generatePrivateKey()

        
    
    def generatePublicKey(self):
        self.publicKey = []
        index_generators = list(range(-1*(self.n), self.n+1))
        index_generators.remove(0)
        
        for i in range(self.keyLength):
            randomList = random.choices(index_generators, k = random.randint(5,10))
            self.publicKey.append(self.Bn(randomList))

            
    def keyFunction(self, f, l):
        key = self.Bn([1,-1])

        for i in f:
            if i > 0:
                aux = l[i-1]
            else:
                aux = (l[-1*i-1]^-1)
            
            key = key * aux

        return key
            
    def generatePrivateKey(self):
        self.privateKey = self.keyFunction(self.func, self.publicKey)
        
    def messages(self, publicKeyReceiver):
        message = []
        
        for t in publicKeyReceiver:
            message.append(self.lnf(self.privateKey*t*(self.privateKey^-1)))
        
        return message
    
    def commonKey(self, message, participant):
        if(participant == 0):
            key = self.privateKey*(self.keyFunction(self.func, message)^-1)
        else:
            key = self.keyFunction(self.func, message)*(self.privateKey^-1)

        return self.lnf(key)
    
    def lnf(self,braid):
        normal = self.Bn([1,-1])
        tup = braid.left_normal_form()

        for i in range(len(tup)):
            normal = normal*tup[i]
        
        return normal
        

In [3]:
A = AnshelAnshelGoldfeld(6,4)
B = AnshelAnshelGoldfeld(6,4)

In [4]:
messageA = A.messages(B.publicKey)
messageB = B.messages(A.publicKey)

In [5]:
commonKeyA = A.commonKey(messageB, 0)
commonKeyB = B.commonKey(messageA, 1)

In [8]:
print(commonKeyA)
print(commonKeyB)
print(commonKeyA == commonKeyB)

(s0^-1*s1^-1*s2^-1*s3^-1*s4^-1*s0^-1*s1^-1*s2^-1*s3^-1*s0^-1*s1^-1*s2^-1*s0^-1*s1^-1*s0^-1)^15*s1*s2*s3*s2*s1*s4*s3*s1*s0*s2*s3*s4*s3*s0*s1*s3*s2*s1*s0*s4*s3*s2*(s1*s0)^2*s2*s1*s3*s2*s1*s0*s4*s0*s1*s2*s1*s0*s3*s4*s3^2*s2*s1*s0*s4*s3*s2*s1*s0^2*s1*s0*s2*s1*s3*s4*s1*s0*s2*s1*s3*s1*s2*s3*s2*s4*s2*(s1*s3)^2*s2*s1*s0*s4*s3*s2*s1^2*s2*s3*s2*s1*s4*s3^2*s2^2*s1*s0*s3*s2^2*s1*s3*s1*s0*s2*s1*s3*s2*s1*s4*s1*(s0*s2*s1*s0*s3*s4*s3*s2*s1*s0)^2*s0*s1*s0*s2*s1^2*s2*s1*s0*s3*s2*s1*s4*s3*s1*s0*s2*s3*s4*s3*s0*s3*s4*s3*s2^2*s1*s0*s3*s2*s1*s4*s3*s2*s1*s0^2*s1*s2*s1*s0*s3*s2*s4*s0*s2*s1*s0*s3*s4*s3*s2*s1*s0^2*s1*s2*s1*s0*s3*s2*s4*s3*s2*s1*s0^2*s1*s0*s2*s1*s3*s2*s1*s4*s3*s2*s1*s0*s2*s4*s3*s2*s1*s0^2*s3*s2*s4*(s2*s3*s4*s3*s2)^2*s1
(s0^-1*s1^-1*s2^-1*s3^-1*s4^-1*s0^-1*s1^-1*s2^-1*s3^-1*s0^-1*s1^-1*s2^-1*s0^-1*s1^-1*s0^-1)^15*s1*s2*s3*s2*s1*s4*s3*s1*s0*s2*s3*s4*s3*s0*s1*s3*s2*s1*s0*s4*s3*s2*(s1*s0)^2*s2*s1*s3*s2*s1*s0*s4*s0*s1*s2*s1*s0*s3*s4*s3^2*s2*s1*s0*s4*s3*s2*s1*s0^2*s1*s0*s2*s1*s3*s4*s1*s0*s2*s1*s3*s1*s2*

In [10]:
print(commonKeyA.left_normal_form())

((s0^-1*s1^-1*s2^-1*s3^-1*s4^-1*s0^-1*s1^-1*s2^-1*s3^-1*s0^-1*s1^-1*s2^-1*s0^-1*s1^-1*s0^-1)^15, s1*s2*s3*s2*s1*s4*s3, s1*s0*s2*s3*s4*s3, s0*s1*s3*s2*s1*s0*s4*s3*s2*s1*s0, s1*s0*s2*s1*s3*s2*s1*s0*s4, s0*s1*s2*s1*s0*s3*s4*s3, s3*s2*s1*s0*s4*s3*s2*s1*s0, s0*s1*s0*s2*s1*s3*s4, s1*s0*s2*s1*s3, s1*s2*s3*s2*s4, s2*s1*s3, s1*s3*s2*s1*s0*s4*s3*s2*s1, s1*s2*s3*s2*s1*s4*s3, s3*s2, s2*s1*s0*s3*s2, s2*s1*s3, s1*s0*s2*s1*s3*s2*s1*s4, s1*s0*s2*s1*s0*s3*s4*s3*s2*s1*s0, s0*s2*s1*s0*s3*s4*s3*s2*s1*s0, s0*s1*s0*s2*s1, s1*s2*s1*s0*s3*s2*s1*s4*s3, s1*s0*s2*s3*s4*s3, s0*s3*s4*s3*s2, s2*s1*s0*s3*s2*s1*s4*s3*s2*s1*s0, s0*s1*s2*s1*s0*s3*s2*s4, s0*s2*s1*s0*s3*s4*s3*s2*s1*s0, s0*s1*s2*s1*s0*s3*s2*s4*s3*s2*s1*s0, s0*s1*s0*s2*s1*s3*s2*s1*s4*s3*s2*s1*s0, s2*s4*s3*s2*s1*s0, s0*s3*s2*s4, s2*s3*s4*s3*s2, s2*s3*s4*s3*s2*s1)
