In [2]:
import torch
import torchvision
import numpy as np
import copy
import torch.nn as nn
import torch.nn.functional as F

In [223]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [224]:
# Define the CNN model in PyTorch
class MyCNNModel(nn.Module):
    def __init__(self, no_classes):
        super(MyCNNModel, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(3, 3)) # output shape: 
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(3, 3))
        self.conv3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(1024, 256, bias=False)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, no_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def generate_model(output_dim=100):
    model = MyCNNModel(output_dim)
    return model.to(device)

In [245]:
def permute_bias(model_bias, permuted_model_bias, perm):
    permuted_model_bias.weight = torch.nn.Parameter(model_bias.weight[perm])
    permuted_model_bias.bias = torch.nn.Parameter(model_bias.bias[perm])
    permuted_model_bias.running_mean = model_bias.running_mean[perm]
    permuted_model_bias.running_var = model_bias.running_var[perm]

def permute_weights(model, perm1, perm2, perm3) -> torchvision.models.resnet.ResNet:
    with torch.no_grad():
            permuted_model = copy.deepcopy(model)
            permuted_model.conv1.weight = torch.nn.Parameter(model.conv1.weight[perm1])
            permuted_model.conv1.bias = torch.nn.Parameter(model.conv1.bias[perm1])
            permuted_model.conv2.weight =  torch.nn.Parameter(model.conv2.weight.transpose(0,1)[perm1].transpose(0,1))

            permuted_model.conv2.weight = torch.nn.Parameter(permuted_model.conv2.weight[perm2])
            permuted_model.conv2.bias = torch.nn.Parameter(model.conv2.bias[perm2])
            permuted_model.conv3.weight =  torch.nn.Parameter(model.conv3.weight.transpose(0,1)[perm2].transpose(0,1))

            permuted_model.conv3.weight = torch.nn.Parameter(permuted_model.conv3.weight[perm3])
            permuted_model.conv3.bias = torch.nn.Parameter(model.conv3.bias[perm3])
            permuted_model.fc1.weight =  torch.nn.Parameter(model.fc1.weight[perm3])
    return permuted_model

ValueError: only one element tensors can be converted to Python scalars

In [286]:
model.fc1(model.flatten(model_thi_output))[0]

