In [1]:
import os
import sys

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [35]:
import pdb

In [2]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [3]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
# evaluates accuracy
def evaluate(model, loader=test_loader):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            pred = outputs.argmax(dim=1)
            correct += (labels.to(DEVICE) == pred).sum().item()
    return correct

# evaluates loss
def evaluate1(model, loader=test_loader):
    model.eval()
    losses = []
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            loss = F.cross_entropy(outputs, labels.to(DEVICE))
            losses.append(loss.item())
    return np.array(losses).mean()

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        self.linear = nn.Linear(w*64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1):
    return ResNet(BasicBlock, [3, 3, 3], w=w)

In [201]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2]).double()

                out = torch.cat((out0, out1), dim=-1)
                mean_b = out.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                std_b = torch.cat((std0_b, std1_b), dim=-1)
                outer_b = (out.T @ out) / out.shape[0]
                
                if i == 0:
                    mean = torch.zeros_like(mean_b)
                    std = torch.zeros_like(std_b)
                    outer = torch.zeros_like(outer_b)
                
                mean += mean_b / n
                std += std_b / n
                outer += outer_b / n
                
#                 mean0_b = out0.mean(dim=0)
#                 mean1_b = out1.mean(dim=0)
#                 std0_b = out0.std(dim=0)
#                 std1_b = out1.std(dim=0)
                
#                 outer_b = (out0.T @ out1) / out0.shape[0]

#                 if i == 0:
#                     mean0 = torch.zeros_like(mean0_b)
#                     mean1 = torch.zeros_like(mean1_b)
#                     std0 = torch.zeros_like(std0_b)
#                     std1 = torch.zeros_like(std1_b)
#                     outer = torch.zeros_like(outer_b)
#                 mean0 += mean0_b / n
#                 mean1 += mean1_b / n
#                 std0 += std0_b / n
#                 std1 += std1_b / n
#                 outer += outer_b / n
    cov = outer - torch.outer(mean, mean)
#     cov = outer - torch.outer(mean0, mean1)
    if norm:
#         corr = cov / (torch.outer(std0, std1) + 1e-4)
        cov = cov / (torch.outer(std, std) + 1e-4)
    torch.diagonal(cov)[:] = -torch.inf
    return cov

In [168]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix_og(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2])#.double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2])#.double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    if norm:
        cov = cov / (torch.outer(std0, std1) + 1e-4)
    return cov

In [252]:
def match_tensors_exact_bipartite_cov(
    covariance,
    r=.5
):
    O = covariance.shape[0]
    remainder = int(O * (1-r))
    bound = O - remainder
    sims = covariance
    permutation_matrix = torch.zeros((O, O - bound), device=sims.device)
    for i in range(bound):
        best_idx = sims.view(-1).argmax()
        row_idx = best_idx % sims.shape[1]
        col_idx = best_idx // sims.shape[1]
        permutation_matrix[row_idx, i] = 1
        permutation_matrix[col_idx, i] = 1
        sims[row_idx] = -torch.inf
        sims[col_idx] = -torch.inf
        sims[:, row_idx] = -torch.inf
        sims[:, col_idx] = -torch.inf
    
    unused = (sims.max(-1)[0] > -torch.inf).to(torch.int).nonzero().view(-1)
    for i in range(bound, O-bound):
        permutation_matrix[unused[i-bound], i] = 1
    merge = permutation_matrix / (permutation_matrix.sum(dim=0, keepdim=True) + 1e-5)
    unmerge = permutation_matrix
    return merge.T, unmerge

def match_tensors_permute(
    covariance
):
#     torch.diagonal(covariance)[:] = -torch.inf
    corr_mtx_a = covariance.cpu().numpy()
    O = corr_mtx_a.shape[0]# // 2
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a[:O, :O], maximize=True)
#     assert (row_ind == np.arange(len(corr_mtx_a))).all()
    unmerge = torch.tensor(col_ind).long()
    unmerge = torch.eye(O, device=covariance.device)[unmerge]
    unmerge = torch.cat(
        (
            torch.eye(O, device=covariance.device),
            unmerge.T
        ), 
        dim=0
    )
    merge = unmerge / (unmerge.sum(dim=0, keepdim=True) + 1e-5)
    return merge.T, unmerge

In [271]:
# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix_og(net0, net1)
    return match_tensors_permute(corr_mtx)

# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(merge, convs, bns):
    pre_weights = [
        (convs[0].weight, convs[1].weight, convs[2].weight),
        (bns[0].weight, bns[1].weight, bns[2].weight),
        (bns[0].bias, bns[1].bias, bns[2].bias),
        (bns[0].running_mean, bns[1].running_mean, bns[2].running_mean),
        (bns[0].running_var, bns[1].running_var, bns[2].running_var)
    ]
    for a, b, c in pre_weights:
        w = torch.cat((a, b), dim=0)
        if len(w.shape) == 4:
            try:
                transform = torch.einsum('ab,bcde->acde', merge, w)
            except:
                import pdb; pdb.set_trace()
        else:# len(w.shape) == 2:
            transform = merge @ w
#         else:
#             transform = w @ merge.t()
#         assert torch.allclose(w[perm_map.argmax(-1)], transform)
        c.data = transform
#         w.data = w[perm_map]

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(unmerge, after_convs):
    if not isinstance(after_convs, list):
        after_convs = [after_convs]
    post_weights = [(c[0].weight, c[1].weight, c[2].weight) for c in after_convs]
    for (a, b, c) in post_weights:
        w = torch.cat((a, b), dim=1)
        if len(w.shape) == 4:
            transform = torch.einsum('abcd,be->aecd', w, unmerge)
        elif len(w.shape) == 2:
            transform = w @ unmerge
    #     assert torch.allclose(w[:, perm_map.argmax(-1)], transform)
        c.data = transform
#         w.data = w[:, perm_map, :, :]

# def permute_cls_output(perm_map, linear):
#     for w in [linear.weight, linear.bias]:
#         w.data = perm_map @ w

def combine_layers(elements, output_align, input_align=None):
    for (a, b, c) in elements:
        if input_align is not None:
            zeros = torch.zeros_like(a)
            ab = torch.cat(
                (
                    torch.cat((a, zeros), dim=0),
                    torch.cat((zeros, b), dim=0)
                ), dim=1
            ) # [2O, 2I, H, W]
            if len(ab.shape) == 4:
                ab = torch.einsum('abcd,be->aecd', w, unmerge)
            else:
                ab = ab @ unmerge
        else:
            ab = torch.cat((a, b), dim=0)
        if len(ab.shape) == 4:
            try:
                transform = torch.einsum('ab,bcde->acde', merge, ab)
            except:
                import pdb; pdb.set_trace()
        else:# len(w.shape) == 2:
            transform = merge @ ab
    c.data = transform
    

# Find Bipartite Permutation

In [272]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
modelc = resnet20(w=4).to(DEVICE)
load_model(model0, 'resnet20x4_v1')
load_model(model1, 'resnet20x4_v2')

evaluate(model0), evaluate(model1)

(9536, 9510)

In [255]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        return x
# corr = run_corr_matrix(Subnet(model0), Subnet(model1))
# perm_map1 = get_layer_perm1(corr)
models = [model0, model1, modelc]
merge, unmerge = get_layer_perm(Subnet(model0), Subnet(model1))
permute_output(
    merge, 
    [m.conv1 for m in models], 
    [m.bn1 for m in models]
)

permute_output(
    merge, 
    [m.layer1[0].conv2 for m in models], 
    [m.layer1[0].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer1[1].conv2 for m in models], 
    [m.layer1[1].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer1[2].conv2 for m in models], 
    [m.layer1[2].bn2 for m in models]
)
permute_input(
    unmerge, 
    [
        [m.layer1[0].conv1 for m in models], 
        [m.layer1[1].conv1 for m in models], 
        [m.layer1[2].conv1 for m in models]
    ]
)
permute_input(
    unmerge, 
    [
        [m.layer2[0].conv1 for m in models], 
        [m.layer2[0].shortcut[0] for m in models]
    ]
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.06it/s]


In [256]:
# (merge[:128] * 2 - torch.eye(128, device=merge.device)).abs().sum()

In [257]:
merge[:, 64:][17].nonzero()

tensor([[22]], device='cuda:0')

In [258]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        return x

merge, unmerge = get_layer_perm(Subnet(model0), Subnet(model1))
permute_output(
    merge, 
    [m.layer2[0].conv2 for m in models], 
    [m.layer2[0].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer2[0].shortcut[0] for m in models], 
    [m.layer2[0].shortcut[1] for m in models]
)
permute_output(
    merge, 
    [m.layer2[1].conv2 for m in models], 
    [m.layer2[1].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer2[2].conv2 for m in models], 
    [m.layer2[2].bn2 for m in models]
)

