In [1]:
import torch
import torchvision
import numpy as np
import copy
import sys
from Custom_ResNet18 import *

## Test with ResNet-18

In [174]:
def permute_weights_layer1(model: torchvision.models.resnet.ResNet, perm: torch.Tensor) -> torchvision.models.resnet.ResNet: 
    with torch.no_grad():
        permuted_model = copy.deepcopy(model)
        permuted_model.conv1.weight = torch.nn.Parameter(model.conv1.weight[perm])
        permute_bias(model.bn1, permuted_model.bn1, perm)
        permuted_model.layer1[0].conv1.weight =  torch.nn.Parameter(model.layer1[0].conv1.weight.transpose(0,1)[perm].transpose(0,1))
    return permuted_model

def permute_weights_first_block(model: torchvision.models.resnet.ResNet, perm: torch.Tensor) -> torchvision.models.resnet.ResNet:
    with torch.no_grad():
            permuted_model = copy.deepcopy(model)
            permuted_model.layer1[0].conv1.weight = torch.nn.Parameter(model.layer1[0].conv1.weight[perm])
            permute_bias(model.layer1[0].bn1, permuted_model.layer1[0].bn1, perm)
            permuted_model.layer1[0].conv2.weight =  torch.nn.Parameter(model.layer1[0].conv2.weight.transpose(0,1)[perm].transpose(0,1))
    return permuted_model

#perm = torch.tensor(np.linspace(0,model_layer.weight.shape[0]-1,model_layer.weight.shape[0]),dtype=torch.int)

In [197]:
def permute_bias(model_bias, permuted_model_bias, perm):
    permuted_model_bias.weight = torch.nn.Parameter(model_bias.weight[perm])
    permuted_model_bias.bias = torch.nn.Parameter(model_bias.bias[perm])
    permuted_model_bias.running_mean = model_bias.running_mean[perm]
    permuted_model_bias.running_var = model_bias.running_var[perm]

def permute_layer(model_layer, permuted_layer, model_layer_next, permuted_layer_next, model_bias, permuted_bias, perm = None):
    if not perm is None: 
        permuted_layer.weight = torch.nn.Parameter(model_layer.weight[perm])
        permute_bias(model_bias, permuted_bias, perm)
        permuted_layer_next.weight =  torch.nn.Parameter(model_layer_next.weight.transpose(0,1)[perm].transpose(0,1))

def permute_weights(model, perm1 = None, perm2 = None, perm3 = None, perm4 = None):
     with torch.no_grad():
        permuted_model = copy.deepcopy(model)
        # Permute conv1
        permute_layer(model.conv1,          permuted_model.conv1,           model.layer1[0].conv1, permuted_model.layer1[0].conv1, model.bn1,           permuted_model.bn1,           perm1)
        # Permute layer[0].conv1
        permute_layer(model.layer1[0].conv1,permuted_model.layer1[0].conv1, model.layer1[0].conv2, permuted_model.layer1[0].conv2, model.layer1[0].bn1, permuted_model.layer1[0].bn1, perm2)
        # Permute layer1[0].conv2
        permute_layer(model.layer1[0].conv2,permuted_model.layer1[0].conv2, model.layer1[1].conv1, permuted_model.layer1[1].conv1, model.layer1[0].bn2, permuted_model.layer1[0].bn2, perm3)
        # Permute layer1[1].conv1
        permute_layer(model.layer1[1].conv1,permuted_model.layer1[1].conv1, model.layer1[1].conv2, permuted_model.layer1[1].conv2, model.layer1[1].bn1, permuted_model.layer1[1].bn1, perm4)
        return permuted_model

In [228]:
perm = torch.tensor(np.random.permutation(np.linspace(0,63,64)),dtype=torch.int)
reverse_perm = [0] * len(perm)
for i, p in enumerate(perm):
    reverse_perm[p] = i

randimg = torch.rand(1,3,128,128)

**Permutations**

Reverse_perm[i] := The index where the i-th original filter will now be in the permuted model

In [229]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
#permuted_model = permute_weights_layer1(model, perm)
permuted_model = permute_weights_first_block(model, perm)

Using cache found in /Users/ldiazbone/.cache/torch/hub/pytorch_vision_v0.10.0


In [230]:
perm1 = torch.tensor(np.random.permutation(np.linspace(0,63,64)),dtype=torch.int)
perm2 = torch.tensor(np.random.permutation(np.linspace(0,63,64)),dtype=torch.int)
perm3 = torch.tensor(np.random.permutation(np.linspace(0,63,64)),dtype=torch.int)
perm4 = torch.tensor(np.random.permutation(np.linspace(0,63,64)),dtype=torch.int)

In [255]:
model_cust = custom_resnet_18()
permuted_model_cust = permute_weights(model_cust, None, perm2, perm3, None)

In [256]:
model_cust(randimg)

tensor([[ 0.3764,  0.4467, -0.0445,  0.5630, -0.0643, -0.7974, -0.4723,  0.5518,
          0.0972,  0.3992]], grad_fn=<AddmmBackward0>)

