In [1]:
import torch
import torch.nn as nn
device = "cuda"
device = torch.device(device)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# S-box and Inverse S-box (S is for Substitution)
S = [ 0x63, 0x7c, 0x77, 0x7b, 0xf2, 0x6b, 0x6f, 0xc5, 0x30, 0x01, 0x67, 0x2b, 0xfe, 0xd7, 0xab, 0x76, 0xca, 0x82, 0xc9, 0x7d, 0xfa, 0x59, 0x47, 0xf0, 0xad, 0xd4, 0xa2, 0xaf, 0x9c, 0xa4, 0x72, 0xc0, 0xb7, 0xfd, 0x93, 0x26, 0x36, 0x3f, 0xf7, 0xcc, 0x34, 0xa5, 0xe5, 0xf1, 0x71, 0xd8, 0x31, 0x15, 0x04, 0xc7, 0x23, 0xc3, 0x18, 0x96, 0x05, 0x9a, 0x07, 0x12, 0x80, 0xe2, 0xeb, 0x27, 0xb2, 0x75, 0x09, 0x83, 0x2c, 0x1a, 0x1b, 0x6e, 0x5a, 0xa0, 0x52, 0x3b, 0xd6, 0xb3, 0x29, 0xe3, 0x2f, 0x84, 0x53, 0xd1, 0x00, 0xed, 0x20, 0xfc, 0xb1, 0x5b, 0x6a, 0xcb, 0xbe, 0x39, 0x4a, 0x4c, 0x58, 0xcf, 0xd0, 0xef, 0xaa, 0xfb, 0x43, 0x4d, 0x33, 0x85, 0x45, 0xf9, 0x02, 0x7f, 0x50, 0x3c, 0x9f, 0xa8, 0x51, 0xa3, 0x40, 0x8f, 0x92, 0x9d, 0x38, 0xf5, 0xbc, 0xb6, 0xda, 0x21, 0x10, 0xff, 0xf3, 0xd2, 0xcd, 0x0c, 0x13, 0xec, 0x5f, 0x97, 0x44, 0x17, 0xc4, 0xa7, 0x7e, 0x3d, 0x64, 0x5d, 0x19, 0x73, 0x60, 0x81, 0x4f, 0xdc, 0x22, 0x2a, 0x90, 0x88, 0x46, 0xee, 0xb8, 0x14, 0xde, 0x5e, 0x0b, 0xdb, 0xe0, 0x32, 0x3a, 0x0a, 0x49, 0x06, 0x24, 0x5c, 0xc2, 0xd3, 0xac, 0x62, 0x91, 0x95, 0xe4, 0x79, 0xe7, 0xc8, 0x37, 0x6d, 0x8d, 0xd5, 0x4e, 0xa9, 0x6c, 0x56, 0xf4, 0xea, 0x65, 0x7a, 0xae, 0x08, 0xba, 0x78, 0x25, 0x2e, 0x1c, 0xa6, 0xb4, 0xc6, 0xe8, 0xdd, 0x74, 0x1f, 0x4b, 0xbd, 0x8b, 0x8a, 0x70, 0x3e, 0xb5, 0x66, 0x48, 0x03, 0xf6, 0x0e, 0x61, 0x35, 0x57, 0xb9, 0x86, 0xc1, 0x1d, 0x9e, 0xe1, 0xf8, 0x98, 0x11, 0x69, 0xd9, 0x8e, 0x94, 0x9b, 0x1e, 0x87, 0xe9, 0xce, 0x55, 0x28, 0xdf, 0x8c, 0xa1, 0x89, 0x0d, 0xbf, 0xe6, 0x42, 0x68, 0x41, 0x99, 0x2d, 0x0f, 0xb0, 0x54, 0xbb, 0x16 ]
Si =[ 0x52, 0x09, 0x6a, 0xd5, 0x30, 0x36, 0xa5, 0x38, 0xbf, 0x40, 0xa3, 0x9e, 0x81, 0xf3, 0xd7, 0xfb, 0x7c, 0xe3, 0x39, 0x82, 0x9b, 0x2f, 0xff, 0x87, 0x34, 0x8e, 0x43, 0x44, 0xc4, 0xde, 0xe9, 0xcb, 0x54, 0x7b, 0x94, 0x32, 0xa6, 0xc2, 0x23, 0x3d, 0xee, 0x4c, 0x95, 0x0b, 0x42, 0xfa, 0xc3, 0x4e, 0x08, 0x2e, 0xa1, 0x66, 0x28, 0xd9, 0x24, 0xb2, 0x76, 0x5b, 0xa2, 0x49, 0x6d, 0x8b, 0xd1, 0x25, 0x72, 0xf8, 0xf6, 0x64, 0x86, 0x68, 0x98, 0x16, 0xd4, 0xa4, 0x5c, 0xcc, 0x5d, 0x65, 0xb6, 0x92, 0x6c, 0x70, 0x48, 0x50, 0xfd, 0xed, 0xb9, 0xda, 0x5e, 0x15, 0x46, 0x57, 0xa7, 0x8d, 0x9d, 0x84, 0x90, 0xd8, 0xab, 0x00, 0x8c, 0xbc, 0xd3, 0x0a, 0xf7, 0xe4, 0x58, 0x05, 0xb8, 0xb3, 0x45, 0x06, 0xd0, 0x2c, 0x1e, 0x8f, 0xca, 0x3f, 0x0f, 0x02, 0xc1, 0xaf, 0xbd, 0x03, 0x01, 0x13, 0x8a, 0x6b, 0x3a, 0x91, 0x11, 0x41, 0x4f, 0x67, 0xdc, 0xea, 0x97, 0xf2, 0xcf, 0xce, 0xf0, 0xb4, 0xe6, 0x73, 0x96, 0xac, 0x74, 0x22, 0xe7, 0xad, 0x35, 0x85, 0xe2, 0xf9, 0x37, 0xe8, 0x1c, 0x75, 0xdf, 0x6e, 0x47, 0xf1, 0x1a, 0x71, 0x1d, 0x29, 0xc5, 0x89, 0x6f, 0xb7, 0x62, 0x0e, 0xaa, 0x18, 0xbe, 0x1b, 0xfc, 0x56, 0x3e, 0x4b, 0xc6, 0xd2, 0x79, 0x20, 0x9a, 0xdb, 0xc0, 0xfe, 0x78, 0xcd, 0x5a, 0xf4, 0x1f, 0xdd, 0xa8, 0x33, 0x88, 0x07, 0xc7, 0x31, 0xb1, 0x12, 0x10, 0x59, 0x27, 0x80, 0xec, 0x5f, 0x60, 0x51, 0x7f, 0xa9, 0x19, 0xb5, 0x4a, 0x0d, 0x2d, 0xe5, 0x7a, 0x9f, 0x93, 0xc9, 0x9c, 0xef, 0xa0, 0xe0, 0x3b, 0x4d, 0xae, 0x2a, 0xf5, 0xb0, 0xc8, 0xeb, 0xbb, 0x3c, 0x83, 0x53, 0x99, 0x61, 0x17, 0x2b, 0x04, 0x7e, 0xba, 0x77, 0xd6, 0x26, 0xe1, 0x69, 0x14, 0x63, 0x55, 0x21, 0x0c, 0x7d ] 


