In [6]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'
if DEVICE == 'mps':
    DOWNLOAD_PATH = '/Users/georgestoica/Downloads' 
else:
    DOWNLOAD_PATH = '/srv/share/gstoica3/checkpoints/REPAIR/'
    
torch.autograd.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7f3c8b7e3290>

In [7]:
from copy import deepcopy

In [8]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)


In [9]:
cifar100_info = {
    'dir': '/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python',
    'classes1': np.arange(50),
    'classes2': np.arange(50, 100),
    'num_classes': 100,
    'split_classes': 50
}

cifar10_info = {
    'dir': '/tmp',
    'classes1': np.array([3, 2, 0, 6, 4]),
    'classes2': np.array([5, 7, 9, 8, 1]),
    'num_classes': 10,
    'split_classes': 5
}

In [10]:
ds_info = cifar10_info

In [11]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(
    root=ds_info['dir'], 
    train=True,
    download=True, transform=train_transform
)
test_dset = torchvision.datasets.CIFAR10(
    root=ds_info['dir'],
    train=False,
    download=True, 
    transform=test_transform
)

Files already downloaded and verified
Files already downloaded and verified


In [12]:
train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

model1_classes= ds_info['classes1']
model2_classes = ds_info['classes2']

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

50000it [00:13, 3792.19it/s]
50000it [00:13, 3798.84it/s]
10000it [00:01, 5619.34it/s]
10000it [00:01, 5639.44it/s]


In [13]:
class_idxs = np.zeros(ds_info['num_classes'], dtype=int)
class_idxs[model1_classes] = np.arange(ds_info['classes1'].shape[0])
class_idxs[model2_classes] = np.arange(ds_info['classes2'].shape[0])
class_idxs = torch.from_numpy(class_idxs)
class_idxs

tensor([2, 4, 1, 0, 4, 0, 3, 1, 3, 2])

In [14]:
print(test_dset.classes)

['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [15]:
import clip

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)
model, preprocess = clip.load('ViT-B/32', DEVICE)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)


text_features /= text_features.norm(dim=-1, keepdim=True)
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]

In [16]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    confusion = np.zeros((100, 100))
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
#                 correct += (labels.to(DEVICE) == pred).sum().item()
                
            confusion[labels.cpu().numpy(), pred.cpu().numpy()] += 1
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