In [257]:
permuted_model_cust(randimg)

tensor([[ 0.2588,  0.5331, -0.0137,  0.4958,  0.0197, -0.7252, -0.4046,  0.5354,
          0.1291,  0.4933]], grad_fn=<AddmmBackward0>)

In [115]:
# Both outputs should be of dimension 1x64x64x64, which stands for (batch size)x(channels)x(height)x(width)
model_first_output = model.conv1(randimg)
perm_model_first_output = permuted_model.conv1(randimg)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(model_first_output[0][idx][0], perm_model_first_output[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.conv1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.conv1 is correct!


In [116]:
bn1_output_original = model.bn1(model_first_output)
bn1_output_perm = permuted_model.bn1(perm_model_first_output)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(bn1_output_original[0][idx][0], bn1_output_perm[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.bn1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.bn1 is correct!


In [117]:
relu_output_original = model.relu(bn1_output_original)
relu_output_perm = permuted_model.relu(bn1_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(relu_output_original[0][idx][0], relu_output_perm[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.relu is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.relu is correct!


In [118]:
maxpool_output_original = model.maxpool(relu_output_original)
maxpool_output_perm = permuted_model.maxpool(relu_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(maxpool_output_original[0][idx][0], maxpool_output_perm[0][reverse_perm[idx]][0],atol=1e-06))

if all_close:
    print("The output after model.maxpool is correct!")
else:
    print("!!!INCORRECT!!!")

The output after model.maxpool is correct!


In [119]:
layer1_output_original = model.layer1[0](maxpool_output_original)
layer1_output_perm = permuted_model.layer1[0](maxpool_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_output_original[0][idx][0], layer1_output_perm[0][idx][0],atol=1e-06))

if all_close:
    print("The output after layer1 is correct!")
else:
    print("!!!INCORRECT!!!")

!!!INCORRECT!!!


In [120]:
layer1_0_conv1_output_original = model.layer1[0].conv1(maxpool_output_original)
layer1_0_conv1_output_perm = permuted_model.layer1[0].conv1(maxpool_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_conv1_output_original[0][idx][0], layer1_0_conv1_output_perm[0][idx][0],atol=1e-06))

if all_close:
    print("The output after model.layer1[0].conv1 is correct!")
else:
    print("!!!INCORRECT!!!")

!!!INCORRECT!!!


In [77]:
layer1_0_bn1_output_original = model.layer1[0].bn1(layer1_0_conv1_output_original)
layer1_0_bn1_output_perm = permuted_model.layer1[0].bn1(layer1_0_conv1_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_bn1_output_original[0][idx][0], layer1_0_bn1_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].bn1 is correct!")
else:
    print("!!!INCORRECT!!!")

The output after layer1[0].bn1 is correct!


In [78]:
layer1_0_relu_output_original = model.layer1[0].relu(layer1_0_bn1_output_original)
layer1_0_relu_output_perm = permuted_model.layer1[0].relu(layer1_0_bn1_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_relu_output_original[0][idx][0], layer1_0_relu_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].relu is correct!")
else:
    print("!!!INCORRECT!!!")

The output after layer1[0].relu is correct!


In [122]:
layer1_0_conv2_output_original = model.layer1[0].conv2(layer1_0_relu_output_original)
layer1_0_conv2_output_perm = permuted_model.layer1[0].conv2(layer1_0_relu_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_conv2_output_original[0][idx][0], layer1_0_conv2_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].conv2 is correct!")
else:
    print("!!!INCORRECT!!!")

!!!INCORRECT!!!


In [80]:
layer1_0_bn2_output_original = model.layer1[0].bn2(layer1_0_conv2_output_original)
layer1_0_bn2_output_perm = permuted_model.layer1[0].bn2(layer1_0_conv2_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_0_bn2_output_original[0][idx][0], layer1_0_bn2_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[0].bn2 is correct!")
else:
    print("!!!INCORRECT!!!")

!!!INCORRECT!!!


In [123]:
layer1_1_conv1_output_original = model.layer1[1].conv1(layer1_0_bn2_output_original)
layer1_1_conv1_output_perm = permuted_model.layer1[1].conv1(layer1_0_bn2_output_perm)

all_close = True
for (idx, perm) in enumerate(reverse_perm):
    all_close &= torch.all(torch.isclose(layer1_1_conv1_output_original[0][idx][0], layer1_1_conv1_output_perm[0][idx][0],atol=1e-05))

if all_close:
    print("The output after layer1[1].conv1 is correct!")
else:
    print("!!!INCORRECT!!!")

!!!INCORRECT!!!


In [18]:
model


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [128]:
output_original = model(randimg)
output_perm = permuted_model(randimg)

close = torch.all(torch.isclose(output_original, output_perm,atol=1e-05))
if close:
    print("The output is correct!")
else:
    print("!!!INCORRECT!!!")

The output is correct!


In [68]:
print(output_original[0][3])
print(output_perm[0][3])

tensor(-1.6585, grad_fn=<SelectBackward0>)
tensor(-1.7740, grad_fn=<SelectBackward0>)
