# Train-Merge-REPAIR-VGG11

This is a minimal notebook which demonstrates REPAIR applied to a VGG11 network. It does the following:
1. Separately trains two VGG11 networks on CIFAR-10, "A" and "B".
2. Permutes the channels of each convolutional layer in "B" in order to align them with "A".
3. Merges the two models in weight-space. The merged model performs poorly.
4. Uses REPAIR to correct the neuronal statistics of the merged model.

Notes:
* The merged VGG network should initially attain 67.3% (+/-5%) accuracy. After REPAIR, it should reach 84.9% (+/-1%).
* The trained networks should obtain 89-91% accuracy.
* We use the original VGG architecture which does not contain normalization layers.
* To align channels, we maximize correlations between the activations of matched neurons; this method is due to Li et al. (2015) https://arxiv.org/abs/1511.07543
* REPAIR is a generalization of the method of resetting BatchNorms, which goes back to SWA https://arxiv.org/abs/1803.05407. We introduce it in https://arxiv.org/abs/2211.08403.

In [1]:
from tqdm import tqdm
import numpy as np
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, lr_scheduler
import torchvision
import torchvision.transforms as T

  from .autonotebook import tqdm as notebook_tqdm


## Data

In [2]:
import os

def save_model(model, i):
    sd = model.state_dict()
    torch.save(model.state_dict(), os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR/', '%s.pt' % i))

def load_model(model, i):
    sd = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % i))
    model.load_state_dict(sd)

In [3]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
def evaluate(model):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in test_loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
    return correct / len(test_loader.dataset)

# Train two VGG11 networks on CIFAR-10

In [5]:
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/vgg.py
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name, w=1):
        super(VGG, self).__init__()
        self.w = w
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(self.w*512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers.append(nn.Conv2d(in_channels if in_channels == 3 else self.w*in_channels,
                                     self.w*x, kernel_size=3, padding=1))
                layers.append(nn.ReLU(inplace=True))
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
    
def vgg11(w=1):
    return VGG('VGG11', w).cuda()

In [6]:
def train_model(w=1):
    model = vgg11(w)
    optimizer = SGD(model.parameters(), lr=0.08, momentum=0.9, weight_decay=5e-4)

    EPOCHS = 100
    ne_iters = len(train_aug_loader)
    lr_schedule = np.interp(np.arange(1+EPOCHS*ne_iters), [0, 5*ne_iters, EPOCHS*ne_iters], [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()

    losses = []
    for epoch in tqdm(range(EPOCHS)):
        model.train()
        for i, (inputs, labels) in enumerate(train_aug_loader):
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                outputs = model(inputs.cuda())
                loss = loss_fn(outputs, labels.cuda())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            losses.append(loss.item())
    return model

In [7]:
model = train_model()
print(evaluate(model))
save_model(model, 'vgg11_v1')

model = train_model()
print(evaluate(model))
save_model(model, 'vgg11_v2')

100%|███████████████████████████████████████████████████████████████████████| 100/100 [08:02<00:00,  4.82s/it]


0.895


100%|███████████████████████████████████████████████████████████████████████| 100/100 [07:51<00:00,  4.72s/it]


0.8979


# Permute the channels of model B to align with model A

In [8]:
# Given two networks net0, net1 which each output a feature map of shape NxCxWxH,
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the two
def run_corr_matrix(net0, net1):
    n = len(train_aug_loader)
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for i, (images, _) in enumerate(tqdm(train_aug_loader)):
            
            img_t = images.float().cuda()
            out0 = net0(img_t).double()
            out0 = out0.permute(0, 2, 3, 1).reshape(-1, out0.shape[1])
            out1 = net1(img_t).double()
            out1 = out1.permute(0, 2, 3, 1).reshape(-1, out1.shape[1])

            # save batchwise first+second moments and outer product
            mean0_b = out0.mean(dim=0)
            mean1_b = out1.mean(dim=0)
            sqmean0_b = out0.square().mean(dim=0)
            sqmean1_b = out1.square().mean(dim=0)
            outer_b = (out0.T @ out1) / out0.shape[0]
            if i == 0:
                mean0 = torch.zeros_like(mean0_b)
                mean1 = torch.zeros_like(mean1_b)
                sqmean0 = torch.zeros_like(sqmean0_b)
                sqmean1 = torch.zeros_like(sqmean1_b)
                outer = torch.zeros_like(outer_b)
            mean0 += mean0_b / n
            mean1 += mean1_b / n
            sqmean0 += sqmean0_b / n
            sqmean1 += sqmean1_b / n
            outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    std0 = (sqmean0 - mean0**2).sqrt()
    std1 = (sqmean1 - mean1**2).sqrt()
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [9]:
def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    return perm_map

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx)

In [10]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, layer):
    pre_weights = [layer.weight,
                   layer.bias]
    for w in pre_weights:
        w.data = w[perm_map]

# modifies the weight matrix of a layer for a given permutation of the input channels
# works for both conv2d and linear
def permute_input(perm_map, layer):
    w = layer.weight
    w.data = w[:, perm_map]

In [11]:
model0 = vgg11()
model1 = vgg11()
load_model(model0, 'vgg11_v1')
load_model(model1, 'vgg11_v2')

evaluate(model0), evaluate(model1)

(0.895, 0.8979)

In [12]:
def subnet(model, n_layers):
    return model.features[:n_layers]

feats1 = model1.features

n = len(feats1)
for i in range(n):
    if not isinstance(feats1[i], nn.Conv2d):
        continue
    
    # permute the outputs of the current conv layer
    assert isinstance(feats1[i+1], nn.ReLU)
    perm_map = get_layer_perm(subnet(model0, i+2), subnet(model1, i+2))
    permute_output(perm_map, feats1[i])
    
    # look for the next conv layer, whose inputs should be permuted the same way
    next_layer = None
    for j in range(i+1, n):
        if isinstance(feats1[j], nn.Conv2d):
            next_layer = feats1[j]
            break
    if next_layer is None:
        next_layer = model1.classifier
    permute_input(perm_map, next_layer)

100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.92it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.35it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.64it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.13it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.89it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.51it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.67it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.23it/s]


