In [1]:
import os
import sys

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pdb

In [3]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [4]:
cifar100_info = {
    'dir': '/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python',
    'classes1': np.arange(50),
    'classes2': np.arange(50, 100),
    'num_classes': 100,
    'split_classes': 50,
    'wrapper': torchvision.datasets.CIFAR100
}

cifar10_info = {
    'dir': '/tmp',
    'classes1': np.array([3, 2, 0, 6, 4]),
    'classes2': np.array([5, 7, 9, 8, 1]),
    'num_classes': 10,
    'split_classes': 5,
    'wrapper': torchvision.datasets.CIFAR10
}

ds_info = cifar100_info

In [5]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = ds_info['wrapper'](root=ds_info['dir'], train=True,
                                        download=True, transform=train_transform)
test_dset = ds_info['wrapper'](root=ds_info['dir'], train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [6]:
train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

model1_classes= ds_info['classes1']#np.array([3, 2, 0, 6, 4])
model2_classes = ds_info['classes2']

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

50000it [00:13, 3572.23it/s]
50000it [00:14, 3569.48it/s]
10000it [00:01, 5379.81it/s]
10000it [00:01, 5362.74it/s]


In [7]:
class_idxs = np.zeros(ds_info['num_classes'], dtype=int)
class_idxs[model1_classes] = np.arange(ds_info['split_classes'])
class_idxs[model2_classes] = np.arange(ds_info['split_classes'])
class_idxs = torch.from_numpy(class_idxs)
print(class_idxs)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,  0,  1,  2,  3,
         4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47, 48, 49])


In [8]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [9]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

In [10]:
import clip

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)
model, preprocess = clip.load('ViT-B/32', DEVICE)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)


text_features /= text_features.norm(dim=-1, keepdim=True)
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]

In [11]:
from resnets import resnet20


In [12]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix_og(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2])#.double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2])#.double()

                mean0_b = out0.mean(dim=0)
                mean1_b = out1.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                outer_b = (out0.T @ out1) / out0.shape[0]

                if i == 0:
                    mean0 = torch.zeros_like(mean0_b)
                    mean1 = torch.zeros_like(mean1_b)
                    std0 = torch.zeros_like(std0_b)
                    std1 = torch.zeros_like(std1_b)
                    outer = torch.zeros_like(outer_b)
                mean0 += mean0_b / n
                mean1 += mean1_b / n
                std0 += std0_b / n
                std1 += std1_b / n
                outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    if norm:
        cov = cov / (torch.outer(std0, std1) + 1e-4)
    return cov

In [13]:
# given two networks net0, net1 which each output a feature map of shape NxCxWxH
# this will reshape both outputs to (N*W*H)xC
# and then compute a CxC correlation matrix between the outputs of the two networks
def run_corr_matrix(net0, net1, epochs=1, norm=True, loader=train_aug_loader):
    n = epochs*len(loader)
    mean0 = mean1 = std0 = std1 = None
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 = out1.reshape(-1, out1.shape[2]).double()

                out = torch.cat((out0, out1), dim=-1)
                mean_b = out.mean(dim=0)
                std0_b = out0.std(dim=0)
                std1_b = out1.std(dim=0)
                std_b = torch.cat((std0_b, std1_b), dim=-1)
                outer_b = (out.T @ out) / out.shape[0]
                
                if i == 0:
                    mean = torch.zeros_like(mean_b)
                    std = torch.zeros_like(std_b)
                    outer = torch.zeros_like(outer_b)
                
                mean += mean_b / n
                std += std_b / n
                outer += outer_b / n
 
    cov = outer - torch.outer(mean, mean)
    if norm:
        cov = cov / (torch.outer(std, std) + 1e-4)
    torch.diagonal(cov)[:] = -torch.inf
    return cov

