In [3]:
#---------------------------------------------------------------------------
#                                     Imports
#---------------------------------------------------------------------------
from sage.all import *
import random
from bitarray import bitarray
from bitarray.util import ba2int
import pandas as pd
from tabulate import tabulate
from numpy import bitwise_xor

In [4]:
#---------------------------------------------------------------------------
#                                Auxiliary functions
#---------------------------------------------------------------------------

def allmx():
    allmx = {};
    for i in range(2**16):
        a = bitarray(bin(i)[2:].zfill(16))
        allmx.update({i:a});
    return allmx

# All possible byte keys
def allbyte():
    allby = {};
    for i in range(2**7):
        a = bitarray('0') + bitarray(bin(i)[2:].zfill(7));
        allby.update({i:a});
    return allby

# 256 possible bytes
def allbyte_full():
    allby = {};
    for i in range(2**8):
        a = bitarray(bin(i)[2:].zfill(8));
        allby.update({i:a});
    return allby

#Shift bits Left
def shiftleft(a,n):
    l = len(a);
    b = a << 2;
    b1 = a >> l-2;
    return b | b1

#Shift bits Right
def shiftright(a,n):
    l = len(a);
    b = a >> 2;
    b1 = a << l-2;
    return b | b1

#S_0
def Sbox0(a,b): # a,b binary arrays with 32-bits
    a1 = ba2int(a);
    b1 = ba2int(b);                                          
    r = bitarray(bin((a1+b1)%256)[2:].zfill(8));
    r = shiftleft(r,2);
    return r

#S_1    
def Sbox1(a,b): # a,b binary arrays with 32-bits
    a1 = ba2int(a);
    b1 = ba2int(b);
    r = bitarray(bin((a1+b1+1)%256)[2:].zfill(8));
    r = shiftleft(r,2);
    return r

#S_1 inverted
def invS1(out,a):
    x = shiftright(out,2);
    x1 = ba2int(x);
    a1 = ba2int(a);
    r = (x1 - a1 - 1)%256;
    res = bitarray(bin(r)[2:].zfill(8));
    return res

#S_0 inverted
def invS0(out,a):
    x = shiftright(out,2);
    x1 = ba2int(x);
    a1 = ba2int(a);
    r = (x1 - a1)%256;
    res = bitarray(bin(r)[2:].zfill(8));
    return res

#Function f_k
def fk(a,b):   # a,b binary arrays with 32-bits                                                                                         
    a0, a1, a2, a3 = a[:8], a[8:16], a[16:24], a[24:32];
    b0, b1, b2, b3 = b[:8], b[8:16], b[16:24], b[24:32];
    f1 = Sbox1((a1^^a0), ((a2^^a3)^^b0));
    f0 = Sbox0(a0,(f1^^b2));
    f2 = Sbox0((f1^^b1),(a2^^a3));
    f3 = Sbox1(a3,(f2^^b3));
    f = f0+f1+f2+f3;
    return f

#Function f
def f(a,b): # a,b binary arrays, a with 32-bits and b with 16 bit                                                                       
    a0, a1, a2, a3 = a[:8], a[8:16], a[16:24], a[24:32];
    b0, b1 = b[:8], b[8:];
    f1 = Sbox1(((a1^^b0)^^a0), ((a2^^b1)^^a3));
    f0 = Sbox0(a0,f1);
    f2 = Sbox0(f1,((a2^^b1)^^a3));
    f3 = Sbox1(a3,f2);
    f = f0+f1+f2+f3;
    return f

#Function f (middle)
def fmiddle(a,b):  # a+b is the 16-bit input, already XORed with mx(AK7))
    f1 = Sbox1(a, b);
    f2 = Sbox0(f1,b);
    return f1 + f2

#Function f (keyless)                                                                                                                    
def fkeyless(a):
    a0, a1, a2, a3 = a[:8], a[8:16], a[16:24], a[24:32];
    f1 = Sbox1(a1^^a0, a2^^a3);
    f0 = Sbox0(a0,f1);
    f2 = Sbox0(f1,a2^^a3);
    f3 = Sbox1(a3,f2);
    f = f0+f1+f2+f3;
    return f

#Just a testing function
def testing(tries):
    key = bitarray(bin(random.randint(0,2**64-1))[2:].zfill(64));
    keys = keygen(key);
    cc = 0;
    for i in range(tries):
        m = bitarray(bin(random.randint(0,2**64-1))[2:].zfill(64));
        c = enc(m,keys);
        p = dec(c,keys);
        if p == m :
            cc = cc + 1;
    print ('correct percentage:', float(cc/tries)*100, '%')
    return cc

#Merge mx with first and last byte for the 8 to 7 and 7 to 6 reductions
def get_actualkey(fbyte, mxbytes, lbyte):
    fbyteOp = {0: fbyte, 1: fbyte^^bitarray('10000000')}
    lbyteOp = {0: lbyte, 1: lbyte^^bitarray('10000000')};
    mxbytesOp = {0: mxbytes, 1: mxbytes^^bitarray('1000000010000000')};
    possiblekeys = {};
    for p1 in fbyteOp.values():
        ak7_0 = p1;
        for p2 in lbyteOp.values():
            ak7_3 = p2;
            for p3 in mxbytesOp.values():
                ak7_1 = p3[:8]^^ak7_0;
                ak7_2 = p3[8:]^^ak7_3;
                pkey = ak7_0 + ak7_1 + ak7_2 + ak7_3;
                possiblekeys.update({len(possiblekeys):pkey});
    return possiblekeys

#Merge mx with first and last byte for every reduction apart from the 8 to 7 and 7 to 6
def get_actualkey_5orless(fbyte, mxbytes, lbyte):
    possiblekeys = {};
    for p in mxbytes.values():         
        ak7_1 = p[:8]^^fbyte;
        ak7_2 = p[8:]^^lbyte;
        pkey = fbyte + ak7_1 + ak7_2 + lbyte;
        possiblekeys.update({len(possiblekeys):pkey});
    return possiblekeys