In [13]:
# ensure accuracy didn't change
# (it may be slightly different due to non-associativity of floating point arithmetic)
print(evaluate(model1))
save_model(model1, 'vgg11_v2_perm1')

0.8979


# Merge the two networks

In [14]:
def mix_weights(net, alpha, key0, key1):
    sd0 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key0))
    sd1 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key1))
    sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
                for k in sd0.keys()}
    net.load_state_dict(sd_alpha)

In [15]:
k0 = 'vgg11_v1'
k1 = 'vgg11_v2_perm1'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
mix_weights(model_a, alpha, k0, k1)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

(α=0): 89.5% 		<-- Model A
(α=1): 89.8% 		<-- Model B
(α=0.5): 72.9% 		<-- Merged model


# Merge the two networks via Procrustes

In [32]:
model0 = vgg11()
model1 = vgg11()
load_model(model0, 'vgg11_v1')
load_model(model1, 'vgg11_v2')

evaluate(model0), evaluate(model1)

(0.895, 0.8979)

In [33]:
def get_procrustes(corr_mtx):
    U, _, Vh = torch.linalg.svd(corr_mtx)
    return U @ Vh

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_procrustes(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_procrustes(corr_mtx)

In [47]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def procrustes_output(perm_map, layer):
    pre_weights = [layer.weight,
                   layer.bias]
    for w in pre_weights:
        if len(w.shape) == 4:
            w.data = torch.einsum('ab,bcde->acde', perm_map.to(torch.float32), w)
        else:
            w.data = perm_map.to(torch.float32) @ w

# modifies the weight matrix of a layer for a given permutation of the input channels
# works for both conv2d and linear
def procrustes_input(perm_map, layer):
    w = layer.weight
    if len(w.shape) == 4:
        w.data = torch.einsum('abcd,be->aecd', w, perm_map.t().to(torch.float32)) # w @ perm_map.t()
    else:
        w.data = w @ perm_map.t().to(torch.float32)

In [48]:
feats1 = model1.features

n = len(feats1)
for i in range(n):
    if not isinstance(feats1[i], nn.Conv2d):
        continue
    
    # permute the outputs of the current conv layer
    assert isinstance(feats1[i+1], nn.ReLU)
    procrustes_map = get_layer_procrustes(subnet(model0, i+2), subnet(model1, i+2))
    procrustes_output(procrustes_map, feats1[i])
    
    # look for the next conv layer, whose inputs should be permuted the same way
    next_layer = None
    for j in range(i+1, n):
        if isinstance(feats1[j], nn.Conv2d):
            next_layer = feats1[j]
            break
    if next_layer is None:
        next_layer = model1.classifier
    procrustes_input(procrustes_map, next_layer)

100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.74it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.53it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.86it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.60it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.63it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.24it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.68it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.33it/s]


