In [1]:
import torch
import torch.nn as nn
device = "cuda"
device = torch.device(device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def mixColumns(a, b, c, d):
    return (gmul(a, 2) ^ gmul(b, 3) ^ gmul(c, 1) ^ gmul(d, 1)), (gmul(a, 1) ^ gmul(b, 2) ^ gmul(c, 3) ^ gmul(d, 1)), (gmul(a, 1) ^ gmul(b, 1) ^ gmul(c, 2) ^ gmul(d, 3)), (gmul(a, 3) ^ gmul(b, 1) ^ gmul(c, 1) ^ gmul(d, 2))


def gmul(a, b):
    if b == 1:
        return a
    tmp = (a << 1) & 0xff
    if b == 2:
        return tmp if a < 128 else tmp ^ 0x1b
    if b == 3:
        return gmul(a, 2) ^ a

def printHex(val):
    return print('{:02x}'.format(val), end=' ')

# test vectors from https://en.wikipedia.org/wiki/Rijndael_MixColumns#Test_vectors_for_MixColumn()
assert mixColumns(0xdb, 0x13, 0x53, 0x45) == (0x8e, 0x4d, 0xa1, 0xbc)
assert mixColumns(0xf2, 0x0a, 0x22, 0x5c) == (0x9f, 0xdc, 0x58, 0x9d)
assert mixColumns(0x01, 0x01, 0x01, 0x01) == (0x01, 0x01, 0x01, 0x01) 
assert mixColumns(0xc6, 0xc6, 0xc6, 0xc6) == (0xc6, 0xc6, 0xc6, 0xc6) 
assert mixColumns(0xd4, 0xd4, 0xd4, 0xd5) == (0xd5, 0xd5, 0xd7, 0xd6) 
assert mixColumns(0x2d, 0x26, 0x31, 0x4c) == (0x4d, 0x7e, 0xbd, 0xf8)
print("ok")

ok


In [3]:
import torch.nn.functional as F
noise = 0.0
noise_n = 1

def dec2bin(x, bits):
    # mask = 2 ** torch.arange(bits).to(x.device, x.dtype)
    mask = 2 ** torch.arange(bits - 1, -1, -1).to(x.device, x.dtype)
    return x.unsqueeze(-1).bitwise_and(mask).ne(0).float()


def bin2dec(b, bits):
    mask = 2 ** torch.arange(bits - 1, -1, -1).to(b.device, b.dtype)
    return torch.sum(mask * b, -1)

def bin_tuple(i0,i1,i2,i3):
    b0 = torch.tensor([float(a) for a in bin(i0)[2:].rjust(8,"0")], dtype=torch.float32)
    # b0 = dec2bin(i0, 8)
    # b0 += (torch.FloatTensor(8).uniform_(-noise, +noise))
    b0 = b0.to(device)
    
    b1 = torch.tensor([float(a) for a in bin(i1)[2:].rjust(8,"0")], dtype=torch.float32)
    # b1 = dec2bin(i1, 8)
    # b1 += (torch.FloatTensor(8).uniform_(-noise, +noise))
    b1 = b1.to(device)

    b2 = torch.tensor([float(a) for a in bin(i2)[2:].rjust(8,"0")], dtype=torch.float32)
    # b2 = dec2bin(i2, 8)
    # b2 += (torch.FloatTensor(8).uniform_(-noise, +noise))
    b2 = b2.to(device)

    b3 = torch.tensor([float(a) for a in bin(i3)[2:].rjust(8,"0")], dtype=torch.float32)
    # b3 = dec2bin(i3, 8)
    # b3 += (torch.FloatTensor(8).uniform_(-noise, +noise))
    b3 = b3.to(device)

    return (b0,b1,b2,b3)

train_data = []
already = set()

while len(train_data) < 10000:
    t = torch.randint(0,256,(4,))
    if t in already:
        continue
    already.add(t)
    for _ in range(noise_n):
        b = dec2bin(t,8).to(device)
        res = bin_tuple(*mixColumns(*t))
        # tuple to matrix (4x8) then to (1x32)
        b = b.view(32)
        res = torch.stack(res, dim=0).view(32)
        train_data.append((b,res))

train_data_triple = []
for i in range(256):
    res = []
    for k in range(3):
        res.append(gmul(i, k+1))
    
    b = dec2bin(torch.tensor([i]),8).to(device)
    b = b.view(8)

    res = dec2bin(torch.tensor(res),8).to(device)
    res = res.view(8*3)

    train_data_triple.append((b,res))

In [4]:
print(train_data[0])
print(train_data_triple[1])

(tensor([0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0.,
        1., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0.],
       device='cuda:0'), tensor([1., 1., 0., 0., 0., 0., 0., 1., 0., 1., 0., 1., 0., 0., 0., 0., 0., 1.,
        0., 0., 1., 1., 1., 0., 1., 1., 1., 0., 1., 1., 0., 1.],
       device='cuda:0'))
(tensor([0., 0., 0., 0., 0., 0., 0., 1.], device='cuda:0'), tensor([0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 0., 1., 1.], device='cuda:0'))


In [5]:
print(2**4+2**1+1)

19


In [6]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

def get_loaders(train_data, device=device):
    test_size = 0.05
    valid_size = 0.05
    batch_size = 500
    num_workers = 0

    #cuda or cpu
    device = torch.device(device)

    num_train = len(train_data)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(test_size * num_train))
    split2 = int(np.floor((valid_size+test_size) * num_train))
    train_idx, valid_idx, test_idx = indices[split2:], indices[split:split2], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    # prepare data loaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers)
    valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers)
    return train_loader, valid_loader, test_loader

