In [1]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform
from collections import defaultdict

DEVICE = 'mps' if platform == 'darwin' else 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR100(
    root='/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python', 
    train=True,
    download=True, transform=train_transform
)
test_dset = torchvision.datasets.CIFAR100(
    root='/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python',
    train=False,
    download=True, 
    transform=test_transform
)

model1_classes= np.arange(50) # np.array([3, 2, 0, 6, 4])
model2_classes = np.arange(50, 100) # np.array([5, 7, 9, 8, 1])

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


50000it [00:13, 3741.44it/s]
50000it [00:13, 3773.18it/s]
10000it [00:01, 5689.69it/s]
10000it [00:01, 5744.98it/s]


In [3]:
# test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
# test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

In [4]:
# test_loader1 = torch.utils.data.DataLoader(
#     torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
# )
# test_loader2 = torch.utils.data.DataLoader(
#     torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
# )

In [5]:
model1_classes, model2_classes

(array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
        17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
        34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]),
 array([50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
        67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,
        84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]))

In [3]:
class_idxs = np.zeros(100, dtype=int)
class_idxs[model1_classes] = np.arange(50)
class_idxs[model2_classes] = np.arange(50)
class_idxs = torch.from_numpy(class_idxs)
class_idxs

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,  0,  1,  2,  3,
         4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47, 48, 49])

In [4]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        self.linear = nn.Linear(w*64, 512)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1):
    return ResNet(BasicBlock, [3, 3, 3], w=w)