In [49]:
# ensure accuracy didn't change
# (it may be slightly different due to non-associativity of floating point arithmetic)
print(evaluate(model1))
save_model(model1, 'vgg11_v2_procrustes1')

0.1


In [50]:
k0 = 'vgg11_v1'
k1 = 'vgg11_v2_procrustes1'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
mix_weights(model_a, alpha, k0, k1)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

(α=0): 89.5% 		<-- Model A
(α=1): 10.0% 		<-- Model B
(α=0.5): 22.5% 		<-- Merged model


In [17]:
feats1 = model1.features

# The Fails. Why?

In [51]:
model0 = vgg11()
model1 = vgg11()
load_model(model0, 'vgg11_v1')
load_model(model1, 'vgg11_v2')

evaluate(model0), evaluate(model1)

(0.895, 0.8979)

In [53]:
feats1 = model1.features

n = len(feats1)
for i in range(n):
    if not isinstance(feats1[i], nn.Conv2d):
        continue
    
    # permute the outputs of the current conv layer
    assert isinstance(feats1[i+1], nn.ReLU)
    perm_map = get_layer_perm(subnet(model0, i+2), subnet(model1, i+2))
    procrustes_map = get_layer_procrustes(subnet(model0, i+2), subnet(model1, i+2))
    break
#     permute_output(perm_map, feats1[i])
    
#     # look for the next conv layer, whose inputs should be permuted the same way
#     next_layer = None
#     for j in range(i+1, n):
#         if isinstance(feats1[j], nn.Conv2d):
#             next_layer = feats1[j]
#             break
#     if next_layer is None:
#         next_layer = model1.classifier
#     permute_input(perm_map, next_layer)
#     break

100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.33it/s]
100%|███████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 21.55it/s]


In [61]:
perm_map

tensor([ 8, 24,  6, 34, 22, 55, 14,  0,  1,  9, 26, 40, 23, 58, 11, 30, 63, 41,
        45, 48, 43,  7,  5, 50, 13, 31, 42, 60,  2, 62, 39, 20, 18, 16, 10, 27,
        54, 28, 21, 33, 15, 19, 29, 61, 59, 17, 47, 57, 35, 52, 46,  3, 49, 51,
        36, 12, 37, 32, 53, 56, 44,  4, 25, 38])

In [63]:
procrustes_map.argmax(-1)

tensor([ 8, 24,  6, 34, 22, 60, 14,  0,  4,  9, 26, 40,  7, 58, 11, 30, 63, 41,
        45, 48, 43, 13, 62, 50,  1, 31, 27, 43, 41, 62,  2, 20, 22, 16, 10, 27,
        54, 28, 48, 33, 15, 19, 29, 61, 59, 17, 47, 37, 26, 18, 46,  2, 34, 51,
        27, 12,  0, 32, 53,  9, 44, 20, 25, 38], device='cuda:0')

In [65]:
w = feats1[i].weight
b = feats1[i].bias

In [67]:
w_out_permuted = w[perm_map]
b_out_permuted = b[perm_map]

w_out_procrustes = torch.einsum('ab,bcde->acde', procrustes_map.to(torch.float32), w)
b_out_procrustes = procrustes_map.to(torch.float32) @ b

In [68]:
torch.linalg.norm(w_out_permuted - w_out_procrustes)

tensor(5.8678, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)

In [71]:
feats0 = model0.features
print(torch.linalg.norm(feats0[i].weight - w_out_procrustes))
print(torch.linalg.norm(feats0[i].weight - w_out_permuted))

tensor(3.8173, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)
tensor(6.2049, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)


In [None]:
0-\\

### Merge the two networks by pruning

In [None]:
sd0 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % 'vgg11_v1'))
sd1 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % 'vgg11_v2_perm1'))

In [None]:
features.6.weight

