In [None]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform
from collections import defaultdict

DEVICE = 'mps' if platform == 'darwin' else 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

# class_idxs = np.arange(10)
# np.random.shuffle(class_idxs)
# model1_classes = class_idxs[:5]
# model2_classes = class_idxs[5:]

model1_classes= np.array([3, 2, 0, 6, 4])
model2_classes = np.array([5, 7, 9, 8, 1])

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

In [None]:
train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)

In [None]:
test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

In [None]:
test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

In [None]:
model1_classes, model2_classes

In [None]:
class_idxs = np.zeros(10, dtype=int)
class_idxs[model1_classes] = np.arange(5)
class_idxs[model2_classes] = np.arange(5)
class_idxs = torch.from_numpy(class_idxs)
class_idxs

In [None]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [None]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2]).double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n
                
    cov = outer - torch.outer(mean0, mean1)
    if cov.isnan().sum() > 0: pdb.set_trace()
    if norm:
        corr = cov / (torch.outer(std0, std1) + 1e-4)
        return corr.to(torch.float32)
    else:
        return cov.to(torch.float32)

# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, conv, bn):
    pre_weights = [
        conv.weight,
        bn.weight,
        bn.bias,
        bn.running_mean,
        bn.running_var,
    ]
    for i, w in enumerate(pre_weights):
        if len(pre_weights) == i + 1:
            w @ (perm_map * perm_map).t()
        if len(w.shape) == 4:
            transform = torch.einsum('ab,bcde->acde', perm_map, w)
        elif len(w.shape) == 2:
            transform = perm_map @ w
        else:
            transform = w @ perm_map.t()
        # assert torch.allclose(w[perm_map.argmax(-1)], transform)
        w.data = transform
        # w.data = w[perm_map]
            

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(perm_map, after_convs):
    if not isinstance(after_convs, list):
        after_convs = [after_convs]
    post_weights = [c.weight for c in after_convs]
    for w in post_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('abcd,be->aecd', w, perm_map.t())
        elif len(w.shape) == 2:
            transform = w @ perm_map.t()
    #     assert torch.allclose(w[:, perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[:, perm_map, :, :]

def permute_cls_output(perm_map, linear):
    for w in [linear.weight, linear.bias]:
        w.data = perm_map @ w

