In [1]:
import os
import sys

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import pdb

In [3]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        '/srv/share/gstoica3/checkpoints/REPAIR/',
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)

In [4]:
cifar100_info = {
    'dir': '/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python',
    'classes1': np.arange(50),
    'classes2': np.arange(50, 100),
    'num_classes': 100,
    'split_classes': 50,
    'wrapper': torchvision.datasets.CIFAR100
}

cifar10_info = {
    'dir': '/tmp',
    'classes1': np.array([3, 2, 0, 6, 4]),
    'classes2': np.array([5, 7, 9, 8, 1]),
    'num_classes': 10,
    'split_classes': 5,
    'wrapper': torchvision.datasets.CIFAR10
}

ds_info = cifar10_info

In [5]:
import random
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [6]:
import numpy as np
import torch

from torch.utils.data import DataLoader
from torch.utils.data.sampler import BatchSampler


class BalancedBatchSampler(BatchSampler):
    """
    BatchSampler - from a MNIST-like dataset, samples n_classes and within these classes samples n_samples.
    Returns batches of size n_classes * n_samples
    """

    def __init__(self, dataset, n_classes, n_samples):
        loader = DataLoader(dataset)
        self.labels_list = []
        for _, label in loader:
            self.labels_list.append(label)
        self.labels = torch.LongTensor(self.labels_list)
        self.labels_set = list(set(self.labels.numpy()))
        self.label_to_indices = {label: np.where(self.labels.numpy() == label)[0]
                                 for label in self.labels_set}
        for l in self.labels_set:
            np.random.shuffle(self.label_to_indices[l])
        self.used_label_indices_count = {label: 0 for label in self.labels_set}
        self.count = 0
        self.n_classes = n_classes
        self.n_samples = n_samples
        self.dataset = dataset
        self.batch_size = self.n_samples * self.n_classes

    def __iter__(self):
        self.count = 0
        while self.count + self.batch_size <= len(self.dataset):
            classes = np.random.choice(self.labels_set, self.n_classes, replace=False)
            indices = []
            for class_ in classes:
                indices.extend(self.label_to_indices[class_][
                               self.used_label_indices_count[class_]:self.used_label_indices_count[
                                                                         class_] + self.n_samples])
                self.used_label_indices_count[class_] += self.n_samples
                if self.used_label_indices_count[class_] + self.n_samples > len(self.label_to_indices[class_]):
                    np.random.shuffle(self.label_to_indices[class_])
                    self.used_label_indices_count[class_] = 0
            yield indices
            self.count += self.n_classes * self.n_samples

    def __len__(self):
        return len(self.dataset) // self.batch_size

In [8]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = ds_info['wrapper'](root=ds_info['dir'], train=True,
                                        download=True, transform=train_transform)
test_dset = ds_info['wrapper'](root=ds_info['dir'], train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Using downloaded and verified file: /tmp/cifar-10-python.tar.gz
Extracting /tmp/cifar-10-python.tar.gz to /tmp
Files already downloaded and verified


In [9]:
model1_classes= ds_info['classes1']#np.array([3, 2, 0, 6, 4])
model2_classes = ds_info['classes2']

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

50000it [00:16, 3058.19it/s]
50000it [00:15, 3189.60it/s]
10000it [00:02, 4535.10it/s]
10000it [00:01, 5000.86it/s]


In [29]:
balanced_batched_sampler = BalancedBatchSampler(train_dset, n_classes=ds_info['num_classes'], n_samples=50)
balanced_train_aug_loader = torch.utils.data.DataLoader(
    train_dset, num_workers=8, batch_sampler=balanced_batched_sampler
)

In [11]:
class_idxs = np.zeros(ds_info['num_classes'], dtype=int)
class_idxs[model1_classes] = np.arange(ds_info['split_classes'])
class_idxs[model2_classes] = np.arange(ds_info['split_classes'])
class_idxs = torch.from_numpy(class_idxs)
print(class_idxs)

tensor([2, 4, 1, 0, 4, 0, 3, 1, 3, 2])


In [12]:
# use the train loader with data augmentation as this gives better results
def reset_bn_stats(model, epochs=1, loader=train_aug_loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    # run a single train epoch with augmentations to recalc stats
    model.train()
    for _ in range(epochs):
        with torch.no_grad(), autocast():
            for images, _ in loader:
                output = model(images.to(DEVICE))

In [13]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

In [14]:
import clip

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)
model, preprocess = clip.load('ViT-B/32', DEVICE)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)


text_features /= text_features.norm(dim=-1, keepdim=True)
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]

In [15]:
from resnets import resnet20