In [3]:
import torch.nn.functional as F
noise = 0.0
noise_n = 100

train_data = []
for i in range(256):
    for _ in range(noise_n):
        b = torch.tensor([float(a) for a in bin(i)[2:].rjust(8,"0")], dtype=torch.float32)
        b += (torch.FloatTensor(8).uniform_(-noise, +noise))
        b = b.to(device)

        res = torch.tensor([float(a) for a in bin(S[i])[2:].rjust(8,"0")], dtype=torch.float32)
        res += (torch.FloatTensor(8).uniform_(-noise, +noise))
        res = res.to(device)
        train_data.append((b,res))

train_data_one_hot = []
for i in range(256):
    for _ in range(noise_n):
        b = torch.tensor([float(a) for a in bin(i)[2:].rjust(8,"0")], dtype=torch.float32)
        b += (torch.FloatTensor(8).uniform_(-noise, +noise))
        b = b.to(device)
        res = i
        res = F.one_hot(torch.tensor([res]),256)[0]
        res = res.to(torch.float32).to(device)

        train_data_one_hot.append((b,res))

train_data_reverse_one_hot = []
for i in range(256):
    for _ in range(noise_n):
        b = torch.tensor([float(a) for a in bin(i)[2:].rjust(8,"0")], dtype=torch.float32)
        b += (torch.FloatTensor(8).uniform_(-noise, +noise))
        b = b.to(device)
        
        res = i
        res = F.one_hot(torch.tensor([res]),256)[0]
        res = res.to(torch.float32).to(device)

        train_data_reverse_one_hot.append((res,b))