In [None]:
param_shape = sd1[key].shape
param0 = sd1[key].flatten(1)
param1 = sd1[key].flatten(1)



In [53]:
def mix_weights_pruning(net, alpha, key0, key1):
    d = {}
    sd0 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key0))
    sd1 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key1))
    sd_alpha = {}
    for key in sd0.keys():
        if 'weight' not in key:
            sd_alpha[key] = (1 - alpha) * sd0[key].cuda() + alpha * sd1[key].cuda()
            continue
        param_shape = sd1[key].shape
        param0 = sd0[key].flatten(1)
        param1 = sd1[key].flatten(1)
        
        cosine_num = torch.diagonal(param0 @ param1.T)
        cosine_denom = (
            torch.norm(param0, p=2,dim=-1) * torch.norm(param1, p=2,dim=-1)
        )
        cosine_alignments = (cosine_num / cosine_denom).reshape(-1)
#         import pdb; pdb.set_trace()
        sorted_alignments, sorted_args = cosine_alignments.sort()
        outlier_bound = 0.0# sorted_alignments.mean() - 1.97 * sorted_alignments.std()
        mixed_param = (1 - alpha) * param0 + alpha * param1
        conditional = cosine_alignments < outlier_bound
        if conditional.sum() > 0:
            print(f'pruning {conditional.sum()} components ouf of {param0.shape[1]} for {key}')
            d[key.replace('.weight', '')] = conditional
#             print(conditional)
#             print(outlier_bound)
#             print(cosine_alignments.sort())
#             return
        mixed_param[conditional] = param0[conditional]
#         new_param = torch.where(cosine_alignments <= outlier_bound, param0.t(), mixed_param.t()).t()
        sd_alpha[key] = mixed_param.cuda().reshape(*param_shape)
        
#     sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
#                 for k in sd0.keys()}
    net.load_state_dict(sd_alpha)
    return d

In [54]:
k0 = 'vgg11_v1'
k1 = 'vgg11_v2_perm1'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
key2cond = mix_weights_pruning(model_a, alpha, k0, k1)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

pruning 1 components ouf of 27 for features.0.weight
pruning 6 components ouf of 576 for features.3.weight
pruning 4 components ouf of 1152 for features.6.weight
pruning 1 components ouf of 2304 for features.8.weight
pruning 15 components ouf of 2304 for features.11.weight
pruning 20 components ouf of 4608 for features.13.weight
pruning 73 components ouf of 4608 for features.16.weight
pruning 200 components ouf of 4608 for features.18.weight
(α=0): 89.6% 		<-- Model A
(α=1): 90.2% 		<-- Model B
(α=0.5): 60.8% 		<-- Merged model


In [66]:
def mix_weights_pruning_ranges(net, alpha, key0, key1):
    d = {}
    sd0 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key0))
    sd1 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key1))
    k2sd_alpha = {}
    for k in tqdm(np.arange(-1., 1.05, .05)):
        k2sd_alpha[k] = {}
        
        for key in sd0.keys():
            if 'weight' not in key:
                k2sd_alpha[k][key] = (1 - alpha) * sd0[key].cuda() + alpha * sd1[key].cuda()
                continue
            
            param_shape = sd1[key].shape
            param0 = sd0[key].flatten(1)
            param1 = sd1[key].flatten(1)
            
            mixed_param = (1 - alpha) * param0 + alpha * param1
            
            cosine_num = torch.diagonal(param0 @ param1.T)
            cosine_denom = (
                torch.norm(param0, p=2,dim=-1) * torch.norm(param1, p=2,dim=-1)
            )
            cosine_alignments = (cosine_num / cosine_denom).reshape(-1)
            sorted_alignments, sorted_args = cosine_alignments.sort()
            conditional = cosine_alignments < k        
            if conditional.sum() > 0:
#                 print(f'pruning {conditional.sum()} components ouf of {param0.shape[1]} for {key}')
                d[key.replace('.weight', '')] = conditional
    
            mixed_param[conditional] = param0[conditional]
            k2sd_alpha[k][key] = mixed_param.cuda().reshape(*param_shape)
#     net.load_state_dict(sd_alpha)
    return k2sd_alpha

In [67]:
model_a = vgg11()
k2sd_alpha = mix_weights_pruning_ranges(model_a, alpha, k0, k1)