train_loader, valid_loader, test_loader = get_loaders(train_data)
train_loader_triple, valid_loader_triple, test_loader_triple = get_loaders(train_data_triple)

In [7]:
import numpy as np
def noise_to_int(bits):
    bits = [round(float(b)) for b in bits]
    bits = "".join([str(b) if b in [0,1] else "0" if b<1/10**5 else "1" for b in bits])
    return int(bits,2)

def lin_to_tuple(t):
    a = (t[:8],t[8:16],t[16:24],t[24:])
    return tuple([noise_to_int(b) for b in a])

def lin_to_triple(t):
    a = (t[:8],t[8:16],t[16:])
    return tuple([noise_to_int(b) for b in a])

def lin_to_list(t):
    res = []
    for i in range(len(t)//8):
        res += [noise_to_int(t[i*8:i*8+8])]
    return res

In [8]:
X, Y = next(iter(train_loader))

for x,y in zip(X,Y):
    a = lin_to_tuple(x)
    b = lin_to_tuple(y)
    b1 = mixColumns(*a)
    assert b == b1
print("ok")

ok


In [9]:
X, Y = next(iter(train_loader_triple))

for x,y in zip(X,Y):
    a = noise_to_int(x)
    b,c,d = lin_to_triple(y)
    assert (b,c,d) == (gmul(a, 1), gmul(a, 2), gmul(a, 3))
print("ok")

ok


In [29]:
import torch.nn.functional as F
    
class Gmul(nn.Module):
    def __init__(self):
        super(Gmul, self).__init__()
        
        self.body = nn.Sequential(
            nn.Linear(8,24),
            nn.Sigmoid(),
            nn.Linear(24,24),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        out = self.body(x)
        return out

In [24]:
def train(save_file, model, criterion, train_loader, valid_loader, optimizer=None, n_epochs = 100000, f=lin_to_tuple, lrate=0.005):
    # number of epochs to train the model

    if optimizer is None:
        # specify optimizer (stochastic gradient descent) and learning rate = 0.001
        optimizer = torch.optim.Adam(model.parameters(), lr=lrate)#, weight_decay=0.00000001)

    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf # set initial "min" to infinity
    
    for epoch in range(n_epochs):
        # monitor training loss
        train_loss = 0.0
        valid_loss = 0.0
        results = 0
        results_n = 0
        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        i=0
        for X, target in train_loader:
            i+=1
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            target = target.to(device)
            output = model(X)
            # calculate the loss
            # print(output)
            # print(target)
            loss = criterion(output, target) #
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # update running training loss
            train_loss += loss.item()*X.size(0)
            if epoch%100 == 0:
                for x,y in zip(output,target):
                    # print(x.cpu().detach().numpy(),y)
                    a = f(x.cpu().detach().numpy())
                    # a = int(x[0])
                    b = f(y.cpu().detach().numpy())
                    # b = int(y[0])
                    # a = noise_to_int(x)
                    # b = noise_to_int(y)
                    
                    
                    # print(a,b)
                    # print(float(x[0]),float(y[0]))
                    if a==b:
                        
                        results +=1
                    results_n+=1
        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        for X, target in valid_loader:
        
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(X)
            # target = target.to(device)
            # calculate the loss
            loss = criterion(output, target)
            # update running validation loss
            valid_loss += loss.item()*X.size(0)
            

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = train_loss/len(train_loader.dataset)
        valid_loss = valid_loss/len(valid_loader.dataset)

        print('Epoch: {} \tTraining Loss: {:.10f} \tValidation Loss: {:.10f}'.format(
            epoch+1,
            train_loss,
            valid_loss
            ))

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.10f} --> {:.10f}).  Saving model ...'.format(
                valid_loss_min,
                valid_loss))
            torch.save(model.state_dict(), save_file)
            valid_loss_min = valid_loss
            if train_loss <= 0.000000001:
                print("stop: loss <= 0.00000")
                return
            else:
                print(" loss >= 0.00000")
        
        if results_n != 0 :
            print(f"{results/results_n=}")
            print(f"{results}")
            # if results == results_n and valid_loss <= valid_loss_min:
            #     print("stop: no errors")
            #     return
        

