# Train-and-Permute-CIFAR10-ResNet20

In the following notebook we train two standard 4x-width ResNet20s on CIFAR-10, and then generate a neuron-matching/permutation between the two models. Our permutation-search method matches neurons based on the correlation between their activations.

We then interpolate between the two matched networks, and evaluate the accuracy+loss with and without resetting BatchNorm statistics.

In [20]:
import os
import sys

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'

# setup

In [3]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [4]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/cifar-10-python.tar.gz


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 170498071/170498071 [00:01<00:00, 87193492.39it/s]


Extracting /tmp/cifar-10-python.tar.gz to /tmp
Files already downloaded and verified


In [5]:
# evaluates accuracy
def evaluate(model, loader=test_loader):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            pred = outputs.argmax(dim=1)
            correct += (labels.to(DEVICE) == pred).sum().item()
    return correct

# evaluates loss
def evaluate1(model, loader=test_loader):
    model.eval()
    losses = []
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            loss = F.cross_entropy(outputs, labels.to(DEVICE))
            losses.append(loss.item())
    return np.array(losses).mean()

In [6]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        self.linear = nn.Linear(w*64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1):
    return ResNet(BasicBlock, [3, 3, 3], w=w)

## Train and save two models

In [7]:
def train(save_key, w=4):
    model = resnet20(w=w).to(DEVICE)
    optimizer = SGD(model.parameters(), lr=0.4, momentum=0.9, weight_decay=5e-4)
    # optimizer = Adam(model.parameters(), lr=0.05)
    
    # Adam seems to perform worse than SGD for training ResNets on CIFAR-10.
    # To make Adam work, we find that we need a very high learning rate: 0.05 (50x the default)
    # At this LR, Adam gives 1.0-1.5% worse accuracy than SGD.
    
    # It is not yet clear whether the increased interpolation barrier for Adam-trained networks
    # is simply due to the increased test loss of said networks relative to those trained with SGD.
    # We include the option of using Adam in this notebook to explore this question.

    EPOCHS = 100
    ne_iters = len(train_aug_loader)
    lr_schedule = np.interp(np.arange(1+EPOCHS*ne_iters), [0, 5*ne_iters, EPOCHS*ne_iters], [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()
    
    losses = []
    for _ in tqdm(range(EPOCHS)):
        for i, (inputs, labels) in enumerate(train_aug_loader):
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                outputs = model(inputs.to(DEVICE))
                loss = loss_fn(outputs, labels.to(DEVICE))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            losses.append(loss.item())
    print(evaluate(model))
    save_model(model, save_key)

In [8]:
# train('resnet20x2_v8', w=4)

In [9]:
# train('resnet20x3_v1', w=3)

In [10]:
# train('resnet20x4_v1')
# train('resnet20x4_v2')

In [11]:
# train('resnet20x4_v3')
# train('resnet20x4_v6')

In [21]:
model0

NameError: name 'model0' is not defined

### matching code

In [12]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2])#.double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2])#.double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    if norm:
        corr = cov / (torch.outer(std0, std1) + 1e-4)
        return corr
    else:
        return cov

In [13]:
def compute_perm_map(corr_mtx):
    # sort the (i, j) channel pairs by correlation
    nchan = corr_mtx.shape[0]
    triples = [(i, j, corr_mtx[i, j].item()) for i in range(nchan) for j in range(nchan)]
    triples = sorted(triples, key=lambda p: -p[2])
    # greedily find a matching
    perm_d = {}
    for i, j, c in triples:
        if not (i in perm_d.keys() or j in perm_d.values()):
            perm_d[i] = j
    perm_map = torch.tensor([perm_d[i] for i in range(nchan)])

    # qual_map will be a permutation of the indices in the order
    # of the quality / degree of correlation between the neurons found in the permutation.
    # this just for visualization purposes.
    qual_l = [corr_mtx[i, perm_map[i]].item() for i in range(nchan)]
    qual_map = torch.tensor(sorted(range(nchan), key=lambda i: -qual_l[i]))

    return perm_map, qual_map