train_data_sbox = []
for i in range(256):
    for _ in range(noise_n):
        b = F.one_hot(torch.tensor([i]),256)[0]
        b = b.to(torch.float32).to(device)
        
        res = S[i]
        res = F.one_hot(torch.tensor([res]),256)[0]
        res = res.to(torch.float32).to(device)
        
        train_data_sbox.append((b,res))

In [4]:
train_data[100]

(tensor([0., 0., 0., 0., 0., 0., 0., 1.], device='cuda:0'),
 tensor([0., 1., 1., 1., 1., 1., 0., 0.], device='cuda:0'))

In [5]:
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data.sampler import SubsetRandomSampler
import numpy as np

def get_loaders(train_data, device=device):
    test_size = 0.05
    valid_size = 0.05
    batch_size = 100
    num_workers = 0

    #cuda or cpu
    device = torch.device(device)

    num_train = len(train_data)
    indices = list(range(num_train))
    np.random.shuffle(indices)
    split = int(np.floor(test_size * num_train))
    split2 = int(np.floor((valid_size+test_size) * num_train))
    train_idx, valid_idx, test_idx = indices[split2:], indices[split:split2], indices[:split]

    train_sampler = SubsetRandomSampler(train_idx)
    valid_sampler = SubsetRandomSampler(valid_idx)
    test_sampler = SubsetRandomSampler(test_idx)

    # prepare data loaders
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=train_sampler, num_workers=num_workers)
    valid_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=valid_sampler, num_workers=num_workers)
    test_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, sampler=test_sampler, num_workers=num_workers)
    return train_loader, valid_loader, test_loader

train_loader, valid_loader, test_loader = get_loaders(train_data)
train_loader_one_hot, valid_loader_one_hot, test_loader_one_hot = get_loaders(train_data_one_hot)
train_loader_reverse_one_hot, valid_loader_reverse_one_hot, test_loader_reverse_one_hot = get_loaders(train_data_reverse_one_hot)
train_loader_sbox, valid_loader_sbox, test_loader_sbox = get_loaders(train_data_sbox)

In [6]:
import numpy as np
def noise_to_int(bits):
    bits = [round(float(b)) for b in bits]
    bits = "".join([str(b) if b in [0,1] else "0" if b<0 else "1" for b in bits])
    return int(bits,2)
def pos_to_int(vec):
    return int(np.argmax(vec))
# def pos_to_int(vec):
#     v = int(np.argmax(vec))
#     if vec[v] == 1.0:
#         return v
#     return -1

In [7]:
X, Y = next(iter(train_loader_one_hot))

# print(X,Y)
results = 0
for x,y in zip(X,Y):
    a = noise_to_int(x)
    b = pos_to_int(y.cpu())
    # b = int(y[0])
    # b = noise_to_int(y)
    print(a,b)
    if a==b:
        results +=1

105 105
133 133
171 171
92 92
40 40
88 88
100 100
219 219
152 152
214 214
159 159
91 91
198 198
208 208
112 112
162 162
62 62
22 22
17 17
130 130
96 96
14 14
32 32
247 247
19 19
203 203
74 74
109 109
154 154
116 116
53 53
211 211
85 85
189 189
226 226
180 180
18 18
222 222
113 113
178 178
155 155
106 106
171 171
251 251
245 245
30 30
246 246
29 29
121 121
230 230
94 94
171 171
102 102
29 29
159 159
110 110
41 41
119 119
140 140
135 135
116 116
234 234
195 195
45 45
122 122
255 255
66 66
49 49
215 215
154 154
126 126
174 174
181 181
216 216
188 188
100 100
242 242
171 171
70 70
144 144
109 109
93 93
12 12
0 0
49 49
29 29
137 137
205 205
14 14
34 34
195 195
219 219
81 81
153 153
180 180
183 183
54 54
76 76
141 141
69 69


In [8]:
import torch.nn.functional as F
class OneHot(nn.Module):
    def __init__(self):
        super(OneHot, self).__init__()
        self.body = nn.Sequential(
            nn.Linear(8,256),
        )
    
    def forward(self, x):
        o = self.body(x)
        return o

class ReverseOneHot(nn.Module):
    def __init__(self):
        super(ReverseOneHot, self).__init__()
        self.body = nn.Sequential(
            nn.Linear(256,8),
        )
    
    def forward(self, x):
        x = self.body(x)
        return x

class SboxOneHot(nn.Module):
    def __init__(self):
        super(SboxOneHot, self).__init__()
        self.body = nn.Sequential(
            nn.Linear(256,256),
        )
    
    def forward(self, x):
        x = self.body(x)
        return x

