In [1]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'
if DEVICE == 'mps':
    DOWNLOAD_PATH = '/Users/georgestoica/Downloads' 
else:
    DOWNLOAD_PATH = '/srv/share/gstoica3/checkpoints/REPAIR/'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [3]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR10(root='/tmp', train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR10(root='/tmp', train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [213]:
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init

def _weights_init(m):
    classname = m.__class__.__name__
    if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight)

class LambdaLayer(nn.Module):
    def __init__(self, lambd):
        super(LambdaLayer, self).__init__()
        self.lambd = lambd

    def forward(self, x):
        return self.lambd(x)


class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
#             self.shortcut = LambdaLayer(lambda x:
#                                         F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(planes)
            )


    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, w=1, num_classes=10, text_head=False):
        super(ResNet, self).__init__()
        self.in_planes = w*16

        self.conv1 = nn.Conv2d(3, w*16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(w*16)
        self.layer1 = self._make_layer(block, w*16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, w*32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, w*64, num_blocks[2], stride=2)
        if text_head:
            num_classes = 512
        self.linear = nn.Linear(w*64, num_classes)

        self.apply(_weights_init)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def resnet20(w=1, text_head=False):
    return ResNet(BasicBlock, [3, 3, 3], w=w, text_head=text_head)

In [284]:
# evaluates accuracy
def evaluate(model, loader=test_loader, return_confusion=False):
    model.eval()
    correct = 0
    confusion_matrix = np.zeros((10, 10))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            pred = outputs.argmax(dim=1)
            correct += (labels.to(DEVICE) == pred).sum().item()
            confusion_matrix[labels.cpu().numpy(), pred.cpu().numpy()] += 1
    confusion_matrix /= confusion_matrix.sum(-1, keepdims=True)
    if return_confusion:
        return correct, confusion_matrix
    else:
        return correct

# evaluates loss
def evaluate1(model, loader=test_loader):
    model.eval()
    losses = []
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            outputs = model(inputs.to(DEVICE))
            loss = F.cross_entropy(outputs, labels.to(DEVICE))
            losses.append(loss.item())
    return np.array(losses).mean()

In [6]:
# modifies the weight matrices of a convolution and batchnorm
# layer given a permutation of the output channels
def permute_output(perm_map, conv, bn):
    pre_weights = [
        conv.weight,
        bn.weight,
        bn.bias,
        bn.running_mean,
        bn.running_var,
    ]
    for w in pre_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('ab,bcde->acde', perm_map, w)
        elif len(w.shape) == 2:
            transform = perm_map @ w
        else:
            transform = w @ perm_map.t()
#         assert torch.allclose(w[perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[perm_map]

# modifies the weight matrix of a convolution layer for a given
# permutation of the input channels
def permute_input(perm_map, after_convs):
    if not isinstance(after_convs, list):
        after_convs = [after_convs]
    post_weights = [c.weight for c in after_convs]
    for w in post_weights:
        if len(w.shape) == 4:
            transform = torch.einsum('abcd,be->aecd', w, perm_map.t())
        elif len(w.shape) == 2:
            transform = w @ perm_map.t()
    #     assert torch.allclose(w[:, perm_map.argmax(-1)], transform)
        w.data = transform
#         w.data = w[:, perm_map, :, :]

def permute_cls_output(perm_map, linear):
    for w in [linear.weight, linear.bias]:
        w.data = perm_map @ w

In [87]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [414]:
def do_nothing(x, mode="mean"):
    return x

def bipartite_soft_matching(
    metric: torch.Tensor,
    r: float,
    class_token: bool = False,
    distill_token: bool = False,
):
    """
    Applies ToMe with a balanced matching set (50%, 50%).
    Input size is [batch, tokens, channels].
    r indicates the ratio of tokens to remove (max 50% of tokens).
    Extra args:
     - class_token: Whether or not there's a class token.
     - distill_token: Whether or not there's also a distillation token.
    When enabled, the class token and distillation tokens won't get merged.
    """
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    # We can only reduce by a maximum of 50% tokens
    t = metric.shape[1]
    r = int(r * t)
    r = min(r, (t - protected) // 2)

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        scores = a @ b.transpose(-1, -2)
        
#         scores = torch.cov(torch.cat([a, b], dim=1)[0])[None, :r, r:t]
#         scores = -torch.cdist(a, b)

        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf

        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]

    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    def unmerge(x: torch.Tensor) -> torch.Tensor:
        unm_len = unm_idx.shape[1]
        unm, dst = x[..., :unm_len, :], x[..., unm_len:, :]
        n, _, c = unm.shape

        src = dst.gather(dim=-2, index=dst_idx.expand(n, r, c))

        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = dst
        out.scatter_(dim=-2, index=(2 * unm_idx).expand(n, unm_len, c), src=unm)
        out.scatter_(dim=-2, index=(2 * src_idx).expand(n, r, c), src=src)

        return out

    return merge, unmerge

def permutation_matching(metric, r):
    with torch.no_grad():
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        scores = -(a @ b.transpose(-1, -2))
#     pdb.set_trace()
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(scores[0].cpu().numpy().T)
    row_ind = torch.from_numpy(col_ind)[None, :, None].to(metric.device)
    
#     print(row_ind, col_ind)
    
    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        
        src = src.gather(dim=-2, index=row_ind.expand(n, t1, c))
        
        if mode == "sum":
            return dst + src
        elif mode == "mean":
            return (dst + src) / 2
        else:
            return 1 / 0
    
    def unmerge(x):
        n, r, c = x.shape
        out = torch.zeros(n, metric.shape[1], c, device=x.device, dtype=x.dtype)

        out[..., 1::2, :] = x
        out.scatter_(dim=-2, index=(2 * row_ind).expand(n, r, c), src=x)

        return out
    
    return merge, unmerge

# bipartite_soft_matching = permutation_matching

def svd_matching(metric, r):
    U, S, V = torch.svd(metric[0])
    return U[None, :, :r], (S @ V.T)[None, :r, :]
        
        

In [7]:
modela = resnet20(w=4).to(DEVICE)
modelb = resnet20(w=4).to(DEVICE)
load_model(modela, 'resnet20x4_v1')
load_model(modelb, 'resnet20x4_v2')

evaluate(modela), evaluate(modelb)

(9536, 9510)

In [81]:
def strip_param_suffix(name):
    return name.replace('.weight', '').replace('.bias', '')

def naively_compress_model(model, ratio):
    new_sd = {}
    for key, val in model.state_dict().items():
        shape = val.shape
        if len(shape) == 4:
            val = val.flatten(2).permute(2, 0, 1)
            merge, unmerge = bipartite_soft_matching(val, r=ratio)
            val = unmerge(merge(val))
            if not key.startswith('conv1'):
                val = val.transpose(2, 1)
                merge, unmerge = bipartite_soft_matching(val, r=ratio)
                val = unmerge(merge(val)).transpose(2, 1)
            val = val.permute(1, 2, 0).reshape(*shape)
        elif len(shape) == 2:
            val = val[None]
            merge, unmerge = bipartite_soft_matching(val, r=0)
            val = unmerge(merge(val)).transpose(2, 1)
            merge, unmerge = bipartite_soft_matching(val, r=ratio)
            val = unmerge(merge(val)).transpose(2, 1)[0]
        elif len(shape) == 1:
            val = val[None, :, None]
            merge, unmerge = bipartite_soft_matching(val, r=ratio)
            val = unmerge(merge(val))[0, :, 0]
        new_sd[key] = val
    return new_sd

In [97]:
modelcomp_sd = naively_compress_model(modela, ratio=.5)

In [223]:
modelc = resnet20(w=4).to(DEVICE)
modelc.load_state_dict(modelcomp_sd)
reset_bn_stats(modelc)
evaluate(modelc)

7344

## Naive just merge into B without changing inputs:
the strat

In [226]:
state_dict = {}

def interleave_vals(tensor1, tensor2):
    # Assume tensor is of shape [B,H,D]
    return torch.cat((tensor1.unsqueeze(2), tensor2.unsqueeze(2)), dim=2).flatten(1, 2)
#     return torch.cat((tensor1, tensor2), dim=1)

def prep_conv(weight):
    return weight.flatten(1)[None, ...]

def unprep_conv(weight, k=3):
    weight = weight[0]
    o, *_ = weight.shape
    return weight.reshape(o, -1, k, k)

def merge_conv(state_dict, prefix, a_conv, b_conv, in_merge=None, out_merge=None):
    a_c1 = prep_conv(a_conv.weight)
    b_c1 = prep_conv(b_conv.weight)
    
    c_c1 = interleave_vals(a_c1, b_c1)
    c1_merge, _ = bipartite_soft_matching(c_c1, r=0.5)
    
    if out_merge is not None:
        c1_merge = out_merge
    
    c_c1 = unprep_conv(c1_merge(c_c1))
    
    state_dict[prefix + ".weight"] = c_c1
    
    return c1_merge

def merge_bn(state_dict, prefix, a_bn, b_bn, merge):
    # weight, bias, running_mean, running_var
    
    a_stats = torch.stack([a_bn.weight, a_bn.bias, a_bn.running_mean], dim=1)[None, ...]
    b_stats = torch.stack([b_bn.weight, b_bn.bias, b_bn.running_mean], dim=1)[None, ...]
    
    c_stats = interleave_vals(a_stats, b_stats)
    c_stats = merge(c_stats)[0]
    
    c_weight, c_bias, c_mean = c_stats.unbind(dim=-1)
    
    c_var = interleave_vals(a_bn.running_var[None, :, None], b_bn.running_var[None, :, None])
    
    ones = c_var * 0 + 1
    c_denom = merge(ones, mode="sum")
    c_var = merge(c_var, mode="sum")
    c_var = c_var / (c_denom ** 2)
    c_var = c_var[0, :, 0]
    
    state_dict[prefix + ".weight"] = c_weight
    state_dict[prefix + ".bias"] = c_bias
    state_dict[prefix + ".running_mean"] = c_mean
    state_dict[prefix + ".running_var"] = c_var

def merge_block(state_dict, prefix, a, b, in_merge):
    c1_merge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, in_merge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_merge, out_merge=in_merge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
def merge_block_shortcut(state_dict, prefix, a, b, in_merge):
    c1_merge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, in_merge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_merge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
    s_merge = merge_conv(state_dict, prefix + ".shortcut.0", a.shortcut[0], b.shortcut[0], in_merge, out_merge=c2_merge)
    merge_bn(state_dict, prefix + ".shortcut.1", a.shortcut[1], b.shortcut[1], s_merge)
    return s_merge

def merge_resnet20(state_dict, a, b):
    conv1_merge = merge_conv(state_dict, "conv1", a.conv1, b.conv1, None)
    merge_bn(state_dict, "bn1", a.bn1, b.bn1, conv1_merge)
    
    for i in range(3):
        merge_block(state_dict, f"layer1.{i}", a.layer1[i], b.layer1[i], conv1_merge)
    
    conv1_merge = merge_block_shortcut(state_dict, "layer2.0", a.layer2[0], b.layer2[0], conv1_merge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer2.{i}", a.layer2[i], b.layer2[i], conv1_merge)
        
    conv1_merge = merge_block_shortcut(state_dict, "layer3.0", a.layer3[0], b.layer3[0], conv1_merge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer3.{i}", a.layer3[i], b.layer3[i], conv1_merge)

#     c_linear = interleave_vals(a.linear.weight.mT[None, ...], b.linear.weight.mT[None, ...])
#     c_linear = conv1_merge(c_linear)[0].mT
#     state_dict["linear.weight"] = c_linear
    state_dict["linear.weight"] = b.linear.weight # (a.linear.weight + b.linear.weight) / 2
    state_dict["linear.bias"] = b.linear.bias # (a.linear.bias + b.linear.bias) / 2
    
    pass


def stoopid_dumb(state_dict, a, b):
    for (k, v1), (_, v2) in zip(a.named_parameters(), b.named_parameters()):
        state_dict[k] = (v1 + v2) / 2
    for (k, v1), (_, v2) in zip(a.named_buffers(), b.named_buffers()):
        state_dict[k] = (v1 + v2) / 2


state_dict = {}
modelc = resnet20(w=4).to(DEVICE)
merge_resnet20(state_dict, modela, modelb)
# stoopid_dumb(state_dict, modela, modelb)
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
evaluate(modelc)

8834

## Merge Inputs and Outputs
don't think just do

In [395]:
state_dict = {}

def interleave_vals(tensor1, tensor2, dim=1):
    # Assume tensor is of shape [B,H,D]
    return torch.cat((tensor1.unsqueeze(dim+1), tensor2.unsqueeze(dim+1)), dim=dim+1).flatten(dim, dim+1)
#     return torch.cat((tensor1, tensor2), dim=dim)


def merge_inconv(state_dict, prefix, a_conv, b_conv):
    a_c1 = prep_conv(a_conv.weight)
    b_c1 = prep_conv(b_conv.weight)
    
    c_c1 = interleave_vals(a_c1, b_c1)
    c1_merge, _ = bipartite_soft_matching(c_c1, r=0.5)
    
    c_c1 = unprep_conv(c1_merge(c_c1))
    
    state_dict[prefix + ".weight"] = c_c1
    
    return c1_merge

def merge_conv(state_dict, prefix, a_conv, b_conv, in_merge, out_merge=None, eps=1e-7):
    def move_kernel_to_output(x): # [out, in, h, w]
        return x.transpose(0, 1).flatten(1)[None, ...] # output: [1, in, w*h*out]
    
    out, _in, k, _ = a_conv.weight.shape
    
    a_c1 = move_kernel_to_output(a_conv.weight)
    b_c1 = move_kernel_to_output(b_conv.weight)
    
    ones = torch.ones_like(a_c1)
    zeros = torch.zeros_like(a_c1)
    
    big_boy = torch.cat([interleave_vals(a_c1, zeros), interleave_vals(zeros, b_c1)], dim=-1)
    counter = torch.cat([interleave_vals(ones, zeros), interleave_vals(zeros, ones)], dim=-1)
    
    merged_boy = in_merge(big_boy, mode="sum")
    counter = in_merge(counter, mode="sum")
    
    a_boy, b_boy = merged_boy.view(1, counter.shape[1], 2, -1).unbind(-2)
    a_count, b_count = counter.view(1, counter.shape[1], 2, -1).unbind(-2)
    
    def move_kernel_to_input(x):
        return x.transpose(1, 2).reshape(1, -1, k*k*x.shape[1]) # [1, out, h*w*in]
    
    c_boy = interleave_vals(move_kernel_to_input(a_boy), move_kernel_to_input(b_boy))
    c_count = interleave_vals(move_kernel_to_input(a_count), move_kernel_to_input(b_count))
    
    c_boy = c_boy / (c_count + eps)
    
    
    if out_merge is None:
        out_merge, _ = bipartite_soft_matching(c_boy, r=0.5)
    
    c_boy = out_merge(c_boy, mode="sum")
    c_count = out_merge((c_count > eps).float(), mode="sum")
#     c_count = out_merge(c_count, mode="sum")
    
    c_boy = c_boy / (c_count + eps)
    
    
    c_boy = c_boy.reshape(1, out, k, k, _in)[0].permute(0, 3, 1, 2)
    
    state_dict[prefix + ".weight"] = c_boy
    return out_merge

def merge_bn(state_dict, prefix, a_bn, b_bn, merge):
    # weight, bias, running_mean, running_var
    
    a_stats = torch.stack([a_bn.weight, a_bn.bias, a_bn.running_mean], dim=1)[None, ...]
    b_stats = torch.stack([b_bn.weight, b_bn.bias, b_bn.running_mean], dim=1)[None, ...]
    
    c_stats = interleave_vals(a_stats, b_stats)
    c_stats = merge(c_stats)[0]
    
    c_weight, c_bias, c_mean = c_stats.unbind(dim=-1)
    
    c_var = interleave_vals(a_bn.running_var[None, :, None], b_bn.running_var[None, :, None])
    
    ones = c_var * 0 + 1
    c_denom = merge(ones, mode="sum")
    c_var = merge(c_var, mode="sum")
    c_var = c_var / (c_denom ** 2)
    c_var = c_var[0, :, 0]
    
    state_dict[prefix + ".weight"] = c_weight
    state_dict[prefix + ".bias"] = c_bias
    state_dict[prefix + ".running_mean"] = c_mean
    state_dict[prefix + ".running_var"] = c_var

def merge_block(state_dict, prefix, a, b, in_merge):
    c1_merge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, in_merge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_merge, out_merge=in_merge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
def merge_block_shortcut(state_dict, prefix, a, b, in_merge):
    c1_merge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, in_merge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_merge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
    s_merge = merge_conv(state_dict, prefix + ".shortcut.0", a.shortcut[0], b.shortcut[0], in_merge, out_merge=c2_merge)
    merge_bn(state_dict, prefix + ".shortcut.1", a.shortcut[1], b.shortcut[1], s_merge)
    return s_merge

class conv_wrapper:
    def __init__(self, linear):
        self.weight = linear.weight[:, :, None, None]

def merge_resnet20(state_dict, a, b): #, merge_output=False):
    conv1_merge = merge_inconv(state_dict, "conv1", a.conv1, b.conv1)
    merge_bn(state_dict, "bn1", a.bn1, b.bn1, conv1_merge)
    
    for i in range(3):
        merge_block(state_dict, f"layer1.{i}", a.layer1[i], b.layer1[i], conv1_merge)
    
    conv1_merge = merge_block_shortcut(state_dict, "layer2.0", a.layer2[0], b.layer2[0], conv1_merge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer2.{i}", a.layer2[i], b.layer2[i], conv1_merge)
        
    conv1_merge = merge_block_shortcut(state_dict, "layer3.0", a.layer3[0], b.layer3[0], conv1_merge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer3.{i}", a.layer3[i], b.layer3[i], conv1_merge)

#     if merge_output:
    merge_conv(state_dict, "linear", conv_wrapper(a.linear), conv_wrapper(b.linear), conv1_merge)
    state_dict["linear.weight"] = state_dict["linear.weight"][:, :, 0, 0]
#     else:
#         c_linear = interleave_vals(a.linear.weight.mT[None, ...], b.linear.weight.mT[None, ...])
#         c_linear = conv1_merge(c_linear)[0].mT
#         state_dict["linear.weight"] = c_linear
    state_dict["linear.bias"] = (a.linear.bias + b.linear.bias) / 2

# in_merge, _ = bipartite_soft_matching(torch.rand(1, 128, 20, device=DEVICE), r=0.5)

state_dict = {}
modelc = resnet20(w=4).to(DEVICE)
# merge_conv(state_dict, "layer1.0.conv1", modela.layer1[0].conv1, modelb.layer1[0].conv1, in_merge)
merge_resnet20(state_dict, modelb, modela)
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
evaluate(modelc)

4028

In [415]:
def make_mats(merge, unmerge, n, t, r):
    merge_mat   = merge(torch.eye(t, device=DEVICE)[None, ...].expand(n, t, t))
    unmerge_mat = unmerge(torch.eye(t-r, device=DEVICE)[None, ...].expand(n, t-r, t-r))
    return merge_mat, unmerge_mat

state_dict = {}

def interleave_vals(tensor1, tensor2, dim=1):
    # Assume tensor is of shape [B,H,D]
    return torch.cat((tensor1.unsqueeze(dim+1), tensor2.unsqueeze(dim+1)), dim=dim+1).flatten(dim, dim+1)
#     return torch.cat((tensor1, tensor2), dim=dim)

def unterleave_vals(tensor):
    return tensor[:, ::2, :], tensor[:, 1::2, :]
#     return tensor.chunk(2, dim=1)


def merge_inconv(state_dict, prefix, a_conv, b_conv):
    a_c1 = prep_conv(a_conv.weight)
    b_c1 = prep_conv(b_conv.weight)
    
    c_c1 = interleave_vals(a_c1, b_c1)
    merge, unmerge = bipartite_soft_matching(c_c1, r=0.5)
    
    _, t, _ = c_c1.shape
    r = int(0.5*t)
    
    _, out_unmerge = make_mats(merge, unmerge, 1, t, r)
    
    c_c1 = unprep_conv(merge(c_c1))
    
    state_dict[prefix + ".weight"] = c_c1
    
    return merge, out_unmerge

def merge_conv(state_dict, prefix, a_conv, b_conv, unmerge, out_merge=None, eps=1e-7):
    def move_kernel_to_output(x): # [out, in, h, w]
        return x.transpose(0, 1).flatten(1)[None, ...] # output: [1, in, w*h*out]
    
    out, _in, k, _ = a_conv.weight.shape
    
    a_c1 = move_kernel_to_output(a_conv.weight)
    b_c1 = move_kernel_to_output(b_conv.weight)
    
    unmerge_a, unmerge_b = unterleave_vals(unmerge)
    
    a_c1 = unmerge_a @ a_c1
    b_c1 = unmerge_b @ b_c1
    
    def move_kernel_to_input(x):
        return x.transpose(1, 2).reshape(1, -1, k*k*x.shape[1]) # [1, out, h*w*in]
    
    c_c1 = interleave_vals(move_kernel_to_input(a_c1), move_kernel_to_input(b_c1))
    
    if out_merge is None:
        out_merge, out_unmerge = bipartite_soft_matching(c_c1, r=0.5)
    else:
        out_unmerge = None
    
    _, t, _ = c_c1.shape
    r = int(0.5*t)
    
    if out_unmerge is not None:
        _, out_unmerge = make_mats(out_merge, out_unmerge, 1, t, r)
    
    c_c1 = out_merge(c_c1)
    c_c1 = c_c1.reshape(1, out, k, k, _in)[0].permute(0, 3, 1, 2)
    
    state_dict[prefix + ".weight"] = c_c1
    return out_merge, out_unmerge


def merge_bn(state_dict, prefix, a_bn, b_bn, merge):
    # weight, bias, running_mean, running_var
    
    a_stats = torch.stack([a_bn.weight, a_bn.bias, a_bn.running_mean], dim=1)[None, ...]
    b_stats = torch.stack([b_bn.weight, b_bn.bias, b_bn.running_mean], dim=1)[None, ...]
    
    c_stats = interleave_vals(a_stats, b_stats)
    c_stats = merge(c_stats)[0]
    
    c_weight, c_bias, c_mean = c_stats.unbind(dim=-1)
    
    c_var = interleave_vals(a_bn.running_var[None, :, None], b_bn.running_var[None, :, None])
    
    ones = c_var * 0 + 1
    c_denom = merge(ones, mode="sum")
    c_var = merge(c_var, mode="sum")
    c_var = c_var / (c_denom ** 2)
    c_var = c_var[0, :, 0]
    
    state_dict[prefix + ".weight"] = c_weight
    state_dict[prefix + ".bias"] = c_bias
    state_dict[prefix + ".running_mean"] = c_mean
    state_dict[prefix + ".running_var"] = c_var

def merge_block(state_dict, prefix, a, b, conv1_merge, conv1_unmerge):
    c1_merge, c1_unmerge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, conv1_unmerge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge, c2_unmerge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_unmerge, out_merge=conv1_merge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
def merge_block_shortcut(state_dict, prefix, a, b, conv1_merge, conv1_unmerge):
    c1_merge, c1_unmerge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, conv1_unmerge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge, c2_unmerge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_unmerge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
    s_merge, s_unmerge = merge_conv(state_dict, prefix + ".shortcut.0", a.shortcut[0], b.shortcut[0], conv1_unmerge, out_merge=c2_merge)
    merge_bn(state_dict, prefix + ".shortcut.1", a.shortcut[1], b.shortcut[1], s_merge)
    return s_merge, c2_unmerge

class conv_wrapper:
    def __init__(self, linear):
        self.weight = linear.weight[:, :, None, None]

def merge_resnet20(state_dict, a, b): #, merge_output=False):
    conv1_merge, conv1_unmerge = merge_inconv(state_dict, "conv1", a.conv1, b.conv1)
    merge_bn(state_dict, "bn1", a.bn1, b.bn1, conv1_merge)
    
    for i in range(3):
        merge_block(state_dict, f"layer1.{i}", a.layer1[i], b.layer1[i], conv1_merge, conv1_unmerge)
    
    conv1_merge, conv1_unmerge = merge_block_shortcut(state_dict, "layer2.0", a.layer2[0], b.layer2[0], conv1_merge, conv1_unmerge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer2.{i}", a.layer2[i], b.layer2[i], conv1_merge, conv1_unmerge)
        
    conv1_merge, conv1_unmerge = merge_block_shortcut(state_dict, "layer3.0", a.layer3[0], b.layer3[0], conv1_merge, conv1_unmerge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer3.{i}", a.layer3[i], b.layer3[i], conv1_merge, conv1_unmerge)

#     if merge_output:
    merge_conv(state_dict, "linear", conv_wrapper(a.linear), conv_wrapper(b.linear), conv1_unmerge)
    state_dict["linear.weight"] = state_dict["linear.weight"][:, :, 0, 0]
#     else:
#         c_linear = interleave_vals(a.linear.weight.mT[None, ...], b.linear.weight.mT[None, ...])
#         c_linear = conv1_merge(c_linear)[0].mT
#         state_dict["linear.weight"] = c_linear
    state_dict["linear.bias"] = (a.linear.bias + b.linear.bias) / 2

# in_merge, _ = bipartite_soft_matching(torch.rand(1, 128, 20, device=DEVICE), r=0.5)

state_dict = {}
modelc = resnet20(w=4).to(DEVICE)
# merge_conv(state_dict, "layer1.0.conv1", modela.layer1[0].conv1, modelb.layer1[0].conv1, in_merge)
merge_resnet20(state_dict, modelb, modela)
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
evaluate(modelc)

7489

In [204]:
modela.linear.weight.norm().item(), modelb.linear.weight.norm().item()

(6.92144775390625, 6.855426788330078)

In [189]:
evaluate(modela), evaluate(modelb)

(9536, 9510)

# Merge Inputs and Outputs Using all Corresponding  and Output Layers

In [324]:
state_dict = {}

def interleave_vals(tensor1, tensor2, dim=1):
    # Assume tensor is of shape [B,H,D]
    b, h, d = tensor1.shape
#     a = torch.cat((tensor1.unsqueeze(dim+1), tensor2.unsqueeze(dim+1)), dim=dim+1).flatten(dim, dim+1)
#     tensor1, tensor2 = a.view(b, 2, h, d).unbind(1)
    
    return torch.cat((tensor1.unsqueeze(dim+1), tensor2.unsqueeze(dim+1)), dim=dim+1).flatten(dim, dim+1)
#     return torch.cat((tensor1, tensor2), dim=dim)


def merge_inconv(state_dict, prefix, a_conv, b_conv):
    a_c1 = prep_conv(a_conv.weight)
    b_c1 = prep_conv(b_conv.weight)
    
    c_c1 = interleave_vals(a_c1, b_c1)
    c1_merge, _ = bipartite_soft_matching(c_c1, r=0.5)
    
    c_c1 = unprep_conv(c1_merge(c_c1))
    
    state_dict[prefix + ".weight"] = c_c1
    
    return c1_merge

def merge_conv(state_dict, prefix, a_conv, b_conv, in_merge, out_merge=None, eps=1e-7):
    def move_kernel_to_output(x): # [out, in, h, w]
        return x.transpose(0, 1).flatten(1)[None, ...] # output: [1, in, w*h*out]
    
    out, _in, k, _ = a_conv.weight.shape
    
    a_c1 = move_kernel_to_output(a_conv.weight)
    b_c1 = move_kernel_to_output(b_conv.weight)
    
    ones = torch.ones_like(a_c1)
    zeros = torch.zeros_like(a_c1)
    
    big_boy = torch.cat([interleave_vals(a_c1, zeros), interleave_vals(zeros, b_c1)], dim=-1)
    counter = torch.cat([interleave_vals(ones, zeros), interleave_vals(zeros, ones)], dim=-1)
    
    merged_boy = in_merge(big_boy, mode="sum")
    counter = in_merge(counter, mode="sum")
    
    a_boy, b_boy = merged_boy.view(1, counter.shape[1], 2, -1).unbind(-2)
    a_count, b_count = counter.view(1, counter.shape[1], 2, -1).unbind(-2)
    
    def move_kernel_to_input(x):
        return x.transpose(1, 2).reshape(1, -1, k*k*x.shape[1]) # [1, out, h*w*in]
    
    c_boy = interleave_vals(move_kernel_to_input(a_boy), move_kernel_to_input(b_boy))
    c_count = interleave_vals(move_kernel_to_input(a_count), move_kernel_to_input(b_count))
    
    c_boy = c_boy / (c_count + eps)
    
    
    if out_merge is None:
        out_merge, _ = bipartite_soft_matching(c_boy, r=0.5)
    
    c_boy = out_merge(c_boy, mode="sum")
    c_count = out_merge((c_count > eps).float(), mode="sum")
    
    c_boy = c_boy / (c_count + eps)
    
    
    c_boy = c_boy.reshape(1, out, k, k, _in)[0].permute(0, 3, 1, 2)
    
    state_dict[prefix + ".weight"] = c_boy
    return out_merge

def merge_bn(state_dict, prefix, a_bn, b_bn, merge):
    # weight, bias, running_mean, running_var
    
    a_stats = torch.stack([a_bn.weight, a_bn.bias, a_bn.running_mean], dim=1)[None, ...]
    b_stats = torch.stack([b_bn.weight, b_bn.bias, b_bn.running_mean], dim=1)[None, ...]
    
    c_stats = interleave_vals(a_stats, b_stats)
    c_stats = merge(c_stats)[0]
    
    c_weight, c_bias, c_mean = c_stats.unbind(dim=-1)
    
    c_var = interleave_vals(a_bn.running_var[None, :, None], b_bn.running_var[None, :, None])
    
    ones = c_var * 0 + 1
    c_denom = merge(ones, mode="sum")
    c_var = merge(c_var, mode="sum")
    c_var = c_var / (c_denom ** 2)
    c_var = c_var[0, :, 0]
    
    state_dict[prefix + ".weight"] = c_weight
    state_dict[prefix + ".bias"] = c_bias
    state_dict[prefix + ".running_mean"] = c_mean
    state_dict[prefix + ".running_var"] = c_var

def merge_block(state_dict, prefix, a, b, in_merge):
    c1_merge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, in_merge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_merge, out_merge=in_merge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
def merge_block_shortcut(state_dict, prefix, a, b, in_merge):
    c1_merge = merge_conv(state_dict, prefix + ".conv1", a.conv1, b.conv1, in_merge)
    merge_bn(state_dict, prefix + ".bn1", a.bn1, b.bn1, c1_merge)
    c2_merge = merge_conv(state_dict, prefix + ".conv2", a.conv2, b.conv2, c1_merge)
    merge_bn(state_dict, prefix + ".bn2", a.bn2, b.bn2, c2_merge)
    
    s_merge = merge_conv(state_dict, prefix + ".shortcut.0", a.shortcut[0], b.shortcut[0], in_merge, out_merge=c2_merge)
    merge_bn(state_dict, prefix + ".shortcut.1", a.shortcut[1], b.shortcut[1], s_merge)
    return s_merge

class conv_wrapper:
    def __init__(self, linear):
        self.weight = linear.weight[:, :, None, None]

def merge_resnet20(state_dict, a, b): #, merge_output=False):
    conv1_merge = merge_inconv(state_dict, "conv1", a.conv1, b.conv1)
    merge_bn(state_dict, "bn1", a.bn1, b.bn1, conv1_merge)
    
    for i in range(3):
        merge_block(state_dict, f"layer1.{i}", a.layer1[i], b.layer1[i], conv1_merge)
    
    conv1_merge = merge_block_shortcut(state_dict, "layer2.0", a.layer2[0], b.layer2[0], conv1_merge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer2.{i}", a.layer2[i], b.layer2[i], conv1_merge)
        
    conv1_merge = merge_block_shortcut(state_dict, "layer3.0", a.layer3[0], b.layer3[0], conv1_merge)
    for i in range(1, 3):
        merge_block(state_dict, f"layer3.{i}", a.layer3[i], b.layer3[i], conv1_merge)

#     if merge_output:
    merge_conv(state_dict, "linear", conv_wrapper(a.linear), conv_wrapper(b.linear), conv1_merge)
    state_dict["linear.weight"] = state_dict["linear.weight"][:, :, 0, 0]
#     else:
#         c_linear = interleave_vals(a.linear.weight.mT[None, ...], b.linear.weight.mT[None, ...])
#         c_linear = conv1_merge(c_linear)[0].mT
#         state_dict["linear.weight"] = c_linear
    state_dict["linear.bias"] = (a.linear.bias + b.linear.bias) / 2

# in_merge, _ = bipartite_soft_matching(torch.rand(1, 128, 20, device=DEVICE), r=0.5)

state_dict = {}
modelc = resnet20(w=4).to(DEVICE)
# merge_conv(state_dict, "layer1.0.conv1", modela.layer1[0].conv1, modelb.layer1[0].conv1, in_merge)
merge_resnet20(state_dict, modelb, modela)
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
evaluate(modelc)

9222

In [338]:
_, confusion = evaluate(modelc, return_confusion=True)

In [339]:
_, confusion_b = evaluate(modelb, return_confusion=True)
_, confusion_a = evaluate(modela, return_confusion=True)
print(np.diag(confusion).round(3).tolist())
print(np.diag(confusion_b).round(3).tolist())
print(np.diag(confusion_a).round(3).tolist())

[0.408, 0.606, 0.278, 0.25, 0.4, 0.278, 0.455, 0.385, 0.392, 0.455]
[0.4, 0.556, 0.29, 0.253, 0.392, 0.328, 0.465, 0.435, 0.476, 0.408]
[0.417, 0.625, 0.286, 0.253, 0.4, 0.303, 0.476, 0.392, 0.4, 0.465]


# CIFAR10 Disjoint Models

In [205]:
model1_classes= np.array([3, 2, 0, 6, 4])
model2_classes = np.array([5, 7, 9, 8, 1])

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

50000it [00:13, 3764.02it/s]
50000it [00:13, 3762.17it/s]
10000it [00:01, 5769.96it/s]
10000it [00:01, 5815.04it/s]


In [206]:
class_idxs = np.zeros(10, dtype=int)
class_idxs[model1_classes] = np.arange(5)
class_idxs[model2_classes] = np.arange(5)
class_idxs = torch.from_numpy(class_idxs)
class_idxs

tensor([2, 4, 1, 0, 4, 0, 3, 1, 3, 2])

In [207]:
import clip

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)
model, preprocess = clip.load('ViT-B/32', DEVICE)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)


text_features /= text_features.norm(dim=-1, keepdim=True)
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]

if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}.pth.tar'
    )
):
    print('training model...')
    model1 = resnet20(w=4).to(DEVICE)
    train(
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}', 
        model=model1, 
        class_vectors=class_vecs1,
        train_loader=train_aug_loader1,
        test_loader=test_loader1,
        remap_class_idxs=class_idxs
    )
if not os.path.exists(
    os.path.join(
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}.pth.tar'
    )
):
    print('training model...')
    model2 = resnet20(w=4).to(DEVICE)
    train(
        f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}', 
        model=model2, 
        class_vectors=class_vecs2,
        train_loader=train_aug_loader2,
        test_loader=test_loader2,
        remap_class_idxs=class_idxs
    )
      