def get_layer_perm1(corr_mtx, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    if method == 'greedy':
        perm_map, qual_map = compute_perm_map(corr_mtx)
        if vizz:
            corr_mtx_viz = (corr_mtx[qual_map].T[perm_map[qual_map]]).T
            viz(corr_mtx_viz)
    elif method == 'max_weight':
        corr_mtx_a = corr_mtx.cpu().numpy()
        row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
        assert (row_ind == np.arange(len(corr_mtx_a))).all()
        perm_map = torch.tensor(col_ind).long()
        perm_map = torch.eye(corr_mtx.shape[0], device=corr_mtx.device)[perm_map]
    else:
        raise Exception('Unknown method: %s' % method)
    
#     pdb.set_trace()
    pruned_elements = torch.from_numpy(
        corr_mtx_a[row_ind, col_ind] >= prune_threshold
    ).to(perm_map.device).to(torch.float32)
    return perm_map, pruned_elements

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx, method, vizz, prune_threshold)

In [14]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, conv, bn):
    pre_weights = [
        conv.weight,
        bn.weight,
        bn.bias,
        bn.running_mean,
        bn.running_var,
    ]
    for w in pre_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('ab,bcde->acde', perm_map, w)
        elif len(w.shape) == 2:
            transform = perm_map @ w
        else:
            transform = w @ perm_map.t()
#         assert torch.allclose(w[perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[perm_map]

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(perm_map, after_convs):
    if not isinstance(after_convs, list):
        after_convs = [after_convs]
    post_weights = [c.weight for c in after_convs]
    for w in post_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('abcd,be->aecd', w, perm_map.t())
        elif len(w.shape) == 2:
            transform = w @ perm_map.t()
    #     assert torch.allclose(w[:, perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[:, perm_map, :, :]

def permute_cls_output(perm_map, linear):
    for w in [linear.weight, linear.bias]:
        w.data = perm_map @ w

In [15]:
import pdb

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

# Find neuron-permutation for each layer

In [34]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
load_model(model0, 'resnet20x4_v1')
load_model(model1, 'resnet20x4_v2')

evaluate(model0), evaluate(model1)

(9536, 9510)

## residual streams

In [35]:
prune_threshold = -torch.inf

In [36]:
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [37]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        return x
# corr = run_corr_matrix(Subnet(model0), Subnet(model1))
# perm_map1 = get_layer_perm1(corr)
perm_map, collapse_totals = get_layer_perm(Subnet(model0), Subnet(model1), prune_threshold=prune_threshold)
permute_output(perm_map, model1.conv1, model1.bn1)
permute_output(perm_map, model1.layer1[0].conv2, model1.layer1[0].bn2)
permute_output(perm_map, model1.layer1[1].conv2, model1.layer1[1].bn2)
permute_output(perm_map, model1.layer1[2].conv2, model1.layer1[2].bn2)
permute_input(perm_map, [model1.layer1[0].conv1, model1.layer1[1].conv1, model1.layer1[2].conv1])
permute_input(perm_map, [model1.layer2[0].conv1, model1.layer2[0].shortcut[0]])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.12it/s]


In [38]:
module2io['conv1']['output'] = collapse_totals
module2io['bn1']['output'] = collapse_totals
module2io['layer1.0.conv2']['output'] = collapse_totals
module2io['layer1.0.bn2']['output'] = collapse_totals
module2io['layer1.1.conv2']['output'] = collapse_totals
module2io['layer1.1.bn2']['output'] = collapse_totals
module2io['layer1.2.conv2']['output'] = collapse_totals
module2io['layer1.2.bn2']['output'] = collapse_totals