In [None]:
def transform_model(
    model0, 
    model1, 
    model_to_alter, 
    transform_fn, 
    prune_threshold=-torch.inf, 
    module2io=defaultdict(lambda: dict())
):
    class Subnet(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            self = self.model
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            return x

    perm_map, collapse_totals = transform_fn(Subnet(model0), Subnet(model1))
    permute_output(perm_map, model_to_alter.conv1, model_to_alter.bn1)
    permute_output(perm_map, model_to_alter.layer1[0].conv2, model_to_alter.layer1[0].bn2)
    permute_output(perm_map, model_to_alter.layer1[1].conv2, model_to_alter.layer1[1].bn2)
    permute_output(perm_map, model_to_alter.layer1[2].conv2, model_to_alter.layer1[2].bn2)
    permute_input(perm_map, 
                  [
                      model_to_alter.layer1[0].conv1, 
                      model_to_alter.layer1[1].conv1, 
                      model_to_alter.layer1[2].conv1
                  ]
                 )
    permute_input(perm_map, [model_to_alter.layer2[0].conv1, model_to_alter.layer2[0].shortcut[0]])
    
    module2io['conv1']['output'] = collapse_totals
    module2io['bn1']['output'] = collapse_totals
    module2io['layer1.0.conv2']['output'] = collapse_totals
    module2io['layer1.0.bn2']['output'] = collapse_totals
    module2io['layer1.1.conv2']['output'] = collapse_totals
    module2io['layer1.1.bn2']['output'] = collapse_totals
    module2io['layer1.2.conv2']['output'] = collapse_totals
    module2io['layer1.2.bn2']['output'] = collapse_totals

    module2io['layer1.0.conv1']['input'] = collapse_totals
    module2io['layer1.1.conv1']['input'] = collapse_totals
    module2io['layer1.2.conv1']['input'] = collapse_totals
    module2io['layer2.0.conv1']['input'] = collapse_totals
    module2io['layer2.0.shortcut.0']['input'] = collapse_totals
    
    class Subnet(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            self = self.model
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            x = self.layer2(x)
            return x

    perm_map, collapse_totals = transform_fn(Subnet(model0), Subnet(model1))
    permute_output(perm_map, model_to_alter.layer2[0].conv2, model_to_alter.layer2[0].bn2)
    permute_output(perm_map, model_to_alter.layer2[0].shortcut[0], model_to_alter.layer2[0].shortcut[1])
    permute_output(perm_map, model_to_alter.layer2[1].conv2, model_to_alter.layer2[1].bn2)
    permute_output(perm_map, model_to_alter.layer2[2].conv2, model_to_alter.layer2[2].bn2)

    permute_input(perm_map, [model_to_alter.layer2[1].conv1, model_to_alter.layer2[2].conv1])
    permute_input(perm_map, [model_to_alter.layer3[0].conv1, model_to_alter.layer3[0].shortcut[0]])
    
    module2io['layer2.0.conv2']['output'] = collapse_totals
    module2io['layer2.0.bn2']['output'] = collapse_totals
    module2io['layer2.0.shortcut.0']['output'] = collapse_totals
    module2io['layer2.0.shortcut.1']['output'] = collapse_totals
    module2io['layer2.1.conv2']['output'] = collapse_totals
    module2io['layer2.1.bn2']['output'] = collapse_totals
    module2io['layer2.2.conv2']['output'] = collapse_totals
    module2io['layer2.2.bn2']['output'] = collapse_totals

    module2io['layer2.1.conv1']['input'] = collapse_totals
    module2io['layer2.2.conv1']['input'] = collapse_totals
    module2io['layer3.0.conv1']['input'] = collapse_totals
    module2io['layer3.0.shortcut.0']['input'] = collapse_totals
    
    class Subnet(nn.Module):
        def __init__(self, model):
            super().__init__()
            self.model = model
        def forward(self, x):
            self = self.model
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            return x

    perm_map, collapse_totals = transform_fn(Subnet(model0), Subnet(model1))
    permute_output(perm_map, model_to_alter.layer3[0].conv2, model_to_alter.layer3[0].bn2)
    permute_output(perm_map, model_to_alter.layer3[0].shortcut[0], model_to_alter.layer3[0].shortcut[1])
    permute_output(perm_map, model_to_alter.layer3[1].conv2, model_to_alter.layer3[1].bn2)
    permute_output(perm_map, model_to_alter.layer3[2].conv2, model_to_alter.layer3[2].bn2)

    permute_input(perm_map, [model_to_alter.layer3[1].conv1, model_to_alter.layer3[2].conv1])
    model_to_alter.linear.weight.data = model_to_alter.linear.weight @ perm_map.t()
    
    module2io['layer3.0.conv2']['output'] = collapse_totals
    module2io['layer3.0.bn2']['output'] = collapse_totals
    module2io['layer3.0.shortcut.0']['output'] = collapse_totals
    module2io['layer3.0.shortcut.1']['output'] = collapse_totals
    module2io['layer3.1.conv2']['output'] = collapse_totals
    module2io['layer3.1.bn2']['output'] = collapse_totals
    module2io['layer3.2.conv2']['output'] = collapse_totals
    module2io['layer3.2.bn2']['output'] = collapse_totals

    module2io['layer3.1.conv1']['input'] = collapse_totals
    module2io['layer3.2.conv1']['input'] = collapse_totals
    module2io['linear']['input'] = collapse_totals
    
    class Subnet(nn.Module):
        def __init__(self, model, nb=9):
            super().__init__()
            self.model = model
            self.blocks = []
            self.blocks += list(model.layer1)
            self.blocks += list(model.layer2)
            self.blocks += list(model.layer3)
            self.blocks = nn.Sequential(*self.blocks)
            self.bn1 = model.bn1
            self.conv1 = model.conv1
            self.linear = model.linear
            self.nb = nb

        def forward(self, x):
            x = F.relu(self.bn1(self.conv1(x)))
            x = self.blocks[:self.nb](x)
            block = self.blocks[self.nb]
            x = block.conv1(x)
            x = block.bn1(x)
            x = F.relu(x)
            return x

    blocks1 = []
    blocks1 += list(model_to_alter.layer1)
    blocks1 += list(model_to_alter.layer2)
    blocks1 += list(model_to_alter.layer3)
    blocks1 = nn.Sequential(*blocks1)
    
    block_idx2name = {
        0: 'layer1.0',
        1: 'layer1.1',
        2: 'layer1.2',
        3: 'layer2.0',
        4: 'layer2.1',
        5: 'layer2.2',
        6: 'layer3.0',
        7: 'layer3.1',
        8: 'layer3.2'
    }
    for nb, (block_idx, layer_name) in zip(range(9), block_idx2name.items()):
        perm_map, collapse_totals = transform_fn(Subnet(model0, nb=nb), Subnet(model1, nb=nb))
        block = blocks1[nb]
        permute_output(perm_map, block.conv1, block.bn1)
        permute_input(perm_map, [block.conv2])

        module2io[layer_name + '.conv1']['output'] = collapse_totals
        module2io[layer_name + '.bn1']['output'] = collapse_totals
        module2io[layer_name + '.conv2']['output'] = collapse_totals
    
    return model_to_alter, module2io


In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        self.linear = nn.Linear(w*64, 512)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1):
    return ResNet(BasicBlock, [3, 3, 3], w=w)

In [None]:
def train(save_key, model, train_loader, test_loader, class_vectors, remap_class_idxs):
    optimizer = SGD(model.parameters(), lr=0.4, momentum=0.9, weight_decay=5e-4)
    # optimizer = Adam(model.parameters(), lr=0.05)
    
    # Adam seems to perform worse than SGD for training ResNets on CIFAR-10.
    # To make Adam work, we find that we need a very high learning rate: 0.05 (50x the default)
    # At this LR, Adam gives 1.0-1.5% worse accuracy than SGD.
    
    # It is not yet clear whether the increased interpolation barrier for Adam-trained networks
    # is simply due to the increased test loss of said networks relative to those trained with SGD.
    # We include the option of using Adam in this notebook to explore this question.

    EPOCHS = 100
    ne_iters = len(train_loader)
    lr_schedule = np.interp(np.arange(1+EPOCHS*ne_iters), [0, 5*ne_iters, EPOCHS*ne_iters], [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()
    
    losses = []
    for _ in tqdm(range(EPOCHS)):
        for i, (inputs, labels) in enumerate(train_loader):
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                encodings = model(inputs.to(DEVICE))
                normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
                logits = (100.0 * normed_encodings @ class_vectors.T)
                remapped_labels = remap_class_idxs[labels].to(DEVICE)
                loss = loss_fn(logits, remapped_labels)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            losses.append(loss.item())
    print(evaluate(
        model, test_loader, 
        class_vectors=class_vectors, 
        remap_class_idxs=remap_class_idxs
    ))
    save_model(model, save_key)

In [None]:
# evaluates accuracy
def evaluate(model, loader, class_vectors, remap_class_idxs=None):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                correct += (labels.to(DEVICE) == pred).sum().item()
            total += inputs.shape[0]
    return correct / total

# evaluates loss
def evaluate1(model, loader, class_vectors, remap_class_idxs):
    model.eval()
    losses = []
    pdb.set_trace()
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            loss = F.cross_entropy(outputs, remap_class_idxs[labels].to(DEVICE))
            losses.append(loss.item())
    return np.array(losses).mean()

In [None]:
import clip

In [None]:
test_dset.classes

In [None]:
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)

In [None]:
model, preprocess = clip.load('ViT-B/32', DEVICE)

In [None]:
with torch.no_grad():
    text_features = model.encode_text(text_inputs)

In [None]:
text_features.shape

In [None]:
text_features /= text_features.norm(dim=-1, keepdim=True)

In [None]:
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]
# class_vecs1 /= class_vecs1.norm(dim=-1, keepdim=True)
# class_vecs2 /= class_vecs2.norm(dim=-1, keepdim=True)

In [None]:
model1_classes

In [None]:
os.listdir('/srv/share/gstoica3/checkpoints/REPAIR/')

In [None]:
if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}.pth.tar'
    )
):
    print('training model...')
    model1 = resnet20(w=4).to(DEVICE)
    train(
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}', 
        model=model1, 
        class_vectors=class_vecs1,
        train_loader=train_aug_loader1,
        test_loader=test_loader1,
        remap_class_idxs=class_idxs
    )
