In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
import pdb
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import scipy.optimize

import torch
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from torch.nn import CrossEntropyLoss
from torch.optim import SGD, Adam, lr_scheduler
import torchvision
import torchvision.transforms as T

from sys import platform

DEVICE = 'mps' if platform == 'darwin' else 'cuda'
if DEVICE == 'mps':
    DOWNLOAD_PATH = '/Users/georgestoica/Downloads' 
else:
    DOWNLOAD_PATH = '/srv/share/gstoica3/checkpoints/REPAIR/'
    
torch.autograd.set_grad_enabled(False)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.grad_mode.set_grad_enabled at 0x7f3274739850>

In [3]:
from copy import deepcopy
from collections import defaultdict

In [4]:
def save_model(model, i):
    sd = model.state_dict()
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    torch.save(model.state_dict(), path)

def load_model(model, i):
    path = os.path.join(
        # '/Users/georgestoica/Downloads',
        DOWNLOAD_PATH,
        '%s.pth.tar' % i
    )
    sd = torch.load(path, map_location=torch.device(DEVICE))
    model.load_state_dict(sd)
    return model


In [26]:
cifar100_info = {
    'dir': '/nethome/gstoica3/research/pytorch-cifar100/data/cifar-100-python',
    'classes1': np.arange(50),
    'classes2': np.arange(50, 100),
    'num_classes': 100,
    'split_classes': 50
}

cifar10_info = {
    'dir': '/tmp',
    'classes1': np.array([3, 2, 0, 6, 4]),
    'classes2': np.array([5, 7, 9, 8, 1]),
    'num_classes': 10,
    'split_classes': 5
}

ds_info = cifar100_info

In [37]:
CIFAR_MEAN = [125.307, 122.961, 113.8575]
CIFAR_STD = [51.5865, 50.847, 51.255]
normalize = T.Normalize(np.array(CIFAR_MEAN)/255, np.array(CIFAR_STD)/255)
denormalize = T.Normalize(-np.array(CIFAR_MEAN)/np.array(CIFAR_STD), 255/np.array(CIFAR_STD))

train_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomCrop(32, padding=4),
    T.ToTensor(),
    normalize,
])
test_transform = T.Compose([
    T.ToTensor(),
    normalize,
])
train_dset = torchvision.datasets.CIFAR100(root=ds_info['dir'], train=True,
                                        download=True, transform=train_transform)
test_dset = torchvision.datasets.CIFAR100(root=ds_info['dir'], train=False,
                                        download=True, transform=test_transform)

train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [38]:
train_aug_loader = torch.utils.data.DataLoader(train_dset, batch_size=500, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dset, batch_size=500, shuffle=False, num_workers=8)

model1_classes= ds_info['classes1']#np.array([3, 2, 0, 6, 4])
model2_classes = ds_infoZ['classes2']

valid_examples1 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model1_classes]
valid_examples2 = [i for i, (_, label) in tqdm(enumerate(train_dset)) if label in model2_classes]

assert len(set(valid_examples1).intersection(set(valid_examples2))) == 0, 'sets should be disjoint'

train_aug_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples1), batch_size=500, shuffle=True, num_workers=8
)
train_aug_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(train_dset, valid_examples2), batch_size=500, shuffle=True, num_workers=8
)

test_valid_examples1 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model1_classes]
test_valid_examples2 = [i for i, (_, label) in tqdm(enumerate(test_dset)) if label in model2_classes]

test_loader1 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples1), batch_size=500, shuffle=False, num_workers=8
)
test_loader2 = torch.utils.data.DataLoader(
    torch.utils.data.Subset(test_dset, test_valid_examples2), batch_size=500, shuffle=False, num_workers=8
)

50000it [00:12, 3869.64it/s]
50000it [00:12, 3866.69it/s]
10000it [00:01, 5890.21it/s]
10000it [00:01, 5932.15it/s]


In [39]:
class_idxs = np.zeros(100, dtype=int)
class_idxs[model1_classes] = np.arange(ds_info['split_classes'])
class_idxs[model2_classes] = np.arange(ds_info['split_classes'])
class_idxs = torch.from_numpy(class_idxs)
print(class_idxs)

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49,  0,  1,  2,  3,
         4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21,
        22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39,
        40, 41, 42, 43, 44, 45, 46, 47, 48, 49])


In [31]:
from kmeans_pytorch import kmeans, kmeans_predict

In [None]:
from resnets import resnet20
from matching_algs import *
from model_matchings import *

In [40]:
modela = load_model(
    resnet20(w=4, num_classes=5).to(DEVICE),
    f'resnet20x4_CIFAR50_clses{model1_classes.tolist()}_onehot'
)

modelb = load_model(
    resnet20(w=4, num_classes=5).to(DEVICE),
    f'resnet20x4_CIFAR50_clses{model2_classes.tolist()}_onehot'
)