In [297]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    confusion = np.zeros((10, 10))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                correct += (labels.to(DEVICE) == pred).sum().item()
            confusion[labels.cpu().numpy(), pred.cpu().numpy()] += 1
            total += inputs.shape[0]
    if return_confusion:
        return correct / total, confusion / confusion.sum(-1, keepdims=True)
    else:
        return correct / total

In [416]:
model1 = resnet20(w=4, text_head=True).to(DEVICE)
model2 = resnet20(w=4, text_head=True).to(DEVICE)
load_model(model1, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model2, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')

print(evaluate_texthead(model1, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model2, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.9558
0.9726


In [417]:
state_dict = {}
merge_resnet20(state_dict, model1, model2)
modelc = resnet20(w=4, text_head=True).to(DEVICE)
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc)
acc, confusion_disj = evaluate_texthead(modelc, test_loader, class_vectors=text_features, return_confusion=True)
acc

0.4407

In [418]:
print(np.diag(confusion_disj).tolist())

[0.16806722689075632, 0.3508771929824561, 0.0, 0.0, 0.0, 0.20833333333333334, 0.0, 0.22727272727272727, 0.23809523809523808, 0.3225806451612903]


In [352]:
acc, confusion_disj_1 = evaluate_texthead(model1, test_loader, class_vectors=text_features, return_confusion=True)
print(np.diag(confusion_disj_1).tolist())

[0.5263157894736842, 0.0, 0.26666666666666666, 0.2777777777777778, 0.4878048780487805, 0.0, 0.47619047619047616, 0.0, 0.1794871794871795, 0.0]


In [353]:
acc, confusion_disj_2 = evaluate_texthead(model2, test_loader, class_vectors=text_features, return_confusion=True)
print(np.diag(confusion_disj_2).tolist())

[0.08108108108108109, 0.5555555555555556, 0.0, 0.046511627906976744, 0.02702702702702703, 0.5128205128205128, 0.0, 0.5263157894736842, 0.5263157894736842, 0.40816326530612246]