In [12]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10, text_head=False):
        super(ResNet, self).__init__()
        self.in_planes = int(w*16)

        self.conv1 = nn.Conv2d(3, int(w*16), kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(int(w*16))
        self.layer1 = self._make_layer(block, int(w*16), num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, int(w*32), num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, int(w*64), num_blocks[2], stride=2)
        if text_head:
            num_classes = 512
        self.linear = nn.Linear(int(w*64), num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1, text_head=False):
    return ResNet(BasicBlock, [3, 3, 3], w=w, text_head=text_head)

In [13]:
def bipartite_soft_matching(
    metric: torch.Tensor,
    r: float,
    class_token: bool = False,
    distill_token: bool = False,
):
    """
    Applies ToMe with a balanced matching set (50%, 50%).
    Input size is [batch, tokens, channels].
    r indicates the ratio of tokens to remove (max 50% of tokens).
    Extra args:
     - class_token: Whether or not there's a class token.
     - distill_token: Whether or not there's also a distillation token.
    When enabled, the class token and distillation tokens won't get merged.
    """
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[1]
    r = int(r * t)
    r = min(r, (t - protected) // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric.chunk(2, dim=-2)
        scores = a @ b.transpose(-1, -2)

        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x.chunk(2, dim=-2)
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., dst.shape[-2]:, :] = dst
        out.scatter_(dim=-2, index=(unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(src_idx).expand(n, r, c), src=src)

        return out
    
    return merge, unmerge

def general_soft_matching(
    hull_tensor,
    interleave=False,
    random=False,
    r=.5
):  
    hull_tensor = hull_tensor[0]
    hull_normed = hull_tensor / hull_tensor.norm(dim=-1, keepdim=True)
    
    bound = int(hull_tensor.shape[0] * (1-r))
    
    sims = hull_normed @ hull_normed.transpose(-1, -2)
    uppertri_indices = torch.triu_indices(sims.shape[-2], sims.shape[-1], offset=0)
    sims[uppertri_indices[0], uppertri_indices[1]] = -torch.inf
    candidate_scores, candidate_indices = sims.max(-1)
    argsorted_scores = candidate_scores.argsort(descending=True)
    merge_indices = argsorted_scores[:bound]
    unmerge_indices = argsorted_scores[bound:]
    
    roots = torch.arange(sims.shape[0], device=sims.device)
    for _ in range(bound-1):
        roots[merge_indices] = roots[candidate_indices[merge_indices]]
    
    def merge(x, mode='mean'):
        x = x[0]
        merge_tensor = x.scatter_reduce(
            0, 
            roots[merge_indices][:, None].expand(bound, x.shape[1]),
            x[merge_indices], 
            reduce='mean'
        )
        unmerge_tensor = merge_tensor[unmerge_indices]
        return unmerge_tensor[None]
    
    def unmerge(x):
        x = x[0]
        out = torch.zeros((hull_tensor.shape[0], x.shape[1]), device=x.device)
        out.scatter_(
            0,
            index=unmerge_indices[:, None].expand(*x.shape),
            src=x
        )
        out = out.scatter(
            0,
            index=merge_indices[:, None].expand(*x.shape),
            src=out[roots[merge_indices]]
        )
        return out[None]
    
    return merge, unmerge

def unravel_index(index, shape):
    out = []
    for dim in reversed(shape):
        out.append(index % dim)
        index = index // dim
    return tuple(reversed(out))

In [14]:
def concat_mats(args, dim=0):
    return torch.cat(args, dim=dim)

def unconcat_mat(tensor, dim=0):
    return torch.chunk(tensor, chunks=2, dim=dim)

def match_tensors_permute(hull_tensor, eps=1e-7, interleave=False, random_perm=False):
    """
    hull_tensor: [2O,I]
    """
    O, I = hull_tensor.shape
    O //= 2
    
    interleave_mat = torch.eye(2*O, device=hull_tensor.device)
    if interleave:
        A1, A2, B1, B2 = interleave_mat.chunk(4, dim=0)
        interleave_mat = torch.cat([A1, B1, A2, B2], dim=0)
        interleave_mat = interleave_mat.view(2, O, 2*O).transpose(0, 1).reshape(2*O, 2*O)
#         interleave_mat = interleave_mat[torch.randperm(2*O, device=hull_tensor.device)]
    
    hull_tensor = interleave_mat @ hull_tensor
    
    hull_tensor = hull_tensor / (hull_tensor.norm(dim=-1, keepdim=True) + eps)
    A, B = unconcat_mat(hull_tensor, dim=0)
    scores = -(A @ B.T)
    
    O_eye = torch.eye(O, device=hull_tensor.device)
    
    try:
        row_idx, col_idx = scipy.optimize.linear_sum_assignment(scores.cpu().numpy())
    except ValueError:
        pdb.set_trace()
    
    A_perm = O_eye[torch.from_numpy(row_idx)]#[perm]
    B_perm = O_eye[torch.from_numpy(col_idx)]#[perm]
    
    if random_perm:
        perm = torch.randperm(O, device=A.device)
        A_perm = A_perm[perm]
        B_perm = B_perm[perm]
    
    merge = (torch.cat((A_perm, B_perm), dim=1) / 2.) @ interleave_mat
    unmerge = interleave_mat.T @ (torch.cat((A_perm.T, B_perm.T), dim=0))
    return merge, unmerge


def match_tensors_tome(hull_tensor, eps=1e-7, interleave=False, random_perm=False):
    """
    hull_tensor: [2O,I]
    """
    O, I = hull_tensor.shape
    O //= 2
    
    big_eye = torch.eye(2*O, device=hull_tensor.device)
    small_eye = torch.eye(O, device=hull_tensor.device)
    
    interleave_mat = big_eye
    if interleave:
        A1, A2, B1, B2 = interleave_mat.chunk(4, dim=0)
        interleave_mat = torch.cat([A1, B1, A2, B2], dim=0)
    
    
    hull_tensor = interleave_mat @ hull_tensor
    
    merge, unmerge = bipartite_soft_matching(hull_tensor[None], 0.5)
    
    merge_mat = merge(big_eye[None])[0] @ interleave_mat
    unmerge_mat = interleave_mat.T @ unmerge(small_eye[None])[0]
    return merge_mat, unmerge_mat

##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################

def get_procrustes(corr_mtx):
    U, _, Vh = torch.linalg.svd(corr_mtx)
    S = torch.eye(U.shape[0], device=U.device)
    S[-1, -1] = -1.
    return (U @ S) @ Vh

def match_tensors_procrustes(hull_tensor, eps=1e-7, interleave=False, random_perm=False):
    """
    hull_tensor: [2O,I]
    """
    O, I = hull_tensor.shape
    O //= 2
    
    interleave_mat = torch.eye(2*O, device=hull_tensor.device)
    if interleave:
        A1, A2, B1, B2 = interleave_mat.chunk(4, dim=0)
        interleave_mat = torch.cat([A1, B1, A2, B2], dim=0)
        interleave_mat = interleave_mat.view(2, O, 2*O).transpose(0, 1).reshape(2*O, 2*O)
#         interleave_mat = interleave_mat[torch.randperm(2*O, device=hull_tensor.device)]
    
    hull_tensor = interleave_mat @ hull_tensor
    
    hull_tensor = hull_tensor / (hull_tensor.norm(dim=-1, keepdim=True) + eps)
    A, B = unconcat_mat(hull_tensor, dim=0)
    scores = (A @ B.T)
    
    P = get_procrustes(scores).T
    
    O_eye = torch.eye(O, device=hull_tensor.device)
    
    try:
        row_idx, col_idx = scipy.optimize.linear_sum_assignment(scores.cpu().numpy())
    except ValueError:
        pdb.set_trace()
    
    A_perm = O_eye[torch.from_numpy(row_idx)]#[perm]
    B_perm = P #O_eye[torch.from_numpy(col_idx)]#[perm]
    
    if random_perm:
        perm = torch.randperm(O, device=A.device)
        A_perm = A_perm[perm]
        B_perm = B_perm[perm]
    
    merge = (torch.cat((A_perm, B_perm), dim=1) / 2.) @ interleave_mat
    unmerge = interleave_mat.T @ (torch.cat((A_perm.T, B_perm.T), dim=0))
    return merge, unmerge



# def match_tensors_procrustes(hull_tensor, use_S=True, interleave=False, random_perm=False):
#     # We can only reduce by a maximum of 50% tokens
#     t = hull_tensor.shape[0]
#     r = int(.f5 * t)
#     with torch.no_grad():
#         A, B = unconcat_mat(hull_tensor, dim=0)
#         scores = -(A @ B.T)
#         U, S, V = torch.svd(scores)
        
# #         U[:, :rank] /= S[None, :]
#         U_r = U[:, :r]
#         U_r[:, rank:] = 0
#         S_r = torch.diag(S[:r]) if use_S else torch.eye(r, device=DEVICE)
# #         pdb.set_trace()
# #         V_r = V[:, :r]
#     merge_mat = U_r.T
#     unmerge_mat = U_r
#     return merge_mat, unmerge_mat

# match_tensors = match_tensors

def match_wrapper(fn, interleave=False, random_perm=False):
    return lambda x: fn(x, interleave=interleave, random_perm=random_perm)
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################
##########################################################################################

In [15]:
class LayerTransform(dict):
    def __init__(self, normalize_tensors=False, tensor_merge_type='concat'):
        super(LayerTransform, self).__init__()
        self.output_align = None
        self.next_input_align = None
        self.normalize_tensors = normalize_tensors
        self.tensor_merge_type = tensor_merge_type
    
    def compute_transform(self):
        inputs = list(self.values())
        if self.normalize_tensors:
            for idx, inp in enumerate(inputs):
                inputs[idx] = F.normalize(inp, dim=-1)
        if self.tensor_merge_type == 'concat':
            match_input = concat_mats(inputs, dim=-1)
        elif self.tensor_merge_type == 'mean':
            match_input = torch.stack(inputs, dim=0).mean(0)
                
        self.output_align, self.next_input_align = match_tensors(match_input)

In [16]:
def unflatten(x, k=3):
    O, IHW = x.shape
    return x.view(O, -1, k, k)

def merge_first_convs(state_dict, prefix, a_conv, b_conv, output_transform):
    flatten_conv = lambda x: x.flatten(1)
    a_w = flatten_conv(a_conv.weight)
    b_w = flatten_conv(b_conv.weight)
    ab_w = concat_mats((a_w, b_w), dim=0)
    output_transform[prefix] = ab_w
    output_transform.compute_transform()
    # merge_mat, unmerge_mat = match_tensors(ab_w)
    c_w = output_transform.output_align @ ab_w
    state_dict[prefix + '.weight'] = unflatten(c_w, a_conv.weight.shape[-1])
    return output_transform

def merge_bn(state_dict, prefix, a_bn, b_bn, output_transform):
    staterify = lambda bn: torch.stack((bn.weight, bn.bias, bn.running_mean), dim=1)
    unstaterify = lambda stats: stats.unbind(-1)
    
    a_stats = staterify(a_bn)
    b_stats = staterify(b_bn)
    ab_stats = concat_mats((a_stats, b_stats), dim=0)
    c_stats = output_transform.output_align @ ab_stats
    c_weight, c_bias, c_mean = unstaterify(c_stats)
    ab_var = concat_mats((a_bn.running_var[..., None], b_bn.running_var[...,None]))
    var_out_transform = output_transform.output_align.square()
    c_var = (var_out_transform @ ab_var).reshape(-1)
    state_dict[prefix + '.weight'] = c_weight
    state_dict[prefix + '.bias'] = c_bias
    state_dict[prefix + '.running_mean'] = c_mean
    state_dict[prefix + '.running_var'] = c_var
    pass

def block_diagonalize_tensors(tensor1, tensor2):
    zerooos = torch.zeros_like(tensor1)
    block_diagonal = concat_mats(
        (
            concat_mats((tensor1, zerooos), dim=1),
            concat_mats((zerooos, tensor2), dim=1),
        ),
        dim=0
    )
    return block_diagonal

def merge_hidden_conv(
    state_dict, prefix, a_conv, b_conv, 
    input_transform, output_transform, 
    recompute_output=False
):
    O, I, H, W = a_conv.weight.shape
    get_I_by_O_by_HW = lambda x: x.permute(1, 0, 2, 3).flatten(2)
    
    a_I_by_O_by_HW = get_I_by_O_by_HW(a_conv.weight)
    b_I_by_O_by_HW = get_I_by_O_by_HW(b_conv.weight)
    
    dummy_zerooooo = torch.zeros_like(b_I_by_O_by_HW)
    ab_block_diago = concat_mats(
        (
            concat_mats((a_I_by_O_by_HW, dummy_zerooooo), dim=1),
            concat_mats((dummy_zerooooo, b_I_by_O_by_HW), dim=1)
        ),
        dim=0
    )
    
    # [I,2I]x[2I,2OHW]->[I,2OHW]
    ab_input_aligned = input_transform.next_input_align.T @ ab_block_diago.flatten(1)
    ab_input_aligned = ab_input_aligned.\
    reshape(I, 2 * O, H*W).\
    transpose(1, 0).\
    flatten(1) # [I,2O,HW]->[2O,I,HW]->[2O,IHW]
    output_transform[prefix] = ab_input_aligned
    if recompute_output:
        output_transform.compute_transform()
    c_flat = output_transform.output_align @ ab_input_aligned
    state_dict[prefix + '.weight'] = unflatten(c_flat, a_conv.weight.shape[-1])
    
    output_block_diagonal_ab = block_diagonalize_tensors(
        a_conv.weight.flatten(2),
        b_conv.weight.flatten(2)
    )
    ab_output_aligned = output_transform.output_align @ output_block_diagonal_ab.flatten(1)
    ab_output_aligned = ab_output_aligned.reshape(O, 2 * I, H*W).transpose(1, 0).flatten(1)
    input_transform[prefix] = ab_output_aligned
    
    return output_transform

def merge_linear(
    state_dict, prefix, a_linear, 
    b_linear, input_transform, 
    output_transform, 
    recompute_output=False
):
    class conv_wrapper:
        def __init__(self, linear):
            self.weight = linear.weight[:, :, None, None]
    
    output_transform = merge_hidden_conv(
        state_dict, prefix, conv_wrapper(a_linear), 
        conv_wrapper(b_linear), input_transform, 
        output_transform, recompute_output=recompute_output
    )
    state_dict[prefix + '.weight'] = state_dict[prefix + '.weight'][..., 0, 0]
    state_dict[prefix + '.bias'] = output_transform.output_align @ concat_mats((a_linear.bias, b_linear.bias))
    return output_transform
    
def merge_block(
    state_dict, prefix, a_block, b_block, 
    input_transform, intra_transform,
    output_transform=None, shortcut=False
):
    conv1_transform = merge_hidden_conv(
        state_dict, prefix + '.conv1', a_block.conv1, b_block.conv1, 
        input_transform, intra_transform, recompute_output=True
    )
    merge_bn(state_dict, prefix + '.bn1', a_block.bn1, b_block.bn1, conv1_transform)
    
    
    conv2_transform = merge_hidden_conv(
        state_dict, 
        prefix + '.conv2', 
        a_block.conv2, 
        b_block.conv2, 
        conv1_transform,
        output_transform,
        recompute_output=shortcut
    )
    merge_bn(state_dict, prefix + '.bn2', a_block.bn2, b_block.bn2, conv2_transform)
    
    if shortcut:
        shortcut_transform = merge_hidden_conv(
            state_dict, 
            prefix + '.shortcut.0', 
            a_block.shortcut[0], 
            b_block.shortcut[0], 
            input_transform,
            output_transform=conv2_transform
        )
        merge_bn(
            state_dict, 
            prefix + '.shortcut.1', 
            a_block.shortcut[1], 
            b_block.shortcut[1], 
            shortcut_transform
        )
    
    return conv2_transform

hard_pass = lambda : None

def merge_resnet20(state_dict, a, b, transforms):
    transforms['conv1'] = merge_first_convs(
        state_dict, 'conv1', a.conv1, b.conv1, 
        output_transform=transforms['conv1']
    )
    merge_bn(state_dict, 'bn1', a.bn1, b.bn1, transforms['conv1'])
    
    for i in range(3):
        merge_block(
            state_dict, f'layer1.{i}', a.layer1[i], b.layer1[i], 
            input_transform=transforms['conv1'], 
            intra_transform=transforms[f'block1.{i}'],
            output_transform=transforms['conv1'],
            shortcut=False
        )
    
    transforms['block2'] = merge_block(
        state_dict, 'layer2.0', a.layer2[0], b.layer2[0], 
        input_transform=transforms['conv1'], 
        intra_transform=transforms[f'block2.0'],
        output_transform=transforms['block2'],
        shortcut=True
    )
    
    for i in range(1, 3):
        merge_block(
            state_dict, f'layer2.{i}', a.layer2[i], b.layer2[i], 
            input_transform=transforms['block2'], 
            intra_transform=transforms[f'block2.{i}'],
            output_transform=transforms['block2'],
            shortcut=False
        )
        
    transforms['block3'] = merge_block(
        state_dict, 'layer3.0', a.layer3[0], b.layer3[0], 
        input_transform=transforms['block2'], 
        intra_transform=transforms[f'block3.0'],
        output_transform=transforms['block3'],
        shortcut=True
    )
    for i in range(1, 3):
        merge_block(
            state_dict, f'layer3.{i}', a.layer3[i], b.layer3[i], 
            input_transform=transforms['block3'], 
            intra_transform=transforms[f'block3.{i}'],
            output_transform=transforms['block3'],
            shortcut=False
        )
        
    output_align_identity = torch.eye(a.linear.weight.shape[0], device=a.linear.weight.device)
    output_align_mat = torch.cat((output_align_identity/2, output_align_identity/2), dim=1)
    transforms['linear'].output_align = output_align_mat
    transforms['linear'] = merge_linear(
        state_dict, 'linear', a.linear, b.linear, 
        transforms['block3'], transforms['linear'],
        recompute_output=False
    )
    
    return transforms


In [17]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [18]:
model1 = resnet20(w=4, text_head=True).to(DEVICE)
model2 = resnet20(w=4, text_head=True).to(DEVICE)
load_model(model1, f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}')
load_model(model2, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')

print(evaluate_texthead(model1, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model2, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.778
0.7758


In [19]:
from collections import defaultdict

In [20]:
def get_procrustes(corr_mtx):
    U, _, Vh = torch.linalg.svd(corr_mtx)
    S = torch.eye(U.shape[0], device=U.device)
    S[-1, -1] = -1.
    return (U @ S) @ Vh

def match_tensors_procrustes(hull_tensor, eps=1e-7, interleave=False, random_perm=False):
    """
    hull_tensor: [2O,I]
    """
    O, I = hull_tensor.shape
    O //= 2
    
    interleave_mat = torch.eye(2*O, device=hull_tensor.device)
    if interleave:
        A1, A2, B1, B2 = interleave_mat.chunk(4, dim=0)
        interleave_mat = torch.cat([A1, B1, A2, B2], dim=0)
        interleave_mat = interleave_mat.view(2, O, 2*O).transpose(0, 1).reshape(2*O, 2*O)
#         interleave_mat = interleave_mat[torch.randperm(2*O, device=hull_tensor.device)]
    
    hull_tensor = interleave_mat @ hull_tensor
    
#     hull_tensor = hull_tensor / (hull_tensor.norm(dim=-1, keepdim=True) + eps)
    A, B = unconcat_mat(hull_tensor, dim=0)
    scores = (A @ B.T)
    
    P = get_procrustes(scores)
    
    O_eye = torch.eye(O, device=hull_tensor.device)
    
    try:
        row_idx, col_idx = scipy.optimize.linear_sum_assignment(scores.cpu().numpy())
    except ValueError:
        pdb.set_trace()
    
    A_perm = O_eye[torch.from_numpy(row_idx)]#[perm]
    B_perm = P #O_eye[torch.from_numpy(col_idx)]#[perm]
    
    if random_perm:
        perm = torch.randperm(O, device=A.device)
        A_perm = A_perm[perm]
        B_perm = B_perm[perm]
    
    merge = (torch.cat((A_perm, B_perm), dim=1) / 2.) @ interleave_mat
    unmerge = interleave_mat.T @ (torch.cat((A_perm.T, B_perm.T), dim=0))
    return merge, unmerge


In [21]:
def kmeans_matching(
    hull_tensor,
    interleave=False,
    random_perm=False,
    r=.5
):
    hull_normed = hull_tensor / hull_tensor.norm(dim=-1, keepdim=True)
    O = hull_tensor.shape[0]
    k = int(O * (1-r))
    cluster_ids, cluster_centers = kmeans(
        X=hull_normed, num_clusters=k, 
        distance='cosine', 
        device=hull_tensor.device,
        tqdm_flag=False,
        seed=123
    )

    eye = torch.eye(k, device=hull_tensor.device)
    transform = eye[cluster_ids]

    unmerge = transform
    merge = (transform / transform.sum(dim=0, keepdim=True)).T
    return merge, unmerge
    

In [42]:
def match_tensors_exact_bipartite(
    hull_tensor,
    interleave=False,
    random_perm=False,
    r=.5
):
    hull_normed = hull_tensor / hull_tensor.norm(dim=-1, keepdim=True)
    O = hull_tensor.shape[0]
    remainder = int(hull_tensor.shape[0] * (1-r))
    bound = O - remainder
    sims = hull_normed @ hull_normed.transpose(-1, -2)
    torch.diagonal(sims)[:] = -torch.inf
    permutation_matrix = torch.zeros((O, O - bound), device=sims.device)
    for i in range(bound):
        best_idx = sims.view(-1).argmax()
        row_idx = best_idx % sims.shape[1]
        col_idx = best_idx // sims.shape[1]
        permutation_matrix[row_idx, i] = 1
        permutation_matrix[col_idx, i] = 1
        sims[row_idx] = -torch.inf
        sims[col_idx] = -torch.inf
        sims[:, row_idx] = -torch.inf
        sims[:, col_idx] = -torch.inf
    
    unused = (sims.max(-1)[0] > -torch.inf).to(torch.int).nonzero().view(-1)
    for i in range(bound, O-bound):
        permutation_matrix[unused[i-bound], i] = 1
    merge = permutation_matrix / (permutation_matrix.sum(dim=0, keepdim=True) + 1e-5)
    unmerge = permutation_matrix
    return merge.T, unmerge


In [2]:
import numpy as np
def check_convergence(old_transforms, current_transforms, prev_distance=np.inf, eps=1e-5):
    if len(old_transforms) == 0: 
        return False, {}
    transform_norms = {}
    is_converged = True
    for key, old_transform in old_transforms.items():
        current_transform = current_transforms[key]
        is_close = torch.allclose(
            current_transform.output_align, 
            old_transform.output_align, 
            atol=eps
        )
        norm = torch.norm(current_transform.output_align - old_transform.output_align)
        if not is_close: 
            is_converged = False
        transform_norms[key] = torch.round(norm, decimals=3).cpu().numpy().round(3)
        
    if np.abs(prev_distance - np.mean(list(transform_norms.values()))) <= eps:
        is_converged = True
    return (is_converged, transform_norms)

def find_transform_differences(old_transforms, current_transforms):
    if len(old_transforms) == 0:
        return {}
    transform2norm = {}
    for key, old_transform in old_transforms.items():
        current_transform = current_transforms[key]
        old_align = old_transform.output_align
        new_align = current_transform.output_align
        cost = old_align.T @ new_align
        row_ind, col_idx = scipy.optimize.linear_sum_assignment(cost.detach().cpu().numpy())
        permutation = torch.eye(new_align.shape[1], device=old_align.device)[col_idx]
        aligned_new = new_align @ permutation
#         pdb.set_trace()
        norm = torch.norm(old_align - aligned_new).cpu().numpy()
        transform2norm[key] = norm
    return transform2norm        

In [116]:
# old_transforms['conv1'].output_align.shape

In [128]:
# np.random.seed(123)
# torch.manual_seed(123)
# Can choose between:
# match_tensors_tome: ToMe with or without interleaving
# match_tensors_permute: 
# match_tensors_svd
match_tensors = match_wrapper(match_tensors_exact_bipartite, interleave=False, random_perm=False)
layer_transform = lambda : LayerTransform(normalize_tensors=False, tensor_merge_type='concat')
old_state_dict = {}
state_dict = {}
old_transforms = defaultdict(lambda: layer_transform())
new_transforms = defaultdict(lambda: layer_transform())
modelc = resnet20(w=4, text_head=True).to(DEVICE)
accuracies = []
steps = []
distances = []
best_info = {'acc': 0., 'dist': np.inf}
step = 1
is_converged = False
prev_distance = np.inf
same_window = 5
same_span = 0
while not is_converged:
# for step in range(110):
    old_transforms = new_transforms
    old_state_dict = deepcopy(state_dict)
    new_transforms = merge_resnet20(state_dict, model1, model2, transforms=deepcopy(old_transforms))
    if step == 0:
        original_computation = deepcopy(new_transforms)
    
    transform2dist = find_transform_differences(old_transforms, new_transforms)
    avg_distance = np.mean(list(transform2dist.values())
    
    if abs(avg_distance - prev_distance) <= 1e-5:
        same_span += 1
    else:
        same_span = 0
    if same_span >= same_window:
        is_converged = True
        
    prev_distance = avg_distance
    
#     modelc.load_state_dict(state_dict)
#     acc, perclass_acc = evaluate_texthead(
#         modelc, test_loader, class_vectors=text_features, return_confusion=True
#     )
#     print(f'--------------------- Step {step} Dist: {avg_distance:.3f} Acc: {acc:.3f} ---------------------')
    print(f'--------------------- Step {step} Dist: {avg_distance:.3f} ---------------------')
    step += 1
#     if step >= 100 or is_converged:
#         break

--------------------- Step 1 Dist: nan ---------------------
--------------------- Step 2 Dist: 12.606 ---------------------
--------------------- Step 3 Dist: 12.606 ---------------------
--------------------- Step 4 Dist: 12.606 ---------------------
--------------------- Step 5 Dist: 12.606 ---------------------
--------------------- Step 6 Dist: 12.606 ---------------------
--------------------- Step 7 Dist: 12.599 ---------------------
--------------------- Step 8 Dist: 12.606 ---------------------
--------------------- Step 9 Dist: 12.606 ---------------------
--------------------- Step 10 Dist: 12.605 ---------------------
--------------------- Step 11 Dist: 12.606 ---------------------
--------------------- Step 12 Dist: 12.606 ---------------------
--------------------- Step 13 Dist: 12.606 ---------------------
--------------------- Step 14 Dist: 12.606 ---------------------
--------------------- Step 15 Dist: 12.606 ---------------------
--------------------- Step 16 Dist: 1

--------------------- Step 127 Dist: 12.605 ---------------------
--------------------- Step 128 Dist: 12.606 ---------------------
--------------------- Step 129 Dist: 12.606 ---------------------
--------------------- Step 130 Dist: 12.606 ---------------------
--------------------- Step 131 Dist: 12.605 ---------------------
--------------------- Step 132 Dist: 12.606 ---------------------
--------------------- Step 133 Dist: 12.606 ---------------------
--------------------- Step 134 Dist: 12.606 ---------------------
--------------------- Step 135 Dist: 12.606 ---------------------
--------------------- Step 136 Dist: 12.606 ---------------------
--------------------- Step 137 Dist: 12.606 ---------------------
--------------------- Step 138 Dist: 12.606 ---------------------
--------------------- Step 139 Dist: 12.606 ---------------------
--------------------- Step 140 Dist: 12.606 ---------------------
--------------------- Step 141 Dist: 12.606 ---------------------
----------

In [129]:
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
acc, perclass_acc = evaluate_texthead(
    modelc, test_loader, class_vectors=text_features, return_confusion=True
)
acc

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff337debcb0>
Traceback (most recent call last):
  File "/srv/share/gstoica3/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7ff337debcb0>    
Traceback (most recent call last):
Exception ignored in:   File "/srv/share/gstoica3/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__
self._shutdown_workers()
<function _MultiProcessingDataLoaderIter.__del__ at 0x7ff337debcb0>    
  File "/srv/share/gstoica3/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1449, in _shutdown_workers
Traceback (most recent call last):
self._shutdown_workers()  File "/srv/share/gstoica3/miniconda3/envs/open-mmlab/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1466, in __del__

    E

0.3705

In [37]:
print(perclass_acc)

[0.22, 0.19, 0.16, 0.27, 0.35, 0.21, 0.2, 0.28, 0.45, 0.26, 0.21, 0.27, 0.47, 0.28, 0.42, 0.55, 0.35, 0.74, 0.28, 0.22, 0.78, 0.37, 0.42, 0.59, 0.58, 0.22, 0.16, 0.27, 0.48, 0.44, 0.15, 0.3, 0.22, 0.27, 0.31, 0.39, 0.7, 0.52, 0.18, 0.18, 0.45, 0.48, 0.42, 0.53, 0.21, 0.2, 0.1, 0.25, 0.71, 0.26, 0.3, 0.64, 0.37, 0.58, 0.53, 0.14, 0.8, 0.63, 0.7, 0.33, 0.36, 0.72, 0.5, 0.27, 0.29, 0.07, 0.32, 0.17, 0.62, 0.58, 0.38, 0.57, 0.23, 0.36, 0.34, 0.26, 0.58, 0.14, 0.52, 0.29, 0.1, 0.73, 0.72, 0.62, 0.1, 0.42, 0.57, 0.51, 0.51, 0.24, 0.45, 0.24, 0.39, 0.33, 0.44, 0.31, 0.18, 0.34, 0.56, 0.35]


In [None]:
print(accuracies)
print(np.round(distances, 3))

In [None]:
steps

In [None]:
best_info

In [None]:
import pickle

permuted_sd = pickle.load(open(
'/srv/share/jbjorner3/checkpoints/REPAIR/pytorch_cifar50_2_permuted_to_1_jakob.pkl', 'rb'))


In [None]:
model_permute = resnet20(w=4, text_head=True)
model_permute.load_state_dict(permuted_sd)

In [None]:
model_permute = model_permute.to(DEVICE)
print(evaluate_texthead(model_permute, test_loader, class_vectors=text_features))

In [None]:
def mix_weights(model, alpha, sd0, sd1):
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].to(DEVICE)
        param1 = sd1[k].to(DEVICE)
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
    model.load_state_dict(sd_alpha)

In [None]:
merged_model = resnet20(w=4, text_head=True).to(DEVICE)
mix_weights(merged_model, .5, model1.state_dict(), model_permute.state_dict())

print(evaluate_texthead(merged_model, test_loader, class_vectors=text_features))
reset_bn_stats(merged_model)
print(evaluate_texthead(merged_model, test_loader, class_vectors=text_features))