FileNotFoundError: [Errno 2] No such file or directory: '/srv/share/gstoica3/checkpoints/REPAIR/resnet20x4_CIFAR50_clses[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49]_onehot.pth.tar'

In [None]:
# evaluates accuracy
def evaluate(model, loader, num_classes, remap_class_idxs=None, return_confusion=False):
    model.eval()
    correct = 0
    total = 0
    totals = [0] * num_classes
    corrects = [0] * num_classes
    confusion = np.zeros((10, 10))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            logits = model(inputs.to(DEVICE))
            pred = logits.argmax(dim=1)
            if remap_class_idxs is not None:
                remaped_labels = remap_class_idxs[labels]
            else:
                remaped_labels = labels
                
            correct += (remaped_labels.to(DEVICE) == pred).sum().item()
            
            for gt, p in zip(remaped_labels, pred):
                totals[gt] += 1
                if gt == p:
                    corrects[gt] += 1
            
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total


In [None]:
evaluate(
    modela, test_loader1, num_classes=ds_info['num_classes'],  
    remap_class_idxs=class_idxs, return_confusion=False
)

In [None]:
evaluate(
    modelb, test_loader2, num_classes=ds_info['num_classes'], 
    remap_class_idxs=class_idxs, return_confusion=False
)

In [None]:
concat_class_idxs = deepcopy(class_idxs)
concat_class_idxs[model2_classes] += 5
concat_class_idxs

In [16]:
# def general_soft_matching(
#     hull_tensor,
#     interleave=False,
#     random=False,
#     r=.5
# ):  
#     hull_tensor = hull_tensor[0]
#     hull_normed = hull_tensor / hull_tensor.norm(dim=-1, keepdim=True)
    
#     bound = int(hull_tensor.shape[0] * (1-r))
    
#     sims = hull_normed @ hull_normed.transpose(-1, -2)
#     uppertri_indices = torch.triu_indices(sims.shape[-2], sims.shape[-1], offset=0)
#     sims[uppertri_indices[0], uppertri_indices[1]] = -torch.inf
#     candidate_scores, candidate_indices = sims.max(-1)
#     argsorted_scores = candidate_scores.argsort(descending=True)
#     merge_indices = argsorted_scores[:bound]
#     unmerge_indices = argsorted_scores[bound:]
    
#     roots = torch.arange(sims.shape[0], device=sims.device)
#     for _ in range(bound-1):
#         roots[merge_indices] = roots[candidate_indices[merge_indices]]
    
#     def merge(x, mode='mean'):
#         x = x[0]
#         merge_tensor = x.scatter_reduce(
#             0, 
#             roots[merge_indices][:, None].expand(bound, x.shape[1]),
#             x[merge_indices], 
#             reduce='mean'
#         )
#         unmerge_tensor = merge_tensor[unmerge_indices]
#         return unmerge_tensor[None]
    
#     def unmerge(x):
#         x = x[0]
#         out = torch.zeros((hull_tensor.shape[0], x.shape[1]), device=x.device)
#         out.scatter_(
#             0,
#             index=unmerge_indices[:, None].expand(*x.shape),
#             src=x
#         )
#         out = out.scatter(
#             0,
#             index=merge_indices[:, None].expand(*x.shape),
#             src=out[roots[merge_indices]]
#         )
#         return out[None]
    
#     return merge, unmerge

    
# def match_tensors_tome(
#     hull_tensor, eps=1e-7, interleave=False, random_perm=False, 
#     backend_alg=general_soft_matching
# ):
#     """
#     hull_tensor: [2O,I]
#     """
#     O, I = hull_tensor.shape
#     O //= 2
    
#     big_eye = torch.eye(2*O, device=hull_tensor.device)
#     small_eye = torch.eye(O, device=hull_tensor.device)
    
#     interleave_mat = big_eye
#     if interleave:
#         A1, A2, B1, B2 = interleave_mat.chunk(4, dim=0)
#         interleave_mat = torch.cat([A1, B1, A2, B2], dim=0)
    
    
#     hull_tensor = interleave_mat @ hull_tensor
    
#     merge, unmerge = backend_alg(hull_tensor[None], 0.5)
    
#     merge_mat = merge(big_eye[None])[0] @ interleave_mat
#     unmerge_mat = interleave_mat.T @ unmerge(small_eye[None])[0]
#     return merge_mat, unmerge_mat

# def kmeans_matching(
#     hull_tensor,
#     interleave=False,
#     random_perm=False,
#     r=.5
# ):
#     hull_normed = hull_tensor / hull_tensor.norm(dim=-1, keepdim=True)
#     O = hull_tensor.shape[0]
#     k = int(O * (1-r))
#     cluster_ids, cluster_centers = kmeans(
#         X=hull_normed, num_clusters=k, 
#         distance='cosine', 
#         device=hull_tensor.device,
#         tqdm_flag=False,
#         seed=123
#     )