In [16]:
def match_tensors_exact_bipartite_cov(
    covariance,
    r=.5
):
    O = covariance.shape[0]
    O_half = O // 2
    remainder = int(O * (1-r))
    bound = O - remainder
    sims = covariance

#     plot_sims(sims, filename="conv1_covar_sorted.png")
    
    permutation_matrix = torch.zeros((O, O - bound), device=sims.device)

    for i in range(bound):
        best_idx = sims.view(-1).argmax()
        row_idx = best_idx % sims.shape[1]
        col_idx = best_idx // sims.shape[1]
        permutation_matrix[row_idx, i] = 1
        permutation_matrix[col_idx, i] = 1
        sims[row_idx] = -torch.inf
        sims[col_idx] = -torch.inf
        sims[:, row_idx] = -torch.inf
        sims[:, col_idx] = -torch.inf
    
    unused = (sims.max(-1)[0] > -torch.inf).to(torch.int).nonzero().view(-1)
    for i in range(bound, O-bound):
        permutation_matrix[unused[i-bound], i] = 1
    merge = permutation_matrix / (permutation_matrix.sum(dim=0, keepdim=True) + 1e-5)
    unmerge = permutation_matrix
    return merge.T, unmerge

def match_tensors_permute(
    covariance
):
#     torch.diagonal(covariance)[:] = -torch.inf
    corr_mtx_a = covariance.cpu().numpy()
    O = corr_mtx_a.shape[0]# // 2
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a[:O, :O], maximize=True)
#     assert (row_ind == np.arange(len(corr_mtx_a))).all()
    unmerge = torch.tensor(col_ind).long()
    unmerge = torch.eye(O, device=covariance.device)[unmerge]
    unmerge = torch.cat(
        (
            torch.eye(O, device=covariance.device),
            unmerge.T
        ), 
        dim=0
    )
    merge = unmerge / (unmerge.sum(dim=0, keepdim=True) + 1e-5)
    return merge.T, unmerge

In [17]:
def remove_col(x, idx):
    return torch.cat([x[:, :idx], x[:, idx+1:]], dim=-1)

def match_tensors_chain_cov(
    covariance,
    r=.5
):
    O = covariance.shape[0]
    O_half = O // 2
    remainder = int(O * (1-r))
    bound = O - remainder
    sims = covariance.clone()

    # result after alg should be (O, O-bound)
    permutation_matrix = torch.eye(O, O, device=sims.device)

    for i in range(bound):
        best_idx = sims.reshape(-1).argmax()
        row_idx = best_idx % sims.shape[1]
        col_idx = best_idx // sims.shape[1]
        
#         sizes = permutation_matrix.sum(dim=0)
        a = 0.5
        
        permutation_matrix[:, row_idx] += permutation_matrix[:, col_idx]
        permutation_matrix = remove_col(permutation_matrix, col_idx)
        
#         row_size = sizes[row_idx]
#         col_size = sizes[col_idx]
#         total_size = row_size + col_size
        
#         sims[:, row_idx] = (sims[:, row_idx] * row_size + sims[:, col_idx] * col_size) / total_size
        sims[:, row_idx] = torch.minimum(a * sims[:, row_idx], a * sims[:, col_idx])
        sims = remove_col(sims, col_idx)
        
#         sims[row_idx, :] = (sims[row_idx, :] * row_size + sims[col_idx, :] * col_size) / total_size
        sims[row_idx, :] = torch.minimum(a * sims[row_idx, :], a * sims[col_idx, :])
        sims = remove_col(sims.T, col_idx).T
    
#     unused = (sims.max(-1)[0] > -torch.inf).to(torch.int).nonzero().view(-1)
#     for i in range(bound, O-bound):
#         permutation_matrix[unused[i-bound], i] = 1
    merge = permutation_matrix / (permutation_matrix.sum(dim=0, keepdim=True) + 1e-5)
    unmerge = permutation_matrix
    return merge.T, unmerge


In [18]:
def match_tensors_consecutive_chain_cov(
    covariance,
    r=.5
):
    O = covariance.shape[0]
    O_half = O // 2
    remainder = int(O * (1-r))
    bound = O - remainder
    sims = covariance
    
    permutation_matrix = torch.zeros((O, O - bound), device=sims.device)
    temp_sims = torch.ones((O), device=covariance.device) * -torch.inf
    temp_count = 0
    ignore_temp = True
    for i in range(bound):
        best_idx = sims.view(-1).argmax()
        best_sim = sim.view(-1).max()
        if not ignore_temp and temp_sims.max() > best_sim:
            temp_best = temp_sims.max()
            temp_idx = temp_sims.argmax()
            permutation_matrix[temp_idx, i] = 1