module2io['layer1.0.conv1']['input'] = collapse_totals
module2io['layer1.1.conv1']['input'] = collapse_totals
module2io['layer1.2.conv1']['input'] = collapse_totals
module2io['layer2.0.conv1']['input'] = collapse_totals
module2io['layer2.0.shortcut.0']['input'] = collapse_totals

In [49]:
perm_map[17].nonzero()

tensor([[124]], device='cuda:0')

In [41]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        return x

perm_map, collapse_totals = get_layer_perm(Subnet(model0), Subnet(model1), prune_threshold=prune_threshold)
permute_output(perm_map, model1.layer2[0].conv2, model1.layer2[0].bn2)
permute_output(perm_map, model1.layer2[0].shortcut[0], model1.layer2[0].shortcut[1])
permute_output(perm_map, model1.layer2[1].conv2, model1.layer2[1].bn2)
permute_output(perm_map, model1.layer2[2].conv2, model1.layer2[2].bn2)

permute_input(perm_map, [model1.layer2[1].conv1, model1.layer2[2].conv1])
permute_input(perm_map, [model1.layer3[0].conv1, model1.layer3[0].shortcut[0]])

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 14.24it/s]


In [42]:
module2io['layer2.0.conv2']['output'] = collapse_totals
module2io['layer2.0.bn2']['output'] = collapse_totals
module2io['layer2.0.shortcut.0']['output'] = collapse_totals
module2io['layer2.0.shortcut.1']['output'] = collapse_totals
module2io['layer2.1.conv2']['output'] = collapse_totals
module2io['layer2.1.bn2']['output'] = collapse_totals
module2io['layer2.2.conv2']['output'] = collapse_totals
module2io['layer2.2.bn2']['output'] = collapse_totals

module2io['layer2.1.conv1']['input'] = collapse_totals
module2io['layer2.2.conv1']['input'] = collapse_totals
module2io['layer3.0.conv1']['input'] = collapse_totals
module2io['layer3.0.shortcut.0']['input'] = collapse_totals

In [50]:
perm_map.T[17].nonzero()

tensor([[77]], device='cuda:0')

In [23]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

perm_map, collapse_totals = get_layer_perm(Subnet(model0), Subnet(model1), prune_threshold=prune_threshold)
permute_output(perm_map, model1.layer3[0].conv2, model1.layer3[0].bn2)
permute_output(perm_map, model1.layer3[0].shortcut[0], model1.layer3[0].shortcut[1])
permute_output(perm_map, model1.layer3[1].conv2, model1.layer3[1].bn2)
permute_output(perm_map, model1.layer3[2].conv2, model1.layer3[2].bn2)

permute_input(perm_map, [model1.layer3[1].conv1, model1.layer3[2].conv1])
model1.linear.weight.data = model1.linear.weight @ perm_map.t()# [:, perm_map]
# w @ perm_map.t()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 12.01it/s]


In [24]:
perm_map.shape, model1.linear.weight.shape

(torch.Size([256, 256]), torch.Size([10, 256]))

In [25]:
module2io['layer3.0.conv2']['output'] = collapse_totals
module2io['layer3.0.bn2']['output'] = collapse_totals
module2io['layer3.0.shortcut.0']['output'] = collapse_totals
module2io['layer3.0.shortcut.1']['output'] = collapse_totals
module2io['layer3.1.conv2']['output'] = collapse_totals
module2io['layer3.1.bn2']['output'] = collapse_totals
module2io['layer3.2.conv2']['output'] = collapse_totals
module2io['layer3.2.bn2']['output'] = collapse_totals

module2io['layer3.1.conv1']['input'] = collapse_totals
module2io['layer3.2.conv1']['input'] = collapse_totals
module2io['linear']['input'] = collapse_totals

In [26]:
# collapse_totals

## blocks

