In [1]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'
if DEVICE == 'mps':
    DOWNLOAD_PATH = '/Users/georgestoica/Downloads' 
else:
    DOWNLOAD_PATH = '/srv/share/gstoica3/checkpoints/REPAIR/'
    
torch.autograd.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7fa6bcbb0d50>

In [2]:
from copy import deepcopy

In [3]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [4]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10, text_head=False):
        super(ResNet, self).__init__()
        self.in_planes = int(w*16)

        self.conv1 = nn.Conv2d(3, int(w*16), kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(int(w*16))
        self.layer1 = self._make_layer(block, int(w*16), num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, int(w*32), num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, int(w*64), num_blocks[2], stride=2)
        if text_head:
            num_classes = 512
        self.linear = nn.Linear(int(w*64), num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1, text_head=False):
    return ResNet(BasicBlock, [3, 3, 3], w=w, text_head=text_head)

In [16]:
# evaluates accuracy
def evaluate(model, loader=test_loader, return_confusion=False):
    model.eval()
    correct = 0
    confusion_matrix = np.zeros((10, 10))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            pred = outputs.argmax(dim=1)
            correct += (labels.to(DEVICE) == pred).sum().item()
            confusion_matrix[labels.cpu().numpy(), pred.cpu().numpy()] += 1
    confusion_matrix /= confusion_matrix.sum(-1, keepdims=True)
    if return_confusion:
        return correct, confusion_matrix
    else:
        return correct

# evaluates loss
def evaluate1(model, loader=test_loader):
    model.eval()
    losses = []
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            loss = F.cross_entropy(outputs, labels.to(DEVICE))
            losses.append(loss.item())
    return np.array(losses).mean()

In [17]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [7]:
def bipartite_soft_matching(
    metric: torch.Tensor,
    r: float,
    class_token: bool = False,
    distill_token: bool = False,
):
    """
    Applies ToMe with a balanced matching set (50%, 50%).
    Input size is [batch, tokens, channels].
    r indicates the ratio of tokens to remove (max 50% of tokens).
    Extra args:
     - class_token: Whether or not there's a class token.
     - distill_token: Whether or not there's also a distillation token.
    When enabled, the class token and distillation tokens won't get merged.
    """
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[1]
    r = int(r * t)
    r = min(r, (t - protected) // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric.chunk(2, dim=-2)
        scores = a @ b.transpose(-1, -2)

        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x.chunk(2, dim=-2)
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., dst.shape[-2]:, :] = dst
        out.scatter_(dim=-2, index=(unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(src_idx).expand(n, r, c), src=src)

        return out

    return merge, unmerge

In [8]:
modela = resnet20(w=4).to(DEVICE)
modelb = resnet20(w=4).to(DEVICE)
load_model(modela, 'resnet20x4_v1')
load_model(modelb, 'resnet20x4_v3')

evaluate(modela), evaluate(modelb)

(9536, 9531)

In [29]:
def concat_mats(args, dim=0):
    return torch.cat(args, dim=dim)

ones = torch.ones(4, 5)
twos = ones * 2
threes = ones * 3
concat_mats((ones, twos, threes), dim=1)

tensor([[1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3.],
        [1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3.],
        [1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3.],
        [1., 1., 1., 1., 1., 2., 2., 2., 2., 2., 3., 3., 3., 3., 3.]])

In [8]:
def concat_mats(args, dim=0):
    return torch.cat(args, dim=dim)

def unconcat_mat(tensor, dim=0):
    return torch.chunk(tensor, chunks=2, dim=dim)

def match_tensors_permute(hull_tensor, r=1/4, eps=1e-7, interleave=False, random_perm=False):
    """
    hull_tensor: [O,I]
    """
    O_orig, I = hull_tensor.shape
    O_new = int(O_orig * r)
    
    interleave_mat = torch.eye(O_orig, device=hull_tensor.device)
    if interleave:
        A1, A2, B1, B2 = interleave_mat.chunk(4, dim=0)
        interleave_mat = torch.cat([A1, B1, A2, B2], dim=0)
        interleave_mat = interleave_mat.view(2, O_new, O_orig).transpose(0, 1).reshape(O_orig, O_orig)
#         interleave_mat = interleave_mat[torch.randperm(2*O, device=hull_tensor.device)]
    
    hull_tensor = interleave_mat @ hull_tensor
    
    hull_tensor = hull_tensor / (hull_tensor.norm(dim=-1, keepdim=True) + eps)
    A, B = unconcat_mat(hull_tensor, dim=0)
    scores = -(A @ B.T)
    
    O_eye = torch.eye(O_new, device=hull_tensor.device)
    
    try:
        row_idx, col_idx = scipy.optimize.linear_sum_assignment(scores.cpu().numpy())
    except ValueError:
        pdb.set_trace()
    
    A_perm = O_eye[torch.from_numpy(row_idx)]#[perm]
    B_perm = O_eye[torch.from_numpy(col_idx)]#[perm]
    
    if random_perm:
        perm = torch.randperm(O, device=A.device)
        A_perm = A_perm[perm]
        B_perm = B_perm[perm]
    
    pdb.set_trace()
    # Need to figure out whether to divide by the sum of the columns or the rows of the matrix.
    # I think it's the columns
    AB_perm = torch.cat((A_perm, B_perm), dim=1)
    AB_perm = AB_perm / AB_perm.sum(dim=1, keepdim=True)
    merge = AB_perm @ interleave_mat
    unmerge = interleave_mat.T @ (torch.cat((A_perm.T, B_perm.T), dim=0))
    return merge, unmerge


def match_tensors_tome(hull_tensor, r=.5, eps=1e-7, interleave=False, random_perm=False):
    """
    hull_tensor: [2O,I]
    """
    O_orig, I = hull_tensor.shape
    O_new = int(O_orig * r)
    
    big_eye = torch.eye(O_orig, device=hull_tensor.device)
    small_eye = torch.eye(O_new, device=hull_tensor.device)
    
    interleave_mat = big_eye
    if interleave:
        A1, A2, B1, B2 = interleave_mat.chunk(4, dim=0)
        interleave_mat = torch.cat([A1, B1, A2, B2], dim=0)
    
    
    hull_tensor = interleave_mat @ hull_tensor
    
    merge, unmerge = bipartite_soft_matching(hull_tensor[None], r)
    
    merge_mat = merge(big_eye[None])[0] @ interleave_mat
    unmerge_mat = interleave_mat.T @ unmerge(small_eye[None])[0]
    return merge_mat, unmerge_mat

def match_tensors_svd(hull_tensor, use_S=True):
    # We can only reduce by a maximum of 50% tokens
    t = hull_tensor.shape[0]
    r = int(.5 * t)
    r = min(r, (t) // 2)
    with torch.no_grad():
        hull_tensor = hull_tensor / hull_tensor.norm(dim=-1, keepdim=True)
        scores = hull_tensor @ hull_tensor.transpose(-1, -2)
        U, S, V = torch.svd(scores)
        
        U_r = U[:, :r]
        S_r = torch.diag(S[:r]) if use_S else torch.eye(r, device=DEVICE)
        V_r = V[:, :r]
    merge_mat = U_r
    unmerge_mat = S_r @ V_r.mT
    return merge_mat.T, unmerge_mat.T


def match_wrapper(fn, interleave=False, random_perm=False, r=.5):
    return lambda x: fn(x, interleave=interleave, random_perm=random_perm, r=r)

In [9]:
class LayerTransform(dict):
    def __init__(self, normalize_tensors=False, r=.5):
        super(LayerTransform, self).__init__()
        self.output_align = None
        self.next_input_align = None
        self.normalize_tensors = normalize_tensors
        self.r = r
    
    def compute_transform(self):
        inputs = list(self.values())
        if self.normalize_tensors:
            for idx, inp in enumerate(inputs):
                inputs[idx] = F.normalize(inp, dim=-1)
        self.output_align, self.next_input_align = match_tensors(concat_mats(inputs, dim=-1))

In [42]:
lt = LayerTransform()
lt[1] = torch.tensor(3)
lt[2] = torch.tensor(4)
lt.output_align = 4

4

In [10]:
def unflatten(x, k=3):
    O, IHW = x.shape
    return x.view(O, -1, k, k)

def merge_first_convs(state_dict, prefix, conv, output_transform):
    flatten_conv = lambda x: x.flatten(1)
    w = flatten_conv(conv.weight)
    output_transform[prefix] = w
    output_transform.compute_transform()
    # merge_mat, unmerge_mat = match_tensors(ab_w)
    c_w = output_transform.output_align @ w
    state_dict[prefix + '.weight'] = unflatten(c_w, conv.weight.shape[-1])
    return output_transform

def merge_bn(state_dict, prefix, bn, output_transform):
    staterify = lambda bn: torch.stack((bn.weight, bn.bias, bn.running_mean), dim=1)
    unstaterify = lambda stats: stats.unbind(-1)
    
    stats = staterify(bn)
    c_stats = output_transform.output_align @ stats
    c_weight, c_bias, c_mean = unstaterify(c_stats)
    var = bn.running_var[..., None]
    var_out_transform = output_transform.output_align.square()
    c_var = (var_out_transform @ var).reshape(-1)
    state_dict[prefix + '.weight'] = c_weight
    state_dict[prefix + '.bias'] = c_bias
    state_dict[prefix + '.running_mean'] = c_mean
    state_dict[prefix + '.running_var'] = c_var
    pass

def block_diagonalize_tensors(tensor1, tensor2):
    zerooos = torch.zeros_like(tensor1)
    block_diagonal = concat_mats(
        (
            concat_mats((tensor1, zerooos), dim=1),
            concat_mats((zerooos, tensor2), dim=1),
        ),
        dim=0
    )
    return block_diagonal

def merge_hidden_conv(
    state_dict, prefix, conv, 
    input_transform, output_transform, 
    recompute_output=False
):
    O_orig, I_orig, H, W = conv.weight.shape
    # Align output spaces for global iterations
    if output_transform.output_align is not None:
        conv_out_aligned = output_transform.output_align @ conv.weight.flatten(1)
        conv_out_aligned = conv_out_aligned.reshape(-1, I_orig, H*W).transpose(1, 0).flatten(1)
        input_transform[prefix] = conv_out_aligned
    
    I_OHW = conv.weight.permute(1, 0, 2, 3).flatten(1)
    # [I,2I]x[2I,2OHW]->[I,2OHW]
    input_aligned = input_transform.next_input_align.T @ I_OHW
    input_aligned = input_aligned.\
    reshape(-1, O_orig, H*W).\
    transpose(1, 0).\
    flatten(1) # [I_new,O,HW]->[O,I_new,HW]->[O,I_newHW]
    output_transform[prefix] = input_aligned
    if recompute_output:
        output_transform.compute_transform()
    c_flat = output_transform.output_align @ input_aligned
    state_dict[prefix + '.weight'] = unflatten(c_flat, conv.weight.shape[-1])
    
    return output_transform

def merge_linear(
    state_dict, prefix, linear, 
    input_transform, 
    output_transform, 
    recompute_output=False
):
    class conv_wrapper:
        def __init__(self, linear):
            self.weight = linear.weight[:, :, None, None]
    
    output_transform = merge_hidden_conv(
        state_dict, prefix, conv_wrapper(linear), 
        input_transform, output_transform, 
        recompute_output=recompute_output
    )
    state_dict[prefix + '.weight'] = state_dict[prefix + '.weight'][..., 0, 0]
    state_dict[prefix + '.bias'] = output_transform.output_align @ linear.bias
    return output_transform
    
def merge_block(
    state_dict, prefix, block, 
    input_transform, intra_transform,
    output_transform=None, shortcut=False
):
    conv1_transform = merge_hidden_conv(
        state_dict, prefix + '.conv1', block.conv1, 
        input_transform, intra_transform, recompute_output=True
    )
    merge_bn(state_dict, prefix + '.bn1', block.bn1, conv1_transform)
    
    
    conv2_transform = merge_hidden_conv(
        state_dict, 
        prefix + '.conv2', 
        block.conv2, 
        conv1_transform,
        output_transform,
        recompute_output=shortcut
    )
    merge_bn(state_dict, prefix + '.bn2', block.bn2, conv2_transform)
    
    if shortcut:
        shortcut_transform = merge_hidden_conv(
            state_dict, 
            prefix + '.shortcut.0', 
            block.shortcut[0], 
            input_transform,
            output_transform=conv2_transform
        )
        merge_bn(
            state_dict, 
            prefix + '.shortcut.1', 
            block.shortcut[1], 
            shortcut_transform
        )
    
    return conv2_transform

hard_pass = lambda : None

def merge_resnet20(state_dict, model, transforms):
    transforms['conv1'] = merge_first_convs(
        state_dict, 'conv1', model.conv1, 
        output_transform=transforms['conv1']
    )
    merge_bn(state_dict, 'bn1', model.bn1, transforms['conv1'])
    
    for i in range(3):
        merge_block(
            state_dict, f'layer1.{i}', model.layer1[i], 
            input_transform=transforms['conv1'], 
            intra_transform=transforms[f'block1.{i}'],
            output_transform=transforms['conv1'],
            shortcut=False
        )
    
    transforms['block2'] = merge_block(
        state_dict, 'layer2.0', model.layer2[0],
        input_transform=transforms['conv1'], 
        intra_transform=transforms[f'block2.0'],
        output_transform=transforms['block2'],
        shortcut=True
    )
    
    for i in range(1, 3):
        merge_block(
            state_dict, f'layer2.{i}', model.layer2[i], 
            input_transform=transforms['block2'], 
            intra_transform=transforms[f'block2.{i}'],
            output_transform=transforms['block2'],
            shortcut=False
        )
        
    transforms['block3'] = merge_block(
        state_dict, 'layer3.0', model.layer3[0], 
        input_transform=transforms['block2'], 
        intra_transform=transforms[f'block3.0'],
        output_transform=transforms['block3'],
        shortcut=True
    )
    for i in range(1, 3):
        merge_block(
            state_dict, f'layer3.{i}', model.layer3[i], 
            input_transform=transforms['block3'], 
            intra_transform=transforms[f'block3.{i}'],
            output_transform=transforms['block3'],
            shortcut=False
        )
        
    output_align_mat = torch.eye(model.linear.weight.shape[0], device=model.linear.weight.device)
    transforms['linear'].output_align = output_align_mat
    transforms['linear'] = merge_linear(
        state_dict, 'linear', model.linear, 
        transforms['block3'], transforms['linear'],
        recompute_output=False
    )
    
    return transforms

In [11]:
modela = resnet20(w=4).to(DEVICE)
modelb = resnet20(w=4).to(DEVICE)
load_model(modela, 'resnet20x4_v1')
load_model(modelb, 'resnet20x4_v2')

evaluate(modela), evaluate(modelb)

(9536, 9510)

In [11]:
def is_converged(old_transforms, current_transforms, eps=1e-5):
    if len(old_transforms) == 0: 
        return False
    transform_norms = {}
    for key, old_transform in old_transforms.items():
        current_transform = current_transforms[key]
        is_close = torch.allclose(
            current_transform.output_align, 
            old_transform.output_align, 
            atol=eps
        )
        norm = torch.norm(current_transform.output_align - old_transform.output_align)
        transform_norms[key] = (is_close, torch.round(norm, decimals=3))
    print(transform_norms)
    converges = [i[0] for i in transform_norms.values()]
    return sum(converges) == len(transform_norms)

In [12]:
from collections import defaultdict

In [28]:
# Can choose between:
# match_tensors_tome: ToMe with or without interleaving
# match_tensors_permute: 
# match_tensors_svd
match_tensors = match_wrapper(match_tensors_tome, interleave=True, random_perm=False, r=.5)
layer_transform = lambda : LayerTransform(normalize_tensors=False)
state_dict = {}
old_transforms = defaultdict(lambda: layer_transform())
new_transforms = defaultdict(lambda: layer_transform())

modelc = resnet20(w=2).to(DEVICE)
step = 1
# while not is_converged(old_transforms, new_transforms):
for step in range(10):
    old_transforms = new_transforms
    new_transforms = merge_resnet20(state_dict, modela, transforms=deepcopy(old_transforms))
    modelc.load_state_dict(state_dict)
    pre_reset_acc = evaluate(modelc)
    reset_bn_stats(modelc)
    print(step, pre_reset_acc, evaluate(modelc))
    step += 1
    

0 1871 5512
1 3896 7301
2 3658 7267
3 3328 7308
4 3127 7265
5 2854 7393
6 3003 7256
7 3099 7413
8 3062 7181
9 3236 7212


In [60]:
new_transforms.keys()

dict_keys(['conv1', 'block2', 'block3', 'linear'])

#### Weight Averaging

In [244]:
def stoopid_dumb(state_dict, a, b):
    for (k, v1), (_, v2) in zip(a.named_parameters(), b.named_parameters()):
        state_dict[k] = (v1 + v2) / 2
    for (k, v1), (_, v2) in zip(a.named_buffers(), b.named_buffers()):
        state_dict[k] = (v1 + v2) / 2

In [245]:
state_dict = {}
stoopid_dumb(state_dict, modela, modelb)
modelc = resnet20(w=4).to(DEVICE)
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
evaluate(modelc)

3094

#### Model Ensembling

In [252]:
def evaluate_logit_ensemble(model1, model2, loader=test_loader, return_confusion=False):
    model.eval()
    correct = 0
    confusion_matrix = np.zeros((10, 10))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs1 = model1(inputs.to(DEVICE))
            outputs2 = model2(inputs.to(DEVICE))
            outputs = (outputs1 + outputs2) / 2.
            pred = outputs.argmax(dim=1)
            correct += (labels.to(DEVICE) == pred).sum().item()
            confusion_matrix[labels.cpu().numpy(), pred.cpu().numpy()] += 1
    confusion_matrix /= confusion_matrix.sum(-1, keepdims=True)
    if return_confusion:
        return correct, confusion_matrix
    else:
        return correct
    

In [253]:
evaluate_logit_ensemble(modela, modelb)

9586

# Test on Disjoint Models

In [13]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR100(
    root='/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python', 
    train=True,
    download=True, transform=train_transform
)
test_dset = torchvision.datasets.CIFAR100(
    root='/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python',
    train=False,
    download=True, 
    transform=test_transform
)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

In [18]:
model1_classes= np.arange(50) # np.array([3, 2, 0, 6, 4])
model2_classes = np.arange(50, 100) # np.array([5, 7, 9, 8, 1])

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

50000it [00:14, 3534.28it/s]
50000it [00:14, 3541.93it/s]
10000it [00:01, 5324.88it/s]
10000it [00:01, 5340.34it/s]


In [19]:
class_idxs = np.zeros(100, dtype=int)
class_idxs[model1_classes] = np.arange(50)
class_idxs[model2_classes] = np.arange(50)
class_idxs = torch.from_numpy(class_idxs)
class_idxs

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,  0,  1,  2,  3,
         4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47, 48, 49])

In [23]:
import clip

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)
model, preprocess = clip.load('ViT-B/32', DEVICE)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)