#             temp_sims = (temp_sims * temp_count + sims[temp_idx]) / (temp_count + 1)
            temp_sims = torch.minimum(temp_sims, sims[temp_idx])
            sims[temp_idx] = -torch.inf
            sims[:, temp_idx] = -torch.inf
#             temp_sims = torch.minimum(temp_sims, sims[temp_idx])
            temp_count += 1
        elif not ignore_temp:
            temp_sims = torch.ones((O), device=covariance.device) * -torch.inf
            temp_count = 0
            ignore_temp = True
        if ignore_temp:
            row_idx = best_idx % sims.shape[1]
            col_idx = best_idx // sims.shape[1]
            
#             temp_sims = (sims[row_idx] + sims[col_idx]) / 2
            temp_sims = torch.minimum(sims[row_idx], sims[col_idx])
            temp_count = 2
            ignore_temp = False
            
            permutation_matrix[row_idx, i] = 1
            permutation_matrix[col_idx, i] = 1

            sims[row_idx] = -torch.inf
            sims[col_idx] = -torch.inf
            sims[:, row_idx] = -torch.inf
            sims[:, col_idx] = -torch.inf
            
    
    unused = (sims.max(-1)[0] > -torch.inf).to(torch.int).nonzero().view(-1)
    for i in range(bound, O-bound):
        permutation_matrix[unused[i-bound], i] = 1
    merge = permutation_matrix / (permutation_matrix.sum(dim=0, keepdim=True) + 1e-5)
    unmerge = permutation_matrix
    return merge.T, unmerge


In [19]:
def mix_weights(sd0, sd1, target, alpha):
    sd_alpha = {}
    for k in sd0.keys():
        param0 = sd0[k].to(DEVICE)
        param1 = sd1[k].to(DEVICE)
        sd_alpha[k] = (1 - alpha) * param0 + alpha * param1
    target.load_state_dict(sd_alpha)

## Single Pass

In [20]:
from collections import defaultdict
import torch.nn.functional as F