class ArgMax(torch.nn.Module):
    def __init__(self):
        super(ArgMax,self).__init__()

    def forward(self, x) -> torch.Tensor:
        pred = torch.argmax(x, dim=1)
        return torch.zeros_like(x).scatter_(1, pred.unsqueeze(1), 1.)
        

class Sbox(nn.Module):
    def __init__(self):
        super(Sbox, self).__init__()
        
        self.body = nn.Sequential(
            OneHot(),
            ArgMax(),
            SboxOneHot(),
            ReverseOneHot(),
        )
        
    def forward(self, x):
        out = self.body(x)
        return out

In [11]:
def train(save_file, model, criterion, train_loader, valid_loader, optimizer=None, n_epochs = 100000, f=pos_to_int, lrate=0.005):
    # number of epochs to train the model

    if optimizer is None:
        # specify optimizer (stochastic gradient descent) and learning rate = 0.001
        optimizer = torch.optim.Adam(model.parameters(), lr=lrate)#, weight_decay=0.00000001)

    # initialize tracker for minimum validation loss
    valid_loss_min = np.Inf # set initial "min" to infinity
    
    for epoch in range(n_epochs):
        # monitor training loss
        train_loss = 0.0
        valid_loss = 0.0
        results = 0
        results_n = 0
        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        i=0
        for X, target in train_loader:
            i+=1
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            target = target.to(device)
            output = model(X)
            # calculate the loss
            # print(output)
            # print(target)
            loss = criterion(output, target) #
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # update running training loss
            train_loss += loss.item()*X.size(0)
            if epoch%100 == 0:
                for x,y in zip(output,target):
                    # print(x.cpu().detach().numpy(),y)
                    a = f(x.cpu().detach().numpy())
                    # a = int(x[0])
                    b = f(y.cpu().detach().numpy())
                    # b = int(y[0])
                    # a = noise_to_int(x)
                    # b = noise_to_int(y)
                    
                    
                    # print(a,b)
                    # print(float(x[0]),float(y[0]))
                    if a==b:
                        
                        results +=1
                    results_n+=1
        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        for X, target in valid_loader:
        
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(X)
            # target = target.to(device)
            # calculate the loss
            loss = criterion(output, target)
            # update running validation loss
            valid_loss += loss.item()*X.size(0)
            

        # print training/validation statistics
        # calculate average loss over an epoch
        train_loss = train_loss/len(train_loader.dataset)
        valid_loss = valid_loss/len(valid_loader.dataset)

        print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
            epoch+1,
            train_loss,
            valid_loss
            ))

        # save model if validation loss has decreased
        if valid_loss <= valid_loss_min:
            print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
                valid_loss_min,
                valid_loss))
            torch.save(model.state_dict(), save_file)
            valid_loss_min = valid_loss
            if train_loss <= 0.0000001:
                print("stop: loss <= 0.00000")
                return
            else:
                print(" loss >= 0.00000")
        
        if results_n != 0 :
            print(f"{results/results_n=}")
            print(f"{results}")
            if results == results_n and valid_loss <= valid_loss_min:
                print("stop: no errors")
                return
        

In [103]:
# OneHot working if argmax is used instead of softmax (after training)
model = OneHot().to(device)
criterion = nn.CrossEntropyLoss()
train("byte_2_onehot.pt", model, criterion, train_loader_one_hot, valid_loader_one_hot, f=pos_to_int, lrate=0.05)

Epoch: 1 	Training Loss: 1.397571 	Validation Loss: 0.019625
Validation loss decreased (inf --> 0.019625).  Saving model ...
 loss >= 0.00000
results/results_n=0.8344618055555556
19226
Epoch: 2 	Training Loss: 0.206560 	Validation Loss: 0.006939
Validation loss decreased (0.019625 --> 0.006939).  Saving model ...
 loss >= 0.00000
Epoch: 3 	Training Loss: 0.091443 	Validation Loss: 0.003766
Validation loss decreased (0.006939 --> 0.003766).  Saving model ...
 loss >= 0.00000
Epoch: 4 	Training Loss: 0.053380 	Validation Loss: 0.002386
Validation loss decreased (0.003766 --> 0.002386).  Saving model ...
 loss >= 0.00000
Epoch: 5 	Training Loss: 0.035518 	Validation Loss: 0.001641
Validation loss decreased (0.002386 --> 0.001641).  Saving model ...
 loss >= 0.00000