In [5]:
#---------------------------------------------------------------------------
#                                 Main functions
#---------------------------------------------------------------------------

#KeyGen
def keygen(K):  #K is the 64-bit secret key                                                                                      
    keys = {};      
    l = K[:32];
    r = K[32:];
    k01 = fk(l,r);
    keys.update({0: k01[:16], 1: k01[16:]});
    k23 = fk(r,k01^^l);
    keys.update({2: k23[:16], 3: k23[16:]});
    k45 = fk(k01,k23^^r);
    keys.update({4: k45[:16], 5: k45[16:]});
    for i in range(5):
        k = fk(keys[2*i+2] + keys[2*i+3], (keys[2*i+4]+keys[2*i+5])^^(keys[2*i]+keys[2*i+1]));
        keys.update({len(keys): k[:16], len(keys)+1 : k[16:]});
    return keys

#Encrypt
def enc(m,keys):    # m is 64-bits and keys has 16 subkeys with 16-bit each
    l0 = m[:32];
    r0 = m[32:];
    l0i = l0^^(keys[8]+keys[9]);
    r0i = (r0^^(keys[10]+keys[11]))^^l0i;
    lp = {0: l0i};
    rp = {0: r0i};
    for i in range(1,9):
        rnew = lp[i-1] ^^ (f(rp[i-1],keys[i-1]));
        lnew = rp[i-1];
        if i == 8:
            lp.update({len(lp): rnew});
        else:
            lp.update({len(lp): lnew});
            rp.update({len(rp): rnew});
    r8f1 = rp[len(rp)-1] ^^ lp[len(lp)-1];
    c = (lp[len(lp)-1] + r8f1) ^^ (keys[12]+keys[13]+keys[14]+keys[15]);
    return c

#Partial decrypt 
def partialdec(c,keys):
    cnew = c ^^ (keys[12]+keys[13]+keys[14]+keys[15]);
    cl = cnew[:32];
    cr = cnew[32:] ^^ cl; #h
    H = f(cr,keys[7]);
    g = cl^^H;
    G = f(g,keys[6]);
    fnew = cr^^G;
    F = f(fnew,keys[5]);
    e = g^^F;
    E = f(e,keys[4]);
    res = fnew + e;
    return res

#Partial decrypt - Dinamic
# We need to be able to partial decrypt: 3 rounds (5-round car), 5 rounds (3-round car), 6 rounds (2-round car), and 7 rounds (1-round car)
def partialdec_dinamic(c,keys,round):  # rounds: {3,5,6,7}
    cnew = c ^^ (keys[12]+keys[13]+keys[14]+keys[15]);
    cl = cnew[:32];
    cr = cnew[32:] ^^ cl; #h
    H = f(cr,keys[7]);
    g = cl^^H;
    G = f(g,keys[6]);
    fnew = cr^^G;
    F = f(fnew,keys[5]);
    e = g^^F;
    E = f(e,keys[4]);
    if round == 3:
        res = fnew + e;
    else:
        d = fnew^^E;
        D = f(d,keys[3]);
        c = e^^D;
        C = f(c,keys[2]);
        if round == 5:
            res = d + c;
        else:
            b = d^^C;
            B = f(b,keys[1]);
            a = c^^B;
            if round == 6:
                res = c + b;
            else:
                res = b + a;
    return res

#Generation of pairs for feal8
def genpairs(n,diff,keys):  # n is the number of pairs wanted, diff is the difference between plaintexts (a bitarray), and keys is the output from the KeyGen function
    pairs = {};
    for i in range(n):
        m = bitarray(bin(random.randint(0,2**64-1))[2:].zfill(64));
        c = enc(m,keys);
        m1 = m ^^ diff;
        c1 = enc(m1,keys);
        pairs.update({i: {0: c, 1: c1}});
    return pairs

# GENERATE RIGHT PAIRS (check for right pairs in all pairs - with key knowledge)
def genrightpairs(pairs,keys):
    car = bitarray('1010001000000000100000000000000010000000100000000000000000000000');
    rpairs = {};
    for p in pairs.values():
        a1 = partialdec(p[0], keys);
        a2 = partialdec(p[1], keys);
        a = a1^^a2;
        if a == car :
            rpairs.update({len(rpairs): p});
    leng = len(rpairs);
    print("Nº of true right pairs:", leng)
    return rpairs

# GENERATE RIGHT PAIRS - DINAMIC (check for right pairs in all pairs - with key knowledge)
def genrightpairs_dinamic(pairs,keys,round):
    # car = [5-round, 3-round, 1-round]
    car_dict = {5: bitarray('1010001000000000100000000000000010000000100000000000000000000000'), 3: bitarray('1010000000000000100000000000000000000000000000000000000000000000'), 1: bitarray('1010000000000000100000000000000010000000100000000000000000000000')};
    car = car_dict[round];
    rpairs = {};
    for p in pairs.values():
        a1 = partialdec_dinamic(p[0], keys,8-round);
        a2 = partialdec_dinamic(p[1], keys,8-round);
        a = a1^^a2;
        if a == car :
            rpairs.update({len(rpairs):p});
    leng = len(rpairs);
    print("Nº of true right pairs for the",round,"round caratetistic:", leng)
    return rpairs