if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}.pth.tar'
    )
):
    print('training model...')
    model2 = resnet20(w=4).to(DEVICE)
    train(
        f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}', 
        model=model2, 
        class_vectors=class_vecs2,
        train_loader=train_aug_loader2,
        test_loader=test_loader2,
        remap_class_idxs=class_idxs
    )
      

# Combine models and evaluate performance

In [30]:
def strip_param_suffix(name):
    return name.replace('.weight', '').replace('.bias', '')

def combine_io_masks(io, param):
    mask = torch.zeros_like(param, device=param.device)
    try:
        if 'output' in io:
            mask[io['output'].view(-1) == 0] = 1.
        if 'input' in io and len(mask.shape) > 1:
            mask[:, io['input'].view(-1) == 0] = 1.
    except:
        pdb.set_trace()
    return mask

def mix_weights(model, alpha, key0, key1, module2io=None, whitelist_fn=lambda x: True):
    sd0 = torch.load(
        '/srv/share/gstoica3/checkpoints/REPAIR/%s.pth.tar' % key0, 
        map_location=torch.device(DEVICE)
    )
    sd1 = torch.load(
        '/srv/share/gstoica3/checkpoints/REPAIR/%s.pth.tar' % key1, 
        map_location=torch.device(DEVICE)
    )
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].to(DEVICE)
        param1 = sd1[k].to(DEVICE)
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
        
        if module2io is not None:
            param_base = strip_param_suffix(k)
