In [1]:
from tqdm import tqdm
import numpy as np
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, lr_scheduler
import torchvision
import torchvision.transforms as T

In [3]:
def save_model(model, i):
    sd = model.state_dict()
    torch.save(model.state_dict(), '%s.pt' % i)

def load_model(model, i):
    sd = torch.load('%s.pt' % i)
    model.load_state_dict(sd)
    
def evaluate(model):
    model.eval()
    correct = 0
    with torch.no_grad(), autocast():
        for inputs, labels in test_loader:
            outputs = model(inputs.cuda())
            pred = outputs.argmax(dim=1)
            correct += (labels.cuda() == pred).sum().item()
    return correct / len(test_loader.dataset)

In [4]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [10]:
# https://github.com/kuangliu/pytorch-cifar/blob/master/models/vgg.py
cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name, w=1):
        super(VGG, self).__init__()
        self.w = w
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Linear(self.w*512, 10)

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        return out

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers.append(nn.Conv2d(in_channels if in_channels == 3 else self.w*in_channels,
                                     self.w*x, kernel_size=3, padding=1))
                layers.append(nn.ReLU(inplace=True))
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)
    
def vgg11(w=1):
    return VGG('VGG11', w).cuda()

In [106]:
# Given two networks net0, net1 which each output a feature map of shape NxCxWxH,
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the two
def run_corr_matrix(net0, net1):
    n = len(train_aug_loader)
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for i, (images, _) in enumerate(tqdm(train_aug_loader)):
            
            img_t = images.float().cuda()
            out0 = net0(img_t).double()
            out0 = out0.permute(0, 2, 3, 1).reshape(-1, out0.shape[1])
            out1 = net1(img_t).double()
            out1 = out1.permute(0, 2, 3, 1).reshape(-1, out1.shape[1])

            # save batchwise first+second moments and outer product
            mean0_b = out0.mean(dim=0)
            mean1_b = out1.mean(dim=0)
            sqmean0_b = out0.square().mean(dim=0)
            sqmean1_b = out1.square().mean(dim=0)
            outer_b = (out0.T @ out1) / out0.shape[0]
            if i == 0:
                mean0 = torch.zeros_like(mean0_b)
                mean1 = torch.zeros_like(mean1_b)
                sqmean0 = torch.zeros_like(sqmean0_b)
                sqmean1 = torch.zeros_like(sqmean1_b)
                outer = torch.zeros_like(outer_b)
            mean0 += mean0_b / n
            mean1 += mean1_b / n
            sqmean0 += sqmean0_b / n
            sqmean1 += sqmean1_b / n
            outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    std0 = (sqmean0 - mean0**2).sqrt()
    std1 = (sqmean1 - mean1**2).sqrt()
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    perm_map = torch.eye(corr_mtx.shape[0], device=corr_mtx.device)[perm_map]
    return perm_map

def get_bipartite_perm(corr):
    idx = corr.argmax(0)
    matches = torch.eye(corr.shape[0], device=corr.device)[idx]
    totals = matches.sum(0, keepdim=True)
    matches = matches / (totals + 1)
    return matches.t(), totals

# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def get_layer_perm(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_layer_perm1(corr_mtx)

def get_layer_bipartite_transform(net0, net1):
    corr_mtx = run_corr_matrix(net0, net1)
    return get_bipartite_perm(corr_mtx)

# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, layer):
    pre_weights = [layer.weight,
                   layer.bias]
    for w in pre_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('ab,bcde->acde', perm_map, w)
        elif len(w.shape) == 2:
            transform = perm_map @ w
        else:
            transform = w @ perm_map.t()
#         assert torch.allclose(w[perm_map.argmax(-1)], transform)
        w.data = transform

# modifies the weight matrix of a layer for a given permutation of the input channels
# works for both conv2d and linear
def permute_input(perm_map, layer):
    w = layer.weight
    if len(w.shape) == 4:
        transform = torch.einsum('abcd,be->aecd', w, perm_map.t())
    elif len(w.shape) == 2:
        transform = w @ perm_map.t()