Epoch: 6 	Training Loss: 0.025326 	Validation Loss: 0.001205
Validation loss decreased (0.001641 --> 0.001205).  Saving model ...
 loss >= 0.00000
Epoch: 7 	Training Loss: 0.019036 	Validation Loss: 0.000925
Validation loss dec

In [36]:
model = ReverseOneHot().to(device)
criterion = nn.MSELoss()
train("./onehot_2_byte.pt", model, criterion, train_loader_reverse_one_hot, valid_loader_reverse_one_hot, f=noise_to_int)

Epoch: 1 	Training Loss: 0.111224 	Validation Loss: 0.000199
Validation loss decreased (inf --> 0.000199).  Saving model ...
 loss >= 0.00000
results/results_n=0.6144965277777777
14158
Epoch: 2 	Training Loss: 0.000606 	Validation Loss: 0.000001
Validation loss decreased (0.000199 --> 0.000001).  Saving model ...
 loss >= 0.00000
Epoch: 3 	Training Loss: 0.000005 	Validation Loss: 0.000000
Validation loss decreased (0.000001 --> 0.000000).  Saving model ...
 loss >= 0.00000
Epoch: 4 	Training Loss: 0.000000 	Validation Loss: 0.000000
Validation loss decreased (0.000000 --> 0.000000).  Saving model ...
stop: loss <= 0.00000


In [16]:
model = SboxOneHot().to(device)
model.body[0].weight.data.fill_(0.0)
model.body[0].bias.data.fill_(0.0)
criterion = nn.MSELoss()
train("./sbox_onehot.pt", model, criterion, train_loader_sbox, valid_loader_sbox, f=pos_to_int, lrate=0.0001)

Epoch: 1 	Training Loss: 0.003437 	Validation Loss: 0.000187
Validation loss decreased (inf --> 0.000187).  Saving model ...
 loss >= 0.00000
results/results_n=0.97265625
22410
Epoch: 2 	Training Loss: 0.003284 	Validation Loss: 0.000178
Validation loss decreased (0.000187 --> 0.000178).  Saving model ...
 loss >= 0.00000
Epoch: 3 	Training Loss: 0.003136 	Validation Loss: 0.000170
Validation loss decreased (0.000178 --> 0.000170).  Saving model ...
 loss >= 0.00000
Epoch: 4 	Training Loss: 0.002994 	Validation Loss: 0.000163
Validation loss decreased (0.000170 --> 0.000163).  Saving model ...
 loss >= 0.00000
Epoch: 5 	Training Loss: 0.002857 	Validation Loss: 0.000155
Validation loss decreased (0.000163 --> 0.000155).  Saving model ...
 loss >= 0.00000
Epoch: 6 	Training Loss: 0.002724 	Validation Loss: 0.000148
Validation loss decreased (0.000155 --> 0.000148).  Saving model ...
 loss >= 0.00000