In [None]:
#---------------------------------------------------------------------------
# Reduction from 8 to 7 rounds
#---------------------------------------------------------------------------
# Attack on middle bytes
def diffatack8to7_1(pairs):
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d;
    size = len(pairs);


    #--------------------------------------------------------------------------------------
    # Filtering proposed 1                                   
    # i = 0;
    # while int(i) < size:   # Filtering: Over f^^h and F^^H
    #     while i not in pairs:
    #         i = i+1;
    #     dif = (pairs[i])[0]^^(pairs[i])[1];
    #     l,r = dif[:32], dif[32:];
    #     h, FH = l^^r, l^^e
    #     G = fh = f^^h;
    #     if ((fh[7] == FH[5] ^^ FH[15]) & (fh[31] == FH[29]^^FH[23]) & (fh[23] == FH[21] ^^ FH[15] ^^ fh[31]) & (fh[15] == FH[13] ^^ fh[31] ^^ fh[23] ^^ fh[7])):
    #         i = i+1;
    #     else:
    #         pairs.pop(i);
    #         i = i + 1;
    print("Nº pairs after filtering:", len(pairs))
    #--------------------------------------------------------------------------------------
    #--------------------------------------------------------------------------------------
    # Filtering proposed 2
    # i = 0;
    # while i < len(pairs):   # Filtering over G' known bits (from the charateristic) 
    #     while i not in pairs:
    #         i = i+1;
    #     dif = (pairs[i])[0]^^(pairs[i])[1];
    #     l,r = dif[:32], dif[32:];
    #     Gfil = f^^l^^r;
    #     g00 = Gfil[5]^^Gfil[15];
    #     g30 = Gfil[29]^^Gfil[23];
    #     g20 = (Gfil[21]^^Gfil[15])^^g30;
    #     g10 = Gfil[13]^^g00^^g20^^g30;
    #     if (g20 != g10):
    #         pairs.pop(i);
    #     else:
    #         i = i+1;
    # print("Nº pairs after filtering:", len(pairs))
    #--------------------------------------------------------------------------------------





    allmxl = allmx();
    count = {i: 0 for i in range(len(allmxl))};
    for p in pairs.values():
        c1, c2 = p[0], p[1];
        l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
        hc1 = (l1^^r1)[8:24]^^((l1^^r1)[:8]+(l1^^r1)[24:]);
        hc2 = (l2^^r2)[8:24]^^((l2^^r2)[:8]+(l2^^r2)[24:]);
        l = l1^^l2;
        for i in range(len(allmxl)):     # Retrieving middle bytes (attack on mx(AK7))
            Hmid1, Hmid2 = fmiddle((hc1 ^^ allmxl[i])[:8],(hc1 ^^ allmxl[i])[8:16]), fmiddle((hc2 ^^ allmxl[i])[:8],(hc2 ^^ allmxl[i])[8:16]);
            gmid = l[8:24]^^Hmid1^^Hmid2;
            Fmid = e[8:24]^^gmid;
            if Fmid == bitarray('1000100000100000'):
                count[i] = count[i] + 1;
    final = max(count.values());                 
    print (final)
    print("Max count: ",final)
    final1 = {};
    for i in range(len(count)):
        if count[i] == final:
            a = allmxl[i];
            final1.update({i: a});
    a111 = len(final1);
    print(" ")
    print("Nº of possible keys: ",a111)
    return final1

# Attack on first byte
def diffattack8to7_2(pairs, mkey):
    # mkey is one of the possibilities obtained from the attack on mx(AK7)
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d;                                           
    allby = allbyte();
    count1 = {i: 0 for i in range(len(allby))};
    i = 0;
    for i in range(len(allby)):     # Retrieving first byte (attack on AK7[:8])
        for p in pairs.values():
            c1, c2 = p[0], p[1];
            l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
            hc1 = (l1^^r1)[8:24]^^((l1^^r1)[:8]+(l1^^r1)[24:]);
            hc2 = (l2^^r2)[8:24]^^((l2^^r2)[:8]+(l2^^r2)[24:]);
            l = l1^^l2;
            hmid1, hmid2 = hc1 ^^ mkey,  hc2 ^^ mkey;
            Hmid1 = fmiddle(hmid1[:8],hmid1[8:16]);
            Hmid2 = fmiddle(hmid2[:8],hmid2[8:16]);
            Hi1 = Sbox0(((l1^^r1)[:8] ^^ allby[i]),Hmid1[:8]);
            Hi2 = Sbox0(((l2^^r2)[:8] ^^ allby[i]),Hmid2[:8]);
            Hi, Hmid = Hi1^^Hi2, Hmid1^^Hmid2;
            Fi, gmid = e[:8]^^Hi^^l[:8], l[8:24]^^Hmid;
            F1 = (e[8:24]^^gmid)[:8];
            
            ######

            # Cline = shiftright(Fi,2)^^F1^^f[:8];
            # #j = 0;
            # j = len(Cline)-1;
            # #while j <= len(Cline):
            # while j >= 0:
            #     print(j)
            #     #if j != len(Cline):
            #     if j > 0:
            #         print(" ")
            #         print(Cline[j-1])
            #         print(" ")
            #         print(Cline[j])
            #         print(Fi[j])
            #         print((f[:8])[j])
            #         print(" ")
            #         print(" ")
            #         if (Cline[j] == Fi[j] == (f[:8])[j] == 0):
            #             #if Cline[j+1] != 0:
            #             if Cline[j-1] == 0:
            #                 #j = j+1;
            #                 j = j-1;
            #             else:
            #                 j = -1;
            #                 break
            #         else:
            #             if (Cline[j] == Fi[j] == (f[:8])[j] == 1):
            #                 #if Cline[j+1] != 0:
            #                 if Cline[j-1] == 1:
            #                     #j = j+1;
            #                     j = j-1;
            #                 else:
            #                     j = -1;
            #                     break
            #             else:
            #                 j = -1;
            #                 break
            #     else:
            #         count1[i] = count1[i] + 1;
            #         #j = j+1;
            #         j = j-1;

            ######


            res = shiftleft(f[:8]^^F1,2);

            if res == Fi:
                count1[i] = count1[i] + 1;
            #else:
                # print(" ")
                # print(i)
                # print(res^^Fi)
                # print(" ")
    print(sum(count1))
    final = max(count1.values());
    final1 = {};
    for i in range(len(count1)):
        if count1[i] == final:
            a = allby[i];
            final1.update({i:a});
    print("Max count: ",final);
    leng = len(final1);
    print("Nº possible keys:", leng);
    return final1