#     assert torch.allclose(w[:, perm_map.argmax(-1)], transform)
    w.data = transform

def subnet(model, n_layers):
    return model.features[:n_layers]



In [29]:
import pdb

In [129]:
def permute_model(source, target):
    source_feats = source.features
    target_feats = target.features
    n = len(source_feats)
    for i in range(n):
        if not isinstance(target_feats[i], nn.Conv2d): continue
        assert isinstance(target_feats[i+1], nn.ReLU)
        perm_map = get_layer_perm(
            subnet(source, i+2), subnet(target, i+2)
        )
        permute_output(perm_map, target_feats[i])
        
        next_layer = None
        for j in range(i+1, n):
            if isinstance(target_feats[j], nn.Conv2d):
                next_layer = target_feats[j]
                break
        if next_layer is None:
            next_layer = target.classifier
        permute_input(perm_map, next_layer)

def strip_param_suffix(name):
    return name.replace('.weight', '').replace('.bias', '')

def get_empty_module_dict(net):
    module2Dict = dict()
    module_list = []
    for key in net.state_dict().keys():
        base_name = strip_param_suffix(key)
        module2Dict[base_name] = dict()
        if base_name not in module_list:
            module_list += [base_name]
    return module2Dict, module_list

def apply_bipartite_transform(source, target):
    source_feats = source.features
    target_feats = target.features
    module2Dict, module_list = get_empty_module_dict(target)
    k = 0
    n = len(source_feats)
    for i in range(n):
        if not isinstance(target_feats[i], nn.Conv2d): continue
        assert isinstance(target_feats[i+1], nn.ReLU)
        bipartite_map, layer_totals = get_layer_bipartite_transform(
            subnet(source, i+2), subnet(target, i+2)
        )
        permute_output(bipartite_map, target_feats[i])
        module2Dict[module_list[k]]['output'] = layer_totals
#         pdb.set_trace()
        next_layer = None
        for j in range(i+1, n):
            if isinstance(target_feats[j], nn.Conv2d):
                next_layer = target_feats[j]
                break
        if next_layer is None:
            next_layer = target.classifier
        permute_input(bipartite_map, next_layer)
        module2Dict[module_list[k+1]]['input'] = layer_totals
        k += 1
    return module2Dict

def combine_io_masks(io, param):
    mask = torch.zeros_like(param, device=param.device)
    try:
        if 'output' in io:
            mask[io['output'].view(-1) == 0] = 1.
        if 'input' in io and len(mask.shape) > 1:
            mask[:, io['input'].view(-1) == 0] = 1.
    except:
        pdb.set_trace()
    return mask
    

In [144]:
model0 = vgg11()
model1 = vgg11()
load_model(model0, 'vgg11_v2')
load_model(model1, 'vgg11_v1')

evaluate(model0), evaluate(model1)

(0.9001, 0.8964)

In [145]:
permute_model(model0, model1)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.82it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.61it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.08it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.69it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.85it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.77it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.56it/s]
100%|█████████████████████████████████████████████████████████

In [146]:
print(evaluate(model1))
save_model(model1, 'vgg11_v2_perm2_copy')

0.8964


# Merge the two networks

In [147]:
def mix_weights(net, alpha, key0, key1, module2io=None):
    sd0 = torch.load('%s.pt' % key0)
    sd1 = torch.load('%s.pt' % key1)
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].cuda()
        param1 = sd1[k].cuda()
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
        if module2io is not None:
            param_base = strip_param_suffix(k)
            mask = combine_io_masks(module2io[param_base], param1)
            sd_alpha[k][mask == 1] = param0[mask == 1]
#     sd_alpha = {k: (1 - alpha) * sd0[k].cuda() + alpha * sd1[k].cuda()
#                 for k in sd0.keys()}
    net.load_state_dict(sd_alpha)