100%|████████████████████████████████████████████████████████████████████████| 41/41 [00:00<00:00, 232.54it/s]


In [None]:
for k, sd_alpha in k2sd_alpha.items():
    k0 = 'vgg11_v1'
    k1 = 'vgg11_v2_perm1'
    model_a = vgg11()
    mix_weights(model0, 0.0, k0, k1)
    mix_weights(model1, 1.0, k0, k1)
    model_a.load_state_dict(sd_alpha)
    print(f'(α={k:.3f}): {100*evaluate(model_a):.3f} \t\t<-- Merged model')

(α=-1.000): 58.160 		<-- Merged model
(α=-0.950): 58.160 		<-- Merged model
(α=-0.900): 58.160 		<-- Merged model
(α=-0.850): 58.160 		<-- Merged model
(α=-0.800): 58.160 		<-- Merged model
(α=-0.750): 58.160 		<-- Merged model
(α=-0.700): 58.160 		<-- Merged model
(α=-0.650): 58.160 		<-- Merged model
(α=-0.600): 58.160 		<-- Merged model
(α=-0.550): 58.160 		<-- Merged model
(α=-0.500): 58.160 		<-- Merged model


In [70]:
k0 = 'vgg11_v1'
k1 = 'vgg11_v2_perm1'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()

alpha = 0.5
key2cond = mix_weights_pruning(model_a, alpha, k0, k1)


pruning 5 components ouf of 1152 for features.6.weight
tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False,  True, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False,  True, False, False, False, False,
         True, False, False, False, False, False, False, False, False,  True,
        F

### Procrustes Aligning 

In [81]:
def compute_procrustes(A, B):
    R = B.T @ A
    U, s, Vh = torch.linalg.svd(R)
    return U @ Vh

def mix_weights_procrustes(net, alpha, key0, key1):
    d = {}
    sd0 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key0))
    sd1 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key1))
    sd_alpha = {}
    for key in sd0.keys():
        if 'weight' not in key:
            sd_alpha[key] = (1 - alpha) * sd0[key].cuda() + alpha * sd1[key].cuda()
            continue
        param_shape = sd1[key].shape
        param0 = sd0[key].flatten(1)
        param1 = sd1[key].flatten(1)
        
        cosine_num = torch.diagonal(param0 @ param1.T)
        cosine_denom = (
            torch.norm(param0, p=2,dim=-1) * torch.norm(param1, p=2,dim=-1)
        )
        cosine_alignments = (cosine_num / cosine_denom).reshape(-1)
#         import pdb; pdb.set_trace()
        sorted_alignments, sorted_args = cosine_alignments.sort()
        outlier_bound = sorted_alignments.mean() - 1.97 * sorted_alignments.std()
        mixed_param = (1 - alpha) * param0 + alpha * param1
        conditional = cosine_alignments < outlier_bound
        if conditional.sum() > 0:
            print(f'pruning {conditional.sum()} components ouf of {param0.shape[1]} for {key}')
            d[key.replace('.weight', '')] = conditional
        
        procrustes = compute_procrustes(param1[conditional == False], param0[conditional == False])
        param1[conditional] = param1[conditional == False] @ procrustes
        mixed_param[conditional] = ((1 - alpha) * param0 + alpha * param1)[conditional == False]
        mixed_param[conditional] = param0[conditional == False]
        
#         new_param = torch.where(cosine_alignments <= outlier_bound, param0.t(), mixed_param.t()).t()
        sd_alpha[key] = mixed_param.cuda().reshape(*param_shape)
        
#     sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
#                 for k in sd0.keys()}
    net.load_state_dict(sd_alpha)
    return d

In [82]:
k0 = 'vgg11_v1'
k1 = 'vgg11_v2_perm1'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
key2cond = mix_weights_procrustes(model_a, alpha, k0, k1)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

pruning 2 components ouf of 27 for features.0.weight


RuntimeError: shape mismatch: value tensor of shape [62, 27] cannot be broadcast to indexing result of shape [2, 27]

# Correct neuronal statistics with REPAIR