In [19]:
class Subnet(nn.Module):
    def __init__(self, model, nb=9):
        super().__init__()
        self.model = model
        self.blocks = []
        self.blocks += list(model.layer1)
        self.blocks += list(model.layer2)
        self.blocks += list(model.layer3)
        self.blocks = nn.Sequential(*self.blocks)
        self.bn1 = model.bn1
        self.conv1 = model.conv1
        self.linear = model.linear
        self.nb = nb
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.blocks[:self.nb](x)
        block = self.blocks[self.nb]
        x = block.conv1(x)
        x = block.bn1(x)
        x = F.relu(x)
        return x

blocks1 = []
blocks1 += list(model1.layer1)
blocks1 += list(model1.layer2)
blocks1 += list(model1.layer3)
blocks1 = nn.Sequential(*blocks1)

In [28]:
block_idx2name = {
    0: 'layer1.0',
    1: 'layer1.1',
    2: 'layer1.2',
    3: 'layer2.0',
    4: 'layer2.1',
    5: 'layer2.2',
    6: 'layer3.0',
    7: 'layer3.1',
    8: 'layer3.2'
}

In [29]:
for nb, (block_idx, layer_name) in zip(range(9), block_idx2name.items()):
    perm_map, collapse_totals = get_layer_perm(Subnet(model0, nb=nb), Subnet(model1, nb=nb))
    block = blocks1[nb]
    permute_output(perm_map, block.conv1, block.bn1)
    permute_input(perm_map, [block.conv2])
    
    module2io[layer_name + '.conv1']['output'] = collapse_totals
    module2io[layer_name + '.bn1']['output'] = collapse_totals
    module2io[layer_name + '.conv2']['output'] = collapse_totals

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 27.15it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 26.97it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.96it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.44it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.25it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.90it/s]
100%|███████████████████████████████████

# Permute Bias

In [163]:
perm_map, collapse_totals = get_layer_perm(model0, model1, prune_threshold=prune_threshold)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 12.00it/s]


In [164]:
permute_cls_output(perm_map, model1.linear)

In [165]:
module2io['linear']['output'] = collapse_totals

## done

In [30]:
save_model(model1, f'resnet20x4_v2_perm1_{prune_threshold}threshold')

## Evaluate the interpolated network

In [31]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [32]:
def strip_param_suffix(name):
    return name.replace('.weight', '').replace('.bias', '')

def combine_io_masks(io, param):
    mask = torch.zeros_like(param, device=param.device)
    try:
        if 'output' in io:
            mask[io['output'].view(-1) == 0] = 1.
        if 'input' in io and len(mask.shape) > 1:
            mask[:, io['input'].view(-1) == 0] = 1.
    except:
        pdb.set_trace()
    return mask

def mix_weights(model, alpha, key0, key1, module2io=None, whitelist_fn=lambda x: True):
    sd0 = torch.load(
        '/srv/share/gstoica3/checkpoints/REPAIR/%s.pth.tar' % key0, 
        map_location=torch.device(DEVICE)
    )
    sd1 = torch.load(
        '/srv/share/gstoica3/checkpoints/REPAIR/%s.pth.tar' % key1, 
        map_location=torch.device(DEVICE)
    )
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].to(DEVICE)
        param1 = sd1[k].to(DEVICE)
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
        
        if module2io is not None:
            param_base = strip_param_suffix(k)
#             pdb.set_trace()
            mask = combine_io_masks(module2io[param_base], param1)
            sd_alpha[k][mask == 1] = param0[mask == 1].to(sd_alpha[k].dtype)
        
        if not whitelist_fn(k):
            sd_alpha[k] = param0
#     sd_alpha = {k: (1 - alpha) * sd0[k].to('mps') + alpha * sd1[k].to('mps')
#                 for k in sd0.keys()}
    model.load_state_dict(sd_alpha)

In [33]:
model_a = resnet20(w=4).to(DEVICE) # W_\alpha
whitelist_fn = lambda x: 'bn' not in x
# whitelist_fn = lambda x: True