# Attack on last byte
def diffattack8to7_3(pairs, mkey):
    # mkey is one of the possibilities obtained from the attack on mx(AK7)
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d;                                           
    allby = allbyte();
    count2 = {i: 0 for i in range(len(allby))};
    i = 0;
    for i in range(len(allby)):
        for p in pairs.values():  # Retrieving last byte (attack on AK7[24:])
            c1, c2 = p[0], p[1];
            l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
            hc1 = (l1^^r1)[8:24]^^((l1^^r1)[:8]+(l1^^r1)[24:]);
            hc2 = (l2^^r2)[8:24]^^((l2^^r2)[:8]+(l2^^r2)[24:]);
            l = l1^^l2;
            hmid1, hmid2 = hc1 ^^ mkey,  hc2 ^^ mkey;
            Hmid1 = fmiddle(hmid1[:8],hmid1[8:16]);
            Hmid2 = fmiddle(hmid2[:8],hmid2[8:16]);
            Hf1 = Sbox1(((l1^^r1)[24:] ^^ allby[i]),Hmid1[8:]);
            Hf2 = Sbox1(((l2^^r2)[24:] ^^ allby[i]),Hmid2[8:]);
            Hf, Hmid = Hf1^^Hf2, Hmid1^^Hmid2;
            Ff, gmid = e[24:]^^Hf^^l[24:], l[8:24]^^Hmid;
            F1, F2 = (e[8:24]^^gmid)[:8], (e[8:24]^^gmid)[8:];
            res = shiftleft(f[:8]^^F1,2);
            resf = shiftleft(f[24:]^^F2,2);
            if resf == Ff:
                count2[i] = count2[i] + 1;
    print(sum(count2))
    final = max(count2.values());
    final1 = {};
    for i in range(len(count2)):
        if count2[i] == final:
            a = allby[i];
            final1.update({i: a});
    leng = len(final1);
    print("Max count: ",final);
    print("Nº of possible keys:", leng);
    return final1

In [None]:
#---------------------------------------------------------------------------
# Reduction from 7 to 6 rounds
#---------------------------------------------------------------------------
# Attack on middle bytes
def diffatack7to6_1(pairs,AK7):
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d;
    i = 0;
    size = len(pairs);


    #--------------------------------------------------------------------
    # Filtering
    #--------------------------------------------------------------------
    while int(i) < size:       # Filtering over f and F and g and G
        while int(i) not in pairs:
            i = i+1;
        c1,c2 = (pairs[i])[0], (pairs[i])[1];
        l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
        l = l1^^l2;
        h = l1^^l2^^r1^^r2;
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        H = H1^^H2;
        g, G = l^^H, h^^f;
        F = e^^g;
        if (f[7] == F[5] ^^ F[15]) & (g[7] == G[5] ^^ G[15]) & (f[31] == F[29]^^F[23]) & (g[31] == G[29]^^G[23]) & (f[23] == F[21] ^^ F[15] ^^ f[31]) & (g[23] == G[21] ^^ G[15] ^^ g[31]) & (f[15] == F[13] ^^ f[31] ^^ f[23] ^^ f[7]) & (g[15] == G[13] ^^ g[31] ^^ g[23] ^^ g[7]):
            i = i+1;
        else:
            pairs.pop(i);
    leng = len(pairs);
    print("Nº pairs after filtering:", leng)
    #--------------------------------------------------------------------


    allmxl = allmx();
    countmid = {i: 0 for i in range(len(allmxl))};
    for p in pairs.values():     # Retrieving middle bytes (attack on mx(AK6))
        c1, c2 = p[0], p[1];
        l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
        l, h = l1^^l2, l1^^l2^^r1^^r2;
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        H = H1^^H2;
        g, G = l^^H, h^^f;
        F = e^^g;
        g1, g2 = l1^^H1, l2^^H2;
        g1mid, g2mid = g1[8:24]^^(g1[:8]+g1[24:]), g2[8:24]^^(g2[:8]+g2[24:]);
        for i in range(len(allmxl)):
            G1mid, G2mid = fmiddle((g1mid^^allmxl[i])[:8],(g1mid^^allmxl[i])[8:]), fmiddle((g2mid^^allmxl[i])[:8],(g2mid^^allmxl[i])[8:]);
            if G1mid^^G2mid == G[8:24]:
                countmid[i] = countmid[i] + 1;
    final = max(countmid.values());
    print("Max count: ",final)
    final1 = {};
    for i in range(len(countmid)):
        if countmid[i] == final:
            a = allmxl[i];
            final1.update({i: a});
    a111 = len(final1);
    print(" ")
    print("Nº of possible keys: ",a111)
    return final1