In [21]:
class SinglePassResNet20(object):
    def __init__(self, model0, model1):
        super(SinglePassResNet20, self).__init__()
        self.model0 = model0
        self.model1 = model1
        self.layer2cov = defaultdict(lambda: None)
        self.output_pattern = lambda model, x: F.relu(model.bn1(model.conv1(x)))
        
    def get_layer_intermediates(self, batch0, batch1, layer0, layer1, label, n):
        zipped_layers = zip(list(layer0), list(layer1))
        for i, (inter0, inter1) in enumerate(zipped_layers):
            out0 = self.output_pattern(inter0, batch0)
            out1 = self.output_pattern(inter1, batch1)
            
            self.batch_update_covs(
                label+f'.{i}', 
                out0=out0, out1=out1,
                n=n
            )
            
            batch0 = inter0(batch0)
            batch1 = inter1(batch1)
        self.batch_update_covs(
            label, out0=batch0, out1=batch1, n=n
        )
        return batch0, batch1
   
   # Jakob needs to write a capture intermediate value on the vit... Can do this with hooks.
    def get_intermediate_scores(self, dataloader, num_passes=1, num_steps=None):
        n = len(dataloader) * num_passes
        with torch.no_grad():
            self.model0.eval()
            self.model1.eval()
        
        for _ in range(num_passes):
            for i, (images, _) in enumerate(tqdm(dataloader)):
                if num_steps is not None and i >= num_steps: break
                inputs = images.float().to(DEVICE)
                out0 = F.relu(self.model0.bn1(self.model0.conv1(inputs)))
                out1 = F.relu(self.model1.bn1(self.model1.conv1(inputs)))

                out0, out1 = self.get_layer_intermediates(
                    out0, out1, self.model0.layer1, self.model1.layer1, 'layer1', n
                )
                out0, out1 = self.get_layer_intermediates(
                    out0, out1, self.model0.layer2, self.model1.layer2, 'layer2', n
                )
                out0, out1 = self.get_layer_intermediates(
                    out0, out1, self.model0.layer3, self.model1.layer3, 'layer3', n
                )
        
        for layer, cov in tqdm(self.layer2cov.items()):
            mean = cov['mean']
            std = cov['std']
            outer = cov['outer']
            
            cov = outer - torch.outer(mean, mean)
            cov /= (torch.outer(std, std) + 1e-4)
            torch.diagonal(cov)[:] = -torch.inf
            self.layer2cov[layer] = cov
    
    def batch_update_covs(self, label, out0, out1, n):
        out0 = out0.reshape(out0.shape[0], out0.shape[1], -1).permute(0, 2, 1)
        out0 = out0.reshape(-1, out0.shape[2]).double()
        out1 = out1.reshape(out1.shape[0], out1.shape[1], -1).permute(0, 2, 1)
        out1 = out1.reshape(-1, out1.shape[2]).double()
        out = torch.cat((out0, out1), dim=-1)
        mean_b = out.mean(dim=0).detach().cpu()
        std0_b = out0.std(dim=0).detach().cpu()
        std1_b = out1.std(dim=0).detach().cpu()
        std_b = torch.cat((std0_b, std1_b), dim=-1)
        outer_b = ((out.T @ out) / out.shape[0]).detach().cpu()

        if self.layer2cov[label] is None:
            self.layer2cov[label] = {}
            self.layer2cov[label]['mean'] = torch.zeros_like(mean_b)
            self.layer2cov[label]['std'] = torch.zeros_like(std_b)
            self.layer2cov[label]['outer'] = torch.zeros_like(outer_b)
        
        self.layer2cov[label]['mean'] += mean_b / n
        self.layer2cov[label]['std'] += std_b / n
        self.layer2cov[label]['outer'] += outer_b / n
    
    def prepare_models(self, cov2update_modules, transformer, precomputed_merges=None):
        layer2merges = {}
        for layer, cov in self.layer2cov.items():
            if precomputed_merges is None:
                merge, unmerge = transformer(cov)
                layer2merges[layer] = (merge, unmerge)
            else:
                layer2merges[layer] = precomputed_merges[layer]
                merge, unmerge = precomputed_merges[layer]
            output_modules = self.get_model_attributes(cov2update_modules[layer]['output'])
            input_modules = self.get_model_attributes(cov2update_modules[layer]['input'])
            self.transform_output_space(merge, output_modules)
            self.transform_input_space(unmerge, input_modules)
        
        return layer2merges
    
    def get_model_attributes(self, components):
        parameter_tuples = []
        # getattr(getattr(getattr(modelc, 'layer1'), '0'), 'conv2')
        for component in components['non_bn']:
            edges = component.split('.')
            node0 = getattr(self.model0, edges[0])
            node1 = getattr(self.model1, edges[0])
            for edge in edges[1:]:
                node0 = getattr(node0, edge)
                node1 = getattr(node1, edge)
            parameter_tuples.append(
                [node0.weight, node1.weight]
            )
        
        if 'bn' in components:
            for component in components['bn']:
                edges = component.split('.')
                node0 = getattr(self.model0, edges[0])
                node1 = getattr(self.model1, edges[0])
                for edge in edges[1:]:
                    node0 = getattr(node0, edge)
                    node1 = getattr(node1, edge)
                parameter_tuples.append(
                    [n.weight for n in [node0, node1]]
                )
                parameter_tuples.append(
                    [n.bias for n in [node0, node1]]
                )
                parameter_tuples.append(
                    [n.running_mean for n in [node0, node1]]
                )
                parameter_tuples.append(
                    [n.running_var for n in [node0, node1]]
                )
        return parameter_tuples
    
    def transform_output_space(self, merge, list_of_modules):
        a_w, b_w = (merge * 2).to(list_of_modules[0][0].device).chunk(2, dim=1)
        for modules in list_of_modules:
            a, b = modules[0], modules[1]
            if len(a.shape) == 4:
                t_a = torch.einsum('ab,bcde->acde', a_w, a)
                t_b = torch.einsum('ab,bcde->acde', b_w, b)
            else:
                t_a = a_w @ a
                t_b = b_w @ b
            a.data = t_a
            b.data = t_b

    def transform_input_space(self, unmerge, list_of_modules):
        a_w, b_w = unmerge.to(list_of_modules[0][0].device).chunk(2, dim=0)
        for modules in list_of_modules:
            a, b = modules[0], modules[1]
            try:
                if len(a.shape) == 4:
                    a.data = torch.einsum('abcd,be->aecd', a, a_w)
                    b.data = torch.einsum('abcd,be->aecd', b, b_w)
                else:
                    a.data = (a @ a_w)
                    b.data = (b @ b_w)
            except:
                pdb.set_trace()

In [22]:
intermediate_layers = [
    'layer1.0',
    'layer1.1',
    'layer1.2',
    'layer2.0',
    'layer2.1',
    'layer2.2',
    'layer3.0',
    'layer3.1',
    'layer3.2'
]