In [6]:
def train(save_key, model, train_loader, test_loader, class_vectors, remap_class_idxs):
    optimizer = SGD(model.parameters(), lr=0.4, momentum=0.9, weight_decay=5e-4)
    # optimizer = Adam(model.parameters(), lr=0.05)
    
    # Adam seems to perform worse than SGD for training ResNets on CIFAR-10.
    # To make Adam work, we find that we need a very high learning rate: 0.05 (50x the default)
    # At this LR, Adam gives 1.0-1.5% worse accuracy than SGD.
    
    # It is not yet clear whether the increased interpolation barrier for Adam-trained networks
    # is simply due to the increased test loss of said networks relative to those trained with SGD.
    # We include the option of using Adam in this notebook to explore this question.

    EPOCHS = 100
    ne_iters = len(train_loader)
    lr_schedule = np.interp(np.arange(1+EPOCHS*ne_iters), [0, 5*ne_iters, EPOCHS*ne_iters], [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()
    
    losses = []
    for _ in tqdm(range(EPOCHS)):
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                encodings = model(inputs.to(DEVICE))
                normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
                logits = (100.0 * normed_encodings @ class_vectors.T)
                remapped_labels = remap_class_idxs[labels].to(DEVICE)
                loss = loss_fn(logits, remapped_labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            losses.append(loss.item())
    print(evaluate(
        model, test_loader, 
        class_vectors=class_vectors, 
        remap_class_idxs=remap_class_idxs
    ))
    save_model(model, save_key)

In [7]:
# evaluates accuracy
def evaluate(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    confusion = np.zeros((100, 100))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                correct += (labels.to(DEVICE) == pred).sum().item()
            confusion[labels.cpu().numpy(), pred.cpu().numpy()] += 1
            total += inputs.shape[0]
    if return_confusion:
        return correct / total, confusion / confusion.sum(-1, keepdims=True)
    else:
        return correct / total
# evaluates loss
def evaluate1(model, loader, class_vectors, remap_class_idxs):
    model.eval()
    losses = []
    pdb.set_trace()
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            loss = F.cross_entropy(outputs, remap_class_idxs[labels].to(DEVICE))
            losses.append(loss.item())
    return np.array(losses).mean()

In [26]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader, interleave=False):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out1 = net1(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                
                if interleave:
                    A1, A2 = out0.chunk(2, dim=0)
                    B1, B2 = out1.chunk(2, dim=0)
                    out0 = torch.cat((A1, B1), dim=0)
                    out1 = torch.cat((A2, B2), dim=0)
                
                out0 = out0.reshape(-1, out0.shape[2]).double()
                out1 = out1.reshape(-1, out1.shape[2]).double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n
                
    cov = outer - torch.outer(mean0, mean1)
    if cov.isnan().sum() > 0: pdb.set_trace()
    if norm:
        corr = cov / (torch.outer(std0, std1) + 1e-4)
        return corr.to(torch.float32)
    else:
        return cov.to(torch.float32)

def interleave_tensors(a, b):
    A1, A2 = a.chunk(2, dim=0)
    B1, B2 = b.chunk(2, dim=0)
    return (
        torch.cat((A1, B1), dim=0), 
        torch.cat((A2, B2), dim=0)
    )
    
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, module2tuple, interleave=False):
    if 'conv' in module2tuple:
        for (a_conv, b_conv) in module2tuple['conv']:
            a_weight = a_conv.weight
            b_weight = b_conv.weight
            if interleave:
                a_weight, b_weight = interleave_tensors(a_weight, b_weight)
            b_weight = (torch.einsum('ab,bcde->acde', perm_map, b_weight) + a_weight) / 2.
            b_conv.weight.data = b_weight
            a_conv.weight.data = b_weight
            
    if 'bn' in module2tuple:
        for (a_bn, b_bn) in module2tuple['bn']:
            a_weight, a_bias, a_mean, a_var = a_bn.weight, a_bn.bias, a_bn.running_mean, a_bn.running_var
            b_weight, b_bias, b_mean, b_var = b_bn.weight, b_bn.bias, b_bn.running_mean, b_bn.running_var
            if interleave:
                a_weight, b_weight = interleave_tensors(a_weight, b_weight)
                a_bias, b_bias = interleave_tensors(a_bias, b_bias)
                a_mean, b_mean = interleave_tensors(a_mean, b_mean)
                a_var, b_var = interleave_tensors(a_var, b_var)
                
            b_bn.weight.data = a_bn.weight.data = (b_weight @ perm_map.t() + a_weight) / 2.
            b_bn.bias.data = a_bn.bias.data = (b_bias @ perm_map.t() + a_bias) / 2.
            b_bn.running_var.data = a_bn.running_var.data = (b_var @ perm_map.t() + a_var) / 2.
            b_bn.running_mean.data = a_bn.running_mean.data = (b_mean @ perm_map.t() + a_mean) / 2.
#             c_bn.num_batches_tracked = (a_bn.num_batches_tracked + b_bn.num_batches_tracked) // 2

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(perm_map, conv_triples, interleave=False):
    if not isinstance(conv_triples, list):
        conv_triples = [conv_triples]
    post_weights = [(c[0].weight, c[1].weight) for c in conv_triples]
    for (a, b) in post_weights:
        if interleave:
            a, b = interleave_tensors(a, b)
        if len(a.shape) == 4:
            transform = (torch.einsum('abcd,be->aecd', b, perm_map.t()))# + a) / 2.
        elif len(w.shape) == 2:
            transform = (b @ perm_map.t())# + a) / 2.
        b.data = transform
        

def permute_cls_output(perm_map, linear):
    for w in [linear.weight, linear.bias]:
        w.data = perm_map @ w

In [114]:
def strip_param_suffix(name):
    for param in ['weight', 'bias', 'running_mean', 'running_var', 'num_batches_tracked']:
        name = name.replace('.' + param, '')
    return name
#     return name.replace('.weight', '').replace('.bias', '').replace('.running_mean', '').replace('.')

def check_similarities(model0, model1, whitelist_layers=None):
    true_keys = []
    false_keys = []
    true_modules = []
    false_modules = []
    
    for key, param in model0.state_dict().items():
        module = strip_param_suffix(key)
        if torch.allclose(param, model1.state_dict()[key]):
            true_keys += [key]
            true_modules += [module] if module not in true_modules else []
        else:
            false_keys += [key]
            false_modules += [module] if module not in false_modules else []
#     print('------- Aligned Keys -------')
#     print(true_keys)
#     print(true_modules)
#     print('----------------------------')
    print('------ Unaligned Keys ------')
    print(false_keys)
    print(false_modules)
    return true_keys, false_keys

In [9]:
def transform_model(
    model0, 
    model1,
    model_merge,
    transform_fn, 
    prune_threshold=-torch.inf, 
    module2io=defaultdict(lambda: dict()),
    interleave=False,
    check_model=None
):
    class Subnet(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            self = self.model
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            return x
    models = [model0, model1]#, model1]
    perm_map, collapse_totals = transform_fn(Subnet(model0), Subnet(model1), interleave=interleave)
    permute_output(
        perm_map,
        {
            'conv': [
                [model.conv1 for model in  models],
                [model.layer1[0].conv2 for model in models],
                [model.layer1[1].conv2 for model in models],
                [model.layer1[2].conv2 for model in models],
            ],
            'bn': [
                [model.bn1 for model in models],
                [model.layer1[0].bn2 for model in models],
                [model.layer1[1].bn2 for model in models],
                [model.layer1[2].bn2 for model in models],
            ]
        },
        interleave=interleave
    )
    permute_input(
        perm_map, 
        [
            [model.layer1[0].conv1 for model in models],
            [model.layer1[1].conv1 for model in models],
            [model.layer1[2].conv1 for model in models],
            [model.layer2[0].conv1 for model in models],
            [model.layer2[0].shortcut[0] for model in models]
        ],
        interleave=interleave
    )
    
    module2io['conv1']['output'] = collapse_totals
    module2io['bn1']['output'] = collapse_totals
    module2io['layer1.0.conv2']['output'] = collapse_totals
    module2io['layer1.0.bn2']['output'] = collapse_totals
    module2io['layer1.1.conv2']['output'] = collapse_totals
    module2io['layer1.1.bn2']['output'] = collapse_totals
    module2io['layer1.2.conv2']['output'] = collapse_totals
    module2io['layer1.2.bn2']['output'] = collapse_totals

    module2io['layer1.0.conv1']['input'] = collapse_totals
    module2io['layer1.1.conv1']['input'] = collapse_totals
    module2io['layer1.2.conv1']['input'] = collapse_totals
    module2io['layer2.0.conv1']['input'] = collapse_totals
    module2io['layer2.0.shortcut.0']['input'] = collapse_totals
    reset_bn_stats(model0)
    reset_bn_stats(model1)
    class Subnet(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            self = self.model
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            x = self.layer2(x)
            return x

    perm_map, collapse_totals = transform_fn(Subnet(model0), Subnet(model1), interleave=interleave)

    permute_output(
        perm_map,
        {
            'conv': [
                [model.layer2[0].conv2 for model in models],
                [model.layer2[0].shortcut[0] for model in models],
                [model.layer2[1].conv2 for model in models],
                [model.layer2[2].conv2 for model in models]
            ],
            'bn': [
                [model.layer2[0].bn2 for model in models],
                [model.layer2[0].shortcut[1] for model in models],
                [model.layer2[1].bn2 for model in models],
                [model.layer2[2].bn2 for model in models],
            ]
        },
        interleave=interleave
    )
    permute_input(
        perm_map,
        [
            [model.layer2[1].conv1 for model in models],
            [model.layer2[2].conv1 for model in models],
            [model.layer3[0].conv1 for model in models],
            [model.layer3[0].shortcut[0] for model in models]
        ],
        interleave=interleave
    )
    reset_bn_stats(model0)
    reset_bn_stats(model1)
    module2io['layer2.0.conv2']['output'] = collapse_totals
    module2io['layer2.0.bn2']['output'] = collapse_totals
    module2io['layer2.0.shortcut.0']['output'] = collapse_totals
    module2io['layer2.0.shortcut.1']['output'] = collapse_totals
    module2io['layer2.1.conv2']['output'] = collapse_totals
    module2io['layer2.1.bn2']['output'] = collapse_totals
    module2io['layer2.2.conv2']['output'] = collapse_totals
    module2io['layer2.2.bn2']['output'] = collapse_totals

    module2io['layer2.1.conv1']['input'] = collapse_totals
    module2io['layer2.2.conv1']['input'] = collapse_totals
    module2io['layer3.0.conv1']['input'] = collapse_totals
    module2io['layer3.0.shortcut.0']['input'] = collapse_totals
    
    class Subnet(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            self = self.model
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            return x

    perm_map, collapse_totals = transform_fn(Subnet(model0), Subnet(model1), interleave=interleave)

    permute_output(
        perm_map,
        {
            'conv': [
                [model.layer3[0].conv2 for model in models],
                [model.layer3[0].shortcut[0] for model in models],
                [model.layer3[1].conv2 for model in models],
                [model.layer3[2].conv2 for model in models]
            ],
            'bn': [
                [model.layer3[0].bn2 for model in models],
                [model.layer3[0].shortcut[1] for model in models],
                [model.layer3[1].bn2 for model in models],
                [model.layer3[2].bn2 for model in models]
            ]
        },
        interleave=interleave
    )
    permute_input(
        perm_map,
        [
            [model.layer3[1].conv1 for model in models],
            [model.layer3[2].conv1 for model in models],
        ],
        interleave=interleave
    )
    
    model_merge.linear.weight.data = (model1.linear.weight @ perm_map.t())# + model0.linear.weight) / 2.
#     model_merge.linear.bias.data = (model1.linear.bias)# + model0.linear.bias) / 2.
    reset_bn_stats(model0)
    reset_bn_stats(model1)
    module2io['layer3.0.conv2']['output'] = collapse_totals
    module2io['layer3.0.bn2']['output'] = collapse_totals
    module2io['layer3.0.shortcut.0']['output'] = collapse_totals
    module2io['layer3.0.shortcut.1']['output'] = collapse_totals
    module2io['layer3.1.conv2']['output'] = collapse_totals
    module2io['layer3.1.bn2']['output'] = collapse_totals
    module2io['layer3.2.conv2']['output'] = collapse_totals
    module2io['layer3.2.bn2']['output'] = collapse_totals

    module2io['layer3.1.conv1']['input'] = collapse_totals
    module2io['layer3.2.conv1']['input'] = collapse_totals
    module2io['linear']['input'] = collapse_totals

    class Subnet(nn.Module):
        def __init__(self, model, nb=9):
            super().__init__()
            self.model = model
            self.blocks = []
            self.blocks += list(model.layer1)
            self.blocks += list(model.layer2)
            self.blocks += list(model.layer3)
            self.blocks = nn.Sequential(*self.blocks)
            self.bn1 = model.bn1
            self.conv1 = model.conv1
            self.linear = model.linear
            self.nb = nb

        def forward(self, x):
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.blocks[:self.nb](x)
            block = self.blocks[self.nb]
            x = block.conv1(x)
            x = block.bn1(x)
            x = F.relu(x)
            return x
    
#     blocks1 = []
#     blocks1 += list(model1.layer1)
#     blocks1 += list(model1.layer2)
#     blocks1 += list(model1.layer3)
#     blocks1 = nn.Sequential(*blocks1)
    
    block_idx2name = {
        0: ('layer1.0', [model.layer1[0] for model in models]),
        1: ('layer1.1', [model.layer1[1] for model in models]),
        2: ('layer1.2', [model.layer1[2] for model in models]),
        3: ('layer2.0', [model.layer2[0] for model in models]),
        4: ('layer2.1', [model.layer2[1] for model in models]),
        5: ('layer2.2', [model.layer2[2] for model in models]),
        6: ('layer3.0', [model.layer3[0] for model in models]),
        7: ('layer3.1', [model.layer3[1] for model in models]),
        8: ('layer3.2', [model.layer3[2] for model in models]),
    }

    for nb, (block_idx, (layer_name, layers)) in zip(range(9), block_idx2name.items()):
        perm_map, collapse_totals = transform_fn(
            Subnet(model0, nb=nb), 
            Subnet(model1, nb=nb), 
            interleave=interleave
        )
        # block = blocks1[nb]
        permute_output(
            perm_map,
            {
                'conv': [
                    [layer.conv1 for layer in layers],
                ],
                'bn': [
                    [layer.bn1 for layer in layers]
                ]
            },
            interleave=interleave
        )
        permute_input(
            perm_map,
            [
                [layer.conv2 for layer in layers]
            ],
            interleave=interleave
        )
        module2io[layer_name + '.conv1']['output'] = collapse_totals
        module2io[layer_name + '.bn1']['output'] = collapse_totals
        module2io[layer_name + '.conv2']['output'] = collapse_totals
        reset_bn_stats(model0)
        reset_bn_stats(model1)
#     _ = check_similarities(model_merge, model1)
    return model_merge, module2io


In [10]:
import clip

In [11]:
print(test_dset.classes)

['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain', 'mouse', 'mushroom', 'oak_tree', 'orange', 'orchid', 'otter', 'palm_tree', 'pear', 'pickup_truck', 'pine_tree', 'plain', 'plate', 'poppy', 'porcupine', 'possum', 'rabbit', 'raccoon', 'ray', 'road', 'rocket', 'rose', 'sea', 'seal', 'shark', 'shrew', 'skunk', 'skyscraper', 'snail', 'snake', 'spider', 'squirrel', 'streetcar', 'sunflower', 'sweet_pepper', 'table', 'tank', 'telephone', 'television', 'tiger', 'tractor', 'train', 'trout', 'tulip', 'turtle', 'wardrobe', 'whale', 'willow_tree',

In [14]:
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)

In [15]:
clip_model, preprocess = clip.load('ViT-B/32', DEVICE)

In [16]:
with torch.no_grad():
    text_features = clip_model.encode_text(text_inputs)

In [17]:
text_features.shape

torch.Size([100, 512])

In [18]:
text_features /= text_features.norm(dim=-1, keepdim=True)

In [19]:
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]
# class_vecs1 /= class_vecs1.norm(dim=-1, keepdim=True)
# class_vecs2 /= class_vecs2.norm(dim=-1, keepdim=True)

In [36]:
print(model1_classes)

[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49]


In [37]:
os.listdir('/srv/share/gstoica3/checkpoints/REPAIR/')

['resnet20x4_v5.pth.tar',
 'resnet20x4_v1.pth.tar',
 'resnet20x4_CIFAR5_perm1_-infthreshold_new.pth.tar',
 'resnet20x4_CIFAR5_procrustes_greedy.pth.tar',
 'resnet20x4_v1_perm1_-infthreshold.pth.tar',
 'resnet20x4_CIFAR5_bipartite_-infthreshold_new.pth.tar',
 'resnet20x4_CIFAR5_clses[5, 7, 9, 8, 1].pth.tar',
 'resnet20x3_v1.pth.tar',
 'resnet20x4_v4_perm1_conv1_0param.pth.tar',
 'resnet20x4_CIFAR5_procrustes_-infthreshold.pth.tar',
 'resnet20x4_v5_perm1_conv1_5param.pth.tar',
 'resnet20x4_CIFAR50_perm1_-infthreshold_new2.pth.tar',
 'resnet20x4_v4_perm1_conv1_1param.pth.tar',
 'resnet20x4_CIFAR50_perm1_-infthreshold_new.pth.tar',
 'resnet20x4_v4_perm1_conv1_3param.pth.tar',
 'resnet20x4_v4_perm1_conv1_2param.pth.tar',
 'resnet20x4_CIFAR50_clses[50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99].pth.tar',
 'resnet20x4_v4.pth.tar',
 'resnet20x

In [25]:
if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}.pth.tar'
    )
):
    print('training model...')
    model1 = resnet20(w=4).to(DEVICE)
    train(
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}', 
        model=model1, 
        class_vectors=class_vecs1,
        train_loader=train_aug_loader1,
        test_loader=test_loader1,
        remap_class_idxs=class_idxs
    )
if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}.pth.tar'
    )89
):
    print('training model...')
    model2 = resnet20(w=4).to(DEVICE)
    train(
        f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}', 
        model=model2, 
        class_vectors=class_vecs2,
        train_loader=train_aug_loader2,
        test_loader=test_loader2,
        remap_class_idxs=class_idxs
    )
      