# Attack on first byte
def diffattack7to6_2(pairs, mkey, AK7):
    # mkey is one of the possibilities obtained from the attack on mx(AK6)
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d;                                          
    allby = allbyte();
    count1 = {i: 0 for i in range(len(allby))};
    i = 0;
    for i in range(len(allby)):
        for p in pairs.values():  # Retrieving first byte (attack on AK6[:8])
            c1, c2 = p[0], p[1];
            l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
            l, h = l1^^l2, l1^^l2^^r1^^r2;
            H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
            H = H1^^H2;
            g, G = l^^H, h^^f;
            F = e^^g;
            g1, g2 = l1^^H1, l2^^H2;
            g1mid, g2mid = g1[8:24]^^(g1[:8]+g1[24:]), g2[8:24]^^(g2[:8]+g2[24:]);
            G1mid, G2mid = fmiddle((g1mid^^mkey)[:8],(g1mid^^mkey)[8:]), fmiddle((g2mid^^mkey)[:8],(g2mid^^mkey)[8:]);
            Gmidobtained = G1mid^^G2mid;
            gi1, gi2 = g1[:8] ^^ allby[i], g2[:8] ^^ allby[i];
            Gi1, Gi2 = Sbox0(gi1,G1mid[:8]), Sbox0(gi2,G2mid[:8]);
            gi, Giobtained = gi1^^gi2, Gi1^^Gi2;
            if G[:8] == Giobtained:
                count1[i] = count1[i] + 1;
    final = max(count1.values());
    final1 = {};
    for i in range(len(count1)):
        if count1[i] == final:
            a = allby[i];
            final1.update({i: a});
    print("Max count: ",final);
    leng = len(final1);
    print("Nº possible keys:", leng)
    return final1

# Attack on last byte
def diffattack7to6_3(pairs, mkey, AK7):
    # mkey is one of the possibilities obtained from the attack on mx(AK6)
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d;                                          
    allby = allbyte();
    count1 = {i: 0 for i in range(len(allby))};
    i = 0;
    for i in range(len(allby)):
        for p in pairs.values():  # Retrieving first byte (attack on AK6[:8])
            c1, c2 = p[0], p[1];
            l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
            l, h = l1^^l2, l1^^l2^^r1^^r2;
            H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
            H = H1^^H2;
            g, G = l^^H, h^^f;
            F = e^^g;
            g1, g2 = l1^^H1, l2^^H2;
            g1mid, g2mid = g1[8:24]^^(g1[:8]+g1[24:]), g2[8:24]^^(g2[:8]+g2[24:]);
            G1mid, G2mid = fmiddle((g1mid^^mkey)[:8],(g1mid^^mkey)[8:]), fmiddle((g2mid^^mkey)[:8],(g2mid^^mkey)[8:]);
            Gmidobtained = G1mid^^G2mid;
            gf1, gf2 = g1[24:] ^^ allby[i], g2[24:] ^^ allby[i];
            Gf1, Gf2 = Sbox1(gf1,G1mid[8:]), Sbox1(gf2,G2mid[8:]);
            gf, Gfobtained  = gf1^^gf2, Gf1^^Gf2;
            if G[24:] == Gfobtained:
                count1[i] = count1[i] + 1;
    final = max(count1.values());
    final1 = {};
    for i in range(len(count1)):
        if count1[i] == final:
            a = allby[i];
            final1.update({i: a});
    print("Max count: ",final);
    leng = len(final1);
    print("Nº possible keys:", leng)
    return final1

In [None]:
#---------------------------------------------------------------------------
# Reduction from 6 to 5 rounds
#---------------------------------------------------------------------------
# In this case we know AK5_0 = AK7_0 and AK5_3 = AK7_3
# This means that we only have to attack the middle bytes (mx(AK5))
# Attack on middle bytes
def diffatack6to5(pairs,AK7,AK6):
    # mkey is one of the possibilities obtained from the attack on mx(AK6)
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d;                                      
    i = 0;
    size = len(pairs);

    #--------------------------------------------------------------------
    # Filtering
    #--------------------------------------------------------------------
    while int(i) < size:   #Filtering
        while int(i) not in pairs:
            i = i+1;
        c1, c2 = (pairs[i])[0], (pairs[i])[1];
        l1, r1, l2, r2 = c1[:32], c1[32:], c2[:32], c2[32:];
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        h, H, l = l1^^r1^^l2^^r2, H1^^H2, l1^^l2;
        g1, g2 = l1^^H1, l2^^H2;
        G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
        f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
        if f1^^f2 != f:
            pairs.pop(i)
            i = i+1;
        else:
            i = i+1;
    leng = len(pairs);
    print("Nº pairs after filtering:", leng)
    #--------------------------------------------------------------------

    allmxl = allmx();
    countmid = {i: 0 for i in range(len(allmxl))};
    for p in pairs.values():     # Retrieving middle bytes (attack on mx(AK5))
        c1, c2 = p[0], p[1];
        l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
        l, h = l1^^l2, l1^^l2^^r1^^r2;
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        H = H1^^H2;
        g, G = l^^H, h^^f;
        F = e^^g;
        g1, g2 = l1^^H1, l2^^H2;
        G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
        f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
        f1mid, f2mid = f1[8:24]^^(f1[:8]+f1[24:]), f2[8:24]^^(f2[:8]+f2[24:]);
        for i in range(len(allmxl)):
            F1mid, F2mid = fmiddle((f1mid^^allmxl[i])[:8],(f1mid^^allmxl[i])[8:]), fmiddle((f2mid^^allmxl[i])[:8],(f2mid^^allmxl[i])[8:]);
            if (f1^^f2 == f) & (F1mid^^F2mid == F[8:24]):
                countmid[i] = countmid[i] + 1;
    final = max(countmid.values());
    print("Max count: ",final)
    final1 = {};
    for i in range(len(countmid)):
        if countmid[i] == final:
            a = allmxl[i];
            final1.update({i: a});
    a111 = len(final1);
    print(" ")
    print("Nº of possible keys: ",a111)
    return final1

