## Solve filter equations in GF

In [1]:
# AES Rijndael field GF(2**8)
F.<a> = GF(2)[]
K.<x> = GF(2**8, name='x', modulus=a^8 + a^4 + a^3 + a + 1)
print(K)

def bin_str_to_GF(binary):
    tmp = 0
    for j, k in enumerate(binary.rjust(8, '0')[::-1]):
        if k == '1':
            tmp += x**int(j)
    return tmp

def int_to_GF(integer):
    tmp = 0
    for j, k in enumerate(bin(integer)[2:].rjust(8, '0')[::-1]):
        if k == '1':
            tmp += x**int(j)
    return K(tmp)

def GF_to_bin(gf_elem):
    return int(gf_elem.integer_representation())

Finite Field in x of size 2^8


In [2]:
# Create HW table to calculate Hamming weight
def calc_hamming_weight(n):
    return bin(n).count("1")

HW = []
for i in range(256):
    HW += [calc_hamming_weight(i)]

### Simplify the 10 infered relations 

The 10 infered raltions are from Caforio et al. paper [https://eprint.iacr.org/2014/816.pdf](https://eprint.iacr.org/2014/816.pdf) :

![Infered relations](images/infered_relations.png)

TYPO in the paper : 

**Relation 10 and 11 are wrong and should be swaped ! I.e. $255*a=x_9*c$ and $255*c=x_{10}*a$.**

In [3]:
var("x_1_var x_2_var x_3_var x_4_var x_5_var x_6_var x_7_var x_8_var x_9_var x_10_var x_255_var a_var b_var c_var d_var")

print(solve([x_255_var*a_var == x_1_var*b_var,
       x_255_var*b_var == x_2_var*a_var,
       x_255_var*d_var == x_3_var*a_var,
       x_255_var*a_var == x_4_var*d_var,
       x_255_var*c_var == x_5_var*d_var,
       x_255_var*d_var == x_6_var*c_var,
       x_255_var*b_var == x_7_var*c_var,
       x_255_var*c_var == x_8_var*b_var,
       x_255_var*a_var == x_9_var*c_var,
       x_255_var*c_var == x_10_var*a_var],
       [x_1_var, x_2_var, x_3_var, x_4_var, x_5_var, x_6_var, x_7_var, x_8_var, x_9_var, x_10_var]))

print(solve([x_1_var*x_2_var == x_255_var**2, x_3_var*x_4_var == x_255_var**2, x_5_var*x_6_var == x_255_var**2, x_7_var*x_8_var == x_255_var**2, x_9_var*x_10_var == x_255_var**2, 
       x_1_var*x_7_var*x_10_var == x_255_var**3, x_2_var*x_8_var*x_9_var == x_255_var**3,
       x_1_var*x_3_var*x_5_var*x_7_var == x_255_var**4, x_2_var*x_4_var*x_6_var*x_8_var == x_255_var**4], [x_1_var, x_2_var, x_3_var, x_4_var, x_5_var, x_6_var, x_7_var, x_8_var, x_9_var, x_10_var]))


[
[x_1_var == a_var*x_255_var/b_var, x_2_var == b_var*x_255_var/a_var, x_3_var == d_var*x_255_var/a_var, x_4_var == a_var*x_255_var/d_var, x_5_var == c_var*x_255_var/d_var, x_6_var == d_var*x_255_var/c_var, x_7_var == b_var*x_255_var/c_var, x_8_var == c_var*x_255_var/b_var, x_9_var == a_var*x_255_var/c_var, x_10_var == c_var*x_255_var/a_var]
]
[
[x_1_var == r2*x_255_var/r3, x_2_var == r3*x_255_var/r2, x_3_var == r1*r3/x_255_var, x_4_var == x_255_var^3/(r1*r3), x_5_var == x_255_var^2/r1, x_6_var == r1, x_7_var == x_255_var^2/r2, x_8_var == r2, x_9_var == x_255_var^2/r3, x_10_var == r3]
]


### Solve the infered equations given the Hamming weight of state bytes after the S-box (i.e. $w_i = HW(x_i)$)

In [4]:
import random
import numpy as np

w_i = [4, 3, 5, 5, 8, 8, 3, 4, 5, 5]  # List of Hamming weight of the x_i variables. We have 10 variables with Hamming weight x_0 = 4, x_1 = 3, x_2 = 5, x_3 = 5, x_4 = 8, x_5 = 8, x_6 = 3, x_7 = 4, x_8 = 5, x_9 = 5

In [5]:
import itertools

# Constant variable equal to 0xFF
x_255 = int_to_GF(0xFF)
# List of all possible values of x_i
x_list = [[] for i in range(10)]

# Function adding all possible values of x_i to the x_list (i.e. all possible values of x_i having Hamming weight w_i)
def all_bytes_same_hamming_weight(index, x_list): 
    x_index_list = itertools.permutations([str(1) for i in range(w_i_get(index, w_i))] + [str(0) for i in range(8 - w_i_get(index, w_i))])
    x_index_list = dict.fromkeys(x_index_list)
    x_index_list = map(lambda element:  "".join(element), x_index_list)
    x_index_list = list(map(bin_str_to_GF, x_index_list))
    x_list[index-1] = x_index_list.copy()
   
# Functions to get/set and add/remove values of x_i in the x_list
def x_b_list(index, x_list):
    return x_list[index-1].copy()

def x_b_list_add_elem(index, elem, x_list):
    x_list[index-1] = x_list[index-1] + [elem]

def x_b_list_remove_elem(index, elem, x_list):
    x_list[index-1]= x_list[index-1][::-1]
    x_list[index-1].remove(elem)
    x_list[index-1]= x_list[index-1][::-1]
    
def w_i_get(index, w_i):
    return w_i[index - 1]

In [6]:
# Add in list all possible values of x_i for the x_i that are free variables (deduced from simplification of system of infered relations)
all_bytes_same_hamming_weight(6, x_list)  # For r_1 
all_bytes_same_hamming_weight(8, x_list)  # For r_2
all_bytes_same_hamming_weight(10, x_list)  # For r_3

In [7]:
# Solving infered relations via brute force
x_list_f = [[] for i in range(10)]

for r1 in x_b_list(6, x_list):
    for r2 in x_b_list(8, x_list):
        for r3 in x_b_list(10, x_list):
            tmp1 = r2*x_255/r3 # x1
            if bin(GF_to_bin(tmp1))[2:].count('1') == w_i_get(1, w_i):
                tmp2 = r3*x_255/r2 # x2
                if bin(GF_to_bin(tmp2))[2:].count('1') == w_i_get(2, w_i):
                    tmp3 = r1*r3/x_255 # x3
                    if bin(GF_to_bin(tmp3))[2:].count('1') == w_i_get(3, w_i):
                        tmp4 = x_255^3/(r1*r3) # x4
                        if bin(GF_to_bin(tmp4))[2:].count('1') == w_i_get(4, w_i): 
                            tmp5 = x_255**2/r1 # x5
                            if bin(GF_to_bin(tmp5))[2:].count('1') == w_i_get(5, w_i):
                                tmp7 = x_255**2/r2 # x7
                                if bin(GF_to_bin(tmp7))[2:].count('1') == w_i_get(7, w_i):
                                    tmp9 = x_255**2/r3 # x9
                                    if bin(GF_to_bin(tmp9))[2:].count('1') == w_i_get(9, w_i):
                                        x_b_list_add_elem(1, tmp1, x_list_f)
                                        x_b_list_add_elem(2, tmp2, x_list_f)
                                        x_b_list_add_elem(3, tmp3, x_list_f)
                                        x_b_list_add_elem(4, tmp4, x_list_f)
                                        x_b_list_add_elem(5, tmp5, x_list_f)
                                        x_b_list_add_elem(6, r1, x_list_f)
                                        x_b_list_add_elem(7, tmp7, x_list_f)
                                        x_b_list_add_elem(8, r2, x_list_f)
                                        x_b_list_add_elem(9, tmp9, x_list_f)
                                        x_b_list_add_elem(10, r3, x_list_f)       

In [8]:
# Function to verify if the found solutions are correct
def verify_filter_equ(x_list):
    for i in range(len(x_list)):
        assert len(x_list[i]) == len(x_list[(i+1)%10])
        
    for i in range(len(x_list[0])):
        for j in range(len(w_i)):
            if HW[GF_to_bin(x_list[j][i])] != w_i[j]:
                return 0
        if x_list[0][i]*x_list[1][i] != x_255**2:
            return 1
        if x_list[2][i]*x_list[3][i] != x_255**2:
            return 2
        if x_list[4][i]*x_list[5][i] != x_255**2:
            return 3
        if x_list[6][i]*x_list[7][i] != x_255**2:
            return 4
        if x_list[8][i]*x_list[9][i] != x_255**2:
            return 5
        if x_list[0][i]*x_list[6][i]*x_list[9][i] != x_255**3:
            return 6
        if x_list[1][i]*x_list[7][i]*x_list[8][i] != x_255**3:
            return 7
        if x_list[0][i]*x_list[2][i]*x_list[4][i]*x_list[6][i] != x_255**4:
            return 8
        if x_list[1][i]*x_list[3][i]*x_list[5][i]*x_list[7][i] != x_255**4:
            return 9
    return -1

assert verify_filter_equ(x_list_f) == -1, "Error in the found solutions"


In [9]:
# Print the found solutions
x_list_f_bin = np.array(x_list_f).copy()

for i in range(len(x_list_f_bin)):
    for j in range(len(x_list_f_bin[i])):
        x_list_f_bin[i][j] = GF_to_bin(x_list_f[i][j])

x_list_f_bin = np.array(x_list_f_bin)
print("Possible solution :")
print(x_list_f_bin.T)

Possible solution :
[[120 84 229 242 255 255 42 240 242 229]
 [92 148 229 242 255 255 74 184 242 229]
 [90 112 229 242 255 255 56 180 242 229]
 [85 26 229 242 255 255 13 170 242 229]
 [78 50 229 242 255 255 25 156 242 229]
 [240 42 242 229 255 255 84 120 229 242]
 [60 168 229 242 255 255 84 120 242 229]
 [184 74 242 229 255 255 148 92 229 242]
 [180 56 242 229 255 255 112 90 229 242]
 [45 224 229 242 255 255 112 90 242 229]
 [170 13 242 229 255 255 26 85 229 242]
 [156 25 242 229 255 255 50 78 229 242]
 [39 100 229 242 255 255 50 78 242 229]
 [120 84 242 229 255 255 168 60 229 242]
 [90 112 242 229 255 255 224 45 229 242]
 [78 50 242 229 255 255 100 39 229 242]]


In [10]:
# Determine constant variable
print("Constant variable x_i :")
for i in range(x_list_f_bin.shape[0]):
    tmp = x_list_f_bin[i, 0]
    constant = True
    for j in range(x_list_f_bin.shape[1]):
        if  x_list_f_bin[i, j] != tmp:
            constant = False
    if constant == True:
        print(f"Var x_{i+1} is constant")

Constant variable x_i :
Var x_5 is constant
Var x_6 is constant