text_features /= text_features.norm(dim=-1, keepdim=True)
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]

if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}.pth.tar'
    )
):
    print('training model...')
#     continue
#     model1 = resnet20(w=4).to(DEVICE)
#     train(
#         f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}', 
#         model=model1, 
#         class_vectors=class_vecs1,
#         train_loader=train_aug_loader1,
#         test_loader=test_loader1,
#         remap_class_idxs=class_idxs
#     )
if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}.pth.tar'
    )
):
    print('training model...')
#     continue
#     model2 = resnet20(w=4).to(DEVICE)
#     train(
#         f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}', 
#         model=model2, 
#         class_vectors=class_vecs2,
#         train_loader=train_aug_loader2,
#         test_loader=test_loader2,
#         remap_class_idxs=class_idxs
#     )
      

In [24]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    confusion = np.zeros((10, 10))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                correct += (labels.to(DEVICE) == pred).sum().item()
            confusion[labels.cpu().numpy(), pred.cpu().numpy()] += 1
            total += inputs.shape[0]
    if return_confusion:
        return correct / total, confusion / confusion.sum(-1, keepdims=True)
    else:
        return correct / total

In [28]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    confusion = np.zeros((100, 100))
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
#                 correct += (labels.to(DEVICE) == pred).sum().item()
                
            confusion[labels.cpu().numpy(), pred.cpu().numpy()] += 1
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