# Filtering for the obtained AK5
def selection_AK5(pairs,possiblekeys_ak5,AK7,AK6):
    # Known from car
    d,E,e = bitarray('10100000000000001000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10000000100000000000000000000000');
    f = E^^d; 
    counting_ak5_list = {};
    for i in range(len(possiblekeys_ak5)):
        counting_ak5 = 0;
        for p in pairs.values():
            c1, c2 = p[0], p[1];
            l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
            l, h = l1^^l2, l1^^l2^^r1^^r2;
            H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
            H = H1^^H2;
            g, G = l^^H, h^^f;
            F = e^^g;
            g1, g2 = l1^^H1, l2^^H2;
            G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
            f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
            F1, F2 = fkeyless(f1^^possiblekeys_ak5[i]), fkeyless(f2^^possiblekeys_ak5[i]);
            if F1^^F2 == F:
                counting_ak5 = counting_ak5 + 1;
        counting_ak5_list.update({i: counting_ak5});
        res = possiblekeys_ak5[max(counting_ak5_list, key = counting_ak5_list.get)];
    return res

In [None]:
#---------------------------------------------------------------------------
# Reduction from 5 to 4 rounds
#---------------------------------------------------------------------------
# In this case we know AK4_0 = AK6_0 and AK4_3 = AK6_3
# This means that we only have to attack the middle bytes (mx(AK4))
# Attack on middle bytes
def diffatack5to4(pairs,AK7,AK6,AK5):
    print(len(pairs));
    d,e = bitarray('10100000000000001000000000000000'), bitarray('10000000100000000000000000000000');
    allmxl = allmx();
    countmid = {i: 0 for i in range(len(allmxl))};
    for p in pairs.values():     # Retrieving middle bytes (attack on mx(AK4))
        c1, c2 = p[0], p[1];
        l1,r1,l2,r2 = c1[:32], c1[32:],c2[:32],c2[32:];
        l, h = l1^^l2, l1^^l2^^r1^^r2;
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        H = H1^^H2;
        g = l^^H;
        F = e^^g;
        g1, g2 = l1^^H1, l2^^H2;
        G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
        f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
        F1, F2 = fkeyless(f1^^AK5), fkeyless(f2^^AK5);
        e1, e2 = g1^^F1, g2^^F2;
        e1mid, e2mid = e1[8:24]^^(e1[:8]+e1[24:]), e2[8:24]^^(e2[:8]+e2[24:]);
        E = f1^^f2^^d;
        if (e1^^e2 != e):
            for i in range(len(allmxl)):
                E1mid, E2mid = fmiddle((e1mid^^allmxl[i])[:8],(e1mid^^allmxl[i])[8:]), fmiddle((e2mid^^allmxl[i])[:8],(e2mid^^allmxl[i])[8:]);
                if (E1mid^^E2mid == E[8:24]) & (d[7] == e[5] ^^ e[15]) & (d[31] == e[29]^^e[23]) & (d[23] == e[21] ^^ e[15] ^^ d[31]) & (d[15] == e[13] ^^ d[31] ^^ d[23] ^^ d[7]):
                    countmid[i] = countmid[i] + 1;
        else:
            continue
    final = max(countmid.values());
    print("Max count: ",final)
    final1 = {};
    for i in range(len(countmid)):
        if countmid[i] == final:
            a = allmxl[i];
            final1.update({i: a});
    a111 = len(final1);
    print(" ")
    print("Nº of possible keys: ",a111)
    return final1
# We cant remove ambiguity in the most significant bits of each byte
# This leaves me with two possible keys, which means we now have a total of 128 possible combinations of AK7, AK6, AK5, AK4

In [None]:
#---------------------------------------------------------------------------
# Reduction from 4 to 3 rounds
#---------------------------------------------------------------------------
# In this case we know AK3_0 = AK7_0 and AK3_3 = AK7_3
# This means that we only have to attack the middle bytes (mx(AK3))
# Attack on middle bytes
def diffatack4to3(pairs,AK7,AK6,AK5,AK4):
    d = bitarray('10100000000000001000000000000000');   #Known from 3-round car
    print(len(pairs));
    allmxl = allmx();
    countmid = {i: 0 for i in range(len(allmxl))};
    for p in pairs.values():     # Retrieving middle bytes (attack on mx(AK3))
        c1, c2 = p[0], p[1];
        l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
        l, h = l1^^l2, l1^^l2^^r1^^r2;
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        H = H1^^H2;
        g = l^^H;
        g1, g2 = l1^^H1, l2^^H2;
        G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
        f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
        F1, F2 = fkeyless(f1^^AK5), fkeyless(f2^^AK5);
        e1, e2 = g1^^F1, g2^^F2;
        E1, E2 = fkeyless(e1^^AK4), fkeyless(e2^^AK4);
        d1, d2 = E1^^f1, E2^^f2;
        d1mid, d2mid = d1[8:24]^^(d1[:8]+d1[24:]), d2[8:24]^^(d2[:8]+d2[24:]);
        if (d1^^d2 == d):
            for i in range(len(allmxl)):
                D1mid, D2mid = fmiddle((d1mid^^allmxl[i])[:8],(d1mid^^allmxl[i])[8:]), fmiddle((d2mid^^allmxl[i])[:8],(d2mid^^allmxl[i])[8:]);
                if (D1mid^^D2mid == (e1^^e2)[8:24]):    #e is our D since c = bitarray(32), and D = e^^c
                    countmid[i] = countmid[i] + 1;
        else:
            continue
    final = max(countmid.values());
    print("Max count: ",final)
    final1 = {};
    for i in range(len(countmid)):
        if countmid[i] == final:
            a = allmxl[i];
            final1.update({i:a});
    a111 = len(final1);
    print(" ")
    print("Nº of possible keys: ",a111)
    return final1