permute_input(
    unmerge, 
    [
        [m.layer2[1].conv1 for m in models], 
        [m.layer2[2].conv1 for m in models]
    ]
)
permute_input(
    unmerge, 
    [
        [m.layer3[0].conv1 for m in models], 
        [m.layer3[0].shortcut[0] for m in models]
    ]
)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.15it/s]


In [264]:
(merge[:, 128:]*2)[17].nonzero()

tensor([[124]], device='cuda:0')

In [267]:
(unmerge[128:])[17].nonzero()

tensor([[77]], device='cuda:0')

In [268]:
unmerge.shape

torch.Size([256, 128])

In [269]:
model1.layer2[0].conv2.weight.shape

torch.Size([128, 128, 3, 3])

In [222]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

merge, unmerge = get_layer_perm(Subnet(model0), Subnet(model1))
permute_output(
    merge, 
    [m.layer3[0].conv2 for m in models], 
    [m.layer3[0].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer3[0].shortcut[0] for m in models], 
    [m.layer3[0].shortcut[1] for m in models]
)
permute_output(
    merge, 
    [m.layer3[1].conv2 for m in models], 
    [m.layer3[1].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer3[2].conv2 for m in models], 
    [m.layer3[2].bn2 for m in models]
)

permute_input(
    unmerge, 
    [
        [m.layer3[1].conv1 for m in models], 
        [m.layer3[2].conv1 for m in models]
    ]
)
modelc.linear.weight.data = torch.cat(
    (
        model0.linear.weight,
        model1.linear.weight
    ), dim=-1
) @ unmerge # [:, perm_map]
# w @ perm_map.t()

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.30it/s]


In [223]:
class Subnet(nn.Module):
    def __init__(self, model, nb=9):
        super().__init__()
        self.model = model
        self.blocks = []
        self.blocks += list(model.layer1)
        self.blocks += list(model.layer2)
        self.blocks += list(model.layer3)
        self.blocks = nn.Sequential(*self.blocks)
        self.bn1 = model.bn1
        self.conv1 = model.conv1
        self.linear = model.linear
        self.nb = nb
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.blocks[:self.nb](x)
        block = self.blocks[self.nb]
        x = block.conv1(x)
        x = block.bn1(x)
        x = F.relu(x)
        return x

blocks0 = []
blocks0 += list(model0.layer1)
blocks0 += list(model0.layer2)
blocks0 += list(model0.layer3)
blocks0 = nn.Sequential(*blocks0)

blocks1 = []
blocks1 += list(model1.layer1)
blocks1 += list(model1.layer2)
blocks1 += list(model1.layer3)
blocks1 = nn.Sequential(*blocks1)

blocksc = []
blocksc += list(modelc.layer1)
blocksc += list(modelc.layer2)
blocksc += list(modelc.layer3)
blocksc = nn.Sequential(*blocksc)

In [224]:
block_idx2name = {
    0: 'layer1.0',
    1: 'layer1.1',
    2: 'layer1.2',
    3: 'layer2.0',
    4: 'layer2.1',
    5: 'layer2.2',
    6: 'layer3.0',
    7: 'layer3.1',
    8: 'layer3.2'
}

In [225]:
for nb, (block_idx, layer_name) in zip(range(9), block_idx2name.items()):
    merge, unmerge = get_layer_perm(Subnet(model0, nb=nb), Subnet(model1, nb=nb))
    block0 = blocks0[nb]
    block1 = blocks1[nb]
    blockc = blocksc[nb]
    blocks = [block0, block1, blockc]
    permute_output(
        merge, 
        [m.conv1 for m in blocks], 
        [m.bn1 for m in blocks]
    )
    permute_input(
        unmerge, 
        [
            [m.conv2 for m in blocks]
        ]
    )

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 26.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 25.25it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 17.34it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.08it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:07<00:00, 13.68it/s]
100%

In [226]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [227]:
print('Pre-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(modelc)/100, evaluate1(modelc)))

reset_bn_stats(modelc)
print('Post-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(modelc)/100, evaluate1(modelc)))

Pre-reset:
Accuracy=11.58%, Loss=3.032
Post-reset:
Accuracy=76.03%, Loss=1.107
