In [1]:
from tqdm import tqdm
import numpy as np
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, lr_scheduler
import torchvision
import torchvision.transforms as T

import pdb


  from .autonotebook import tqdm as notebook_tqdm


In [31]:
import os

def save_model(model, i):
    sd = model.state_dict()
    torch.save(model.state_dict(), os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR/', '%s.pt' % i))

def load_model(model, i):
    sd = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % i))
    model.load_state_dict(sd)
    
def evaluate(model):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in test_loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
    return correct / len(test_loader.dataset)

def train_model(w=1):
    model = vgg11(w)
    optimizer = SGD(model.parameters(), lr=0.08, momentum=0.9, weight_decay=5e-4)

    EPOCHS = 100
    ne_iters = len(train_aug_loader)
    lr_schedule = np.interp(np.arange(1+EPOCHS*ne_iters), [0, 5*ne_iters, EPOCHS*ne_iters], [0, 1, 0])
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_schedule.__getitem__)

    scaler = GradScaler()
    loss_fn = CrossEntropyLoss()

    losses = []
    for epoch in tqdm(range(EPOCHS)):
        model.train()
        for i, (inputs, labels) in enumerate(train_aug_loader):
            optimizer.zero_grad(set_to_none=True)
            with autocast():
                outputs = model(inputs.cuda())
                loss = loss_fn(outputs, labels.cuda())
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            losses.append(loss.item())
    return model

# Given two networks net0, net1 which each output a feature map of shape NxCxWxH,
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the two
def run_corr_matrix(net0, net1):
    n = len(train_aug_loader)
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for i, (images, _) in enumerate(tqdm(train_aug_loader)):
            
            img_t = images.float().cuda()
            out0 = net0(img_t).double()
            out0 = out0.permute(0, 2, 3, 1).reshape(-1, out0.shape[1])
            out1 = net1(img_t).double()
            out1 = out1.permute(0, 2, 3, 1).reshape(-1, out1.shape[1])

            # save batchwise first+second moments and outer product
            mean0_b = out0.mean(dim=0)
            mean1_b = out1.mean(dim=0)
            sqmean0_b = out0.square().mean(dim=0)
            sqmean1_b = out1.square().mean(dim=0)
            outer_b = (out0.T @ out1) / out0.shape[0]
            if i == 0:
                mean0 = torch.zeros_like(mean0_b)
                mean1 = torch.zeros_like(mean1_b)
                sqmean0 = torch.zeros_like(sqmean0_b)
                sqmean1 = torch.zeros_like(sqmean1_b)
                outer = torch.zeros_like(outer_b)
            mean0 += mean0_b / n
            mean1 += mean1_b / n
            sqmean0 += sqmean0_b / n
            sqmean1 += sqmean1_b / n
            outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    std0 = (sqmean0 - mean0**2).sqrt()
    std1 = (sqmean1 - mean1**2).sqrt()
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [4]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [5]:
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/vgg.py
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name, w=1):
        super(VGG, self).__init__()
        self.w = w
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(self.w*512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers.append(nn.Conv2d(in_channels if in_channels == 3 else self.w*in_channels,
                                     self.w*x, kernel_size=3, padding=1))
                layers.append(nn.ReLU(inplace=True))
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
    
def vgg11(w=1):
    return VGG('VGG11', w).cuda()

# Bipartite Matching

In [8]:
sd0 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % 'vgg11_v1'))
sd1 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % 'vgg11_v2_perm1'))

In [86]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, layer):
    pre_weights = [layer.weight,
                   layer.bias]
    for w in pre_weights:
#         pdb.set_trace()
        if len(w.shape) == 4:
            w.data = torch.einsum('ab,bcde->acde', perm_map, w)
        else:
            w.data = perm_map @ w
#         w.data = w[perm_map]

# modifies the weight matrix of a layer for a given permutation of the input channels
# works for both conv2d and linear
def permute_input(perm_map, layer):
    w = layer.weight
#     w.data = w[:, perm_map]
    if len(w.shape) == 4:
        w.data = torch.einsum('abcd,be->aecd', w, perm_map.T)
    else:
        w.data = w @ perm_map.T

In [78]:
def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    perm_map = torch.eye(corr_mtx.shape[0], device=corr_mtx.device)[perm_map]
    return perm_map

def bipartite_matching(corr):
    idx = corr.argmax(0)
    mapping = torch.eye(corr.shape[0], device=corr.device)[idx]
    mapping = mapping / torch.maximum(
        mapping.sum(0, keepdim=True), 
        torch.ones_like(mapping[0], device=mapping.device)
    )
    pdb.set_trace()
    return mapping

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx)
#     return bipartite_matching(corr_mtx)