In [252]:
model = Gmul().to(device)
# model.load_state_dict(torch.load("mixcolumns.pt"))
criterion = nn.BCELoss()
train("gmul.pt", model, criterion, train_loader_triple, valid_loader_triple, f=lin_to_triple, lrate=0.05)

Epoch: 1 	Training Loss: 0.6397709618 	Validation Loss: 0.0367227220
Validation loss decreased (inf --> 0.0367227220).  Saving model ...
 loss >= 0.00000
results/results_n=0.0
0
Epoch: 2 	Training Loss: 0.6366362222 	Validation Loss: 0.0357529116
Validation loss decreased (0.0367227220 --> 0.0357529116).  Saving model ...
 loss >= 0.00000
Epoch: 3 	Training Loss: 0.6283393598 	Validation Loss: 0.0350247154
Validation loss decreased (0.0357529116 --> 0.0350247154).  Saving model ...
 loss >= 0.00000
Epoch: 4 	Training Loss: 0.6201546367 	Validation Loss: 0.0347786155
Validation loss decreased (0.0350247154 --> 0.0347786155).  Saving model ...
 loss >= 0.00000
Epoch: 5 	Training Loss: 0.6158534861 	Validation Loss: 0.0347031876
Validation loss decreased (0.0347786155 --> 0.0347031876).  Saving model ...
 loss >= 0.00000
Epoch: 6 	Training Loss: 0.6112461439 	Validation Loss: 0.0346169239
Validation loss decreased (0.0347031876 --> 0.0346169239).  Saving model ...
 loss >= 0.00000
Epoch: 

Epoch: 27 	Training Loss: 0.3914544513 	Validation Loss: 0.0237746461
Validation loss decreased (0.0241686112 --> 0.0237746461).  Saving model ...
 loss >= 0.00000
Epoch: 28 	Training Loss: 0.3832948990 	Validation Loss: 0.0233778964
Validation loss decreased (0.0237746461 --> 0.0233778964).  Saving model ...
 loss >= 0.00000
Epoch: 29 	Training Loss: 0.3755709424 	Validation Loss: 0.0229606114
Validation loss decreased (0.0233778964 --> 0.0229606114).  Saving model ...
 loss >= 0.00000
Epoch: 30 	Training Loss: 0.3683742561 	Validation Loss: 0.0225294440
Validation loss decreased (0.0229606114 --> 0.0225294440).  Saving model ...
 loss >= 0.00000
Epoch: 31 	Training Loss: 0.3617297688 	Validation Loss: 0.0221091609
Validation loss decreased (0.0225294440 --> 0.0221091609).  Saving model ...
 loss >= 0.00000
Epoch: 32 	Training Loss: 0.3555669700 	Validation Loss: 0.0217319235
Validation loss decreased (0.0221091609 --> 0.0217319235).  Saving model ...
 loss >= 0.00000
Epoch: 33 	Train

In [25]:
model = Gmul().to(device)
model.load_state_dict(torch.load("gmul.pt"))

<All keys matched successfully>

In [26]:
model.eval()
X, Y = next(iter(train_loader_triple))

# print(X,Y)
results = 0
results_n = 0
with torch.no_grad():
    O = model(X)
for x,y in zip(O,Y):
    a = lin_to_triple(x.cpu().detach().numpy())
    b = lin_to_triple(y.cpu().detach().numpy())
    # b = int(y[0]) pos_to_int
    # b = noise_to_int(y)
    # print(x,y)
    # print(a,b)
    if a==b:
        
        results +=1
    results_n +=1
print(f"{results/results_n=}")

results/results_n=1.0


In [16]:
model.eval()
X, Y = next(iter(train_loader))

# print(X,Y)
results = 0
results_n = 0
with torch.no_grad():
    O = model(X)
for x,y in zip(O,Y):
    a = lin_to_tuple(x.cpu().detach().numpy())
    b = lin_to_tuple(y.cpu().detach().numpy())
    # b = int(y[0]) pos_to_int
    # b = noise_to_int(y)
    # print(x,y)
    # print(a,b)
    if a==b:
        
        results +=1
    results_n +=1