In [14]:
def run_cos_sim_matrix(net0, net1, epochs=1, loader=train_aug_loader):
    n = epochs*len(loader)
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for _ in range(epochs):
            for i, (images, _) in enumerate(tqdm(loader)):
                img_t = images.float().to(DEVICE)
                out0 = net0(img_t)
                out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
                out0 /= out0.norm(1, keepdim=True)
                out0 = out0.reshape(-1, out0.shape[2]).double()

                out1 = net1(img_t)
                out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
                out1 /= out1.norm(1, keepdim=True)
                out1 = out1.reshape(-1, out1.shape[2]).double()
                
                out = torch.cat((out0, out1), dim=-1)
                
                outer_b = (out.T @ out)# / out.shape[0]
                
                if i == 0:
                    outer = torch.zeros_like(outer_b)
                
                outer += outer_b
    cov = outer
    torch.diagonal(cov)[:] = -torch.inf
    return cov
    

In [15]:
def match_tensors_exact_bipartite_cov(
    covariance,
    r=.5
):
    O = covariance.shape[0]
    O_half = O // 2
    remainder = int(O * (1-r))
    bound = O - remainder
    sims = covariance
#     sims[:O_half, :O_half] = -torch.inf
#     sims[O_half:, O_half:] = -torch.inf
#     sims[:O_half, O_half:] = -torch.inf
#     sims[O_half:, :O_half] = -torch.inf

#     plot_sims(sims, filename="conv1_covar_sorted.png")
    
    permutation_matrix = torch.zeros((O, O - bound), device=sims.device)

    for i in range(bound):
        best_idx = sims.view(-1).argmax()
        row_idx = best_idx % sims.shape[1]
        col_idx = best_idx // sims.shape[1]
        permutation_matrix[row_idx, i] = 1
        permutation_matrix[col_idx, i] = 1
        sims[row_idx] = -torch.inf
        sims[col_idx] = -torch.inf
        sims[:, row_idx] = -torch.inf
        sims[:, col_idx] = -torch.inf
    
    unused = (sims.max(-1)[0] > -torch.inf).to(torch.int).nonzero().view(-1)
    for i in range(bound, O-bound):
        permutation_matrix[unused[i-bound], i] = 1
    merge = permutation_matrix / (permutation_matrix.sum(dim=0, keepdim=True) + 1e-5)
    unmerge = permutation_matrix
    return merge.T, unmerge

def match_tensors_permute(
    covariance
):
#     torch.diagonal(covariance)[:] = -torch.inf
    corr_mtx_a = covariance.cpu().numpy()
    O = corr_mtx_a.shape[0]# // 2
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a[:O, :O], maximize=True)
#     assert (row_ind == np.arange(len(corr_mtx_a))).all()
    unmerge = torch.tensor(col_ind).long()
    unmerge = torch.eye(O, device=covariance.device)[unmerge]
    unmerge = torch.cat(
        (
            torch.eye(O, device=covariance.device),
            unmerge.T
        ), 
        dim=0
    )
    merge = unmerge / (unmerge.sum(dim=0, keepdim=True) + 1e-5)
    return merge.T, unmerge

In [16]:
def remove_col(x, idx):
    return torch.cat([x[:, :idx], x[:, idx+1:]], dim=-1)

def match_tensors_chain_cov(
    covariance,
    r=.5
):
    O = covariance.shape[0]
    O_half = O // 2
    remainder = int(O * (1-r))
    bound = O - remainder
    sims = covariance
#     sims[:O_half, :O_half] = -torch.inf
#     sims[O_half:, O_half:] = -torch.inf
#     sims[:O_half, O_half:] = -torch.inf
#     sims[O_half:, :O_half] = -torch.inf
    
    # result after alg should be (O, O-bound)
    permutation_matrix = torch.eye(O, O, device=sims.device)

    for i in range(bound):
        best_idx = sims.reshape(-1).argmax()
        row_idx = best_idx % sims.shape[1]
        col_idx = best_idx // sims.shape[1]
        
        permutation_matrix[:, row_idx] += permutation_matrix[:, col_idx]
        permutation_matrix = remove_col(permutation_matrix, col_idx)
        
        sims[:, row_idx] = (sims[:, row_idx] + sims[:, col_idx]) / 2
        sims = remove_col(sims, col_idx)
        
        sims[row_idx, :] = (sims[row_idx, :] + sims[col_idx, :]) / 2
        sims = remove_col(sims.T, col_idx).T
    
