In [60]:
import os
import sys
import pdb
from sys import platform

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

if platform == 'darwin':
    DEVICE = 'mps'
else:
    DEVICE = 'cuda'

In [4]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# evaluates accuracy
def evaluate(model, loader=test_loader, on_mac=False):
    device = 'mps' if on_mac else 'cuda'
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(device))
            pred = outputs.argmax(dim=1)
            correct += (labels.to(device) == pred).sum().item()
    return correct

# evaluates loss
def evaluate1(model, loader=test_loader, on_mac=False):
    device = 'mps' if on_mac else 'cuda'
    model.eval()
    losses = []
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(device))
            loss = F.cross_entropy(outputs, labels.to(device))
            losses.append(loss.item())
    return np.array(losses).mean()

def train(save_key, on_mac=False):
    device = 'mps' if on_mac else 'cuda'
    model = resnet20(w=4).to(device)
    optimizer = SGD(model.parameters(), lr=0.4, momentum=0.9, weight_decay=5e-4)
    # optimizer = Adam(model.parameters(), lr=0.05)
    
    # Adam seems to perform worse than SGD for training ResNets on CIFAR-10.
    # To make Adam work, we find that we need a very high learning rate: 0.05 (50x the default)
    # At this LR, Adam gives 1.0-1.5% worse accuracy than SGD.
    
    # It is not yet clear whether the increased interpolation barrier for Adam-trained networks
    # is simply due to the increased test loss of said networks relative to those trained with SGD.
    # We include the option of using Adam in this notebook to explore this question.

    EPOCHS = 100
    ne_iters = len(train_aug_loader)
    lr_schedule = np.interp(np.arange(1+EPOCHS*ne_iters), [0, 5*ne_iters, EPOCHS*ne_iters], [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()
    
    losses = []
    for _ in tqdm(range(EPOCHS)):
        for i, (inputs, labels) in enumerate(train_aug_loader):
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                outputs = model(inputs.to(device))
                loss = loss_fn(outputs, labels.to(device))
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            losses.append(loss.item())
    print(evaluate(model))
    save_model(model, save_key)
    
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader, on_mac=False):
    device = 'mps' if on_mac else 'cuda'
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(device)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2])#.double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2])#.double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    if norm:
        corr = cov / (torch.outer(std0, std1) + 1e-4)
        return corr
    else:
        return cov

def compute_perm_map(corr_mtx):
    # sort the (i, j) channel pairs by correlation
    nchan = corr_mtx.shape[0]
    triples = [(i, j, corr_mtx[i, j].item()) for i in range(nchan) for j in range(nchan)]
    triples = sorted(triples, key=lambda p: -p[2])
    # greedily find a matching
    perm_d = {}
    for i, j, c in triples:
        if not (i in perm_d.keys() or j in perm_d.values()):
            perm_d[i] = j
    perm_map = torch.tensor([perm_d[i] for i in range(nchan)])

    # qual_map will be a permutation of the indices in the order
    # of the quality / degree of correlation between the neurons found in the permutation.
    # this just for visualization purposes.
    qual_l = [corr_mtx[i, perm_map[i]].item() for i in range(nchan)]
    qual_map = torch.tensor(sorted(range(nchan), key=lambda i: -qual_l[i]))

    return perm_map, qual_map

def get_layer_perm1(corr_mtx, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    if method == 'greedy':
        perm_map, qual_map = compute_perm_map(corr_mtx)
        if vizz:
            corr_mtx_viz = (corr_mtx[qual_map].T[perm_map[qual_map]]).T
            viz(corr_mtx_viz)
    elif method == 'max_weight':
        corr_mtx_a = corr_mtx.cpu().numpy()
        row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
        assert (row_ind == np.arange(len(corr_mtx_a))).all()
        perm_map = torch.tensor(col_ind).long()
        perm_map = torch.eye(corr_mtx.shape[0], device=corr_mtx.device)[perm_map]
    else:
        raise Exception('Unknown method: %s' % method)
    
#     pdb.set_trace()
    pruned_elements = torch.from_numpy(
        corr_mtx_a[row_ind, col_ind] >= prune_threshold
    ).to(perm_map.device).to(torch.float32)
    return perm_map, pruned_elements

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx, method, vizz, prune_threshold)

# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, conv, bn):
    pre_weights = [
        conv.weight,
        bn.weight,
        bn.bias,
        bn.running_mean,
        bn.running_var,
    ]
    for w in pre_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('ab,bcde->acde', perm_map, w)
        elif len(w.shape) == 2:
            transform = perm_map @ w
        else:
            transform = w @ perm_map.t()
#         assert torch.allclose(w[perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[perm_map]

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(perm_map, after_convs):
    if not isinstance(after_convs, list):
        after_convs = [after_convs]
    post_weights = [c.weight for c in after_convs]
    for w in post_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('abcd,be->aecd', w, perm_map.t())
        elif len(w.shape) == 2:
            transform = w @ perm_map.t()
    #     assert torch.allclose(w[:, perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[:, perm_map, :, :]

def permute_cls_output(perm_map, linear):
    for w in [linear.weight, linear.bias]:
        w.data = perm_map @ w

In [8]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    map_location = torch.device('mps') if on_mac else torch.device('cuda')
    sd = torch.load(path, map_location=map_location)
    model.load_state_dict(sd)

In [6]:
def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        self.linear = nn.Linear(w*64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1):
    return ResNet(BasicBlock, [3, 3, 3], w=w)

In [20]:
from sys import platform

In [69]:
def specify_device():
    if platform != 'darwin':
        return 'cuda'
    else:
        return 'mps'

# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

def strip_param_suffix(name):
    return name.replace('.weight', '').replace('.bias', '')

def combine_io_masks(io, param):
    mask = torch.zeros_like(param, device=param.device)
    try:
        if 'output' in io:
            mask[io['output'].view(-1) == 0] = 1.
        if 'input' in io and len(mask.shape) > 1:
            mask[:, io['input'].view(-1) == 0] = 1.
    except:
        pdb.set_trace()
    return mask

def mix_weights(model, alpha, model0, model1, whitelisted_params=None):
    sd0 = model0.state_dict()
    sd1 = model1.state_dict()
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].to(DEVICE)
        param1 = sd1[k].to(DEVICE)
        
        if whitelisted_params is not None and strip_param_suffix(k) not in whitelisted_params:
            sd_alpha[k] = param0
            continue
        print(k)
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
        
    model.load_state_dict(sd_alpha)

In [141]:
def get_procrustes(corr_mtx):
    U, _, Vh = torch.linalg.svd(corr_mtx)
    return U @ Vh

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_procrustes(pairs):
    objective = None
    for pair in pairs:
#         pdb.set_trace()
        pair_obj = pair[0] @ pair[1].t()
        objective = objective + pair_obj if objective is not None else pair_obj
    return get_procrustes(objective)

In [149]:
prep_input_conv = lambda conv: conv.weight.permute(1, 0, 2, 3).flatten(1)
prep_output_conv = lambda conv: conv.weight.flatten(1)

In [185]:
model0 = resnet20(w=4).to(device)
model1 = resnet20(w=4).to(device)
load_model(model0, 'resnet20x4_v5')
load_model(model1, 'resnet20x4_v4')

In [172]:
evaluate(model0)/100, evaluate(model1)/100

(95.38, 95.31)

In [186]:
conv1_procrustes = get_layer_procrustes(
    [
        # Affected Outputs
        (
            prep_output_conv(model0.conv1), 
            prep_output_conv(model1.conv1)
        ),
        (
            prep_output_conv(model0.layer1[0].conv2),
            prep_output_conv(model1.layer1[0].conv2),
        ),
        (
            prep_output_conv(model0.layer1[1].conv2),
            prep_output_conv(model1.layer1[1].conv2),
        ),
        (
            prep_output_conv(model0.layer1[2].conv2),
            prep_output_conv(model1.layer1[2].conv2),
        ),
        # Affected Inputs
        (
            prep_input_conv(model0.layer1[0].conv1), 
            prep_input_conv(model1.layer1[0].conv1)
        ),
        (
            prep_input_conv(model0.layer1[1].conv1), 
            prep_input_conv(model1.layer1[1].conv1)
        ),
        (
            prep_input_conv(model0.layer2[0].conv1), 
            prep_input_conv(model1.layer2[0].conv1)
        ),
        (
            prep_input_conv(model0.layer2[0].shortcut[0]), 
            prep_input_conv(model1.layer2[0].shortcut[0])
        )
    ]
)

