In [316]:
import torch
import torchvision
import numpy as np
import copy

## Test with ResNet-18

In [317]:
def permute_weights_layer1(model: torchvision.models.resnet.ResNet, perm: torch.Tensor) -> torchvision.models.resnet.ResNet: 
    with torch.no_grad():
        permuted_model = copy.deepcopy(model)
        
        permuted_model.conv1.weight = torch.nn.Parameter(model.conv1.weight[perm])
        
        permuted_model.bn1.weight = torch.nn.Parameter(model.bn1.weight[perm])
        permuted_model.bn1.bias = torch.nn.Parameter(model.bn1.bias[perm])
        
        permuted_model.bn1.running_mean = model.bn1.running_mean[perm]
        permuted_model.bn1.running_var = model.bn1.running_var[perm]

        permuted_model.layer1[0].conv1.weight =  torch.nn.Parameter(model.layer1[0].conv1.weight.transpose(0,1)[perm].transpose(0,1))
    return permuted_model

In [318]:
perm = torch.tensor(np.random.permutation(np.linspace(0,63,64)),dtype=torch.int)

reverse_perm = [0] * len(perm)
for i, p in enumerate(perm):
    reverse_perm[p] = i

randimg = torch.rand(1,3,128,128)

**Permutations**

Reverse_perm[i] := The index where the i-th original filter will now be in the permuted model

In [319]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
permuted_model = permute_weights_layer1(model, perm)

Using cache found in /root/.cache/torch/hub/pytorch_vision_v0.10.0


In [320]:
# Both outputs should be of dimension 1x64x64x64, which stands for (batch size)x(channels)x(height)x(width)
model_first_output = model.conv1(randimg)
perm_model_first_output = permuted_model.conv1(randimg)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(model_first_output[0][idx][0], perm_model_first_output[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.conv1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.conv1 is correct!


In [321]:
bn1_output_original = model.bn1(model_first_output)
bn1_output_perm = permuted_model.bn1(perm_model_first_output)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(bn1_output_original[0][idx][0], bn1_output_perm[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.bn1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.bn1 is correct!


In [322]:
relu_output_original = model.relu(bn1_output_original)
relu_output_perm = permuted_model.relu(bn1_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(relu_output_original[0][idx][0], relu_output_perm[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.relu is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.relu is correct!


In [323]:
maxpool_output_original = model.maxpool(relu_output_original)
maxpool_output_perm = permuted_model.maxpool(relu_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(maxpool_output_original[0][idx][0], maxpool_output_perm[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.maxpool is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.maxpool is correct!


In [324]:
layer1_0_conv1_output_original = model.layer1[0].conv1(maxpool_output_original)
layer1_0_conv1_output_perm = permuted_model.layer1[0].conv1(maxpool_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_conv1_output_original[0][idx][0], layer1_0_conv1_output_perm[0][idx][0],atol=1e-06))

if all_close:
    print("The output after model.layer1[0].conv1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.layer1[0].conv1 is correct!


In [325]:
layer1_0_bn1_output_original = model.layer1[0].bn1(layer1_0_conv1_output_original)
layer1_0_bn1_output_perm = permuted_model.layer1[0].bn1(layer1_0_conv1_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_bn1_output_original[0][idx][0], layer1_0_bn1_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].bn1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after layer1[0].bn1 is correct!


In [326]:
layer1_0_relu_output_original = model.layer1[0].relu(layer1_0_bn1_output_original)
layer1_0_relu_output_perm = permuted_model.layer1[0].relu(layer1_0_bn1_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_relu_output_original[0][idx][0], layer1_0_relu_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].relu is correct!")
else:
    print("!!!INCORRECT!!!")

The output after layer1[0].relu is correct!


In [327]:
layer1_0_conv2_output_original = model.layer1[0].conv2(layer1_0_relu_output_original)
layer1_0_conv2_output_perm = permuted_model.layer1[0].conv2(layer1_0_relu_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_conv2_output_original[0][idx][0], layer1_0_conv2_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].conv2 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after layer1[0].conv2 is correct!


In [328]:
layer1_0_bn2_output_original = model.layer1[0].bn2(layer1_0_conv2_output_original)
layer1_0_bn2_output_perm = permuted_model.layer1[0].bn2(layer1_0_conv2_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_bn2_output_original[0][idx][0], layer1_0_bn2_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].bn2 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after layer1[0].bn2 is correct!


In [329]:
layer1_1_conv1_output_original = model.layer1[1].conv1(layer1_0_bn2_output_original)
layer1_1_conv1_output_perm = permuted_model.layer1[1].conv1(layer1_0_bn2_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_1_conv1_output_original[0][idx][0], layer1_1_conv1_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[1].conv1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after layer1[1].conv1 is correct!


In [None]:
model