In [None]:
k0 = 'vgg11_v2_perm2_copy'
k1 = 'vgg11_v2'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
mix_weights(model_a, alpha, k0, k1)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

In [None]:
def mix_weights_over_alphas(num_times, model_a, k0, k1, module2io=None):
    step = 1. / num_times
    alphas = np.arange(0., 1. + step, step)
    accuracies = []
    for alpha in tqdm(alphas):
        mix_weights(model_a, alpha, k0, k1, module2io=module2io)
        accuracies.append(evaluate(model_a))
    return alphas, accuracies
    

In [None]:
alphas, permute_accuracies = mix_weights_over_alphas(10, model_a, k0, k1)

In [None]:
permute_accuracies

# Apply Bipartite Transform

In [152]:
model0 = vgg11()
model1 = vgg11()
load_model(model0, 'vgg11_v2')
load_model(model1, 'vgg11_v1')

evaluate(model0), evaluate(model1)

(0.9001, 0.8964)

In [153]:
module2io = apply_bipartite_transform(model0, model1)

100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.83it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.61it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.00it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:04<00:00, 20.19it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.73it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:05<00:00, 19.70it/s]
100%|█████████████████████████████████████████████████████████

In [154]:
for key, io in module2io.items():
        input_str = io['input'].shape if 'input' in io else 'None'
        output_str = io['output'].shape if 'output' in io else 'None'
        print(f'{key} | input: {input_str}, output: {output_str}')

features.0 | input: None, output: torch.Size([1, 64])
features.3 | input: torch.Size([1, 64]), output: torch.Size([1, 128])
features.6 | input: torch.Size([1, 128]), output: torch.Size([1, 256])
features.8 | input: torch.Size([1, 256]), output: torch.Size([1, 256])
features.11 | input: torch.Size([1, 256]), output: torch.Size([1, 512])
features.13 | input: torch.Size([1, 512]), output: torch.Size([1, 512])
features.16 | input: torch.Size([1, 512]), output: torch.Size([1, 512])
features.18 | input: torch.Size([1, 512]), output: torch.Size([1, 512])
classifier | input: torch.Size([1, 512]), output: None


In [123]:
print(evaluate(model1))
save_model(model1, 'vgg11_v2_bipartite2')

0.1


In [155]:
k0 = 'vgg11_v2'
k1 = 'vgg11_v2_bipartite2'
model0 = vgg11()
model1 = vgg11()
model_a = vgg11()
mix_weights(model0, 0.0, k0, k1)
mix_weights(model1, 1.0, k0, k1)

alpha = 0.5
mix_weights(model_a, alpha, k0, k1, module2io=module2io)
print('(α=0): %.1f%% \t\t<-- Model A' % (100*evaluate(model0)))
print('(α=1): %.1f%% \t\t<-- Model B' % (100*evaluate(model1)))
print('(α=0.5): %.1f%% \t\t<-- Merged model' % (100*evaluate(model_a)))

FileNotFoundError: [Errno 2] No such file or directory: 'vgg11_v2_bipartite2.pt'

In [138]:
alphas, bipartite_accuracies = mix_weights_over_alphas(10, model_a, k0, k1, module2io)

100%|████████████████████████████████████████████████████████████████████████████████████████████████| 11/11 [00:11<00:00,  1.06s/it]


In [139]:
import matplotlib.pyplot as plt

Matplotlib is building the font cache; this may take a moment.


In [None]:
plt.plot(permute_accuracies, label='permutation')
plt.plot(bipartite_accuracies, label='bipartite')
ax = plt.gca()
ax.set_ylim([.7, 1.])

In [142]:
bipartite_accuracies

[0.8964,
 0.8964,
 0.8874,
 0.8721,
 0.8433,
 0.7943,
 0.7204,
 0.6208,
 0.5002,
 0.387,
 0.3103]