# Filtering for the obtained AK3
def selection_AK3(pairs,possiblekeys_ak3,AK7,AK6,AK5,AK4):
    d = bitarray('10100000000000001000000000000000');   #Known from 3-round car
    counting_ak3_list = {};
    for i in range(len(possiblekeys_ak3)):
        counting_ak3 = 0;
        for p in pairs.values():
            c1, c2 = p[0], p[1];
            l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
            l, h = l1^^l2, l1^^l2^^r1^^r2;
            H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
            H = H1^^H2;
            g = l^^H;
            g1, g2 = l1^^H1, l2^^H2;
            G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
            f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
            F1, F2 = fkeyless(f1^^AK5), fkeyless(f2^^AK5);
            e1, e2 = g1^^F1, g2^^F2;
            E1, E2 = fkeyless(e1^^AK4), fkeyless(e2^^AK4);
            d1, d2 = E1^^f1, E2^^f2;
            D1, D2 = fkeyless(d1^^possiblekeys_ak3[i]), fkeyless(d2^^possiblekeys_ak3[i]);
            if (d1^^d2 == d) & (D1^^D2 == e1^^e2):
                counting_ak3 = counting_ak3 + 1;
        counting_ak3_list.update({i: counting_ak3});
    res = possiblekeys_ak3[max(counting_ak3_list, key = counting_ak3_list.get)];
    return res

In [11]:
#---------------------------------------------------------------------------
# Reduction from 3 to 2 rounds
#---------------------------------------------------------------------------
# In this case we know AK2_0 = AK6_0 and AK2_3 = AK6_3
# This means that we only have to attack the middle bytes (mx(AK2))
# Attack on middle bytes
def diffatack3to2(pairs,AK7,AK6,AK5,AK4,AK3):
    #Known from 1-round car
    a, A, b = bitarray('10000000100000000000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10100000000000001000000000000000'); 
    allmxl = allmx();
    countmid = {i: 0 for i in range(len(allmxl))};
    for p in pairs.values():     # Retrieving middle bytes (attack on mx(AK2))
        c1, c2 = p[0], p[1];
        l1,r1,l2,r2 =  c1[:32], c1[32:], c2[:32],c2[32:];
        l, h = l1^^l2, l1^^l2^^r1^^r2;
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        H = H1^^H2;
        g = l^^H;
        g1, g2 = l1^^H1, l2^^H2;
        G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
        f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
        F1, F2 = fkeyless(f1^^AK5), fkeyless(f2^^AK5);
        e1, e2 = g1^^F1, g2^^F2;
        E1, E2 = fkeyless(e1^^AK4), fkeyless(e2^^AK4);
        d1, d2 = E1^^f1, E2^^f2;
        D1, D2 = fkeyless(d1^^AK3), fkeyless(d2^^AK3);
        c1,c2 = D1^^e1, D2^^e2;
        c1mid, c2mid = c1[8:24]^^(c1[:8]+c1[24:]), c2[8:24]^^(c2[:8]+ c2[24:]);
        if (c1^^c2 != bitarray(32)):
            for i in range(len(allmxl)):
                C1mid, C2mid = fmiddle((c1mid^^allmxl[i])[:8],(c1mid^^allmxl[i])[8:]), fmiddle((c2mid^^allmxl[i])[:8],(c2mid^^allmxl[i])[8:]);
                if (C1mid^^C2mid == (b^^d1^^d2)[8:24]):
                    countmid[i] = countmid[i] + 1;
        else:
            continue
    final = max(countmid.values());
    print("Max count: ",final)
    final1 = {};
    for i in range(len(countmid)):
        if countmid[i] == final:
            a = allmxl[i];
            final1.update({i: a});
    a111 = len(final1);
    print(" ")
    print("Nº of possible keys: ",a111)
    return final1

In [12]:
#---------------------------------------------------------------------------
# Reduction from 2 to 1 rounds
#---------------------------------------------------------------------------
# In this case we know AK1_0 = AK7_0 and AK1_3 = AK7_3
# This means that we only have to attack the middle bytes (mx(AK1))
# Attack on middle bytes
def diffatack2to1(pairs,AK7,AK6,AK5,AK4,AK3,AK2):
    #Known from 1-round car
    a, A, b = bitarray('10000000100000000000000000000000'), bitarray('00000010000000000000000000000000'), bitarray('10100000000000001000000000000000'); 
    allmxl = allmx();
    countmid = {i: 0 for i in range(len(allmxl))};
    for p in pairs.values():     # Retrieving middle bytes (attack on mx(AK1))
        c1, c2 = p[0], p[1];
        l1,r1, l2,r2 = c1[:32], c1[32:], c2[:32],c2[32:];
        l, h = l1^^l2, l1^^l2^^r1^^r2;
        H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
        H = H1^^H2;
        g = l^^H;
        g1, g2 = l1^^H1, l2^^H2;
        G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
        f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
        F1, F2 = fkeyless(f1^^AK5), fkeyless(f2^^AK5);
        e1, e2 = g1^^F1, g2^^F2;
        E1, E2 = fkeyless(e1^^AK4), fkeyless(e2^^AK4);
        d1, d2 = E1^^f1, E2^^f2;
        D1, D2 = fkeyless(d1^^AK3), fkeyless(d2^^AK3);
        c1,c2 = D1^^e1, D2^^e2;
        C1, C2 = fkeyless(c1^^AK2), fkeyless(c2^^AK2);
        b1, b2 = C1^^d1, C2^^d2;
        b1mid, b2mid = b1[8:24]^^(b1[:8]+b1[24:]), b2[8:24]^^(b2[:8]+ b2[24:]);
        for i in range(len(allmxl)):
            B1mid, B2mid = fmiddle((b1mid^^allmxl[i])[:8],(b1mid^^allmxl[i])[8:]), fmiddle((b2mid^^allmxl[i])[:8],(b2mid^^allmxl[i])[8:]);
            if B1mid^^B2mid == (a^^c1^^c2)[8:24]:
                countmid[i] = countmid[i] + 1;
    final = max(countmid.values());
    print("Max count: ",final)
    final1 = {};
    for i in range(len(countmid)):
        if countmid[i] == final:
            a = allmxl[i];
            final1.update({i: a});
    a111 = len(final1);
    print(" ")
    print("Nº of possible keys: ",a111)
    return final1