print(f"{results/results_n=}")

results/results_n=1.0


In [18]:
model.eval()
X, Y = next(iter(train_loader))
torch.cuda.empty_cache()
# print(X,Y)
results = 0
results_n = 0
O = X
n = 500
with torch.no_grad():
    for i in range(n):
        O = model(O)
for x,y in zip(X,O):
    # print(x,y)
    a = lin_to_tuple(x.cpu().detach().numpy())
    b = lin_to_tuple(y.cpu().detach().numpy())
    # b = int(y[0]) pos_to_int
    # b = noise_to_int(y)
    
    r = a
    for i in range(n):
        r = mixColumns(*r)
    # print(r,b)
    if r==b:
        
        results +=1
    results_n +=1
print(f"after {n} mix columns")
print(f"{results/results_n=}")
# print(model.state_dict())

RuntimeError: mat1 and mat2 shapes cannot be multiplied (500x32 and 8x24)

In [206]:

import torch.nn.functional as F

class MixColumns4(nn.Module):
    def __init__(self):
        super(MixColumns4, self).__init__()
        
        self.body = nn.Sequential(
            nn.Linear(4*8*4,4000),
            nn.Sigmoid(),
            # nn.Linear(1000,1000),
            # nn.Sigmoid(),
            nn.Linear(4000,4*8*4),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        out = self.body(x)
        return out

In [207]:
model4 = MixColumns4().to(device)
print(model4.body[0].weight.shape)
print( model.body[0].weight.shape)
model4.eval()
w = model4.body[0].weight.clone()
for i in range(4):
    for j in range(4):
        w[1000*i:1000*(i+1),8*i+32*j:8*(i+1)+32*j] = model.body[0].weight.clone()[:,8*j:8*(j+1)]
    # w[1000*i:1000*(i+1),32*i:32*(i+1)] = model.body[0].weight.clone()
model4.body[0].weight = nn.Parameter(w)

print(model4.body[0].bias.shape)
print( model.body[0].bias.shape)

w = model4.body[0].bias.clone()
for i in range(4):
    # for j in range(4):
    #     w[250*i+1000*j:250*(i+1)+1000*j] = model.body[0].bias.clone()[250*j:250*(j+1)]
    w[1000*i:1000*(i+1)] = model.body[0].bias.clone()
model4.body[0].bias = nn.Parameter(w)


print(model4.body[2].weight.shape)
print( model.body[2].weight.shape)

w = model4.body[2].weight.clone()
for i in range(4):
    for j in range(4):
        w[8*i+32*j:8*(i+1)+32*j,1000*i:1000*(i+1)] = model.body[2].weight.clone()[8*j:8*(j+1)]
    # w[32*i:32*(i+1),1000*i:1000*(i+1)] = model.body[2].weight.clone()
model4.body[2].weight = nn.Parameter(w)

print(model4.body[2].bias.shape)
print( model.body[2].bias.shape)

w = model4.body[2].bias.clone()
for i in range(4):
    for j in range(4):
        w[8*i+32*j:8*(i+1)+32*j] = model.body[2].bias.clone()[8*j:8*(j+1)]
    # w[32*i:32*(i+1)] = model.body[2].bias.clone()
model4.body[2].bias = nn.Parameter(w)

torch.save(model4.state_dict(), "mixcolumns4.pt")


torch.Size([4000, 128])
torch.Size([1000, 32])
torch.Size([4000])
torch.Size([1000])
torch.Size([128, 4000])
torch.Size([32, 1000])
torch.Size([128])
torch.Size([32])


In [208]:
model.eval()
X, Y = next(iter(train_loader))

# print(X,Y)
results = 0
results_n = 0
print(X[:4])
X = X.reshape(-1,4,4,8).transpose(1,2).reshape(-1,4*8*4)
print(X[:4])
Y = Y.reshape(-1,4,4,8).transpose(1,2).reshape(-1,4*8*4)
O = model4(X)
for x,y in zip(O,Y):
    a = lin_to_list(x.cpu().detach().numpy())
    b = lin_to_list(y.cpu().detach().numpy())
    # b = int(y[0]) pos_to_int
    # b = noise_to_int(y)
    # print(x,y)
    
    if a==b:
        
        results +=1
    else:
        print(a,b)
    results_n +=1
print(f"{results/results_n=}")

tensor([[1., 0., 1., 0., 0., 0., 0., 1., 0., 1., 1., 1., 1., 1., 0., 0., 1., 1.,
         0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 0., 1.],
        [1., 1., 0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 1., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0.],
        [1., 1., 0., 0., 1., 1., 0., 1., 0., 1., 1., 0., 0., 1., 1., 1., 0., 0.,
         1., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 1.],
        [1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         1., 0., 0., 1., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1.]],
       device='cuda:0')
tensor([[1., 0., 1., 0., 0., 0., 0., 1., 1., 1., 0., 0., 0., 1., 0., 1., 1., 1.,
         0., 0., 1., 1., 0., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 1., 1., 1.,
         1., 1., 0., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 1., 0., 0., 1.,
         1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0., 0., 1., 1., 1., 0.,
         0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 1., 1., 1.

In [209]:
model = MixColumns4().to(device)
model.load_state_dict(torch.load("mixcolumns4.pt"))

<All keys matched successfully>

In [151]:
model = model4
results = 0
results_n = 0

while True:
    n = 1000
    # create 4 x 4 random bytes
    x = torch.randint(0,256,(16, n,)).to(device).view(-1)
    # x = torch.tensor([251,  15, 208, 201, 251,  45, 248,  51, 187, 117, 138, 165, 251,  75,
        # 113, 214])
    # print(x)
    # to binary
    inp = [torch.tensor([float(a) for a in bin(i)[2:].rjust(8,"0")], dtype=torch.float32).to(device) for i in x.view(-1)]
    # stack
    inp = torch.stack(inp, dim=0)
    inp = inp.reshape(-1,4,4,8).transpose(1,2).reshape(-1,4*8*4)
    # print(inp.shape)

    # compute mix columns every 4 bytes
    Y = []
    X = []
    for j in range(0,n*16,4*4):
        y = []
        xx = []
        for k in range(4):
            y.append(mixColumns(*x[j+k*4:j+k*4+4]))
            xx += x[j+k*4:j+k*4+4]
        
        Y.append(torch.tensor(y).reshape(4,4).transpose(0,1).reshape(-1))
        X.append(torch.tensor(xx).reshape(4,4).transpose(0,1).reshape(-1))

    with torch.no_grad():
        O = model(inp)
    
    for o,y,x in zip(O,Y,X):
        o = lin_to_list(o.cpu().detach().numpy())
        if any(torch.tensor(o)!=y):
            print(x)
            print(o)
            print(y)
            print()
            print()
            raise Exception("error")
        results +=1
        results_n +=1

    print(f"{results/results_n=}", results, results_n)


results/results_n=1.0 1000 1000
results/results_n=1.0 2000 2000
results/results_n=1.0 3000 3000
results/results_n=1.0 4000 4000
results/results_n=1.0 5000 5000
results/results_n=1.0 6000 6000
results/results_n=1.0 7000 7000
results/results_n=1.0 8000 8000
results/results_n=1.0 9000 9000
results/results_n=1.0 10000 10000
results/results_n=1.0 11000 11000
results/results_n=1.0 12000 12000
results/results_n=1.0 13000 13000
results/results_n=1.0 14000 14000
results/results_n=1.0 15000 15000
results/results_n=1.0 16000 16000
results/results_n=1.0 17000 17000
results/results_n=1.0 18000 18000
results/results_n=1.0 19000 19000
results/results_n=1.0 20000 20000
results/results_n=1.0 21000 21000
results/results_n=1.0 22000 22000
results/results_n=1.0 23000 23000
results/results_n=1.0 24000 24000
results/results_n=1.0 25000 25000
results/results_n=1.0 26000 26000
results/results_n=1.0 27000 27000
results/results_n=1.0 28000 28000
results/results_n=1.0 29000 29000
results/results_n=1.0 30000 3000

KeyboardInterrupt: 

In [27]:
import torch.nn.functional as F

class Xor(nn.Module):
    def __init__(self):     
        super(Xor, self).__init__()   
        self.body = nn.Sequential(
            nn.Linear(2,2),
            nn.Sigmoid(),
            nn.Linear(2,1),
            nn.Sigmoid()
        )
        
    def forward(self, x):
        out = self.body(x)
        return out

In [28]:
xor = Xor().to(device)
xor.load_state_dict(torch.load("xor.pt"))

<All keys matched successfully>

In [58]:
class MixColumns(nn.Module):
    def __init__(self):
        super(MixColumns, self).__init__()
        
        self.body = nn.Sequential(
            nn.Linear(8*4,8*3*4), # 4 gmul 1
            nn.Sigmoid(),
            nn.Linear(8*3*4,8*3*4), # 4 gmul 2
            nn.Sigmoid(),
            # here we have 9 8 bit: gmul(a, 1), gmul(a, 2), gmul(a, 3),     gmul(b, 1), gmul(b, 2), gmul(b, 3),     gmul(c, 1), gmul(c, 2), gmul(c, 3),     gmul(d, 1), gmul(d, 2), gmul(d, 3)
            nn.Linear(8*3*4, 8*2*8), # 8 xor 1: 
            nn.Sigmoid(),
            nn.Linear(8*2*8, 8*2*8), # 8 xor 2
            nn.Sigmoid(),
            # here we have 8 8 bit: gmul(c, 1) ^ gmul(d, 1),   gmul(d, 1) ^ gmul(a, 1),   gmul(a, 1) ^ gmul(b, 1),   gmul(b, 1) ^ gmul(c, 1),   gmul(a, 2) ^ gmul(b, 3),   gmul(b, 2) ^ gmul(c, 3),   gmul(c, 2) ^ gmul(d, 3),   gmul(d, 2) ^ gmul(a, 3)
            nn.Linear(8*2*8, 8*4*4), # 4 xor 1
            nn.Sigmoid(),
            nn.Linear(8*4*4, 8*4), # 4 xor 2
            nn.Sigmoid(),
            # here we have 4 8 bit: gmul(a, 2) ^ gmul(b, 3) ^ gmul(c, 1) ^ gmul(d, 1),   gmul(a, 1) ^ gmul(b, 2) ^ gmul(c, 3) ^ gmul(d, 1),   gmul(a, 1) ^ gmul(b, 1) ^ gmul(c, 2) ^ gmul(d, 3),   gmul(a, 3) ^ gmul(b, 1) ^ gmul(c, 1) ^ gmul(d, 2)
        )
        
    def forward(self, x):
        out = self.body(x)
        return out

In [146]:
model_gmul = Gmul().to(device)
model_gmul.load_state_dict(torch.load("gmul.pt"))
xor = Xor().to(device)
xor.load_state_dict(torch.load("xor.pt"))

model = MixColumns().to(device)
model.eval()

# out is a dict with the indexes of the output layer and the indexes of the input layer
def multiply_linear_layer(layer, out: dict, size_in, size_out, dest_layer):
    w = dest_layer.weight.clone()
    b = dest_layer.bias.clone()

    for i, inp in out.items():
        for j in inp:
            w[size_out*i:size_out*(i+1),size_in*j:size_in*(j+1)] = layer.weight.clone()
        b[size_out*i:size_out*(i+1)] = layer.bias.clone()
    dest_layer.weight = nn.Parameter(w)
    dest_layer.bias = nn.Parameter(b)

def xor_linear_layer(xor: Xor, out: list, dest_layer1, dest_layer2):
    w1 = dest_layer1.weight.clone()
    b1 = dest_layer1.bias.clone()
    w2 = dest_layer2.weight.clone()
    b2 = dest_layer2.bias.clone()

    # i and o for output positions (2 positions for first layer, 1 for second layer)
    # x and y for input positions
    for i,(x,y) in enumerate(out):
        o = i*2
        w1[o:o+1, x:x+1] = xor.body[0].weight.clone()[0,0]
        w1[o:o+1, y:y+1] = xor.body[0].weight.clone()[0,1]
        w1[o+1:o+2, x:x+1] = xor.body[0].weight.clone()[1,0]
        w1[o+1:o+2, y:y+1] = xor.body[0].weight.clone()[1,1]
        b1[o:o+1] = xor.body[0].bias.clone()[0]
        b1[o+1:o+2] = xor.body[0].bias.clone()[1]

        w2[i:i+1, o:o+1] = xor.body[2].weight.clone()[0,0]
        w2[i:i+1, o+1:o+2] = xor.body[2].weight.clone()[0,1]
        b2[i:i+1] = xor.body[2].bias.clone()[0]

    dest_layer1.weight = nn.Parameter(w1)
    dest_layer1.bias = nn.Parameter(b1)
    dest_layer2.weight = nn.Parameter(w2)
    dest_layer2.bias = nn.Parameter(b2)

# set to 0
for layer in model.body:
    if isinstance(layer, nn.Linear):
        layer.weight.data.fill_(0)
        layer.bias.data.fill_(0)


multiply_linear_layer(model_gmul.body[0], {0:[0],1:[1],2:[2],3:[3]}, 8, 8*3, model.body[0])
multiply_linear_layer(model_gmul.body[2], {0:[0],1:[1],2:[2],3:[3]}, 8*3, 8*3, model.body[2])

def str_2_pos(s):
    l = s[0]
    n = int(s[1])-1
    if l == "a":
        return 0+n
    if l == "b":
        return 3+n
    if l == "c":
        return 6+n
    if l == "d":
        return 9+n

# gmul(c, 1) ^ gmul(d, 1),   gmul(d, 1) ^ gmul(a, 1),   gmul(a, 1) ^ gmul(b, 1),   gmul(b, 1) ^ gmul(c, 1),   gmul(a, 2) ^ gmul(b, 3),   gmul(b, 2) ^ gmul(c, 3),   gmul(c, 2) ^ gmul(d, 3),   gmul(d, 2) ^ gmul(a, 3)
positions = [("c1", "d1"), ("d1", "a1"), ("a1", "b1"), ("b1", "c1"), ("a2", "b3"), ("b2", "c3"), ("c2", "d3"), ("d2", "a3")]
positions = [(str_2_pos(a), str_2_pos(b)) for a,b in positions]
# print(positions)
positions = [(x*8+k,y*8+k) for x,y in positions for k in range(8)]
# print(positions)

xor_linear_layer(xor, positions, model.body[4], model.body[6])

def str_2_pos(s):
    l = s[0]
    n = int(s[1])-1
    if l == "a":
        return 0+n
    if l == "b":
        return 3+n
    if l == "c":
        return 6+n
    if l == "d":
        return 9+n

positions = [(0,4), (1,5), (2,6), (3,7)]
positions = [(x*8+k,y*8+k) for x,y in positions for k in range(8)]
# print(positions)
xor_linear_layer(xor, positions, model.body[8], model.body[10])



In [147]:
# test model
model.eval()
results = 0
results_n = 0

while True:
    x = torch.randint(0,256,(4,)).to(device)
    y = mixColumns(*x)
    x = dec2bin(x,8).to(device)
    X = x.view(1,4*8)

    with torch.no_grad():
        O = model(X)
    for x,y in zip(O,[y]):
        a = lin_to_list(x.cpu().detach().numpy())
        # b = lin_to_list(y.cpu().detach().numpy())
        # b = int(y[0]) pos_to_int
        # b = noise_to_int(y)
        # print(x,y)
        # print(a,b)
        b = list(y)
        if a==b:
            
            results +=1
        else:
            print(a,b)
        results_n +=1
    print(f"{results/results_n=}", results, results_n)

results/results_n=1.0 1 1
results/results_n=1.0 2 2
results/results_n=1.0 3 3
results/results_n=1.0 4 4
results/results_n=1.0 5 5
results/results_n=1.0 6 6
results/results_n=1.0 7 7
results/results_n=1.0 8 8
results/results_n=1.0 9 9
results/results_n=1.0 10 10
results/results_n=1.0 11 11
results/results_n=1.0 12 12
results/results_n=1.0 13 13
results/results_n=1.0 14 14
results/results_n=1.0 15 15
results/results_n=1.0 16 16
results/results_n=1.0 17 17
results/results_n=1.0 18 18
results/results_n=1.0 19 19
results/results_n=1.0 20 20
results/results_n=1.0 21 21
results/results_n=1.0 22 22
results/results_n=1.0 23 23
results/results_n=1.0 24 24
results/results_n=1.0 25 25
results/results_n=1.0 26 26
results/results_n=1.0 27 27
results/results_n=1.0 28 28
results/results_n=1.0 29 29
results/results_n=1.0 30 30
results/results_n=1.0 31 31
results/results_n=1.0 32 32
results/results_n=1.0 33 33
results/results_n=1.0 34 34
results/results_n=1.0 35 35
results/results_n=1.0 36 36
results/re

KeyboardInterrupt: 

In [148]:
class MixColumns4(nn.Module):
    def __init__(self):
        super(MixColumns4, self).__init__()
        
        self.body = nn.Sequential(
            nn.Linear(8*4 *4, 8*3*4 *4), # 4 gmul 1
            nn.Sigmoid(),
            nn.Linear(8*3*4 *4, 8*3*4 *4), # 4 gmul 2
            nn.Sigmoid(),
            # here we have 9 8 bit: gmul(a, 1), gmul(a, 2), gmul(a, 3),     gmul(b, 1), gmul(b, 2), gmul(b, 3),     gmul(c, 1), gmul(c, 2), gmul(c, 3),     gmul(d, 1), gmul(d, 2), gmul(d, 3)
            nn.Linear(8*3*4 *4, 8*2*8 *4), # 8 xor 1: 
            nn.Sigmoid(),
            nn.Linear(8*2*8 *4, 8*2*8 *4), # 8 xor 2
            nn.Sigmoid(),
            # here we have 8 8 bit: gmul(c, 1) ^ gmul(d, 1),   gmul(d, 1) ^ gmul(a, 1),   gmul(a, 1) ^ gmul(b, 1),   gmul(b, 1) ^ gmul(c, 1),   gmul(a, 2) ^ gmul(b, 3),   gmul(b, 2) ^ gmul(c, 3),   gmul(c, 2) ^ gmul(d, 3),   gmul(d, 2) ^ gmul(a, 3)
            nn.Linear(8*2*8 *4, 8*4*4 *4), # 4 xor 1
            nn.Sigmoid(),
            nn.Linear(8*4*4 *4, 8*4 *4), # 4 xor 2
            nn.Sigmoid(),
            # here we have 4 8 bit: gmul(a, 2) ^ gmul(b, 3) ^ gmul(c, 1) ^ gmul(d, 1),   gmul(a, 1) ^ gmul(b, 2) ^ gmul(c, 3) ^ gmul(d, 1),   gmul(a, 1) ^ gmul(b, 1) ^ gmul(c, 2) ^ gmul(d, 3),   gmul(a, 3) ^ gmul(b, 1) ^ gmul(c, 1) ^ gmul(d, 2)
        )
        
    def forward(self, x):
        out = self.body(x)
        return out

In [149]:
model4 = MixColumns4().to(device)
model4.eval()

for layer in model4.body:
    if isinstance(layer, nn.Linear):
        layer.weight.data.fill_(0)
        layer.bias.data.fill_(0)

sizes = [8*4, 8*3*4, 8*3*4, 8*2*8, 8*2*8, 8*4*4, 8*4]


for i in range(len(sizes)-1):
    print("from:", sizes[i], "to:", sizes[i+1])
    multiply_linear_layer(model.body[i*2], {0:[0],1:[1],2:[2],3:[3]}, sizes[i], sizes[i+1], model4.body[i*2])


# interleave input and output like:
# 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# 1 5 9 13 2 6 10 14 3 7 11 15 4 8 12 16


w = model4.body[-2].weight.clone()
b = model4.body[-2].bias.clone()

for i in range(4):
    for j in range(4):
        x = i*4+j
        y = j*4+i
        print(x,y)
        w[8*x:8*(x+1)] = model4.body[-2].weight.clone()[8*y:8*(y+1)]
        b[x] = model4.body[-2].bias.clone()[y]
            
model4.body[-2].weight = nn.Parameter(w)
model4.body[-2].bias = nn.Parameter(b)

w = model4.body[0].weight.clone()
b = model4.body[0].bias.clone()

for i in range(4):
    for j in range(4):
        x = i*4+j
        y = j*4+i
        print(x,y)
        w[8*3*x:8*3*(x+1)] = model4.body[0].weight.clone()[8*3*y:8*3*(y+1)]
        b[8*3*x:8*3*(x+1)] = model4.body[0].bias.clone()[8*3*y:8*3*(y+1)]
            
model4.body[0].weight = nn.Parameter(w)
model4.body[0].bias = nn.Parameter(b)


from: 32 to: 96
from: 96 to: 96
from: 96 to: 128
from: 128 to: 128
from: 128 to: 128
from: 128 to: 32
0 0
1 4
2 8
3 12
4 1
5 5
6 9
7 13
8 2
9 6
10 10
11 14
12 3
13 7
14 11
15 15
0 0
1 4
2 8
3 12
4 1
5 5
6 9
7 13
8 2
9 6
10 10
11 14
12 3
13 7
14 11
15 15


In [150]:
model4.eval()
results = 0
results_n = 0
for X, Y in train_loader:
    # print(X,Y)
    
    # print(X[:4])
    X = X.reshape(-1,4,4,8).transpose(1,2).reshape(-1,4*8*4)
    # print(X[:4])
    Y = Y.reshape(-1,4,4,8).transpose(1,2).reshape(-1,4*8*4)
    print(X.shape)
    O = model4(X)
    for x,y in zip(O,Y):
        a = lin_to_list(x.cpu().detach().numpy())
        b = lin_to_list(y.cpu().detach().numpy())
        # b = int(y[0]) pos_to_int
        # b = noise_to_int(y)
        # print(x,y)
        
        if a==b:
            
            results +=1
        else:
            print(x)
            print(y)
            print(a,b)
        results_n +=1
print(f"{results/results_n=}", results, results_n)

torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
torch.Size([125, 128])
results/results_n=1.0 2250 2250


In [155]:
d = model4.state_dict()
for k,v in d.items():
    d[k] = v.to_sparse()
torch.save(d, "mixcolumns4.pt")

In [158]:
w = torch.load("mixcolumns4.pt")
for k in w:
    w[k] = w[k].to_dense()

model4 = MixColumns4().to(device)
model4.load_state_dict(w)


<All keys matched successfully>

: 