In [80]:
def greedy_matching(alignments, threshold=0.5):
    matrix_copy = alignments.detach().cpu().numpy()
    matches = np.zeros((matrix_copy.shape[1], 2), dtype=np.float32)
    total_matches = 0
    while total_matches < matches.shape[0]:
        best_alignments, best_idxs = matrix_copy.max(0), matrix_copy.argmax(0)
        best_source = best_alignments.argmax()
        matches[best_source, 0] = best_idxs[best_source]
        matches[best_source, 1] = best_alignments[best_source]
        matrix_copy[best_idxs[best_source]] = -np.inf
        matrix_copy[:, best_source] = -np.inf
        total_matches += 1
    matches = torch.from_numpy(matches).to(alignments.device)
    permutations = torch.eye(
        alignments.shape[0], 
        device=alignments.device
    )[matches[:, 0].type(torch.long)]
#     pdb.set_trace()
    permutations[matches[:, 1] < threshold] = torch.zeros(
        permutations.shape[1], device=permutations.device
    )
#     print(permutations.sum(), permutations.shape[0])
#     pdb.set_trace()
    return permutations

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
#     return get_layer_perm1(corr_mtx)
    return greedy_matching(corr_mtx)

In [84]:
model0 = vgg11()
model1 = vgg11()
load_model(model0, 'vgg11_v1')
load_model(model1, 'vgg11_v2')

evaluate(model0), evaluate(model1)

(0.8982, 0.8984)

In [87]:
def subnet(model, n_layers):
    return model.features[:n_layers]

feats1 = model1.features

n = len(feats1)
for i in range(n):
    if not isinstance(feats1[i], nn.Conv2d):
        continue
    
    # permute the outputs of the current conv layer
    assert isinstance(feats1[i+1], nn.ReLU)
    perm_map = get_layer_perm(subnet(model0, i+2), subnet(model1, i+2))
    permute_output(perm_map, feats1[i])
    
    # look for the next conv layer, whose inputs should be permuted the same way
    next_layer = None
    for j in range(i+1, n):
        if isinstance(feats1[j], nn.Conv2d):
            next_layer = feats1[j]
            break
    if next_layer is None:
        next_layer = model1.classifier
    permute_input(perm_map, next_layer)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 23.99it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 22.61it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 24.43it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 23.22it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 24.33it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 23.57it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████████████

In [88]:
# ensure accuracy didn't change
# (it may be slightly different due to non-associativity of floating point arithmetic)
print(evaluate(model1))
save_model(model1, 'vgg11_v2_greedyprune.5')

0.1


In [97]:
def mix_weights(net, alpha, key0, key1):
    sd0 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key0))
    sd1 = torch.load(os.path.join('/srv/share4/gstoica3/checkpoints/REPAIR', '%s.pt' % key1))
    sd_alpha = {}
    for k in sd0.keys():
        v0 = sd0[k].cuda()
        v1 = sd1[k].cuda()
        if '0' not in k: 
            sd_alpha[k] = v0
            continue
        pdb.set_trace()
        sd_alpha[k] = (1 - alpha) * v0 + alpha * v1
#     sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
#                 for k in sd0.keys()}
    net.load_state_dict(sd_alpha)

In [98]:
k0 = 'vgg11_v1'
k1 = 'vgg11_v2_greedyprune.5'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
mix_weights(model_a, alpha, k0, k1)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

> [0;32m/tmp/ipykernel_9942/1373196417.py[0m(12)[0;36mmix_weights[0;34m()[0m
[0;32m     10 [0;31m            [0;32mcontinue[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m        [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 12 [0;31m        [0msd_alpha[0m[0;34m[[0m[0mk[0m[0;34m][0m [0;34m=[0m [0;34m([0m[0;36m1[0m [0;34m-[0m [0malpha[0m[0;34m)[0m [0;34m*[0m [0mv0[0m [0;34m+[0m [0malpha[0m [0;34m*[0m [0mv1[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     13 [0;31m[0;31m#     sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     14 [0;31m[0;31m#                 for k in sd0.keys()}[0m[0;34m[0m[0;34m[0m[0;34m[0m[0m
[0m
ipdb> v1.shape
torch.Size([64, 3, 3, 3])
ipdb> v1[:, 0, 0, 0]
tensor([ 0.0000,  0.0000,  0.0000,  0.0000,  0.0814,  0.0000, -0.1377, -0.0065,
         0.1004, -0.1905, -0.0904, -0.2233, 