#             pdb.set_trace()
            mask = combine_io_masks(module2io[param_base], param1)
            sd_alpha[k][mask == 1] = param0[mask == 1].to(sd_alpha[k].dtype)
        
        if not whitelist_fn(k):
            sd_alpha[k] = param0
    model.load_state_dict(sd_alpha)

In [31]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}'
)

In [32]:
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
) 

0.1


In [33]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [34]:
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

Post-reset:
0.4194


# Combine models via permutation

In [290]:
def compute_perm_map(corr_mtx):
    # sort the (i, j) channel pairs by correlation
    nchan = corr_mtx.shape[0]
    triples = [(i, j, corr_mtx[i, j].item()) for i in range(nchan) for j in range(nchan)]
    triples = sorted(triples, key=lambda p: -p[2])
    # greedily find a matching
    perm_d = {}
    for i, j, c in triples:
        if not (i in perm_d.keys() or j in perm_d.values()):
            perm_d[i] = j
    perm_map = torch.tensor([perm_d[i] for i in range(nchan)])

    # qual_map will be a permutation of the indices in the order
    # of the quality / degree of correlation between the neurons found in the permutation.
    # this just for visualization purposes.
    qual_l = [corr_mtx[i, perm_map[i]].item() for i in range(nchan)]
    qual_map = torch.tensor(sorted(range(nchan), key=lambda i: -qual_l[i]))

    return perm_map, qual_map

def get_layer_perm1(corr_mtx, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    if method == 'greedy':
        perm_map, qual_map = compute_perm_map(corr_mtx)
        if vizz:
            corr_mtx_viz = (corr_mtx[qual_map].T[perm_map[qual_map]]).T
            viz(corr_mtx_viz)
    elif method == 'max_weight':
        corr_mtx_a = corr_mtx.cpu().numpy()
        row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
        assert (row_ind == np.arange(len(corr_mtx_a))).all()
        perm_map = torch.tensor(col_ind).long()
        perm_map = torch.eye(corr_mtx.shape[0], device=corr_mtx.device)[perm_map]
    else:
        raise Exception('Unknown method: %s' % method)
    
#     pdb.set_trace()
    pruned_elements = torch.from_numpy(
        corr_mtx_a[row_ind, col_ind] >= prune_threshold
    ).to(perm_map.device).to(torch.float32)
    return perm_map, pruned_elements

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx, method=method, vizz=vizz, prune_threshold=prune_threshold)