In [46]:
class TrackLayer(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
        self.bn = nn.BatchNorm2d(layer.out_channels)
        
    def get_stats(self):
        return (self.bn.running_mean, self.bn.running_var.sqrt())
        
    def forward(self, x):
        x1 = self.layer(x)
        self.bn(x1)
        return x1

class ResetLayer(nn.Module):
    def __init__(self, layer):
        super().__init__()
        self.layer = layer
        self.bn = nn.BatchNorm2d(layer.out_channels)
        
    def set_stats(self, goal_mean, goal_std):
        self.bn.bias.data = goal_mean
        self.bn.weight.data = goal_std
        
    def forward(self, x):
        x1 = self.layer(x)
        return self.bn(x1)

# adds TrackLayers around every conv layer
def make_tracked_net(net):
    net1 = vgg11()
    net1.load_state_dict(net.state_dict())
    for i, layer in enumerate(net1.features):
        if isinstance(layer, nn.Conv2d):
            net1.features[i] = TrackLayer(layer)
    return net1.cuda().eval()

# adds ResetLayers around every conv layer
def make_repaired_net(net):
    net1 = vgg11()
    net1.load_state_dict(net.state_dict())
    for i, layer in enumerate(net1.features):
        if isinstance(layer, nn.Conv2d):
            net1.features[i] = ResetLayer(layer)
    return net1.cuda().eval()

In [47]:
# reset all tracked BN stats against training data
def reset_bn_stats(model):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    model.train()
    with torch.no_grad(), autocast():
        for images, _ in train_aug_loader:
            output = model(images.cuda())

In [55]:
## Calculate all neuronal statistics in the endpoint networks
wrap0 = make_tracked_net(model0)
wrap1 = make_tracked_net(model1)
reset_bn_stats(wrap0)
reset_bn_stats(wrap1)

In [56]:
wrap_a = make_repaired_net(model_a)
# Iterate through corresponding triples of (TrackLayer, TrackLayer, ResetLayer)
# around conv layers in (model0, model1, model_a).
for (name0, track0), (name1, track1), (namea, reset_a) in zip(
    wrap0.named_modules(), wrap1.named_modules(), wrap_a.named_modules()
): 
    if not isinstance(track0, TrackLayer):
        continue  
    assert (isinstance(track0, TrackLayer)
            and isinstance(track1, TrackLayer)
            and isinstance(reset_a, ResetLayer))

    # get neuronal statistics of original networks
    mu0, std0 = track0.get_stats()
    mu1, std1 = track1.get_stats()
    # set the goal neuronal statistics for the merged network 
    goal_mean = (1 - alpha) * mu0 + alpha * mu1
    goal_std = (1 - alpha) * std0 + alpha * std1
    if name0 in key2cond:
        cond = key2cond[name0]
        goal_mean[cond] = mu0[cond]
        goal_std[cond] = std0[cond]
    reset_a.set_stats(goal_mean, goal_std)

# Estimate mean/vars such that when added BNs are set to eval mode,
# neuronal stats will be goal_mean and goal_std.
reset_bn_stats(wrap_a)

In [57]:
list(wrap0.named_modules())[2][1].get_stats()[0], list(wrap0.named_modules())[2][0]

(tensor([-1.1200e-01, -2.9725e-01, -3.3837e-01, -1.0683e-01,  1.1715e-01,
         -8.1634e-02, -1.6073e-01,  1.3370e-03, -1.0036e-01, -3.2025e-01,
          7.0103e-02, -1.7607e-01, -3.9152e-02, -2.9311e-01, -6.4997e-02,
         -3.1805e-02, -3.6700e-01, -3.2028e-01,  3.4334e-01, -1.2581e-01,
          6.9968e-04, -1.0635e-01, -5.6365e-02, -2.9894e-01, -2.8805e-01,
         -1.2322e-02, -1.3870e-01, -2.2542e-01, -8.2708e-02, -3.2506e-01,
         -5.4113e-01,  2.2096e-01, -8.4765e-02, -3.9836e-02, -7.0918e-02,
          2.0940e-01, -3.2710e-01, -4.6151e-01, -1.2840e-01,  3.8923e-01,
         -2.2121e-02, -4.0536e-01, -2.0058e-01, -1.4678e-01, -6.2533e-02,
         -5.8112e-01, -2.8919e-01, -3.4973e-01,  1.6151e-01, -3.5382e-01,
         -1.6087e-01, -2.5695e-01, -9.4491e-02, -4.5885e-01, -3.7647e-01,
         -1.5788e-01, -1.8804e-01, -1.6771e-01, -3.2800e-01, -7.4287e-01,
         -2.0859e-01, -4.2982e-02,  1.1481e-01, -1.4776e-01], device='cuda:0'),
 'features.0')

In [58]:
key2cond.keys()

dict_keys(['features.0', 'features.3', 'features.6', 'features.8', 'features.11', 'features.13', 'features.16', 'features.18'])

### Fuse added BNs so that weights are compatible with original architecture

In [50]:
def fuse_conv_bn(conv, bn):
    fused_conv = torch.nn.Conv2d(conv.in_channels,
                                 conv.out_channels,
                                 kernel_size=conv.kernel_size,
                                 stride=conv.stride,
                                 padding=conv.padding,
                                 bias=True)

    # set weights
    w_conv = conv.weight.clone()
    bn_std = (bn.eps + bn.running_var).sqrt()
    gamma = bn.weight / bn_std
    fused_conv.weight.data = (w_conv * gamma.reshape(-1, 1, 1, 1))

    # set bias
    beta = bn.bias + gamma * (-bn.running_mean + conv.bias)
    fused_conv.bias.data = beta
    
    return fused_conv

def fuse_tracked_net(net):
    net1 = vgg11()
    for i, rlayer in enumerate(net.features):
        if isinstance(rlayer, ResetLayer):
            fused_conv = fuse_conv_bn(rlayer.layer, rlayer.bn)
            net1.features[i].load_state_dict(fused_conv.state_dict())
    net1.classifier.load_state_dict(net.classifier.state_dict())
    return net1

In [59]:
# fuse the rescaling+shift coefficients back into conv layers
model_b = fuse_tracked_net(wrap_a)
assert abs(evaluate(model_b) - evaluate(wrap_a)) < 0.005

In [60]:
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model with REPAIR' % (100*evaluate(model_b)))

(α=0): 89.6% 		<-- Model A
(α=1): 90.2% 		<-- Model B
(α=0.5): 83.0% 		<-- Merged model with REPAIR


### Reweighting before pruning

In [61]:
k0 = 'vgg11_v1'
k1 = 'vgg11_v2_perm1'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
mix_weights(model_a, alpha, k0, k1)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

(α=0): 89.6% 		<-- Model A
(α=1): 90.2% 		<-- Model B
(α=0.5): 58.2% 		<-- Merged model


In [62]:
## Calculate all neuronal statistics in the endpoint networks
wrap0 = make_tracked_net(model0)
wrap1 = make_tracked_net(model1)
reset_bn_stats(wrap0)
reset_bn_stats(wrap1)

In [63]:
wrap_a = make_repaired_net(model_a)
# Iterate through corresponding triples of (TrackLayer, TrackLayer, ResetLayer)
# around conv layers in (model0, model1, model_a).
for (name0, track0), (name1, track1), (namea, reset_a) in zip(
    wrap0.named_modules(), wrap1.named_modules(), wrap_a.named_modules()
): 
    if not isinstance(track0, TrackLayer):
        continue  
    assert (isinstance(track0, TrackLayer)
            and isinstance(track1, TrackLayer)
            and isinstance(reset_a, ResetLayer))

    # get neuronal statistics of original networks
    mu0, std0 = track0.get_stats()
    mu1, std1 = track1.get_stats()
    # set the goal neuronal statistics for the merged network 
    goal_mean = (1 - alpha) * mu0 + alpha * mu1
    goal_std = (1 - alpha) * std0 + alpha * std1
    reset_a.set_stats(goal_mean, goal_std)

# Estimate mean/vars such that when added BNs are set to eval mode,
# neuronal stats will be goal_mean and goal_std.
reset_bn_stats(wrap_a)

In [64]:
# fuse the rescaling+shift coefficients back into conv layers
model_b = fuse_tracked_net(wrap_a)
assert abs(evaluate(model_b) - evaluate(wrap_a)) < 0.005

In [65]:
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model with REPAIR' % (100*evaluate(model_b)))

(α=0): 89.6% 		<-- Model A
(α=1): 90.2% 		<-- Model B
(α=0.5): 83.8% 		<-- Merged model with REPAIR