# Combine models and evaluate performance

In [20]:
def strip_param_suffix(name):
    return name.replace('.weight', '').replace('.bias', '')

def combine_io_masks(io, param):
    mask = torch.zeros_like(param, device=param.device)
    try:
        if 'output' in io:
            mask[io['output'].view(-1) == 0] = 1.
        if 'input' in io and len(mask.shape) > 1:
            mask[:, io['input'].view(-1) == 0] = 1.
    except:
        pdb.set_trace()
    return mask

def mix_weights(model, alpha, key0, key1, module2io=None, whitelist_fn=lambda x: True):
    sd0 = torch.load(
        '/srv/share/gstoica3/checkpoints/REPAIR/%s.pth.tar' % key0, 
        map_location=torch.device(DEVICE)
    )
    sd1 = torch.load(
        '/srv/share/gstoica3/checkpoints/REPAIR/%s.pth.tar' % key1, 
        map_location=torch.device(DEVICE)
    )
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].to(DEVICE)
        param1 = sd1[k].to(DEVICE)
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
        
        if module2io is not None:
            param_base = strip_param_suffix(k)
#             pdb.set_trace()
            mask = combine_io_masks(module2io[param_base], param1)
            sd_alpha[k][mask == 1] = param0[mask == 1].to(sd_alpha[k].dtype)
        
        if not whitelist_fn(k):
            sd_alpha[k] = param0
    model.load_state_dict(sd_alpha)

