In [1]:
import torch
from msc_project.circuits_custom.custom_stepmlp import GACompatibleStepMLP 
from circuits.examples.keccak import Keccak
from circuits.utils.format import format_msg, Bits
from msc_project.utils.model_utils import get_mlp_layers, process_mlp_layers, unfold_stepmlp_parameters
import matplotlib.pyplot as plt
from msc_project.utils.sampling import WeightSampler
import numpy as np
from msc_project.circuits_custom.custom_logic_gates import get_random_identity_params
import copy

  from .autonotebook import tqdm as notebook_tqdm


In [24]:
target_weights = torch.randn((2000000))
target_biases = torch.distributions.beta.Beta(torch.tensor([0.5]), torch.tensor([0.5])).sample((120830,))

In [25]:
weight_sampler = WeightSampler(target_weights)
bias_sampler = WeightSampler(target_biases)

In [2]:
from msc_project.circuits_custom.custom_stepmlp import MajorityVotingStepMLP


keccak = Keccak(c=20, log_w=1, n=3)
trigger_bits = format_msg("Test", keccak.msg_len)
payload_bits = format_msg("tseT", keccak.d)

mlp = MajorityVotingStepMLP.create_with_backdoor(
    trigger=trigger_bits.bitlist, payload=payload_bits.bitlist, k=keccak
)

[Signal(True), Signal(False), Signal(False), Signal(True), Signal(True), Signal(True), Signal(True), Signal(True), Signal(True), Signal(False)]
[Signal(True), Signal(False), Signal(False), Signal(True), Signal(True), Signal(True), Signal(True), Signal(True), Signal(True), Signal(False)]
[Signal(True), Signal(False), Signal(False), Signal(True), Signal(True), Signal(True), Signal(True), Signal(True), Signal(True), Signal(False)]
Signal(True)
Signal(False)
Signal(False)
Signal(True)
Signal(True)
Signal(True)
Signal(True)
Signal(True)
Signal(True)
Signal(False)


In [3]:
output = mlp.infer_bits(trigger_bits)

In [4]:
output.bitstr

'0111010001'

In [27]:
test_mlp = copy.deepcopy(mlp)

In [3]:
payload_bits.bitstr

'0111010001'

In [5]:
params = []

for name, layer in mlp.named_parameters():
    layer_b1 = layer[0]
    b1_bias, b1_weight = get_random_identity_params(bias_sampler)
    
    b1_random = torch.cat([torch.tensor([b1_bias]), torch.ones(layer_b1[1:].shape)*b1_weight]).unsqueeze(dim=0)

    layer_params = layer[1:]
    b2 = layer_params[:,0]
    w = layer_params[:,1:]

    factors = torch.div(b2,w.sum(dim=1))
    samples = weight_sampler.sample(w.numel(), "positive").reshape(w.shape)

    w_sampled = torch.mul(w, samples)
    b2_sampled = torch.mul(w_sampled.sum(dim=1), factors).unsqueeze(dim=1)
    
    layer_params_sampled = torch.cat([b2_sampled, w_sampled], dim=1)
    layer_full = torch.cat([b1_random, layer_params_sampled])

    assert layer.shape == layer_full.shape, "Error: Sampled layer's dimensions do not match original layer's dimensions" 
    
    params.append(layer_full)

In [None]:
def get_n_smallest(x: torch.Tensor, n: int, dim: int):
    x_sorted = x.sort(dim=dim).values
    return x_sorted[:,:n]

In [45]:
params = []

for name, layer in mlp.named_parameters():
    layer_b1 = layer[0]
    b1_bias, b1_weight = get_random_identity_params(bias_sampler)
    
    b1_random = torch.cat([torch.tensor([b1_bias]), torch.ones(layer_b1[1:].shape)*b1_weight]).unsqueeze(dim=0)

    layer_params = layer[1:]
    b2 = layer_params[:,0]
    w = layer_params[:,1:]

    factors = torch.div(b2,w.sum(dim=1))
    samples = weight_sampler.sample(w.numel(), "positive").reshape(w.shape)

    w_sampled = torch.mul(w, samples)
    w_sampled_masked = w_sampled.clone()
    w_sampled_masked[w_sampled == 0] = float('inf')

    wt = torch.tril(w_sampled_masked.sort(dim=1).values)
    wt = torch.nan_to_num(wt, posinf=0.0)
    print(wt.sum(dim=1))
    
    # layer_params_sampled = torch.cat([b_sampled, w_sampled], dim=1)
    # layer_full = torch.cat([b1_random, layer_params_sampled])

    # assert layer.shape == layer_full.shape, "Error: Sampled layer's dimensions do not match original layer's dimensions" 
    
    # params.append(layer_full)
    break