In [37]:
print(test_dset.classes[:50])

['apple', 'aquarium_fish', 'baby', 'bear', 'beaver', 'bed', 'bee', 'beetle', 'bicycle', 'bottle', 'bowl', 'boy', 'bridge', 'bus', 'butterfly', 'camel', 'can', 'castle', 'caterpillar', 'cattle', 'chair', 'chimpanzee', 'clock', 'cloud', 'cockroach', 'couch', 'crab', 'crocodile', 'cup', 'dinosaur', 'dolphin', 'elephant', 'flatfish', 'forest', 'fox', 'girl', 'hamster', 'house', 'kangaroo', 'keyboard', 'lamp', 'lawn_mower', 'leopard', 'lion', 'lizard', 'lobster', 'man', 'maple_tree', 'motorcycle', 'mountain']


In [29]:
model1 = resnet20(w=4, text_head=True).to(DEVICE)
model2 = resnet20(w=4, text_head=True).to(DEVICE)
load_model(model1, f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}')
load_model(model2, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')
print('models loaded')
print(evaluate_texthead(model1, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model2, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

models loaded
0.778
0.7758


In [51]:
# Can choose between:
# match_tensors_tome: ToMe with or without interleaving
# match_tensors_permute: 
# match_tensors_svd
match_tensors = match_wrapper(match_tensors_tome, interleave=False, random_perm=False)
layer_transform = lambda : LayerTransform(normalize_tensors=False)
state_dict = {}
old_transforms = defaultdict(lambda: layer_transform())
new_transforms = defaultdict(lambda: layer_transform())
modelc = resnet20(w=4, text_head=True).to(DEVICE)
best_info = {'acc': 0.}
step = 1
# while not is_converged(old_transforms, new_transforms):
for step in range(10):
    old_transforms = new_transforms
    new_transforms = merge_resnet20(state_dict, model1, model2, transforms=deepcopy(old_transforms))
    modelc.load_state_dict(state_dict)
    reset_bn_stats(modelc)
    acc, perclass_acc = evaluate_texthead(
        modelc, test_loader, class_vectors=text_features, return_confusion=True
    )
    if acc > best_info['acc']:
        best_info['acc'] = acc
        best_info['perclass_acc'] = perclass_acc
    print(step, acc)
    step += 1

0 0.3144
1 0.3218
2 0.3284
3 0.3314
4 0.3327
5 0.3321
6 0.3291
7 0.3245
8 0.3246
9 0.3308


In [52]:
print(best_info['acc'])
print(best_info['perclass_acc'])

0.3327
[0.0, 0.0, 0.0, 0.01, 0.34, 0.0, 0.0, 0.02, 0.0, 0.01, 0.03, 0.0, 0.03, 0.07, 0.06, 0.02, 0.04, 0.22, 0.42, 0.0, 0.37, 0.0, 0.0, 0.02, 0.01, 0.02, 0.0, 0.16, 0.09, 0.04, 0.37, 0.05, 0.1, 0.34, 0.07, 0.0, 0.14, 0.01, 0.04, 0.01, 0.18, 0.01, 0.6, 0.1, 0.1, 0.0, 0.0, 0.02, 0.0, 0.11, 0.25, 0.56, 0.54, 0.74, 0.75, 0.27, 0.85, 0.73, 0.86, 0.59, 0.72, 0.87, 0.75, 0.36, 0.31, 0.19, 0.73, 0.15, 0.88, 0.68, 0.39, 0.51, 0.25, 0.22, 0.58, 0.63, 0.78, 0.4, 0.62, 0.74, 0.33, 0.36, 0.84, 0.65, 0.43, 0.63, 0.58, 0.72, 0.59, 0.78, 0.83, 0.63, 0.54, 0.41, 0.88, 0.72, 0.32, 0.46, 0.9, 0.54]


In [53]:
acc1, confusion_disj_1 = evaluate_texthead(model1, test_loader, class_vectors=text_features, return_confusion=True)
acc2, confusion_disj_2 = evaluate_texthead(model2, test_loader, class_vectors=text_features, return_confusion=True)
print(acc1, confusion_disj_1)
print(acc2, confusion_disj_2)

0.3904 [0.94, 0.89, 0.62, 0.69, 0.76, 0.73, 0.85, 0.76, 0.85, 0.83, 0.57, 0.68, 0.8, 0.9, 0.76, 0.84, 0.8, 0.86, 0.7, 0.75, 0.84, 0.84, 0.77, 0.9, 0.88, 0.63, 0.73, 0.59, 0.75, 0.76, 0.85, 0.62, 0.74, 0.81, 0.82, 0.6, 0.91, 0.87, 0.61, 0.94, 0.72, 0.91, 0.82, 0.85, 0.56, 0.68, 0.46, 0.93, 0.95, 0.9, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.02, 0.01, 0.0, 0.0, 0.1, 0.01]
0.3884 [0.0, 0.0, 0.0, 0.0, 0.03, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.07, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.01, 0.0, 0.01, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.64, 0.81, 0.73, 0.87, 0.9, 0.6, 0.88, 0.83, 0.91, 0.74, 0.88, 0.89, 0.76, 0.7, 0.61, 0.51, 0.83, 0.61, 0.93, 0.91, 0.68, 0.86, 0.51, 0.66, 0.58, 0.89, 0.92, 0.73, 0.67, 0.77, 0.57,

#### Weight Averaging

In [33]:
def stoopid_dumb(state_dict, a, b):
    for (k, v1), (_, v2) in zip(a.named_parameters(), b.named_parameters()):
        state_dict[k] = (v1 + v2) / 2
    for (k, v1), (_, v2) in zip(a.named_buffers(), b.named_buffers()):
        state_dict[k] = (v1 + v2) / 2

In [34]:
state_dict = {}
stoopid_dumb(state_dict, model1, model2)
modelc = resnet20(w=4, text_head=True).to(DEVICE)
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
acc, confusion_stopid = evaluate_texthead(modelc, test_loader, class_vectors=text_features, return_confusion=True)
print(acc, confusion_stopid)

0.067 [0.0, 0.0, 0.08, 0.07, 0.04, 0.01, 0.13, 0.03, 0.28, 0.07, 0.09, 0.19, 0.04, 0.0, 0.03, 0.05, 0.03, 0.12, 0.12, 0.02, 0.21, 0.04, 0.06, 0.02, 0.27, 0.07, 0.05, 0.02, 0.12, 0.12, 0.0, 0.06, 0.07, 0.18, 0.04, 0.1, 0.17, 0.13, 0.01, 0.15, 0.14, 0.07, 0.12, 0.2, 0.01, 0.02, 0.22, 0.02, 0.01, 0.26, 0.2, 0.17, 0.05, 0.07, 0.09, 0.01, 0.06, 0.05, 0.01, 0.02, 0.01, 0.11, 0.06, 0.02, 0.02, 0.02, 0.1, 0.02, 0.01, 0.04, 0.06, 0.02, 0.02, 0.06, 0.04, 0.1, 0.0, 0.03, 0.1, 0.06, 0.01, 0.02, 0.22, 0.03, 0.0, 0.03, 0.0, 0.04, 0.02, 0.01, 0.12, 0.03, 0.03, 0.01, 0.0, 0.03, 0.0, 0.0, 0.07, 0.04]


#### Model Ensembling

In [40]:
# evaluates accuracy
def evaluate_ensemble_texthead(model1, model2, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    confusion = np.zeros((100, 100))
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings1 = model1(inputs.to(DEVICE))
            encodings2 = model2(inputs.to(DEVICE))
            normed_encodings1 = encodings1 / encodings1.norm(dim=-1, keepdim=True)
            normed_encodings2 = encodings2 / encodings2.norm(dim=-1, keepdim=True)
            outputs1 = normed_encodings1 @ class_vectors.T
            outputs2 = normed_encodings2 @ class_vectors.T
            pred, pred_prob1 = outputs1.argmax(dim=1), outputs1.max(dim=1)[0]
            pred2, pred_prob2 = outputs2.argmax(dim=1), outputs2.max(dim=1)[0]
            pred[pred_prob1 < pred_prob2] = pred2[pred_prob1 < pred_prob2]
#             pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
#                 correct += (labels.to(DEVICE) == pred).sum().item()
                
            confusion[labels.cpu().numpy(), pred.cpu().numpy()] += 1
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

In [41]:
acc_ensemble, confusion_ensemble = evaluate_ensemble_texthead(
    model1, model2, test_loader, class_vectors=text_features, return_confusion=True
)

In [44]:
print(acc_ensemble)
print(confusion_ensemble)

0.6336
[0.91, 0.86, 0.57, 0.65, 0.67, 0.56, 0.84, 0.73, 0.79, 0.73, 0.42, 0.53, 0.58, 0.57, 0.74, 0.81, 0.68, 0.79, 0.64, 0.73, 0.69, 0.69, 0.66, 0.69, 0.81, 0.5, 0.65, 0.41, 0.69, 0.7, 0.59, 0.62, 0.52, 0.72, 0.79, 0.4, 0.85, 0.79, 0.53, 0.79, 0.64, 0.89, 0.78, 0.82, 0.5, 0.61, 0.23, 0.75, 0.91, 0.79, 0.51, 0.53, 0.46, 0.59, 0.67, 0.38, 0.81, 0.57, 0.74, 0.53, 0.84, 0.71, 0.62, 0.42, 0.49, 0.41, 0.47, 0.39, 0.86, 0.87, 0.61, 0.85, 0.2, 0.49, 0.33, 0.57, 0.81, 0.54, 0.47, 0.43, 0.43, 0.61, 0.87, 0.4, 0.62, 0.65, 0.75, 0.74, 0.6, 0.6, 0.7, 0.8, 0.43, 0.54, 0.81, 0.69, 0.42, 0.68, 0.5, 0.54]