In [23]:
cov2update_modules = {
    'layer1': {
        'output': {
            'non_bn': ['conv1', 'layer1.0.conv2', 'layer1.1.conv2', 'layer1.2.conv2'],
            'bn': ['bn1', 'layer1.0.bn2', 'layer1.1.bn2', 'layer1.2.bn2']
        },
        'input': {
            'non_bn': ['layer1.0.conv1', 'layer1.1.conv1', 'layer1.2.conv1', 
                     'layer2.0.conv1', 'layer2.0.shortcut.0'],
        }
    },
    'layer2': {
        'output': {
            'non_bn': ['layer2.0.conv2', 'layer2.1.conv2', 'layer2.2.conv2', 'layer2.0.shortcut.0'],
            'bn': ['layer2.0.bn2', 'layer2.1.bn2', 'layer2.2.bn2', 'layer2.0.shortcut.1']
        },
        'input': {
            'non_bn': ['layer2.1.conv1', 'layer2.2.conv1', 'layer3.0.conv1', 'layer3.0.shortcut.0']
        }
    },
    'layer3': {
        'output': {
            'non_bn': ['layer3.0.conv2', 'layer3.1.conv2', 'layer3.2.conv2', 'layer3.0.shortcut.0'],
            'bn': ['layer3.0.bn2', 'layer3.1.bn2', 'layer3.2.bn2', 'layer3.0.shortcut.1']
        },
        'input': {
            'non_bn': ['layer3.1.conv1', 'layer3.2.conv1', 'linear']
        }
    },
}

for layer_name in intermediate_layers:
    cov2update_modules[layer_name] = {
        'output': {
            'non_bn': [f'{layer_name}.conv1'],
            'bn': [f'{layer_name}.bn1']
        },
        'input': {
            'non_bn': [f'{layer_name}.conv2']
        }
    }

In [30]:
model0 = resnet20(w=4, text_head=True).to(DEVICE)
model1 = resnet20(w=4, text_head=True).to(DEVICE)
load_model(model0, f'resnet20x4_CIFAR5_clses{model1_classes.tolist()}')
load_model(model1, f'resnet20x4_CIFAR5_clses{model2_classes.tolist()}')
print(evaluate_texthead(model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.9558
0.9726


In [31]:
spr20 = SinglePassResNet20(model0=model0, model1=model1)

In [32]:
spr20.get_intermediate_scores(balanced_train_aug_loader, num_passes=1, num_steps=100)

100%|█████████████████████████████████████████| 100/100 [01:05<00:00,  1.53it/s]
100%|██████████████████████████████████████████| 12/12 [00:00<00:00, 712.41it/s]


In [33]:
layer2merges_ = spr20.prepare_models(
    cov2update_modules, transformer=match_tensors_chain_cov # match_tensors_permute this runs the git rebasin.
)

In [34]:
modelc = resnet20(w=4, text_head=True).to(DEVICE)
mix_weights(sd0=spr20.model0.state_dict(), sd1=spr20.model1.state_dict(), target=modelc, alpha=.5)
modelc = modelc.to(DEVICE)
reset_bn_stats(modelc, loader=train_aug_loader)
acc, perclass_acc = evaluate_texthead(
    modelc, test_loader, class_vectors=text_features, return_confusion=True
)
acc

0.7421

In [35]:
reset_bn_stats(spr20.model0, loader=train_aug_loader)
reset_bn_stats(spr20.model1, loader=train_aug_loader)
print(evaluate_texthead(spr20.model0, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(spr20.model1, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.87
0.9298


In [40]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

def evaluate_ensemble(
    modela,
    modelb,
    loader, 
    class_vectors, 
    remap_class_idxs=None,
    return_confusion=False
):
    modela.eval()
    modelb.eval()
    correct = 0
    total = 0
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encoding_a = modela(inputs.to(DEVICE))
            encoding_b = modelb(inputs.to(DEVICE))
            
            normed_encodings_a = encoding_a / encoding_a.norm(dim=-1, keepdim=True)
            outputs_a = normed_encodings_a @ class_vectors.T
            
            normed_encodings_b = encoding_b / encoding_b.norm(dim=-1, keepdim=True)
            outputs_b = normed_encodings_b @ class_vectors.T
            
            preda_score, preda_idx = outputs_a.max(dim=1)
            predb_score, predb_idx = outputs_b.max(dim=1)
            pred = preda_idx
            pred[preda_score < predb_score] = predb_idx[preda_score < predb_score]# + (num_classes // 2)
            
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total


In [42]:
acc = evaluate_ensemble(
    spr20.model0, spr20.model1, test_loader, class_vectors=text_features, return_confusion=False
)

In [43]:
acc

0.3839