#     unused = (sims.max(-1)[0] > -torch.inf).to(torch.int).nonzero().view(-1)
#     for i in range(bound, O-bound):
#         permutation_matrix[unused[i-bound], i] = 1
    merge = permutation_matrix / (permutation_matrix.sum(dim=0, keepdim=True) + 1e-5)
    unmerge = permutation_matrix
    return merge.T, unmerge


In [17]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(merge, convs, bns):
    pre_weights = [
        (convs[0].weight, convs[1].weight, convs[2].weight),
        (bns[0].weight, bns[1].weight, bns[2].weight),
        (bns[0].bias, bns[1].bias, bns[2].bias),
        (bns[0].running_mean, bns[1].running_mean, bns[2].running_mean),
        (bns[0].running_var, bns[1].running_var, bns[2].running_var)
    ]
    for a, b, c in pre_weights:
        a_w, b_w = merge.chunk(2, dim=1)

        if len(a.shape) == 4:
            t_a = torch.einsum('ab,bcde->acde', a_w, a.to(DEVICE) * 2)
            t_b = torch.einsum('ab,bcde->acde', b_w, b.to(DEVICE) * 2)
        else:
            t_a = a_w @ a.to(DEVICE) * 2
            t_b = b_w @ b.to(DEVICE) * 2
        
        a.data = t_a
        b.data = t_b


# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(unmerge, after_convs):
    if not isinstance(after_convs, list):
        after_convs = [after_convs]
    post_weights = [(c[0].weight, c[1].weight, c[2].weight) for c in after_convs]
    for (a, b, c) in post_weights:
        a_w, b_w = unmerge.chunk(2, dim=0)

        if len(a.shape) == 4:
            a.data = torch.einsum('abcd,be->aecd', a.to(DEVICE), a_w)
            b.data = torch.einsum('abcd,be->aecd', b.to(DEVICE), b_w)
        else:
            a.data = (a.to(DEVICE) @ a_w)
            b.data = (b.to(DEVICE) @ b_w)