tensor([0.0861, 0.3776, 2.8373, 3.7137, 4.9184, 5.1289, 3.2302, 3.2678, 3.5687, 2.1458, 2.9978, 3.3276, 3.0889, 1.7264, 3.1062, 4.7389, 1.0113, 5.0251, 3.9355, 6.5283, 5.2090, 3.3166, 5.6699, 5.2519,
        4.1541, 2.7353, 4.3019, 5.8290, 2.5306, 5.1876, 4.2310, 6.9247, 3.3503, 4.5398, 2.6661, 2.2092, 3.9254, 2.1521, 1.8793, 1.3884, 4.2906, 3.1451, 2.6087, 3.4134, 3.8274, 3.8444, 4.7023, 3.7448,
        4.4187, 4.5363, 3.8038, 4.8020, 4.9290, 3.1905, 3.9141, 2.7684, 3.2714, 3.8056, 4.7303, 2.1003, 2.6373, 4.7250, 6.9807, 4.4985, 2.6309, 3.2276, 6.7294, 4.2098, 2.2391, 1.1311, 2.8182, 5.0747,
        4.9494, 3.0356, 4.2056, 3.9931, 1.5302, 3.1435, 3.6949, 5.1238, 4.8052, 3.7671, 2.8871, 3.3577, 6.2639, 2.3011, 2.2814, 4.9371, 2.4291, 2.4456, 1.8628, 4.8519, 4.1191, 3.2753, 3.8760, 2.8176,
        3.3432, 2.1798, 1.7204, 4.5725, 2.9786, 2.3466, 1.7928, 3.5837, 3.6368, 1.6492, 4.2025, 3.1356, 2.3404, 2.5225, 2.4941, 4.0805, 5.0633, 4.4411, 2.1419, 1.1637, 5.4871, 1.7875, 4.4500, 4.6830,


In [36]:
r = torch.randn((3,4))
print(r)
print(get_n_smallest(r,2,1))

tensor([[-1.6578, -0.9660,  1.4770,  0.5617],
        [-0.9108,  1.5746,  0.3590, -0.5274],
        [ 0.1928,  0.2804,  0.2340,  0.7325]])
tensor([[-1.6578, -0.9660],
        [-0.9108, -0.5274],
        [ 0.1928,  0.2340]])


In [34]:
for name, layer in mlp.named_parameters():

    layer_b1 = layer[0]
    layer_params = layer[1:]

    torch.set_printoptions(linewidth=200)

    print(layer_b1)
    print(layer_params)
    break

tensor([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=torch.float64, grad_fn=<SelectBackward0>)
tensor([[ -1.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.],
        [ -2.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.],
        [ -3.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.],
        [ -4.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.],
        [ -5.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.],
        [ -6.,   0.,   0.,   0.,   1.,   0.,   0.,   1.,   0.,   0.,   0.,   0.,   0.,   0.,   1.,   0.,   0., 

In [29]:
test_mlp.load_params(params)

In [30]:
x = trigger_bits
x = Bits("1") + x
x_tensor = torch.tensor(x.ints, dtype=mlp.dtype)
input = x_tensor
for i, (layer1, layer2) in enumerate(zip(mlp.net, test_mlp.net)):
    
    output1 = mlp._step_fn(layer1(input))
    output2 = test_mlp._step_fn(layer2(input))
    input = output1

    print(f"Layer {i} produces correct output: {torch.allclose(output1, output2)}")


Layer 0 produces correct output: False
Layer 1 produces correct output: False
Layer 2 produces correct output: False
Layer 3 produces correct output: True
Layer 4 produces correct output: False
Layer 5 produces correct output: True
Layer 6 produces correct output: False
Layer 7 produces correct output: False
Layer 8 produces correct output: False
Layer 9 produces correct output: False
Layer 10 produces correct output: False
Layer 11 produces correct output: True
Layer 12 produces correct output: False
Layer 13 produces correct output: False
Layer 14 produces correct output: False
Layer 15 produces correct output: False
Layer 16 produces correct output: False
Layer 17 produces correct output: True
Layer 18 produces correct output: True
Layer 19 produces correct output: False


In [9]:
payload_bits.bitstr

'0111010001'

In [6]:
params = []
for name, layer in mlp.named_parameters():
    mlp_bias = layer[0]
    mlp_weight = layer[1:]

    bias_sample = bias_sampler.sample(num_samples = mlp_bias.numel(), sign = "any").reshape(mlp_bias.shape).unsqueeze(dim=0)
    weight_sample = weight_sampler.sample(num_samples=mlp_weight.numel(), sign = "any").reshape(mlp_weight.shape)
    param = torch.cat([torch.mul(mlp_bias,bias_sample), torch.mul(mlp_weight,weight_sample)], dim=0)
    params.append(param)


In [7]:
mlp.load_params(params)

In [8]:
mlp.state_dict()

OrderedDict([('net.0.weight',
              tensor([[ 0.2929,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [-0.9950,  0.0000,  0.0000,  ...,  0.0000, -0.0000,  0.0000],
                      [ 4.1476, -0.0000, -0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      ...,
                      [-4.6071,  0.0000,  0.1201,  ..., -0.0000, -0.0000, -0.4930],
                      [ 5.7146,  0.0000, -1.1477,  ..., -0.0000,  0.0000, -0.1114],
                      [-6.8809,  0.0000, -1.8222,  ...,  0.0000, -0.0000,  1.0975]],
                     dtype=torch.float64)),
             ('net.1.weight',
              tensor([[ 0.8313,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
                      [ 0.9932,  1.5228,  2.2898,  ...,  0.0000,  0.0000,  0.0000],
                      [-0.1439, -0.0000, -0.0000,  ...,  0.0000, -0.0000, -0.0000],
                      ...,
                      [-0.3565, -0.0000,  0.0000,  ...,  0.0000,  0.0000, -0.0000],
 