tensor([-0.0198,  0.0176,  0.0182,  0.0076,  0.0183,  0.0830, -0.0218, -0.0199,
        -0.0614,  0.0099,  0.0213,  0.0527,  0.0329,  0.0260,  0.0368,  0.0735,
        -0.0700,  0.0067,  0.0257,  0.0366, -0.0665,  0.0026,  0.1077,  0.0221,
        -0.0551, -0.0046, -0.0401,  0.0270,  0.0398,  0.0040,  0.0764,  0.0195,
         0.0727,  0.0070, -0.0913,  0.0205, -0.0157, -0.0561, -0.0043,  0.0315,
         0.0511,  0.0232,  0.0493,  0.0084,  0.0404,  0.0917,  0.0832, -0.0080,
        -0.0139, -0.0806,  0.0046,  0.0782,  0.0227,  0.0231,  0.0360,  0.0291,
        -0.0110,  0.0085, -0.0171, -0.0080, -0.0111,  0.0020, -0.0211, -0.1468,
        -0.0414, -0.0335,  0.0122,  0.0304, -0.0750, -0.0682, -0.0460,  0.0578,
         0.0180, -0.0765, -0.0547,  0.0265,  0.0140, -0.0045, -0.0731, -0.0657,
         0.0041, -0.0095,  0.0324,  0.0097,  0.0364,  0.0195,  0.0292, -0.0109,
        -0.0317, -0.0445, -0.0362,  0.0013,  0.0646,  0.0355,  0.0035, -0.0554,
        -0.0697,  0.0055,  0.0069, -0.00

In [290]:
model_thi_output[0][perm3]

tensor([[[0.0221, 0.0131],
         [0.0125, 0.0159]],

        [[0.0345, 0.0462],
         [0.0547, 0.0554]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000]],

        ...,

        [[0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000]],

        [[0.0000, 0.0000],
         [0.0000, 0.0000]]], grad_fn=<IndexBackward0>)

In [246]:
print(model.conv1.weight.shape)
print(model.conv2.weight.shape)
print(model.conv3.weight.shape)
print(model.fc1.weight.shape)

torch.Size([64, 3, 3, 3])
torch.Size([128, 64, 3, 3])
torch.Size([256, 128, 3, 3])
torch.Size([256, 1024])


In [247]:
randimg = torch.rand(1,3,32,32)

In [270]:
perm1 = torch.tensor(np.random.permutation(np.linspace(0,63,64)),dtype=torch.int)
perm2 = torch.tensor(np.random.permutation(np.linspace(0,127,128)),dtype=torch.int)
#perm3 = torch.tensor(np.random.permutation(np.linspace(0,255,256)),dtype=torch.int)
perm3 = torch.tensor(np.linspace(0,255,256),dtype=torch.int)

In [271]:
reverse_perm1 = [0] * len(perm1)
for i, p in enumerate(perm1):
    reverse_perm1[p] = i
    
reverse_perm2 = [0] * len(perm2)
for i, p in enumerate(perm2):
    reverse_perm2[p] = i

reverse_perm3 = [0] * len(perm3)
for i, p in enumerate(perm3):
    reverse_perm3[p] = i

In [272]:
permuted_model = permute_weights(model, perm1, perm2, perm3)

In [293]:
# Both outputs should be of dimension 1x64x64x64, which stands for (batch size)x(channels)x(height)x(width)
model_first_output = model.pool(F.relu(model.conv1(randimg)))
perm_model_first_output = model.pool(F.relu(permuted_model.conv1(randimg)))

all_close = True
for (idx, perm) in enumerate(reverse_perm1):
    all_close &= torch.all(torch.isclose(model_first_output[0][idx][0], perm_model_first_output[0][reverse_perm1[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.conv1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.conv1 is correct!


In [294]:
# Both outputs should be of dimension 1x64x64x64, which stands for (batch size)x(channels)x(height)x(width)
model_sec_output = model.pool(F.relu(model.conv2(model_first_output)))
perm_model_sec_output = model.pool(F.relu(permuted_model.conv2(model_first_output)))

all_close = True
for (idx, perm) in enumerate(reverse_perm2):
    all_close &= torch.all(torch.isclose(model_sec_output[0][idx][0], perm_model_sec_output[0][reverse_perm2[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.conv2 is correct!")
else:
    print("!!!INCORRECT!!!")

!!!INCORRECT!!!


In [275]:
# Both outputs should be of dimension 1x64x64x64, which stands for (batch size)x(channels)x(height)x(width)
model_thi_output = model.pool(F.relu(model.conv3(model_sec_output)))
perm_model_thi_output = model.pool(F.relu(permuted_model.conv3(perm_model_sec_output)))

all_close = True
for (idx, perm) in enumerate(reverse_perm3):
    all_close &= torch.all(torch.isclose(model_thi_output[0][idx][0], perm_model_thi_output[0][reverse_perm3[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.conv3 is correct!")
else:
    print("!!!INCORRECT!!!")

!!!INCORRECT!!!


In [300]:
model_sec_output[0][perm2[0]][0]

tensor([0.0422, 0.0565, 0.0798, 0.0418, 0.1041, 0.1389],
       grad_fn=<SelectBackward0>)

In [297]:
perm_model_sec_output[0][0][0]

tensor([0.3498, 0.3379, 0.4253, 0.4661, 0.4123, 0.3815],
       grad_fn=<SelectBackward0>)

In [298]:
model_sec_output.shape

torch.Size([1, 128, 6, 6])

In [299]:
perm_model_sec_output.shape

torch.Size([1, 128, 6, 6])

In [280]:
model(randimg)

tensor([[-5.7308e-02,  6.8467e-02,  4.9798e-02,  2.1670e-03,  2.5407e-02,
         -5.3186e-02, -4.7419e-03,  8.4272e-02,  9.2971e-02,  3.1007e-02,
         -1.1453e-01, -7.8077e-03, -3.2241e-03, -7.9952e-02, -5.4270e-02,
         -4.3620e-02, -3.9943e-02, -1.0245e-01, -9.0447e-02, -3.1020e-02,
          7.3450e-02,  3.5338e-02, -9.3056e-02,  5.6870e-03, -5.1347e-02,
         -7.4622e-02, -4.2496e-02,  2.7777e-02,  5.4548e-02,  1.6968e-02,
          5.4765e-02, -3.7355e-02, -9.5297e-02,  6.2199e-02,  1.5989e-02,
          3.2986e-02, -6.5759e-02,  5.6070e-02, -6.6707e-02,  2.8574e-02,
          1.3708e-02, -2.5447e-02,  8.9721e-05, -5.6360e-02, -7.3957e-02,
          6.7465e-02,  2.3382e-02,  7.4468e-03,  3.7261e-02,  2.8554e-02,
         -1.3031e-02,  1.1721e-02,  6.2530e-02,  7.3760e-02, -5.7035e-03,
          3.2922e-02,  4.7257e-02,  4.4316e-02, -8.0760e-02,  2.1118e-02,
          3.1636e-03,  1.0655e-01,  3.2337e-03, -8.4182e-03, -7.0871e-02,
          5.3775e-02,  3.3382e-02,  7.

In [281]:
permuted_model(randimg)

tensor([[-5.7308e-02,  6.8467e-02,  4.9798e-02,  2.1670e-03,  2.5407e-02,
         -5.3186e-02, -4.7419e-03,  8.4272e-02,  9.2971e-02,  3.1007e-02,
         -1.1453e-01, -7.8077e-03, -3.2241e-03, -7.9952e-02, -5.4270e-02,
         -4.3620e-02, -3.9943e-02, -1.0245e-01, -9.0447e-02, -3.1020e-02,
          7.3450e-02,  3.5338e-02, -9.3056e-02,  5.6870e-03, -5.1347e-02,
         -7.4622e-02, -4.2496e-02,  2.7777e-02,  5.4548e-02,  1.6968e-02,
          5.4765e-02, -3.7355e-02, -9.5297e-02,  6.2199e-02,  1.5989e-02,
          3.2986e-02, -6.5759e-02,  5.6070e-02, -6.6707e-02,  2.8574e-02,
          1.3708e-02, -2.5447e-02,  8.9720e-05, -5.6360e-02, -7.3957e-02,
          6.7465e-02,  2.3382e-02,  7.4468e-03,  3.7261e-02,  2.8554e-02,
         -1.3031e-02,  1.1721e-02,  6.2530e-02,  7.3760e-02, -5.7035e-03,
          3.2922e-02,  4.7257e-02,  4.4316e-02, -8.0760e-02,  2.1118e-02,
          3.1636e-03,  1.0655e-01,  3.2337e-03, -8.4182e-03, -7.0871e-02,
          5.3775e-02,  3.3382e-02,  7.