In [24]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}'
)

In [27]:
_, confusion = evaluate_texthead(avg_model, test_loader, class_vectors=text_features, return_confusion=True) 
print(np.diag(confusion).round(3).tolist())

[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]


In [30]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [31]:
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

Post-reset:
0.4192


# Combine models via permutation

In [21]:
def compute_perm_map(corr_mtx):
    # sort the (i, j) channel pairs by correlation
    nchan = corr_mtx.shape[0]
    triples = [(i, j, corr_mtx[i, j].item()) for i in range(nchan) for j in range(nchan)]
    triples = sorted(triples, key=lambda p: -p[2])
    # greedily find a matching
    perm_d = {}
    for i, j, c in triples:
        if not (i in perm_d.keys() or j in perm_d.values()):
            perm_d[i] = j
    perm_map = torch.tensor([perm_d[i] for i in range(nchan)])

    # qual_map will be a permutation of the indices in the order
    # of the quality / degree of correlation between the neurons found in the permutation.
    # this just for visualization purposes.
    qual_l = [corr_mtx[i, perm_map[i]].item() for i in range(nchan)]
    qual_map = torch.tensor(sorted(range(nchan), key=lambda i: -qual_l[i]))

    return perm_map, qual_map

def get_layer_perm1(
    corr_mtx, method='max_weight', vizz=False, 
    prune_threshold=-torch.inf, interleave=False
):
    if method == 'greedy':
        perm_map, qual_map = compute_perm_map(corr_mtx)
        if vizz:
            corr_mtx_viz = (corr_mtx[qual_map].T[perm_map[qual_map]]).T
            viz(corr_mtx_viz)
    elif method == 'max_weight':
        corr_mtx_a = corr_mtx.cpu().numpy()
        row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
        assert (row_ind == np.arange(len(corr_mtx_a))).all()
        perm_map = torch.tensor(col_ind).long()
        perm_map = torch.eye(corr_mtx.shape[0], device=corr_mtx.device)[perm_map]
    else:
        raise Exception('Unknown method: %s' % method)
    