Epoch: 7 	Training Loss: 0.002596 	Validation Loss: 0.000141
Validation loss decreased (

In [100]:
model.eval()
X, Y = next(iter(train_loader_sbox))

# print(X,Y)
results = 0
results_n = 0
O = model(X)
for x,y in zip(O,Y):
    a = pos_to_int(x.cpu().detach().numpy())
    b = pos_to_int(y.cpu().detach().numpy())
    # b = int(y[0]) pos_to_int
    # b = noise_to_int(y)
    print(x,y)
    if a==b:
        
        results +=1
    results_n +=1
print(f"{results/results_n=}")

tensor([-1.2270e-04,  2.5511e-05,  7.4506e-07,  9.5040e-05,  2.1607e-05,
        -7.2420e-05,  4.4703e-07,  5.2691e-05, -8.8602e-05,  1.0490e-05,
         1.8117e-04,  3.6657e-06,  2.5690e-05,  2.3946e-04,  6.0499e-06,
        -1.8400e-04,  1.9234e-04,  1.1805e-04, -2.1875e-05,  1.2270e-04,
        -8.0168e-06,  1.1116e-05,  1.0729e-05,  2.0862e-06,  9.6977e-05,
         1.0043e-05,  2.8759e-05, -1.2279e-05, -9.2864e-05, -1.5020e-05,
         9.1791e-06, -1.0329e-04, -1.5259e-04,  1.7825e-04,  1.3828e-05,
        -3.7551e-06,  7.9989e-05,  7.8082e-05, -1.5298e-04,  2.2313e-04,
        -1.9389e-04,  3.5495e-05,  3.6180e-05,  2.3580e-04, -2.6673e-05,
         2.2429e-04,  6.0201e-06, -1.7881e-06,  2.0891e-05,  1.7285e-04,
         6.5565e-07,  4.3184e-05,  4.8906e-05,  7.8976e-06, -1.8263e-04,
         2.3565e-04, -4.4137e-05,  3.1263e-05,  5.3227e-05,  1.3956e-04,
        -7.1496e-05,  2.0450e-04,  1.0517e-04,  6.3270e-05,  1.6761e-04,
        -8.7947e-05,  6.6161e-06,  1.0690e-04, -2.2

In [17]:
model = Sbox().to(device)
model.body[0].load_state_dict(torch.load("byte_2_onehot.pt"))
model.body[2].load_state_dict(torch.load("sbox_onehot.pt"))
model.body[3].load_state_dict(torch.load("onehot_2_byte.pt"))
torch.save(model.state_dict(), "sbox.pt")

In [18]:
model = Sbox().to(device)
model.load_state_dict(torch.load("sbox.pt"))

<All keys matched successfully>

In [19]:
model.eval()
X, Y = next(iter(train_loader))

# print(X,Y)
results = 0
results_n = 0
O = X
n = 1000
for i in range(n):
    O = model(O)
for x,y in zip(X,O):
    print(x,y)
    a = noise_to_int(x.cpu().detach().numpy())
    b = noise_to_int(y.cpu().detach().numpy())
    # b = int(y[0]) pos_to_int
    # b = noise_to_int(y)
    
    r = a
    for i in range(n):
        r = S[r]
    print(r,b)
    if r==b:
        
        results +=1
    results_n +=1
print(f"it works after {n} sbox")
print(f"{results/results_n=}")
# print(model.state_dict())

tensor([0., 1., 0., 0., 1., 0., 0., 0.], device='cuda:0') tensor([1., 0., 0., 0., 1., 1., 1., 0.], device='cuda:0',
       grad_fn=<UnbindBackward0>)
142 142
tensor([1., 0., 0., 1., 0., 1., 0., 1.], device='cuda:0') tensor([1., 1., 1., 1., 0., 1., 0., 0.], device='cuda:0',
       grad_fn=<UnbindBackward0>)
244 244
tensor([0., 0., 0., 0., 1., 1., 0., 1.], device='cuda:0') tensor([0., 1., 0., 1., 1., 0., 1., 1.], device='cuda:0',
       grad_fn=<UnbindBackward0>)
91 91
tensor([1., 1., 1., 1., 1., 1., 0., 0.], device='cuda:0') tensor([0., 0., 0., 1., 1., 1., 0., 0.], device='cuda:0',
       grad_fn=<UnbindBackward0>)
28 28
tensor([1., 0., 1., 1., 1., 0., 1., 0.], device='cuda:0') tensor([0., 0., 0., 1., 1., 0., 0., 0.], device='cuda:0',
       grad_fn=<UnbindBackward0>)
24 24
tensor([0., 0., 1., 0., 0., 0., 1., 0.], device='cuda:0') tensor([1., 0., 1., 0., 0., 1., 0., 0.], device='cuda:0',
       grad_fn=<UnbindBackward0>)
164 164
tensor([0., 1., 0., 0., 0., 1., 0., 1.], device='cuda:0') 

In [31]:
import torch
state = torch.load("sbox_onehot.pt")

for w in state["body.0.weight"]:
    for i in range(len(w)):
        if w[i] > 0:
            w[i] = 1.0
        else:
            w[i] = 0.0

for i in range(len(state["body.0.bias"])):
    state["body.0.bias"][i] = 0.0
torch.save(state,"sbox_onehot.pt")

In [34]:
import torch
state = torch.load("onehot_2_byte.pt")
for w,b in zip(state['body.0.weight'], state['body.0.bias']):
    a = w+b
    for i,v in enumerate(a):
        if v > 0.9:
            w[i] = 1.0
        else:
            w[i] = 0.0
for i in range(len(state["body.0.bias"])):
    state["body.0.bias"][i] = 0.0
torch.save(state,"onehot_2_byte.pt")

In [67]:
import torch
state = torch.load("byte_2_onehot.pt")

tensor([-1.1058,  0.4907,  1.3939, -0.6935,  0.8656, -0.6492, -0.3948,  1.2459],
       device='cuda:0')

In [131]:
import torch.nn.functional as F
class OneHot16(nn.Module):
    def __init__(self):
        super(OneHot16, self).__init__()
        self.body = nn.Sequential(
            nn.Linear(8*16,256*16),
        )
    
    def forward(self, x):
        o = self.body(x)
        return o

class ReverseOneHot16(nn.Module):
    def __init__(self):
        super(ReverseOneHot16, self).__init__()
        self.body = nn.Sequential(
            nn.Linear(256*16,8*16),
        )
    
    def forward(self, x):
        x = self.body(x)
        return x

class SboxOneHot16(nn.Module):
    def __init__(self):
        super(SboxOneHot16, self).__init__()
        self.body = nn.Sequential(
            nn.Linear(256*16,256*16),
        )
    
    def forward(self, x):
        x = self.body(x)
        return x

class ArgMax16(torch.nn.Module):
    def __init__(self):
        super(ArgMax16,self).__init__()

    def forward(self, x) -> torch.Tensor:
        # pred = torch.argmax(x, dim=1)
        # return torch.zeros_like(x).scatter_(1, pred.unsqueeze(1), 1.)
        
        # find the highest value every 256 values
        pred = torch.argmax(x.view(-1,256), dim=1)
        # print(pred)
        # print(pred.unsqueeze(1))
        return torch.zeros_like(x).view(-1,256).scatter_(1, pred.unsqueeze(1), 1.).view(x.shape)
        

class Sbox16(nn.Module):
    def __init__(self):
        super(Sbox16, self).__init__()
        
        self.body = nn.Sequential(
            OneHot16(),
            ArgMax16(),
            SboxOneHot16(),
            ReverseOneHot16(),
        )
        
    def forward(self, x):
        out = self.body(x)
        return out

In [103]:
model = Sbox().to(device)
model.load_state_dict(torch.load("sbox.pt"))

<All keys matched successfully>

In [114]:
model16 = Sbox16().to(device)
model16.eval()

print(model16.body[0].body[0].weight.shape)
print( model.body[0].body[0].weight.shape)
w = model16.body[0].body[0].weight.clone().fill_(0.0)
b = model16.body[0].body[0].bias.clone().fill_(0.0)
for i in range(16):
    w[256*i:256*(i+1), 8*i:8*(i+1)] = model.body[0].body[0].weight
    b[256*i:256*(i+1)] = model.body[0].body[0].bias
model16.body[0].body[0].weight = nn.Parameter(w.to_sparse())
model16.body[0].body[0].bias = nn.Parameter(b.to_sparse())

print(model16.body[2].body[0].weight.shape)
print( model.body[2].body[0].weight.shape)
w = model16.body[2].body[0].weight.clone().fill_(0.0)
b = model16.body[2].body[0].bias.clone().fill_(0.0)
for i in range(16):
    w[256*i:256*(i+1), 256*i:256*(i+1)] = model.body[2].body[0].weight
    b[256*i:256*(i+1)] = model.body[2].body[0].bias
model16.body[2].body[0].weight = nn.Parameter(w.to_sparse())
model16.body[2].body[0].bias = nn.Parameter(b.to_sparse())

print(model16.body[3].body[0].weight.shape)
print( model.body[3].body[0].weight.shape)
w = model16.body[3].body[0].weight.clone().fill_(0.0)
b = model16.body[3].body[0].bias.clone().fill_(0.0)
for i in range(16):
    w[8*i:8*(i+1), 256*i:256*(i+1)] = model.body[3].body[0].weight
    b[8*i:8*(i+1)] = model.body[3].body[0].bias
model16.body[3].body[0].weight = nn.Parameter(w.to_sparse())
model16.body[3].body[0].bias = nn.Parameter(b.to_sparse())

torch.save(model16.cpu().state_dict(), "sbox16.pt")

torch.Size([4096, 128])
torch.Size([256, 8])
torch.Size([4096, 4096])
torch.Size([256, 256])
torch.Size([128, 4096])
torch.Size([8, 256])


In [132]:
model16 = Sbox16().to(device)
w = torch.load("sbox16.pt")
for k in w:
    w[k] = w[k].to_dense()
model16.load_state_dict(w)

<All keys matched successfully>

In [133]:
model16.eval()
X, Y = next(iter(train_loader))

X = X[:16*6]
Y = Y[:16*6]
X16 = X.view(-1,8*16)
Y16 = Y.view(-1,8*16)

# print(X,Y)
results = 0
results_n = 0
O16 = X16
n = 1000
for i in range(n):
    O16 = model16(O16)
for x,y in zip(X16,O16):
    # print(x)
    # print(y)
    a = x.cpu().detach().numpy()
    b = y.cpu().detach().numpy()
    a = [noise_to_int(a[i*8:(i+1)*8]) for i in range(16)]
    b = [noise_to_int(b[i*8:(i+1)*8]) for i in range(16)]
    # b = int(y[0]) pos_to_int
    # # b = noise_to_int(y)
    # print(a)
    # print(b)
    r = a
    for i in range(n):
        for j in range(16):
            r[j] = S[r[j]]
    print(r,b)
    if r==b:
        
        results +=1
    results_n +=1
print(f"it works after {n} sbox")
print(f"{results/results_n=}")
# print(model.state_dict())

[160, 191, 238, 180, 13, 70, 32, 55, 49, 86, 222, 16, 154, 4, 200, 125] [160, 191, 238, 180, 13, 70, 32, 55, 49, 86, 222, 16, 154, 4, 200, 125]
[86, 142, 185, 148, 141, 192, 77, 101, 97, 106, 231, 71, 100, 249, 95, 218] [86, 142, 185, 148, 141, 192, 77, 101, 97, 106, 231, 71, 100, 249, 95, 218]
[1, 219, 204, 112, 188, 139, 251, 88, 217, 38, 129, 32, 123, 197, 238, 194] [1, 219, 204, 112, 188, 139, 251, 88, 217, 38, 129, 32, 123, 197, 238, 194]
[100, 65, 34, 170, 210, 77, 158, 61, 21, 236, 208, 172, 197, 167, 197, 202] [100, 65, 34, 170, 210, 77, 158, 61, 21, 236, 208, 172, 197, 167, 197, 202]
[248, 5, 218, 112, 38, 182, 12, 207, 253, 131, 34, 222, 142, 48, 207, 186] [248, 5, 218, 112, 38, 182, 12, 207, 253, 131, 34, 222, 142, 48, 207, 186]
[224, 202, 11, 238, 60, 121, 147, 95, 79, 2, 116, 202, 96, 92, 183, 12] [224, 202, 11, 238, 60, 121, 147, 95, 79, 2, 116, 202, 96, 92, 183, 12]
it works after 1000 sbox
results/results_n=1.0


In [49]:

X, Y = next(iter(train_loader_one_hot))

X = X[:16]
Y = Y[:16]
X16 = X.view(-1,8*16)
O16 = model16.body[0](X16)

O = model.body[0](X)
print(O.shape)
# print(O[0])
for x,y,o in zip(X16,Y,O16):
    print(x.shape)
    print(y.shape)
    print(o.shape)
    # print(o[:256])
    for i in range(16):
        if all(O[i] != o[i*256:(i+1)*256]):
            print("error")
            print(O[i])
            print(o[i*256:(i+1)*256])
            break
    

torch.Size([16, 256])
torch.Size([128])
torch.Size([256])
torch.Size([4096])


In [60]:

X, Y = next(iter(train_loader_sbox))
print(X.shape)
print(Y.shape)
X = X[:16]
Y = Y[:16]
X16 = X.view(-1,256*16)
O16 = model16.body[2](X16)

O = model.body[2](X)
print(O.shape)
print(O[0])
print(O16[0][:256])
for x,y,o in zip(X16,Y,O16):
    print(x.shape)
    print(y.shape)
    print(o.shape)
    
    for i in range(16):
        if all(O[i] != o[i*256:(i+1)*256]):
            print("error")
            print(O[i])
            print(o[i*256:(i+1)*256])
            break
    

torch.Size([100, 256])
torch.Size([100, 256])
torch.Size([16, 256])
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [63]:

X, Y = next(iter(train_loader_reverse_one_hot))
print(X.shape)
print(Y.shape)
X = X[:16]
Y = Y[:16]
X16 = X.view(-1,256*16)
O16 = model16.body[3](X16)

O = model.body[3](X)
print(O.shape)
print(O[0])
print(O16[0][:8])
for x,y,o in zip(X16,Y,O16):
    print(x.shape)
    print(y.shape)
    print(o.shape)
    
    for i in range(16):
        if all(O[i] != o[i*8:(i+1)*8]):
            print("error")
            print(O[i])
            print(o[i*8:(i+1)*8])
            break
    

torch.Size([100, 256])
torch.Size([100, 8])
torch.Size([16, 8])
tensor([0., 1., 1., 0., 1., 1., 1., 1.], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([0., 1., 1., 0., 1., 1., 1., 1.], device='cuda:0',
       grad_fn=<SliceBackward0>)
torch.Size([4096])
torch.Size([8])
torch.Size([128])