In [174]:
conv1_procrustes[0]

tensor([-2.6174e-01, -1.9765e-03,  2.5963e-05,  5.7304e-02, -3.6224e-03,
        -8.9071e-03,  2.1652e-01,  1.6306e-04,  5.0346e-03,  5.2636e-03,
        -2.2815e-02, -4.4389e-02, -2.7383e-03,  9.9388e-03,  9.3380e-02,
         3.8725e-03,  9.8015e-02, -1.9425e-01,  1.6853e-04,  2.9632e-03,
        -8.8523e-04,  3.5012e-03, -2.7835e-02,  3.1955e-03,  3.5616e-02,
        -2.2257e-02, -4.1228e-03, -3.0342e-02, -5.7705e-03, -1.5230e-01,
        -5.0918e-02, -3.3739e-02,  1.2735e-02, -9.4574e-02, -1.1087e-02,
         1.5460e-04,  5.5336e-03, -1.9358e-02, -1.8505e-01, -4.8186e-02,
         2.1734e-02, -8.1023e-03,  8.4788e-01,  9.0214e-03,  1.1792e-02,
        -6.9269e-03,  1.0127e-02,  2.8304e-03,  1.1876e-01,  1.0427e-01,
         3.3406e-03, -2.0387e-02,  1.1719e-02,  3.4824e-03, -3.6582e-03,
        -2.6726e-04, -3.2073e-04, -5.9925e-03,  6.5378e-03, -1.0761e-02,
         1.3576e-02,  5.6282e-03,  5.6817e-03,  5.4326e-03], device='cuda:0',
       grad_fn=<SelectBackward0>)

In [187]:
permute_output(conv1_procrustes, model1.conv1, model1.bn1)
permute_output(conv1_procrustes, model1.layer1[0].conv2, model1.layer1[0].bn2)
permute_output(conv1_procrustes, model1.layer1[1].conv2, model1.layer1[1].bn2)
permute_output(conv1_procrustes, model1.layer1[2].conv2, model1.layer1[2].bn2)

In [188]:
permute_input(conv1_procrustes, [model1.layer1[0].conv1, model1.layer1[1].conv1, model1.layer1[2].conv1])
permute_input(conv1_procrustes, [model1.layer2[0].conv1, model1.layer2[0].shortcut[0]])

In [189]:
layer1_conv1_procrustes = get_layer_procrustes(
    [
        (
            prep_output_conv(model0.layer1[0].conv1), 
            prep_output_conv(model1.layer1[0].conv1)
        ),
        (
            prep_input_conv(model0.layer1[0].conv2),
            prep_input_conv(model0.layer1[0].conv2)
        )
    ]
)

permute_output(layer1_conv1_procrustes, model1.layer1[0].conv1, model1.layer1[0].bn1)
permute_input(layer1_conv1_procrustes, [model1.layer1[0].conv2, model1.layer1[0].bn2])

In [195]:
model_a = resnet20(w=4).to(device) # W_alpha
mix_weights(
    model_a, 
    0.5, 
    model0,
    model1,
    whitelisted_params=['conv1', 'layer1.0.conv1', 'layer1.0.conv2']
)

print('Pre-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(model_a)/100, evaluate1(model_a)))
reset_bn_stats(model_a)
print('Post-reset:')
print('Accuracy=%.2f%%, Loss=%.3f' % (evaluate(model_a)/100, evaluate1(model_a)))

conv1.weight
layer1.0.conv1.weight
layer1.0.conv2.weight
Pre-reset:
Accuracy=89.36%, Loss=0.408
Post-reset:
Accuracy=94.69%, Loss=0.192


In [183]:
# torch.allclose(model1.conv1.weight, model0.conv1.weight)

False