# Filtering for the obtained AK1
def selection_AK1(pairs,possiblekeys_ak1,AK7,AK6,AK5,AK4,AK3,AK2):
    a = bitarray('10000000100000000000000000000000');   #Known from 1-round car
    counting_ak1_list = {};
    for i in range(len(possiblekeys_ak1)):
        counting_ak1 = 0;
        for p in pairs.values():
            c1, c2 = p[0], p[1];
            l1,r1, l2,r2 = c1[:32], c1[32:], c2[:32],c2[32:];
            l, h = l1^^l2, l1^^l2^^r1^^r2;
            H1, H2 = fkeyless(l1^^r1^^AK7), fkeyless(l2^^r2^^AK7);
            H = H1^^H2;
            g = l^^H;
            g1, g2 = l1^^H1, l2^^H2;
            G1, G2 = fkeyless(g1^^AK6), fkeyless(g2^^AK6);
            f1, f2 = l1^^r1^^G1, l2^^r2^^G2;
            F1, F2 = fkeyless(f1^^AK5), fkeyless(f2^^AK5);
            e1, e2 = g1^^F1, g2^^F2;
            E1, E2 = fkeyless(e1^^AK4), fkeyless(e2^^AK4);
            d1, d2 = E1^^f1, E2^^f2;
            D1, D2 = fkeyless(d1^^AK3), fkeyless(d2^^AK3);
            c1,c2 = D1^^e1, D2^^e2;
            C1, C2 = fkeyless(c1^^AK2), fkeyless(c2^^AK2);
            b1, b2 = C1^^d1, C2^^d2;
            B1, B2 = fkeyless(b1^^possiblekeys_ak1[i]), fkeyless(b2^^possiblekeys_ak1[i]);
            if B1^^B2 == c1^^c2^^a:
                counting_ak1 = counting_ak1 + 1;
        counting_ak1_list.update({i: counting_ak1});
    res = possiblekeys_ak1[max(counting_ak1_list, key = counting_ak1_list.get)];
    return res

In [13]:
#---------------------------------------------------------------------------
# Extracting Keys from the actual subkeys
#---------------------------------------------------------------------------
def extactkey(list_of_actualsubkeys):   #{0: AK1, 1: AK2, 2: AK3, 3: AK4, 4: AK5, 5: AK6, 6: AK7}                                                        
    possible_fullkey = {};
    k5k7, k4k6 = (list_of_actualsubkeys[4]^^list_of_actualsubkeys[6])[8:24], (list_of_actualsubkeys[3]^^list_of_actualsubkeys[5])[8:24];
    k3k5, k2k4, k1k3 = (list_of_actualsubkeys[2]^^list_of_actualsubkeys[4])[8:24], (list_of_actualsubkeys[1]^^list_of_actualsubkeys[3])[8:24], (list_of_actualsubkeys[0]^^list_of_actualsubkeys[2])[8:24];
    allby = allbyte_full();
    for p in allby.values():
        K5_1 = p;
        K7_1 = K5_1^^k5k7[8:];
        K3_1 = K5_1^^k3k5[8:];
        K1_1 = K3_1^^k1k3[8:];
        K7_0 = K1_1^^K5_1^^invS1(K7_1,K3_1);      
        K5_0 = K7_0^^k5k7[:8];
        K3_0 = K5_0^^k3k5[:8];
        K1_0 = K3_0^^k1k3[:8];
        K_7 = K3_1^^K5_0^^invS1(K5_1,K1_1);
        K_3 = K1_1^^K3_0^^invS1(K3_1,K_7);
        if Sbox1(K1_0^^K_7,K_3) != K1_1:
            continue
        else:
            for p1 in allby.values():
                K4_0 = p1;
                K6_0 = K4_0^^k4k6[:8];
                K2_0 = K4_0^^k2k4[:8];
                K6_1 = K1_0^^K5_0^^invS0(K6_0,K2_0);
                K4_1 = K6_1^^k4k6[8:];
                K2_1 = K4_1^^k2k4[8:];
                K0_0 = K4_0^^K3_0^^K3_1^^invS1(K6_1,K2_0^^K2_1);
                K0_1 = K4_1^^K6_1^^invS0(K7_0,K3_0^^K3_1);
                K_4 = K2_0^^K1_0^^K1_1^^invS1(K4_1,K0_0^^K0_1);
                K_5 = K2_1^^K4_1^^invS0(K5_0,K1_0^^K1_1);
                K_6 = K3_0^^K4_1^^invS0(K4_0,K0_0);
                K_0 = K0_0^^K_6^^K_7^^invS1(K2_1,K_4^^K_5);
                K_1 = K0_1^^K2_1^^invS0(K3_0,K_6^^K_7);
                K_2 = K1_0^^K2_1^^invS0(K2_0,K_4);
                K = K_0 + K_1 + K_2 + K_3 + K_4 + K_5 + K_6 + K_7;
                possible_fullkey.update({len(possible_fullkey): K})
    return possible_fullkey

def testpossiblefullkey(list_of_actualsubkeys, keylist):
    finalkeys = {};
    for k in keylist.values():
        finalkeys1 = {};
        ak = keygen(k);
        for i in range(1,8):
            if i%2 != 0:
                actk = (ak[12]+ak[13])^^(ak[14]+ak[15])^^(bitarray(8) + ak[i] + bitarray(8));
                finalkeys1.update({i: actk});
            else:
                actk1 = (ak[12]+ak[13])^^(bitarray(8) + ak[i] + bitarray(8));
                finalkeys1.update({i:actk1});
        if [k for k in finalkeys1.values()] == [k for k in list_of_actualsubkeys.values()]:
            finalkeys.update({len(finalkeys): k})
    return finalkeys