#     eye = torch.eye(k, device=hull_tensor.device)
#     transform = eye[cluster_ids]

#     unmerge = transform
#     merge = (transform / transform.sum(dim=0, keepdim=True)).T
#     return merge, unmerge
    

In [None]:
def find_transform_differences(old_transforms, current_transforms):
    if len(old_transforms) == 0:
        return {}
    transform2norm = {}
    for key, old_transform in old_transforms.items():
        current_transform = current_transforms[key]
        old_align = old_transform.output_align
        new_align = current_transform.output_align
        cost = old_align.T @ new_align
        row_ind, col_idx = scipy.optimize.linear_sum_assignment(cost.detach().cpu().numpy())
        permutation = torch.eye(new_align.shape[1], device=old_align.device)[col_idx]
        aligned_new = new_align @ permutation
#         pdb.set_trace()
        norm = torch.norm(old_align - aligned_new).cpu().numpy()
        transform2norm[key] = norm
    return transform2norm

In [None]:
r = 0.5
fn = match_tensors_exact_bipartite
set_r(r)
set_match_fn(fn)

match_tensors = match_wrapper(fn, interleave=False, random_perm=False)
layer_transform = lambda : LayerTransform(normalize_tensors=False, tensor_merge_type='concat')
old_state_dict = {}
state_dict = {}
old_transforms = defaultdict(lambda: layer_transform())
new_transforms = defaultdict(lambda: layer_transform())
modelc = resnet20(w=4, text_head=True).to(DEVICE)
accuracies = []
steps = []
distances = []
best_info = {'acc': 0., 'dist': np.inf}
step = 1
is_converged = False
prev_distance = np.inf
same_window = 5
same_span = 0
while not is_converged:
# for step in tqdm(range(1000)):
    old_transforms = new_transforms
    old_state_dict = deepcopy(state_dict)
    new_transforms = merge_resnet20(
        state_dict, 
        modela, 
        modelb, 
        transforms=deepcopy(old_transforms),
        concat_head=True
    )
    if step == 0:
        original_computation = deepcopy(new_transforms)

    transform2dist = find_transform_differences(old_transforms, new_transforms)
    avg_distance = np.mean(list(transform2dist.values()))
    
    if abs(avg_distance - prev_distance) <= 1e-5:
        same_span += 1
    else:
        same_span = 0
    if same_span >= same_window:
        is_converged = True
        
    prev_distance = avg_distance

    if is_converged or step >= 1000:
        break
    step += 1

In [None]:
step

In [None]:
model_comp = resnet20(w=8 * (1-r), num_classes=ds_info['num_classes'], text_head=False).eval().to(DEVICE)
model_comp.load_state_dict(state_dict)
reset_bn_stats(model_comp, loader=train_aug_loader)
acc, confusion = evaluate(
    model_comp, test_loader, num_classes=ds_info['num_classes'], 
    remap_class_idxs=concat_class_idxs, return_confusion=True
)
print(step, acc * 100)

In [None]:
confusion

### Model Ensembling

In [20]:
# evaluates accuracy
def evaluate_ensemble(
    modela,
    modelb,
    loader, 
    num_classes, 
    remap_class_idxs=None, 
    return_confusion=False
):
    modela.eval()
    modelb.eval()
    correct = 0
    total = 0
    totals = [0] * num_classes
    corrects = [0] * num_classes
    confusion = np.zeros((10, 10))
    with torch.no_grad(), autocast():
        for inputs, labels in loader:
            logitsa = modela(inputs.to(DEVICE))
            logitsb = modelb(inputs.to(DEVICE))
            
            preda_score, preda_idx = logitsa.max(dim=1)
            predb_score, predb_idx = logitsb.max(dim=1)
            pred = preda_idx
            pred[preda_score < predb_score] = predb_idx[preda_score < predb_score] + (num_classes // 2)
            
            if remap_class_idxs is not None:
                remaped_labels = remap_class_idxs[labels]
            else:
                remaped_labels = labels
                
            correct += (remaped_labels.to(DEVICE) == pred).sum().item()
            
            for gt, p in zip(remaped_labels, pred):
                totals[gt] += 1
                if gt == p:
                    corrects[gt] += 1
            
            total += inputs.shape[0]
    if return_confusion:
        return correct / sum(totals), list(map(lambda a: a[0] / a[1], zip(corrects, totals)))
    else:
        return correct / total


In [21]:
ensemble_acc, ensemble_confusion = evaluate_ensemble(
    modela, modelb, test_loader, num_classes=ds_info['num_classes'],
    remap_class_idxs=concat_class_idxs, return_confusion=True
)

In [22]:
ensemble_acc

0.7853

In [23]:
ensemble_confusion

[0.721, 0.768, 0.871, 0.835, 0.83, 0.714, 0.826, 0.797, 0.727, 0.764]