In [291]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
load_model(model0, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')

print(evaluate(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.9558


KeyboardInterrupt: 

In [37]:
prune_threshold = -torch.inf
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [39]:
model_to_alter, module2io = transform_model(
    model0, 
    model1, 
    model_to_alter=model1, 
    transform_fn=get_layer_perm, 
    prune_threshold=-torch.inf, 
    module2io=module2io
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.35it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 12.19it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.16it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 25.59it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.28it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.42it/s]
100%|███████████████████████████████████

In [40]:
save_model(model1, f'resnet20x4_CIFAR5_perm1_{prune_threshold}threshold_new')

In [42]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR5_perm1_{prune_threshold}threshold_new'
)

print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

0.1
Post-reset:
0.5406


# Combine Models via Bipartite Matching

In [45]:
def get_bipartite_perm(corr, prune_threshold=-torch.inf):
    scores, idx = corr.max(0)
    valid_elements = scores >= prune_threshold
    idx = torch.where(valid_elements, idx, corr.shape[0])
    location_lookup = torch.eye(corr.shape[0]+1, corr.shape[0], device=corr.device)
    matches = location_lookup[idx]
    totals = matches.sum(0, keepdim=True)
    matches = matches / (totals + 1)
    return matches.t(), totals

def get_layer_bipartite_transform(net0, net1, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_bipartite_perm(corr_mtx, prune_threshold=prune_threshold)

In [43]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
model_to_alter = resnet20(w=4).to(DEVICE)

load_model(model0, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')
load_model(model_to_alter, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')

print(evaluate(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))
print(evaluate(model_to_alter, test_loader2, class_vecs2, remap_class_idxs=class_idxs))


0.9558
0.9726
0.9726


In [44]:
prune_threshold = -torch.inf
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [46]:
model_to_alter, module2io = transform_model(
    model0, 
    model1, 
    model_to_alter=model_to_alter, 
    transform_fn=get_layer_bipartite_transform, 
    prune_threshold=-torch.inf, 
    module2io=module2io
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.19it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.85it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.01it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 25.07it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.75it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.51it/s]
100%|███████████████████████████████████

In [48]:
save_model(model_to_alter, f'resnet20x4_CIFAR5_bipartite_{prune_threshold}threshold_new')

In [49]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR5_bipartite_{prune_threshold}threshold_new'
)

print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

0.1
Post-reset:
0.4777


# Combine Via Procrustes

In [292]:
def get_procrustes(corr_mtx):
    U, _, Vh = torch.linalg.svd(corr_mtx)
    return U @ Vh

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_procrustes(net0, net1, prune_threshold=None):
    corr_mtx = run_corr_matrix(net0, net1)
    try:
        return get_procrustes(corr_mtx), torch.ones(corr_mtx.shape[0], device=corr_mtx.device)
    except:
        print('Applying Permutation')
        pdb.set_trace()
        return get_layer_perm1(corr_mtx, 'max_weight', vizz=False, prune_threshold=prune_threshold)

In [297]:
model0 = resnet20(w=4).to(DEVICE)
model1 = resnet20(w=4).to(DEVICE)
model_to_alter = resnet20(w=4).to(DEVICE)

load_model(model0, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')
load_model(model_to_alter, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')

print(evaluate(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))
print(evaluate(model_to_alter, test_loader2, class_vecs2, remap_class_idxs=class_idxs))


0.9558
0.9726
0.9726


In [298]:
prune_threshold = -torch.inf
from collections import defaultdict
module2io = defaultdict(lambda: dict())

In [295]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2]).double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n
                if outer.isnan().sum() > 0: pdb.set_trace()
                
    cov = outer - torch.outer(mean0, mean1)
    if cov.isnan().sum() > 0: pdb.set_trace()
    if norm:
        corr = cov / (torch.outer(std0, std1) + 1e-4)
        return corr.to(torch.float32)
    else:
        return cov.to(torch.float32)

In [296]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        return x
# corr = run_corr_matrix(Subnet(model0), Subnet(model1))
# perm_map1 = get_layer_perm1(corr)
perm_map, collapse_totals = get_layer_procrustes(Subnet(model0), Subnet(model1))
permute_output(perm_map, model_to_alter.conv1, model_to_alter.bn1)
permute_output(perm_map, model_to_alter.layer1[0].conv2, model_to_alter.layer1[0].bn2)
permute_output(perm_map, model_to_alter.layer1[1].conv2, model_to_alter.layer1[1].bn2)
permute_output(perm_map, model_to_alter.layer1[2].conv2, model_to_alter.layer1[2].bn2)
permute_input(perm_map, [model_to_alter.layer1[0].conv1, model_to_alter.layer1[1].conv1, model_to_alter.layer1[2].conv1])
permute_input(perm_map, [model_to_alter.layer2[0].conv1, model_to_alter.layer2[0].shortcut[0]])

module2io['conv1']['output'] = collapse_totals
module2io['bn1']['output'] = collapse_totals
module2io['layer1.0.conv2']['output'] = collapse_totals
module2io['layer1.0.bn2']['output'] = collapse_totals
module2io['layer1.1.conv2']['output'] = collapse_totals
module2io['layer1.1.bn2']['output'] = collapse_totals
module2io['layer1.2.conv2']['output'] = collapse_totals
module2io['layer1.2.bn2']['output'] = collapse_totals

module2io['layer1.0.conv1']['input'] = collapse_totals
module2io['layer1.1.conv1']['input'] = collapse_totals
module2io['layer1.2.conv1']['input'] = collapse_totals
module2io['layer2.0.conv1']['input'] = collapse_totals
module2io['layer2.0.shortcut.0']['input'] = collapse_totals



100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.57it/s]


In [271]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = self.conv1(x)
#         x = F.relu(self.bn1(self.conv1(x)))
#         x = self.layer1(x)
#         x = self.layer2(x)
        return x

In [272]:
model_to_alter.eval()
for (images, _) in train_aug_loader:
    if model_to_alter.bn1(Subnet(model_to_alter)(images.to(DEVICE))).isnan().sum() > 0: break

In [273]:
model1.eval()
model_to_alter.eval()
x = Subnet(model_to_alter)(images.to(DEVICE))
y = Subnet(model1)(images.to(DEVICE))

In [274]:
x.isnan().sum()

tensor(0, device='cuda:0')

In [192]:
x[0, 2]

tensor([[ 0.0047, -0.2131, -0.2379,  ..., -0.4124, -0.6936, -0.7678],
        [ 0.1949,  0.0194,  0.1038,  ..., -0.2479, -0.6081, -0.9205],
        [ 0.1955,  0.0537,  0.1642,  ..., -0.2278, -0.6081, -0.9205],
        ...,
        [ 0.3429,  0.5502,  0.5937,  ..., -0.4127, -0.6081, -0.9205],
        [ 0.3528,  0.5601,  0.6413,  ..., -0.4254, -0.6081, -0.9205],
        [ 0.2916,  0.5425,  0.6184,  ..., -0.7634, -0.9897, -1.2217]],
       device='cuda:0', grad_fn=<SelectBackward0>)

In [191]:
print(model_to_alter.bn1(x)[0, 2])
print(model1.bn1(y)[0,2])

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<SelectBackward0>)
tensor([[ 0.2029,  0.3060,  0.2887,  ...,  0.2293,  0.5042,  0.5268],
        [ 0.0252,  0.0989,  0.1011,  ..., -0.0957,  0.1328,  0.2677],
        [-0.0422,  0.0133,  0.0426,  ..., -0.0927,  0.1328,  0.2677],
        ...,
        [ 0.1896,  0.2273,  0.1936,  ...,  0.0416,  0.1328,  0.2677],
        [ 0.2037,  0.2602,  0.2511,  ...,  0.0430,  0.1328,  0.2677],
        [ 0.1875,  0.3116,  0.4022,  ..., -0.4774, -0.4942, -0.2350]],
       device='cuda:0', grad_fn=<SelectBackward0>)


In [193]:
model_to_alter.bn1.weight

Parameter containing:
tensor([ 0.2404,  0.1396, -0.0716,  0.0786, -0.0161, -0.0274, -0.0696,  0.1303,
         0.0332, -0.0910, -0.1487,  0.0779, -0.0313,  0.1654,  0.0041,  0.0915,
         0.0815,  0.2537,  0.2602,  0.0523,  0.0832,  0.0276,  0.0186,  0.0259,
         0.0361,  0.1434,  0.1514,  0.1169,  0.1514,  0.1451,  0.0742,  0.2794,
         0.0354,  0.0278,  0.1422,  0.1603,  0.0654, -0.0599,  0.1679,  0.0952,
         0.0704,  0.0535,  0.1706, -0.0352,  0.0277,  0.1155,  0.0524,  0.1103,
         0.0493,  0.1667,  0.0365,  0.1371,  0.1578,  0.1853, -0.0271,  0.0426,
         0.3463,  0.0887, -0.0385,  0.4119,  0.1525,  0.2235,  0.0598,  0.2431],
       device='cuda:0', requires_grad=True)

In [194]:
model1.bn1.weight

Parameter containing:
tensor([ 1.4097e-03,  1.2869e-04,  2.3129e-01,  1.9476e-01,  1.4690e-01,
         2.1003e-01,  7.3980e-03,  1.3881e-01,  1.4591e-01,  4.1396e-03,
         9.8003e-03,  6.2927e-02,  4.5149e-02,  3.7162e-02,  1.6157e-01,
         2.6507e-03,  4.1655e-03,  7.1932e-02,  2.6579e-01,  3.8272e-03,
         1.9317e-01,  2.4836e-02,  1.4905e-01,  1.5718e-01,  4.1719e-02,
         7.0815e-03,  1.7597e-01,  7.3099e-03,  1.9581e-01,  8.1096e-03,
         1.4232e-01,  2.0811e-01,  2.4732e-01,  2.0034e-01,  1.4372e-01,
         1.7139e-01,  2.1959e-03,  3.4310e-02,  1.0486e-03,  2.0663e-01,
        -1.3367e-03,  1.7193e-01,  8.8805e-02,  6.6579e-03,  1.3809e-01,
         2.7566e-01,  1.6680e-01,  2.8706e-01,  1.0999e-01,  6.9599e-03,
         3.4698e-05,  2.1785e-01,  1.8298e-01,  3.4205e-02,  8.3630e-02,
         1.6947e-01,  3.7090e-03,  2.4517e-01,  1.0076e-03,  1.4919e-01,
         1.0005e-03,  1.1864e-03,  2.2749e-01,  1.2409e-02], device='cuda:0',
       requires_grad=Tru

In [206]:
model_to_alter.bn1(x)

tensor([[[[ 3.8395e-01,  2.5076e-01,  1.9080e-01,  ...,  4.3645e-01,
            2.2975e-01,  8.8654e-02],
          [ 3.7421e-01,  1.2216e-01,  6.8440e-02,  ...,  3.8697e-01,
            6.8679e-02, -5.3376e-02],
          [ 3.8322e-01,  1.2404e-01,  1.1459e-01,  ...,  4.0046e-01,
            6.8679e-02, -5.3376e-02],
          ...,
          [ 1.6105e-01,  5.0846e-02,  7.1986e-02,  ...,  2.1359e-01,
            6.8679e-02, -5.3376e-02],
          [ 1.8690e-01,  6.0473e-02,  6.9189e-02,  ...,  2.0503e-01,
            6.8679e-02, -5.3376e-02],
          [ 1.2470e-01,  4.2930e-02,  5.8450e-02,  ...,  2.2539e-01,
            1.1935e-01, -5.6858e-02]],

         [[-1.4520e-01, -1.8773e-01, -1.8586e-01,  ..., -9.9788e-02,
           -2.0183e-01, -1.3094e-01],
          [-1.1241e-01, -1.5141e-01, -1.4925e-01,  ..., -4.3954e-02,
           -1.0142e-01, -9.7098e-02],
          [-1.1434e-01, -1.4998e-01, -1.4259e-01,  ..., -4.8374e-02,
           -1.0142e-01, -9.7098e-02],
          ...,
     

In [None]:
def reinit_bn(bn):
    bn_copy = torch.nn.BatchNorm2d(
        num_features=bn.num_features,
        eps=
    )

In [275]:
bn_copy = torch.nn.BatchNorm2d(64, affine=False)
bn_copy.running_mean = model_to_alter.bn1.running_mean
bn_copy.running_var = model_to_alter.bn1.running_var
bn_copy.eval()

BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=True)

In [277]:
bn_copy(x)# * model_to_alter.bn1.weight.view(1, -1, 1, 1) + model_to_alter.bn1.bias.view(1, -1, 1, 1)

tensor([[[[ 8.9498e-01,  2.5174e-01,  3.1861e-01,  ..., -4.5012e-01,
            6.8267e-01,  1.0177e+00],
          [ 6.4966e-01, -1.0747e-01, -5.4608e-02,  ..., -3.3555e-01,
            6.2257e-01,  9.4255e-01],
          [ 5.6776e-01, -1.9791e-01, -3.3246e-01,  ..., -4.6928e-01,
            4.6046e-01,  1.0594e+00],
          ...,
          [ 6.7263e-01, -9.8307e-02, -8.2092e-02,  ..., -1.0027e-01,
            2.2852e-01, -1.8406e-01],
          [ 1.4178e+00,  2.5572e-01,  2.7387e-01,  ...,  1.8409e-01,
            1.6184e-01, -4.3754e-01],
          [ 1.1984e+00,  1.0435e-01,  1.0435e-01,  ...,  1.0435e-01,
            1.0435e-01, -6.2748e-01]],

         [[-1.1181e-02, -4.5649e-01, -4.3789e-01,  ...,  2.1656e-01,
           -2.4214e-01,  4.0661e-01],
          [ 3.4579e-01,  1.4526e-01,  2.8036e-01,  ...,  2.3088e-01,
            2.3999e-01,  4.6710e-01],
          [ 1.9874e-01,  4.8089e-02,  1.5994e-01,  ...,  3.2368e-01,
            3.3562e-01,  5.9756e-01],
          ...,
     

In [278]:
bn_copy(x)[0, 2]

tensor([[nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        ...,
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan],
        [nan, nan, nan,  ..., nan, nan, nan]], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [284]:
(x - model_to_alter.bn1.running_mean.view(1, -1, 1, 1)) / torch.sqrt(model_to_alter.bn1.running_var.view(1, -1, 1, 1))

tensor([[[[ 8.9499e-01,  2.5175e-01,  3.1861e-01,  ..., -4.5013e-01,
            6.8268e-01,  1.0177e+00],
          [ 6.4967e-01, -1.0747e-01, -5.4609e-02,  ..., -3.3555e-01,
            6.2257e-01,  9.4255e-01],
          [ 5.6777e-01, -1.9791e-01, -3.3246e-01,  ..., -4.6929e-01,
            4.6046e-01,  1.0594e+00],
          ...,
          [ 6.7263e-01, -9.8308e-02, -8.2093e-02,  ..., -1.0027e-01,
            2.2853e-01, -1.8406e-01],
          [ 1.4178e+00,  2.5572e-01,  2.7388e-01,  ...,  1.8409e-01,
            1.6184e-01, -4.3754e-01],
          [ 1.1984e+00,  1.0435e-01,  1.0435e-01,  ...,  1.0435e-01,
            1.0435e-01, -6.2749e-01]],

         [[-1.1181e-02, -4.5649e-01, -4.3789e-01,  ...,  2.1657e-01,
           -2.4214e-01,  4.0662e-01],
          [ 3.4580e-01,  1.4527e-01,  2.8037e-01,  ...,  2.3089e-01,
            2.3999e-01,  4.6711e-01],
          [ 1.9875e-01,  4.8090e-02,  1.5994e-01,  ...,  3.2368e-01,
            3.3563e-01,  5.9757e-01],
          ...,
     

In [286]:
torch.sqrt(model_to_alter.bn1.running_var)

tensor([0.7878, 0.5915,    nan, 0.5153, 0.2243, 0.5003,    nan, 0.6875, 0.2514,
           nan,    nan, 0.2534, 0.2961, 0.5223,    nan, 0.3742, 0.6260, 0.7184,
        0.8876, 0.4244, 0.2122, 0.6230, 0.5964, 0.3609,    nan, 0.8267, 0.6419,
        0.5706, 0.7102, 0.8992, 0.9648, 0.6793, 0.4761,    nan, 0.6731,    nan,
        0.0813,    nan, 0.6199, 0.4215, 0.8196, 0.6250, 0.5971, 0.1250, 0.6889,
        0.4163,    nan, 0.4448,    nan, 0.5916,    nan,    nan, 1.0331, 0.8994,
           nan,    nan, 0.9964, 0.7899,    nan, 1.3990, 0.4436, 0.5711,    nan,
        0.9437], device='cuda:0')

In [287]:
model_to_alter.bn1.running_var[2]

tensor(-0.0607, device='cuda:0')

In [288]:
model1.bn1.running_var[2]

tensor(1.6596, device='cuda:0')

In [299]:
model_to_alter, module2io = transform_model(
    model0, 
    model1, 
    model_to_alter, 
    transform_fn=get_layer_procrustes, 
    prune_threshold=-torch.inf, 
    module2io=module2io
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 14.47it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:08<00:00, 11.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.69it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 23.38it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 18.30it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 15.55it/s]
100%|███████████████████████████████████

In [300]:
save_model(model_to_alter, f'resnet20x4_CIFAR5_procrustes_{prune_threshold}threshold')

In [301]:
avg_model = resnet20(w=4).to(DEVICE)

mix_weights(
    avg_model, 
    .5, 
    f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}',
    f'resnet20x4_CIFAR5_procrustes_{prune_threshold}threshold'
)

print('Pre-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)
reset_bn_stats(avg_model)
print('Post-reset:')
print(
    evaluate(
        avg_model, 
        test_loader, 
        class_vectors=text_features
    )
)

Pre-reset:
0.1
Post-reset:
0.3749