#     pdb.set_trace()
    pruned_elements = torch.from_numpy(
        corr_mtx_a[row_ind, col_ind] >= prune_threshold
    ).to(perm_map.device).to(torch.float32)
    return perm_map, pruned_elements

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf, interleave=False):
    corr_mtx = run_corr_matrix(net0, net1, interleave=interleave)
    return get_layer_perm1(
        corr_mtx, method=method, vizz=vizz, 
        prune_threshold=prune_threshold, 
        interleave=interleave
    )

In [39]:
modela = resnet20(w=4).to(DEVICE)
modelb = resnet20(w=4).to(DEVICE)
load_model(modela, f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}')
load_model(modelb, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')

print(evaluate(modela, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(modelb, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.778
0.7758


In [178]:
check_model = resnet20(w=4).to(DEVICE)
load_model(check_model, f'resnet20x4_CIFAR50_perm1_{prune_threshold}threshold_new2')

In [23]:
print(evaluate(modela, test_loader, class_vectors=text_features))
print(evaluate(modelb, test_loader, class_vectors=text_features))

0.3904
0.3884


In [24]:
prune_threshold = -torch.inf
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [36]:
model_merge = resnet20(w=4).to(DEVICE)
# model_merge.eval()
# modela.eval()
# modelb.eval()
model_merge, module2io = transform_model(
    modela, 
    modelb, 
    model_merge=modelb,
    transform_fn=get_layer_perm, 
    prune_threshold=-torch.inf, 
    module2io=module2io,
    interleave=True,
    check_model=None
)

100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.76it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.62it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.71it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 24.52it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.97it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.91it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.34it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████

In [37]:
evaluate(model_merge, test_loader, class_vectors=text_features)

0.0103

In [38]:
reset_bn_stats(model_merge)
evaluate(model_merge, test_loader, class_vectors=text_features)

0.0103

In [207]:
save_model(model_merge, f'resnet20x4_CIFAR50_perm1_{prune_threshold}threshold_new3')

In [56]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR50_perm1_{prune_threshold}threshold_new3',
    whitelist_fn =lambda x: True#lambda x: 'bn' not in strip_param_suffix(x)
)

print(
    evaluate_texthead(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
acc, confusion = evaluate_texthead(
    avg_model, 
    test_loader, 
    class_vectors=text_features,
    return_confusion=True
)
print(acc)

0.0168
Post-reset:
0.3594


In [58]:
print(confusion)

[0.01, 0.2, 0.02, 0.31, 0.59, 0.22, 0.28, 0.25, 0.5, 0.27, 0.19, 0.0, 0.36, 0.19, 0.19, 0.29, 0.33, 0.59, 0.44, 0.13, 0.92, 0.18, 0.26, 0.53, 0.47, 0.25, 0.18, 0.35, 0.35, 0.17, 0.36, 0.09, 0.3, 0.34, 0.57, 0.09, 0.84, 0.05, 0.13, 0.32, 0.33, 0.45, 0.35, 0.38, 0.21, 0.11, 0.02, 0.67, 0.81, 0.67, 0.1, 0.16, 0.41, 0.67, 0.63, 0.1, 0.57, 0.79, 0.84, 0.31, 0.24, 0.74, 0.44, 0.13, 0.1, 0.02, 0.35, 0.02, 0.84, 0.29, 0.17, 0.14, 0.12, 0.15, 0.46, 0.5, 0.82, 0.1, 0.56, 0.65, 0.01, 0.66, 0.73, 0.36, 0.15, 0.31, 0.61, 0.1, 0.69, 0.38, 0.51, 0.29, 0.06, 0.18, 0.86, 0.69, 0.26, 0.16, 0.86, 0.54]


In [50]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    confusion = np.zeros((100, 100))
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
#                 correct += (labels.to(DEVICE) == pred).sum().item()
                
            confusion[labels.cpu().numpy(), pred.cpu().numpy()] += 1
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

In [49]:
_, confusion = evaluate_texthead(avg_model, test_loader, class_vectors=text_features, return_confusion=True) 
# print(np.diag(confusion).round(3).tolist())
print(confusion)

[0.01, 0.18, 0.03, 0.31, 0.58, 0.19, 0.33, 0.23, 0.41, 0.27, 0.17, 0.0, 0.35, 0.19, 0.21, 0.33, 0.37, 0.61, 0.42, 0.11, 0.94, 0.22, 0.24, 0.53, 0.43, 0.28, 0.18, 0.36, 0.38, 0.18, 0.42, 0.12, 0.33, 0.35, 0.57, 0.07, 0.81, 0.05, 0.16, 0.31, 0.29, 0.45, 0.37, 0.42, 0.23, 0.1, 0.01, 0.68, 0.74, 0.63, 0.05, 0.16, 0.43, 0.65, 0.64, 0.09, 0.63, 0.79, 0.84, 0.27, 0.24, 0.72, 0.42, 0.13, 0.11, 0.03, 0.34, 0.02, 0.84, 0.31, 0.15, 0.12, 0.1, 0.13, 0.42, 0.46, 0.8, 0.05, 0.55, 0.59, 0.01, 0.69, 0.72, 0.37, 0.09, 0.29, 0.59, 0.12, 0.63, 0.5, 0.5, 0.35, 0.07, 0.19, 0.89, 0.71, 0.23, 0.16, 0.88, 0.46]


# Combine Models via Bipartite Matching

In [44]:
def get_bipartite_perm(corr, prune_threshold=-torch.inf):
    scores, idx = corr.max(0)
    valid_elements = scores >= prune_threshold
    idx = torch.where(valid_elements, idx, corr.shape[0])
    location_lookup = torch.eye(corr.shape[0]+1, corr.shape[0], device=corr.device)
    matches = location_lookup[idx]
    totals = matches.sum(0, keepdim=True)
    matches = matches / (totals + 1)
    return matches.t(), totals

def get_layer_bipartite_transform(net0, net1, prune_threshold=-torch.inf, interleave=False):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_bipartite_perm(corr_mtx, prune_threshold=prune_threshold)

In [41]:
modela = resnet20(w=4).to(DEVICE)
modelb = resnet20(w=4).to(DEVICE)
modelc = resnet20(w=4).to(DEVICE)
load_model(modela, f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}')
load_model(modelb, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')
load_model(modelc, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')

print(evaluate(modela, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(modelb, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.778
0.7758


In [42]:
prune_threshold = -torch.inf
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [45]:
model_to_alter, module2io = transform_model(
    modela, 
    modelb, 
    model_merge=modelb,
    transform_fn=get_layer_bipartite_transform, 
    prune_threshold=-torch.inf, 
    module2io=module2io
)

100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.02it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.91it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.20it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 23.23it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.78it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.02it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████

In [46]:
save_model(model_to_alter, f'resnet20x4_CIFAR50_bipartite_{prune_threshold}threshold_new')

In [53]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR50_bipartite_{prune_threshold}threshold_new',
    whitelist_fn=lambda x: 'bn' not in strip_param_suffix(x)
)

print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
# print(
#     evaluate(
#         avg_model, 
#         test_loader, 
#         class_vectors=text_features
#     )
# )

acc, confusion = evaluate_texthead(
    avg_model, 
    test_loader, 
    class_vectors=text_features,
    return_confusion=True
)
print(acc)

0.014
Post-reset:
0.3977


In [55]:
print(confusion)

[0.93, 0.91, 0.68, 0.65, 0.79, 0.78, 0.83, 0.76, 0.95, 0.88, 0.57, 0.54, 0.87, 0.82, 0.77, 0.85, 0.76, 0.91, 0.7, 0.67, 0.9, 0.88, 0.8, 0.9, 0.87, 0.64, 0.72, 0.62, 0.82, 0.79, 0.82, 0.71, 0.77, 0.76, 0.86, 0.65, 0.93, 0.87, 0.68, 0.92, 0.8, 0.93, 0.85, 0.84, 0.57, 0.57, 0.58, 0.9, 0.93, 0.91, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.21, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.04, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.0, 0.0, 0.0, 0.1, 0.0]


# Combine Via Procrustes

In [292]:
def get_procrustes(corr_mtx):
    U, _, Vh = torch.linalg.svd(corr_mtx)
    return U @ Vh

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_procrustes(net0, net1, prune_threshold=None):
    corr_mtx = run_corr_matrix(net0, net1)
    try:
        return get_procrustes(corr_mtx), torch.ones(corr_mtx.shape[0], device=corr_mtx.device)
    except:
        print('Applying Permutation')
        pdb.set_trace()
        return get_layer_perm1(corr_mtx, 'max_weight', vizz=False, prune_threshold=prune_threshold)

In [297]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
model_to_alter = resnet20(w=4).to(DEVICE)

load_model(model0, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')
load_model(model_to_alter, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')

print(evaluate(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))
print(evaluate(model_to_alter, test_loader2, class_vecs2, remap_class_idxs=class_idxs))


0.9558
0.9726
0.9726


In [298]:
prune_threshold = -torch.inf
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [295]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2]).double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n
                if outer.isnan().sum() > 0: pdb.set_trace()
                
    cov = outer - torch.outer(mean0, mean1)
    if cov.isnan().sum() > 0: pdb.set_trace()
    if norm:
        corr = cov / (torch.outer(std0, std1) + 1e-4)
        return corr.to(torch.float32)
    else:
        return cov.to(torch.float32)

In [296]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        return x
# corr = run_corr_matrix(Subnet(model0), Subnet(model1))
# perm_map1 = get_layer_perm1(corr)
perm_map, collapse_totals = get_layer_procrustes(Subnet(model0), Subnet(model1))
permute_output(perm_map, model_to_alter.conv1, model_to_alter.bn1)
permute_output(perm_map, model_to_alter.layer1[0].conv2, model_to_alter.layer1[0].bn2)
permute_output(perm_map, model_to_alter.layer1[1].conv2, model_to_alter.layer1[1].bn2)
permute_output(perm_map, model_to_alter.layer1[2].conv2, model_to_alter.layer1[2].bn2)
permute_input(perm_map, [model_to_alter.layer1[0].conv1, model_to_alter.layer1[1].conv1, model_to_alter.layer1[2].conv1])
permute_input(perm_map, [model_to_alter.layer2[0].conv1, model_to_alter.layer2[0].shortcut[0]])

module2io['conv1']['output'] = collapse_totals
module2io['bn1']['output'] = collapse_totals
module2io['layer1.0.conv2']['output'] = collapse_totals
module2io['layer1.0.bn2']['output'] = collapse_totals
module2io['layer1.1.conv2']['output'] = collapse_totals
module2io['layer1.1.bn2']['output'] = collapse_totals
module2io['layer1.2.conv2']['output'] = collapse_totals
module2io['layer1.2.bn2']['output'] = collapse_totals

module2io['layer1.0.conv1']['input'] = collapse_totals
module2io['layer1.1.conv1']['input'] = collapse_totals
module2io['layer1.2.conv1']['input'] = collapse_totals
module2io['layer2.0.conv1']['input'] = collapse_totals
module2io['layer2.0.shortcut.0']['input'] = collapse_totals



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.57it/s]


In [271]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = self.conv1(x)
#         x = F.relu(self.bn1(self.conv1(x)))
#         x = self.layer1(x)
#         x = self.layer2(x)
        return x

In [272]:
model_to_alter.eval()
for (images, _) in train_aug_loader:
    if model_to_alter.bn1(Subnet(model_to_alter)(images.to(DEVICE))).isnan().sum() > 0: break

In [273]:
model1.eval()
model_to_alter.eval()
x = Subnet(model_to_alter)(images.to(DEVICE))
y = Subnet(model1)(images.to(DEVICE))

In [299]:
model_to_alter, module2io = transform_model(
    model0, 
    model1, 
    model_to_alter, 
    transform_fn=get_layer_procrustes, 
    prune_threshold=-torch.inf, 
    module2io=module2io
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.47it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 23.38it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.55it/s]
100%|███████████████████████████████████

In [300]:
save_model(model_to_alter, f'resnet20x4_CIFAR5_procrustes_{prune_threshold}threshold')

In [74]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR5_procrustes_{prune_threshold}threshold',
#     whitelist_fn=lambda x: x in ['conv1']
)

print('Pre-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

Pre-reset:
0.1
Post-reset:
0.4774


# Procrustes Via Weight Parameters

In [132]:
def get_procrustes(corr_mtx):
    U, _, Vh = torch.linalg.svd(corr_mtx)
    return U @ Vh

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_procrustes(pairs):
    objective = None
    for pair in pairs:
#         pdb.set_trace()
        pair_obj = pair[0] @ pair[1].t()
        objective = objective + pair_obj if objective is not None else pair_obj
    return get_procrustes(objective)

prep_input_conv = lambda conv: conv.weight.permute(1, 0, 2, 3).flatten(1)
prep_output_conv = lambda conv: conv.weight.flatten(1)

In [149]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
model_to_alter = resnet20(w=4).to(DEVICE)

load_model(model0, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')
# load_model(model_to_alter, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')

print(evaluate(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))
# print(evaluate(model_to_alter, test_loader2, class_vecs2, remap_class_idxs=class_idxs))


0.9558
0.9726


### Alter Inputs and Outputs of fir Convolution Layer

In [150]:
def compare_state_dicts(sd0, sd1):
    total_norm = 0.
    for key in sd0:
        p0 = sd0[key]
        p1 = sd1[key]
        total_norm += torch.square(p0-p1).sum().sqrt()
    return total_norm

prev_sd = model0.state_dict()

for j in range(300):
    if j % 100 == 0:
        print('Iteration {} | Norm Differences:  Models - {}, Inter Iters - {}'.format(
            j, 
            compare_state_dicts(model0.state_dict(), model1.state_dict()),
            compare_state_dicts(prev_sd, model1.state_dict()),
        ))
    conv1_procrustes = get_layer_procrustes(
        [
            # Affected Outputs
            (
                prep_output_conv(model0.conv1), 
                prep_output_conv(model1.conv1)
            ),
            (
                prep_output_conv(model0.layer1[0].conv2),
                prep_output_conv(model1.layer1[0].conv2),
            ),
            (
                prep_output_conv(model0.layer1[1].conv2),
                prep_output_conv(model1.layer1[1].conv2),
            ),
            (
                prep_output_conv(model0.layer1[2].conv2),
                prep_output_conv(model1.layer1[2].conv2),
            ),
            # Affected Inputs
            (
                prep_input_conv(model0.layer1[0].conv1), 
                prep_input_conv(model1.layer1[0].conv1)
            ),
            (
                prep_input_conv(model0.layer1[1].conv1), 
                prep_input_conv(model1.layer1[1].conv1)
            ),
            (
                prep_input_conv(model0.layer1[2].conv1), 
                prep_input_conv(model1.layer1[2].conv1)
            ),
            (
                prep_input_conv(model0.layer2[0].conv1), 
                prep_input_conv(model1.layer2[0].conv1)
            ),
            (
                prep_input_conv(model0.layer2[0].shortcut[0]), 
                prep_input_conv(model1.layer2[0].shortcut[0])
            )
        ]
    )
    
    permute_output(conv1_procrustes, model1.conv1, model1.bn1)
    permute_output(conv1_procrustes, model1.layer1[0].conv2, model1.layer1[0].bn2)
    permute_output(conv1_procrustes, model1.layer1[1].conv2, model1.layer1[1].bn2)
    permute_output(conv1_procrustes, model1.layer1[2].conv2, model1.layer1[2].bn2)

    permute_input(conv1_procrustes, [model1.layer1[0].conv1, model1.layer1[1].conv1, model1.layer1[2].conv1])
    permute_input(conv1_procrustes, [model1.layer2[0].conv1, model1.layer2[0].shortcut[0]])

    second_residual_transform = get_layer_procrustes(
        [
            # Affected Outputs
            (
                prep_output_conv(model0.layer2[0].shortcut[0]),
                prep_output_conv(model1.layer2[0].shortcut[0]),
            ),
            (
                prep_output_conv(model0.layer2[0].conv2),
                prep_output_conv(model1.layer2[0].conv2)
            ),
            (
                prep_output_conv(model0.layer2[1].conv2),
                prep_output_conv(model1.layer2[1].conv2)
            ),
            (
                prep_output_conv(model0.layer2[2].conv2),
                prep_output_conv(model1.layer2[2].conv2)
            ),
            # Affected Inputs
            (
                prep_input_conv(model0.layer2[1].conv1),
                prep_input_conv(model1.layer2[1].conv1)
            ),
            (
                prep_input_conv(model0.layer2[2].conv1),
                prep_input_conv(model1.layer2[2].conv1),
            ),
            (
                prep_input_conv(model0.layer3[0].conv1),
                prep_input_conv(model1.layer3[0].conv1),
            ),
            (
                prep_input_conv(model0.layer3[0].shortcut[0]),
                prep_input_conv(model1.layer3[0].shortcut[0]),
            ),
        ]
    )
    permute_output(second_residual_transform, model1.layer2[0].shortcut[0], model1.layer2[0].shortcut[1])
    permute_output(second_residual_transform, model1.layer2[0].conv2, model1.layer2[0].bn2)
    permute_output(second_residual_transform, model1.layer2[1].conv2, model1.layer2[1].bn2)
    permute_output(second_residual_transform, model1.layer2[2].conv2, model1.layer2[2].bn2)

    permute_input(second_residual_transform, [model1.layer2[1].conv1, model1.layer2[2].conv1])
    permute_input(second_residual_transform, [model1.layer3[0].conv1, model1.layer3[0].shortcut[0]])
    
    third_residual_transform = get_layer_procrustes(
        [
            # Affected Outputs
            (
                prep_output_conv(model0.layer3[0].shortcut[0]),
                prep_output_conv(model1.layer3[0].shortcut[0])
            ),
            (
                prep_output_conv(model0.layer3[0].conv2),
                prep_output_conv(model1.layer3[0].conv2)
            ),
            (
                prep_output_conv(model0.layer3[1].conv2),
                prep_output_conv(model1.layer3[1].conv2),
            ),
            (
                prep_output_conv(model0.layer3[2].conv2),
                prep_output_conv(model1.layer3[2].conv2),
            ),
            # Affected Inputs
            (
                prep_input_conv(model0.layer3[1].conv1),
                prep_input_conv(model1.layer3[1].conv1)
            ),
            (
                prep_input_conv(model0.layer3[2].conv1),
                prep_input_conv(model1.layer3[2].conv1)
            ),
            (
                model0.linear.weight.T,
                model1.linear.weight.T
            ),
        ]
    )
    permute_output(third_residual_transform, model1.layer3[0].shortcut[0], model1.layer3[0].shortcut[1])
    permute_output(third_residual_transform, model1.layer3[0].conv2, model1.layer3[0].bn2)
    permute_output(third_residual_transform, model1.layer3[1].conv2, model1.layer3[1].bn2)
    permute_output(third_residual_transform, model1.layer3[2].conv2, model1.layer3[2].bn2)

    permute_input(third_residual_transform, [model1.layer3[1].conv1, model1.layer3[2].conv1, model1.linear])
    
    for i in range(3):
        intermediate_transform = get_layer_procrustes(
            [
                (
                    prep_output_conv(model0.layer1[i].conv1),
                    prep_output_conv(model1.layer1[i].conv1),
                ),
                (
                    prep_input_conv(model0.layer1[i].conv2),
                    prep_input_conv(model1.layer1[i].conv2)
                )
            ]
        )
        permute_output(intermediate_transform, model1.layer1[i].conv1, model1.layer1[i].bn1)
        permute_input(intermediate_transform, [model1.layer1[i].conv2])
    
    for i in range(3):
        intermediate_transform = get_layer_procrustes(
            [
                (
                    prep_output_conv(model0.layer2[i].conv1),
                    prep_output_conv(model1.layer2[i].conv1),
                ),
                (
                    prep_input_conv(model0.layer2[i].conv2),
                    prep_input_conv(model1.layer2[i].conv2)
                )
            ]
        )
        permute_output(intermediate_transform, model1.layer2[i].conv1, model1.layer2[i].bn1)
        permute_input(intermediate_transform, [model1.layer2[i].conv2])
    
    for i in range(3):
        intermediate_transform = get_layer_procrustes(
            [
                (
                    prep_output_conv(model0.layer3[i].conv1),
                    prep_output_conv(model1.layer3[i].conv1),
                ),
                (
                    prep_input_conv(model0.layer3[i].conv2),
                    prep_input_conv(model1.layer3[i].conv2)
                )
            ]
        )
        permute_output(intermediate_transform, model1.layer3[i].conv1, model1.layer3[i].bn1)
        permute_input(intermediate_transform, [model1.layer3[i].conv2])
    
    prev_sd = model1.state_dict()
    

Iteration 0 | Norm Differences:  Models - 195.26661682128906, Inter Iters - 195.26661682128906
Iteration 100 | Norm Differences:  Models - 153.6870574951172, Inter Iters - 0.0
Iteration 200 | Norm Differences:  Models - 152.7029571533203, Inter Iters - 0.0


In [163]:
for key, val in model1.state_dict().items():
    if 'bn' in key and 'running_var' in key:
        model1.state_dict()[key] = torch.nn.functional.relu(val)

### Alter Linear

In [101]:
# linear_transform = get_layer_procrustes(
#     [
#         (
#             model0.linear.weight,
#             model1.linear.weight,
#         )
#     ]
# )

# permute_cls_output(linear_transform, model1.linear)

In [164]:
save_model(model1, f'resnet20x4_CIFAR5_procrustes')

In [166]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    'resnet20x4_CIFAR5_procrustes',
#     whitelist_fn=lambda x: 'bn' not in strip_param_suffix(x)
)

print('Pre-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

Pre-reset:
0.1
Post-reset:
0.4245


score to beat : .4759

In [156]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
model_to_alter = resnet20(w=4).to(DEVICE)

load_model(model0, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')

print(evaluate(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.9558
0.9726


In [157]:
conv1_proc = get_layer_procrustes([(prep_output_conv(model0.conv1), prep_output_conv(model1.conv1))])

permute_output(conv1_proc, model1.conv1, model1.bn1)



In [158]:
save_model(model1, f'resnet20x4_CIFAR5_procrustes_greedy')

In [160]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    'resnet20x4_CIFAR5_procrustes_greedy',
    whitelist_fn=lambda x: strip_param_suffix(x) in ['conv1']
)

print('Pre-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

Pre-reset:
0.4212
Post-reset:
0.4687