mix_weights(
    model_a, 
    0.5, 
    'resnet20x4_v1', 
    f'resnet20x4_v2_perm1_{prune_threshold}threshold', 
    module2io=module2io,
    whitelist_fn=whitelist_fn
)

print('Pre-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(model_a)/100, evaluate1(model_a)))

reset_bn_stats(model_a)
print('Post-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(model_a)/100, evaluate1(model_a)))

Pre-reset:
Accuracy=18.12%, Loss=2.308
Post-reset:
Accuracy=90.76%, Loss=0.331


# Find neuron-bipartitations for each layer

In [118]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
model_to_alter = resnet20(w=4).to(DEVICE)
load_model(model0, 'resnet20x4_v1')
load_model(model1, 'resnet20x4_v2')
load_model(model_to_alter, 'resnet20x4_v2')

evaluate(model0), evaluate(model1), evaluate(model_to_alter)

(9536, 9510, 9510)

In [119]:
# torch.allclose(model1.conv1.weight, model_to_alter.conv1.weight)

In [120]:
threshold = -torch.inf

In [121]:
# model_to_alter = model1

In [122]:
def get_bipartite_perm(corr, prune_threshold=-torch.inf):
    scores, idx = corr.max(0)
    valid_elements = scores >= prune_threshold
    idx = torch.where(valid_elements, idx, corr.shape[0])
    location_lookup = torch.eye(corr.shape[0]+1, corr.shape[0], device=corr.device)
    matches = location_lookup[idx]
    totals = matches.sum(0, keepdim=True)
    matches = matches / (totals + 1)
    return matches.t(), totals

def get_layer_bipartite_transform(net0, net1, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_bipartite_perm(corr_mtx, prune_threshold=prune_threshold)

In [123]:
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [124]:
# model1.state_dict().keys()

## residual streams

In [125]:
def transform_first_part(
    model0, 
    model1, 
    opt_change_model=None, 
    transform_fn=get_layer_bipartite_transform,
    module2io=defaultdict(lambda: dict()),
    prune_threshold=-torch.inf
):
    if opt_change_model is None:
        model_to_alter = model1
    else:
        model_to_alter = opt_change_model
    
    class Subnet(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            self = self.model
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            return x
    # corr = run_corr_matrix(Subnet(model0), Subnet(model1))
    # perm_map1 = get_layer_perm1(corr)
    perm_map, collapse_totals = transform_fn(Subnet(model0), Subnet(model1), prune_threshold=prune_threshold)
    permute_output(perm_map, model_to_alter.conv1, model_to_alter.bn1)
    permute_output(perm_map, model_to_alter.layer1[0].conv2, model_to_alter.layer1[0].bn2)
    permute_output(perm_map, model_to_alter.layer1[1].conv2, model_to_alter.layer1[1].bn2)
    permute_output(perm_map, model_to_alter.layer1[2].conv2, model_to_alter.layer1[2].bn2)
    permute_input(perm_map, 
                  [
                      model_to_alter.layer1[0].conv1, 
                      model_to_alter.layer1[1].conv1, 
                      model_to_alter.layer1[2].conv1
                  ]
                 )
    permute_input(perm_map, [model_to_alter.layer2[0].conv1, model_to_alter.layer2[0].shortcut[0]])
    
    module2io['conv1']['output'] = collapse_totals
    module2io['bn1']['output'] = collapse_totals
    module2io['layer1.0.conv2']['output'] = collapse_totals
    module2io['layer1.0.bn2']['output'] = collapse_totals
    module2io['layer1.1.conv2']['output'] = collapse_totals
    module2io['layer1.1.bn2']['output'] = collapse_totals
    module2io['layer1.2.conv2']['output'] = collapse_totals
    module2io['layer1.2.bn2']['output'] = collapse_totals

    module2io['layer1.0.conv1']['input'] = collapse_totals
    module2io['layer1.1.conv1']['input'] = collapse_totals
    module2io['layer1.2.conv1']['input'] = collapse_totals
    module2io['layer2.0.conv1']['input'] = collapse_totals
    module2io['layer2.0.shortcut.0']['input'] = collapse_totals
    return module2io

In [126]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        return x
# corr = run_corr_matrix(Subnet(model0), Subnet(model1))
# perm_map1 = get_layer_perm1(corr)
perm_map, collapse_totals = get_layer_bipartite_transform(Subnet(model0), Subnet(model1), threshold)
permute_output(perm_map, model_to_alter.conv1, model_to_alter.bn1)
permute_output(perm_map, model_to_alter.layer1[0].conv2, model_to_alter.layer1[0].bn2)
permute_output(perm_map, model_to_alter.layer1[1].conv2, model_to_alter.layer1[1].bn2)
permute_output(perm_map, model_to_alter.layer1[2].conv2, model_to_alter.layer1[2].bn2)
permute_input(perm_map, [model_to_alter.layer1[0].conv1, model_to_alter.layer1[1].conv1, model_to_alter.layer1[2].conv1])
permute_input(perm_map, [model_to_alter.layer2[0].conv1, model_to_alter.layer2[0].shortcut[0]])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.87it/s]


In [127]:
module2io['conv1']['output'] = collapse_totals
module2io['bn1']['output'] = collapse_totals
module2io['layer1.0.conv2']['output'] = collapse_totals
module2io['layer1.0.bn2']['output'] = collapse_totals
module2io['layer1.1.conv2']['output'] = collapse_totals
module2io['layer1.1.bn2']['output'] = collapse_totals
module2io['layer1.2.conv2']['output'] = collapse_totals
module2io['layer1.2.bn2']['output'] = collapse_totals

module2io['layer1.0.conv1']['input'] = collapse_totals
module2io['layer1.1.conv1']['input'] = collapse_totals
module2io['layer1.2.conv1']['input'] = collapse_totals
module2io['layer2.0.conv1']['input'] = collapse_totals
module2io['layer2.0.shortcut.0']['input'] = collapse_totals

In [128]:
module2io['conv1']['output'].shape, model1.conv1.weight.shape

(torch.Size([1, 64]), torch.Size([64, 3, 3, 3]))

In [129]:
model_to_alter.layer1[0].conv1.weight.shape

torch.Size([64, 64, 3, 3])

In [130]:
# module2io = transform_first_part(
#     model0,
#     model1,
#     opt_change_model=model_to_alter,
#     module2io=module2io,
#     prune_threshold=-torch.inf
# )

In [131]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        return x

perm_map, collapse_totals = get_layer_bipartite_transform(Subnet(model0), Subnet(model1), threshold)
permute_output(perm_map, model_to_alter.layer2[0].conv2, model_to_alter.layer2[0].bn2)
permute_output(perm_map, model_to_alter.layer2[0].shortcut[0], model_to_alter.layer2[0].shortcut[1])
permute_output(perm_map, model_to_alter.layer2[1].conv2, model_to_alter.layer2[1].bn2)
permute_output(perm_map, model_to_alter.layer2[2].conv2, model_to_alter.layer2[2].bn2)

permute_input(perm_map, [model_to_alter.layer2[1].conv1, model_to_alter.layer2[2].conv1])
permute_input(perm_map, [model_to_alter.layer3[0].conv1, model_to_alter.layer3[0].shortcut[0]])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.57it/s]


In [132]:
module2io['layer2.0.conv2']['output'] = collapse_totals
module2io['layer2.0.bn2']['output'] = collapse_totals
module2io['layer2.0.shortcut.0']['output'] = collapse_totals
module2io['layer2.0.shortcut.1']['output'] = collapse_totals
module2io['layer2.1.conv2']['output'] = collapse_totals
module2io['layer2.1.bn2']['output'] = collapse_totals
module2io['layer2.2.conv2']['output'] = collapse_totals
module2io['layer2.2.bn2']['output'] = collapse_totals

module2io['layer2.1.conv1']['input'] = collapse_totals
module2io['layer2.2.conv1']['input'] = collapse_totals
module2io['layer3.0.conv1']['input'] = collapse_totals
module2io['layer3.0.shortcut.0']['input'] = collapse_totals

In [133]:
collapse_totals

tensor([[2., 1., 1., 1., 0., 2., 0., 0., 1., 2., 1., 1., 1., 2., 0., 0., 0., 1.,
         1., 1., 1., 1., 0., 0., 0., 2., 2., 0., 0., 1., 0., 2., 0., 0., 1., 2.,
         0., 1., 0., 3., 1., 1., 0., 4., 1., 0., 3., 0., 0., 1., 0., 0., 0., 1.,
         2., 1., 0., 1., 1., 3., 0., 0., 0., 1., 2., 1., 0., 2., 1., 0., 2., 1.,
         0., 1., 0., 0., 0., 2., 1., 1., 0., 1., 1., 3., 1., 2., 0., 1., 1., 1.,
         4., 1., 2., 0., 0., 0., 2., 1., 3., 2., 1., 0., 1., 1., 3., 2., 1., 2.,
         2., 0., 0., 1., 1., 3., 1., 1., 1., 1., 0., 1., 3., 1., 1., 0., 2., 0.,
         0., 3.]], device='cuda:0')

In [134]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

perm_map, collapse_totals = get_layer_bipartite_transform(Subnet(model0), Subnet(model1), threshold)
permute_output(perm_map, model_to_alter.layer3[0].conv2, model_to_alter.layer3[0].bn2)
permute_output(perm_map, model_to_alter.layer3[0].shortcut[0], model_to_alter.layer3[0].shortcut[1])
permute_output(perm_map, model_to_alter.layer3[1].conv2, model_to_alter.layer3[1].bn2)
permute_output(perm_map, model_to_alter.layer3[2].conv2, model_to_alter.layer3[2].bn2)

permute_input(perm_map, [model_to_alter.layer3[1].conv1, model_to_alter.layer3[2].conv1])
model_to_alter.linear.weight.data = model_to_alter.linear.weight @ perm_map.t()# [:, perm_map]
# w @ perm_map.t()

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.55it/s]


In [135]:
module2io['layer3.0.conv2']['output'] = collapse_totals
module2io['layer3.0.bn2']['output'] = collapse_totals
module2io['layer3.0.shortcut.0']['output'] = collapse_totals
module2io['layer3.0.shortcut.1']['output'] = collapse_totals
module2io['layer3.1.conv2']['output'] = collapse_totals
module2io['layer3.1.bn2']['output'] = collapse_totals
module2io['layer3.2.conv2']['output'] = collapse_totals
module2io['layer3.2.bn2']['output'] = collapse_totals

module2io['layer3.1.conv1']['input'] = collapse_totals
module2io['layer3.2.conv1']['input'] = collapse_totals
module2io['linear']['input'] = collapse_totals

In [136]:
# perm_map.shape, model1.linear.weight.shape

In [137]:
collapse_totals

tensor([[ 4.,  0.,  8.,  0.,  7.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,
          0.,  1.,  1.,  1.,  0.,  3.,  0.,  3.,  0.,  0.,  0.,  0.,  0.,  0.,
          3.,  0.,  0.,  2.,  0.,  3.,  0.,  1.,  0.,  0.,  3.,  7.,  0.,  1.,
          0.,  0.,  0.,  5.,  3.,  2.,  0.,  0.,  0.,  6.,  0.,  0.,  2.,  0.,
          0.,  0.,  3.,  0.,  2.,  0.,  0.,  0.,  1.,  0.,  1.,  0.,  0.,  0.,
          0.,  1.,  0.,  0.,  0.,  4.,  0.,  0.,  0.,  0.,  3.,  0.,  0.,  0.,
          1.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  2.,  0.,  1.,  0.,  3.,  0.,
          3.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  3.,  0.,  0.,  0.,  1.,
          0.,  0.,  1.,  0.,  2.,  2.,  2.,  0.,  1.,  0.,  3.,  0.,  1.,  0.,
          0.,  2.,  0.,  2.,  1.,  7.,  0.,  1.,  0.,  5.,  0.,  0.,  1.,  0.,
          2.,  0.,  1.,  1.,  2.,  1.,  0.,  4.,  0.,  2.,  0.,  0.,  1.,  1.,
          0.,  1.,  0.,  2.,  4.,  0.,  0.,  0.,  0., 14.,  0.,  1.,  1.,  0.,
          1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  2.

## blocks

In [138]:
class Subnet(nn.Module):
    def __init__(self, model, nb=9):
        super().__init__()
        self.model = model
        self.blocks = []
        self.blocks += list(model.layer1)
        self.blocks += list(model.layer2)
        self.blocks += list(model.layer3)
        self.blocks = nn.Sequential(*self.blocks)
        self.bn1 = model.bn1
        self.conv1 = model.conv1
        self.linear = model.linear
        self.nb = nb
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.blocks[:self.nb](x)
        block = self.blocks[self.nb]
        x = block.conv1(x)
        x = block.bn1(x)
        x = F.relu(x)
        return x

blocks1 = []
blocks1 += list(model_to_alter.layer1)
blocks1 += list(model_to_alter.layer2)
blocks1 += list(model_to_alter.layer3)
blocks1 = nn.Sequential(*blocks1)

In [139]:
block_idx2name = {
    0: 'layer1.0',
    1: 'layer1.1',
    2: 'layer1.2',
    3: 'layer2.0',
    4: 'layer2.1',
    5: 'layer2.2',
    6: 'layer3.0',
    7: 'layer3.1',
    8: 'layer3.2'
}

In [140]:
for nb, (block_idx, layer_name) in zip(range(9), block_idx2name.items()):
    perm_map, collapse_totals = get_layer_bipartite_transform(
        Subnet(model0, nb=nb), Subnet(model1, nb=nb), threshold
    )
    block = blocks1[nb]
    permute_output(perm_map, block.conv1, block.bn1)
    permute_input(perm_map, [block.conv2])
    
    module2io[layer_name + '.conv1']['output'] = collapse_totals
    module2io[layer_name + '.bn1']['output'] = collapse_totals
    module2io[layer_name + '.conv2']['output'] = collapse_totals

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 28.73it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 27.79it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.49it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.64it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.41it/s]
100%|███████████████████████████████████

# Permute Bias

In [90]:
perm_map, collapse_totals = get_layer_bipartite_transform(model0, model1, threshold)
permute_cls_output(perm_map, model1.linear)
module2io['linear']['output'] = collapse_totals

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.60it/s]


## done

In [142]:
save_model(model_to_alter, f'resnet20x4_v2_bipartite_{threshold}threshold')#_change_model1')

## Evaluate the interpolated network

In [199]:
# threshold = 0.0

In [144]:
# whitelist_fn = lambda x: 'bn' not in x
whitelist_fn = lambda x: True

model_a = resnet20(w=4).to(DEVICE) # W_alpha
mix_weights(
    model_a, 
    0.5, 
    'resnet20x4_v1', 
    f'resnet20x4_v2_bipartite_{threshold}threshold',#_change_model1', 
    module2io=module2io,
    whitelist_fn=whitelist_fn
)

print('Pre-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(model_a)/100, evaluate1(model_a)))

reset_bn_stats(model_a)
print('Post-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(model_a)/100, evaluate1(model_a)))

Pre-reset:
Accuracy=16.10%, Loss=2.294
Post-reset:
Accuracy=93.94%, Loss=0.202
