In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'
if DEVICE == 'mps':
    DOWNLOAD_PATH = '/Users/georgestoica/Downloads' 
else:
    DOWNLOAD_PATH = '/srv/share/gstoica3/checkpoints/REPAIR/'
    
torch.autograd.set_grad_enabled(False)

from copy import deepcopy
from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from resnets import resnet20
from matching_algs import *
from model_matchings import *

In [5]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)
    return model


In [24]:
cifar100_info = {
    'dir': '/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python',
    'classes1': np.arange(50),
    'classes2': np.arange(50, 100),
    'num_classes': 100,
    'split_classes': 50,
    'wrapper': torchvision.datasets.CIFAR100
}

cifar10_info = {
    'dir': '/tmp',
    'classes1': np.array([3, 2, 0, 6, 4]),
    'classes2': np.array([5, 7, 9, 8, 1]),
    'num_classes': 10,
    'split_classes': 5,
    'wrapper': torchvision.datasets.CIFAR10
}

ds_info = cifar100_info

In [25]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = ds_info['wrapper'](root=ds_info['dir'], train=True,
                                        download=True, transform=train_transform)
test_dset = ds_info['wrapper'](root=ds_info['dir'], train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [26]:
train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

model1_classes= ds_info['classes1']#np.array([3, 2, 0, 6, 4])
model2_classes = ds_info['classes2']

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

50000it [00:12, 3928.50it/s]
50000it [00:12, 3935.26it/s]
10000it [00:01, 6061.08it/s]
10000it [00:01, 6092.10it/s]


In [27]:
class_idxs = np.zeros(ds_info['num_classes'], dtype=int)
class_idxs[model1_classes] = np.arange(ds_info['split_classes'])
class_idxs[model2_classes] = np.arange(ds_info['split_classes'])
class_idxs = torch.from_numpy(class_idxs)
print(class_idxs)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,  0,  1,  2,  3,
         4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47, 48, 49])


In [28]:
# evaluates accuracy
def evaluate_texthead(model, loader, class_vectors, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    
    totals = [0] * class_vectors.shape[0]
    corrects = [0] * class_vectors.shape[0]
    
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            encodings = model(inputs.to(DEVICE))
            normed_encodings = encodings / encodings.norm(dim=-1, keepdim=True)
            outputs = normed_encodings @ class_vectors.T
            pred = outputs.argmax(dim=1)
            if remap_class_idxs is not None:
                correct += (remap_class_idxs[labels].to(DEVICE) == pred).sum().item()
            else:
                for gt, p in zip(labels, pred):
                    totals[gt] += 1
                    
                    if gt == p:
                        correct += 1
                        corrects[gt] += 1
                
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total

In [29]:
import clip

text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in test_dset.classes]).to(DEVICE)
model, preprocess = clip.load('ViT-B/32', DEVICE)
with torch.no_grad():
    text_features = model.encode_text(text_inputs)


text_features /= text_features.norm(dim=-1, keepdim=True)
class_vecs1 = text_features[model1_classes]
class_vecs2 = text_features[model2_classes]

In [31]:
model1 = resnet20(w=4, text_head=True).to(DEVICE)
model2 = resnet20(w=4, text_head=True).to(DEVICE)
load_model(model1, f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}')
load_model(model2, f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}')

print(evaluate_texthead(model1, test_loader1, class_vecs1, remap_class_idxs=class_idxs))
print(evaluate_texthead(model2, test_loader2, class_vecs2, remap_class_idxs=class_idxs))

0.778
0.7758


In [32]:
def find_transform_differences(old_transforms, current_transforms):
    if len(old_transforms) == 0:
        return {}
    transform2norm = {}
    for key, old_transform in old_transforms.items():
        current_transform = current_transforms[key]
        old_align = old_transform.output_align
        new_align = current_transform.output_align
        cost = old_align.T @ new_align
        row_ind, col_idx = scipy.optimize.linear_sum_assignment(cost.detach().cpu().numpy())
        permutation = torch.eye(new_align.shape[1], device=old_align.device)[col_idx]
        aligned_new = new_align @ permutation
#         pdb.set_trace()
        norm = torch.norm(old_align - aligned_new).cpu().numpy()
        transform2norm[key] = norm
    return transform2norm

In [40]:
r = 0.5
fn = match_tensors_exact_bipartite
set_r(r)
set_match_fn(fn)

match_tensors = match_wrapper(
    fn, 
    backend_alg=match_tensors_exact_bipartite,
    interleave=True, 
    random_perm=False
)
layer_transform = lambda : LayerTransform(normalize_tensors=True, tensor_merge_type='concat')
old_state_dict = {}
state_dict = {}
old_transforms = defaultdict(lambda: layer_transform())
new_transforms = defaultdict(lambda: layer_transform())
modelc = resnet20(w=4, text_head=True).to(DEVICE)
accuracies = []
steps = []
distances = []
best_info = {'acc': 0., 'dist': np.inf}
step = 1
is_converged = False
prev_distance = np.inf
same_window = 5
same_span = 0
while not is_converged:
# for step in tqdm(range(121)):
    old_transforms = new_transforms
    old_state_dict = deepcopy(state_dict)
    new_transforms = merge_resnet20(
        state_dict, 
        model1, 
        model2, 
        transforms=deepcopy(old_transforms),
        concat_head=False
    )
    if step == 0:
        original_computation = deepcopy(new_transforms)

    transform2dist = find_transform_differences(old_transforms, new_transforms)
    avg_distance = np.mean(list(transform2dist.values()))
    
    if abs(avg_distance - prev_distance) <= 1e-5:
        same_span += 1
    else:
        same_span = 0
    if same_span >= same_window:
        is_converged = True
        
    prev_distance = avg_distance
    if is_converged or step >= 1000:
        break
    step += 1

In [41]:
step

7

In [42]:
modelc.load_state_dict(state_dict)
reset_bn_stats(modelc, loader=train_aug_loader)
acc, perclass_acc = evaluate_texthead(
    modelc, test_loader, class_vectors=text_features, return_confusion=True
)
acc

0.3818

In [43]:
print(perclass_acc)

[0.13, 0.04, 0.18, 0.32, 0.33, 0.41, 0.31, 0.29, 0.61, 0.48, 0.27, 0.4, 0.33, 0.11, 0.26, 0.32, 0.27, 0.69, 0.52, 0.36, 0.79, 0.08, 0.27, 0.46, 0.65, 0.34, 0.19, 0.13, 0.37, 0.28, 0.3, 0.14, 0.15, 0.4, 0.39, 0.39, 0.56, 0.5, 0.24, 0.44, 0.35, 0.31, 0.58, 0.48, 0.26, 0.19, 0.16, 0.25, 0.65, 0.43, 0.24, 0.49, 0.5, 0.74, 0.54, 0.1, 0.79, 0.45, 0.79, 0.39, 0.37, 0.52, 0.43, 0.34, 0.29, 0.1, 0.41, 0.21, 0.68, 0.62, 0.38, 0.5, 0.4, 0.51, 0.17, 0.36, 0.64, 0.17, 0.59, 0.28, 0.08, 0.72, 0.29, 0.53, 0.14, 0.61, 0.4, 0.56, 0.28, 0.51, 0.46, 0.22, 0.33, 0.19, 0.58, 0.56, 0.14, 0.34, 0.42, 0.46]