In [18]:
# returns the channel-permutation to make layer1's activations most closely
# match layer0's.
def layer_perm_baseline(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix_og(net0, net1)
    return match_tensors_permute(corr_mtx)

def layer_perm_ours(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix(net0, net1)
    return match_tensors_exact_bipartite_cov(corr_mtx)

def layer_cosine_perm_ours(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_cos_sim_matrix(net0, net1)
    return match_tensors_exact_bipartite_cov(corr_mtx)

def layer_perm_chain_ours(net0, net1, method='max_weight', vizz=False, prune_threshold=-torch.inf):
    corr_mtx = run_corr_matrix(net0, net1)
    return match_tensors_chain_cov(corr_mtx)

In [19]:
def mix_weights(sd0, sd1, target, alpha):
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].to(DEVICE)
        param1 = sd1[k].to(DEVICE)
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
        # TODO: alpha is ignored
#         sd_alpha[k] = param0 + param1
#     sd_alpha = {k: (1 - alpha) * sd0[k].to('mps') + alpha * sd1[k].to('mps')
#                 for k in sd0.keys()}
    target.load_state_dict(sd_alpha)

# Find Bipartite Permutation

In [65]:
def check_model_equality(source, candidate, check=False, return_result=False):
    source_sd, candidate_sd = source.state_dict(), candidate.state_dict()
    match_dict = {}
    for key, source_val in source_sd.items():
        candidate_val = candidate_sd[key]
        is_match = torch.allclose(source_val.cuda(), candidate_val.cuda())
        match_dict[key] = is_match
        
    all_matched = True
    for key, is_match in match_dict.items():
        if is_match == check:
            print(f'{key}: {is_match}')
            all_matched = False
    if all_matched:
        print('All Matched')
    if return_result:
        return match_dict

In [66]:
model0 = resnet20(w=4, text_head=True).to(DEVICE)
model1 = resnet20(w=4, text_head=True).to(DEVICE)
model0_c = resnet20(w=4, text_head=True).to(DEVICE)
model1_c = resnet20(w=4, text_head=True).to(DEVICE)
load_model(model0, f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')
load_model(model0_c, f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}')
load_model(model1_c, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')

print(evaluate_texthead(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.778
0.7758


In [67]:
modelc = resnet20(w=4, text_head=True).to(DEVICE)

In [68]:
import torch.nn.functional as F

In [69]:
get_layer_perm = layer_perm_ours

In [70]:
# check_model_equality(model0, model0_c, check=False)

In [71]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        return x

models = [model0_c, model1_c, modelc]
merge, unmerge = get_layer_perm(Subnet(model0), Subnet(model1))

permute_output(
    merge, 
    [m.conv1 for m in models], 
    [m.bn1 for m in models]
)

permute_output(
    merge, 
    [m.layer1[0].conv2 for m in models], 
    [m.layer1[0].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer1[1].conv2 for m in models], 
    [m.layer1[1].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer1[2].conv2 for m in models], 
    [m.layer1[2].bn2 for m in models]
)
permute_input(
    unmerge, 
    [
        [m.layer1[0].conv1 for m in models], 
        [m.layer1[1].conv1 for m in models], 
        [m.layer1[2].conv1 for m in models]
    ]
)
permute_input(
    unmerge, 
    [
        [m.layer2[0].conv1 for m in models], 
        [m.layer2[0].shortcut[0] for m in models]
    ]
)

# check_model_equality(model0, model0_c)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:09<00:00, 10.61it/s]


In [72]:
# for cmap_id in plt.colormaps():
def plot_sims(sims, cmap_id="twilight", filename=None):
    plt.figure()
    plt.set_cmap(cmap_id)

    sims_sort = sims.sort(-1, descending=True)[0]
#     new_sort = F.relu(sims_sort[:, :-1]).norm(dim=-1).argsort(descending=True)
#     new_sort = torch.special.entr(sims_sort[:, :-1].softmax(dim=-1)).sum(dim=-1).argsort(descending=False)
    new_sort = sims_sort[:, 0].argsort(descending=True)
    
    sims_sort = sims_sort[new_sort,:]

    plt.imshow(sims_sort.cpu().numpy(), vmin=-1, vmax=+1)
    plt.colorbar()
#     plt.title(cmap_id)
    
    plt.axis('off')
    plt.tight_layout()
    
    if filename is None:
        plt.show()
    else:
        plt.savefig(filename, dpi=1200, bbox_inches="tight")

In [73]:
def find_merge_sets(merge):
    A, B = merge.chunk(2, dim=-1)
    A_merges = (A.sum(1)*2).abs().ceil().to(torch.int).bincount()[1:]
    B_merges = (B.sum(1)*2).abs().ceil().to(torch.int).bincount()[1:]
    merges = {}
    if len(A_merges) == 2:
        merges['A->A'] = A_merges[1].cpu().numpy()
    if len(B_merges) == 2:
        merges['B->B'] = B_merges[1].cpu().numpy()
    try:
        merges['A^B'] = A_merges[0].cpu().numpy()
    except:
        merges['A^B'] = B_merges[0].cpu().numpy()
    
    print(merges)

In [None]:
find_merge_sets(merge)

In [None]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        return x

merge, unmerge = get_layer_perm(Subnet(model0), Subnet(model1))

# combine_convs_and_fcs(
#     elements=[
#         [m.layer1[0].conv1 for m in models], 
#         [m.layer1[1].conv1 for m in models], 
#         [m.layer1[2].conv1 for m in models]
#     ]
# )

permute_output(
    merge, 
    [m.layer2[0].conv2 for m in models], 
    [m.layer2[0].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer2[0].shortcut[0] for m in models], 
    [m.layer2[0].shortcut[1] for m in models]
)
permute_output(
    merge, 
    [m.layer2[1].conv2 for m in models], 
    [m.layer2[1].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer2[2].conv2 for m in models], 
    [m.layer2[2].bn2 for m in models]
)

permute_input(
    unmerge, 
    [
        [m.layer2[1].conv1 for m in models], 
        [m.layer2[2].conv1 for m in models]
    ]
)
permute_input(
    unmerge, 
    [
        [m.layer3[0].conv1 for m in models], 
        [m.layer3[0].shortcut[0] for m in models]
    ]
)

In [None]:
find_merge_sets(merge)

In [None]:
class Subnet(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model
    def forward(self, x):
        self = self.model
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

merge, unmerge = get_layer_perm(Subnet(model0), Subnet(model1))
permute_output(
    merge, 
    [m.layer3[0].conv2 for m in models], 
    [m.layer3[0].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer3[0].shortcut[0] for m in models], 
    [m.layer3[0].shortcut[1] for m in models]
)
permute_output(
    merge, 
    [m.layer3[1].conv2 for m in models], 
    [m.layer3[1].bn2 for m in models]
)
permute_output(
    merge, 
    [m.layer3[2].conv2 for m in models], 
    [m.layer3[2].bn2 for m in models]
)

permute_input(
    unmerge, 
    [
        [m.layer3[1].conv1 for m in models], 
        [m.layer3[2].conv1 for m in models]
    ]
)

a_w, b_w = unmerge.chunk(2, dim=0)
model0_c.linear.weight.data = (model0_c.linear.weight.to(DEVICE) @ a_w)
model1_c.linear.weight.data = (model1_c.linear.weight.to(DEVICE) @ b_w)

# check_model_equality(model0, model0_c)
# modelc.linear.weight.data = (torch.cat(
#     (
#         model0_c.linear.weight.to(DEVICE),
#         model1_c.linear.weight.to(DEVICE)
#     ), dim=-1
# ) @ unmerge / 2.).cpu()


# modelc.linear.weight.data = (torch.cat(
#     (
#         model0.linear.weight,
#         model1.linear.weight
#     ), dim=-1
# ) @ unmerge / 2.)

In [None]:
find_merge_sets(merge)

In [None]:
class Subnet(nn.Module):
    def __init__(self, model, nb=9):
        super().__init__()
        self.model = model
        self.blocks = []
        self.blocks += list(model.layer1)
        self.blocks += list(model.layer2)
        self.blocks += list(model.layer3)
        self.blocks = nn.Sequential(*self.blocks)
        self.bn1 = model.bn1
        self.conv1 = model.conv1
        self.linear = model.linear
        self.nb = nb
        
    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.blocks[:self.nb](x)
        block = self.blocks[self.nb]
        x = block.conv1(x)
        x = block.bn1(x)
        x = F.relu(x)
        return x

# model0_c = model0
# model1_c = model1
blocks0 = []
blocks0 += list(model0_c.layer1)
blocks0 += list(model0_c.layer2)
blocks0 += list(model0_c.layer3)
blocks0 = nn.Sequential(*blocks0)

blocks1 = []
blocks1 += list(model1_c.layer1)
blocks1 += list(model1_c.layer2)
blocks1 += list(model1_c.layer3)
blocks1 = nn.Sequential(*blocks1)

blocksc = []
blocksc += list(modelc.layer1)
blocksc += list(modelc.layer2)
blocksc += list(modelc.layer3)
blocksc = nn.Sequential(*blocksc)

In [None]:
block_idx2name = {
    0: 'layer1.0',
    1: 'layer1.1',
    2: 'layer1.2',
    3: 'layer2.0',
    4: 'layer2.1',
    5: 'layer2.2',
    6: 'layer3.0',
    7: 'layer3.1',
    8: 'layer3.2'
}

In [None]:
for nb, (block_idx, layer_name) in zip(range(9), block_idx2name.items()):
    merge, unmerge = get_layer_perm(Subnet(model0, nb=nb), Subnet(model1, nb=nb))
    block0 = blocks0[nb]
    block1 = blocks1[nb]
    blockc = blocksc[nb]
    blocks = [block0, block1, blockc]
    permute_output(
        merge, 
        [m.conv1 for m in blocks], 
        [m.bn1 for m in blocks]
    )
    permute_input(
        unmerge, 
        [
            [m.conv2 for m in blocks]
        ]
    )
    find_merge_sets(merge)
# check_model_equality(model0, model0_c)

In [None]:
print(evaluate_texthead(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

In [None]:
print(evaluate_texthead(model0_c, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model1_c, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

In [None]:
mix_weights(sd0=model0_c.state_dict(), sd1=model1_c.state_dict(), target=modelc, alpha=.5)

In [None]:
modelc = modelc.to(DEVICE)

In [None]:
reset_bn_stats(modelc, loader=train_aug_loader)
acc, perclass_acc = evaluate_texthead(
    modelc, test_loader, class_vectors=text_features, return_confusion=True
)
acc

In [None]:
print(perclass_acc)