In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [2]:
from tqdm import tqdm
import numpy as np
import scipy.optimize

import torch
from torch import nn
import torchvision.transforms as T
import torchvision.models
from torch.cuda.amp import autocast

from utils import *

In [3]:
with open("/home/rohan/code/sparse-rebasin-minimal/sparse-rebasin/configs_imagenet/config_1_80_imagenet.yaml", "r") as f:
    config = yaml.safe_load(f)

config['save_path']= f"{config['save_path']}width_{config['width_multiplier']}/sparsity_{config['pruning']['sparsity']*100}/"

In [4]:
def get_blocks(net):
    return nn.Sequential(nn.Sequential(net.conv1, net.bn1, net.relu, net.maxpool),
                         *net.layer1, *net.layer2, *net.layer3, *net.layer4)

In [5]:
train_dl, test_dl = cifar_dataloader(config["batch_size"], config)


Load ImageNet dataset


In [6]:
def run_corr_matrix(net0, net1, loader, device="cuda"):
    """
    Given two networks net0, net1 which each output a feature map of shape NxCxWxH, this will reshape both outputs to (N*W*H)xC 
    and then compute a CxC correlation matrix between the two.
    """
    n = len(loader)
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for i, (images, _) in enumerate(tqdm(loader)):
            
            img_t = images.float().cuda()
            out0 = net0(img_t).double()
            out0 = out0.permute(0, 2, 3, 1).reshape(-1, out0.shape[1])
            out1 = net1(img_t).double()
            out1 = out1.permute(0, 2, 3, 1).reshape(-1, out1.shape[1])

            # save batchwise first+second moments and outer product
            mean0_b = out0.mean(dim=0)
            mean1_b = out1.mean(dim=0)
            sqmean0_b = out0.square().mean(dim=0)
            sqmean1_b = out1.square().mean(dim=0)
            outer_b = (out0.T @ out1) / out0.shape[0]
            if i == 0:
                mean0 = torch.zeros_like(mean0_b)
                mean1 = torch.zeros_like(mean1_b)
                sqmean0 = torch.zeros_like(sqmean0_b)
                sqmean1 = torch.zeros_like(sqmean1_b)
                outer = torch.zeros_like(outer_b)
            mean0 += mean0_b / n
            mean1 += mean1_b / n
            sqmean0 += sqmean0_b / n
            sqmean1 += sqmean1_b / n
            outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    std0 = (sqmean0 - mean0**2).sqrt()
    std1 = (sqmean1 - mean1**2).sqrt()
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [7]:
def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    corr_mtx_a = np.nan_to_num(corr_mtx_a)
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    return perm_map

# returns the channel-permutation to make layer1's activations most closely
# match layer0's. --> so this is permuting model1  --> model0 (i.e. π(net1))
def get_layer_perm(net0, net1, loader):
    corr_mtx = run_corr_matrix(net0, net1, loader)
    return get_layer_perm1(corr_mtx)

In [8]:
def permute_output(perm_map, conv, bn=None):
    pre_weights = [conv.weight]
    if bn is not None:
        pre_weights.extend([bn.weight, bn.bias, bn.running_mean, bn.running_var])
    for w in pre_weights:
        w.data = w[perm_map]

def permute_input(perm_map, layer):
    w = layer.weight
    w.data = w[:, perm_map]

In [9]:
resnet50_1 = get_model(config)
path1 = "/scratch/rohan/test/imagenet_testing/width_1/sparsity_80.0/Model_A_Dense_sparsity_0.8_seed_11_epoch_5"
resnet50_2 = get_model(config)
path2 = "/scratch/rohan/test/imagenet_testing/width_1/sparsity_80.0/Model_B_Dense_sparsity_0.8_seed_11_epoch_5"

resnet50_1.load_state_dict(torch.load(path1))
resnet50_2.load_state_dict(torch.load(path2))

blocks0 = get_blocks(resnet50_1)
blocks1 = get_blocks(resnet50_2)
evaluate(resnet50_1,test_dl,config["device"]), evaluate(resnet50_2,test_dl,config["device"])

((56.718, 1.8221810612143303), (58.089999999999996, 1.7481372061432625))

In [10]:
# Restrict the permutations such that the network is not functionally changed.
# In particular, the same permutation must be applied to every conv output in a residual stream.
def get_permk(k):
    if k == 0:
        return 0
    elif k > 0 and k <= 3:
        return 3
    elif k > 3 and k <= 7:
        return 7
    elif k > 7 and k <= 13:
        return 13
    elif k > 13 and k <= 16:
        return 16
    else:
        raise Exception()

In [11]:
def permute_model_resnet50(model0, model1, loader, config):
    last_kk = None
    blocks0 = get_blocks(model0)
    blocks1 = get_blocks(model1)
    
    for k in range(1, len(blocks1)):
        block0 = blocks0[k]
        block1 = blocks1[k]
        subnet0 = nn.Sequential(blocks0[:k],
                                block0.conv1, block0.bn1, block0.relu)
        subnet1 = nn.Sequential(blocks1[:k],
                                block1.conv1, block1.bn1, block1.relu)
        perm_map = get_layer_perm(subnet0, subnet1, train_dl)
        permute_output(perm_map, block1.conv1, block1.bn1)
        permute_input(perm_map, block1.conv2)
        
        subnet0 = nn.Sequential(blocks0[:k],
                                block0.conv1, block0.bn1, block0.relu,
                                block0.conv2, block0.bn2, block0.relu)
        subnet1 = nn.Sequential(blocks1[:k],
                                block1.conv1, block1.bn1, block1.relu,
                                block1.conv2, block1.bn2, block1.relu)
        perm_map = get_layer_perm(subnet0, subnet1, loader)
        permute_output(perm_map, block1.conv2, block1.bn2)
        permute_input(perm_map, block1.conv3)
    
    for k in range(len(blocks1)):
        kk = get_permk(k)
        if kk != last_kk:
            perm_map = get_layer_perm(blocks0[:kk+1], blocks1[:kk+1], loader)
            last_kk = kk
        
        if k > 0:
            permute_output(perm_map, blocks1[k].conv3, blocks1[k].bn3)
            shortcut = blocks1[k].downsample
            if shortcut:
                permute_output(perm_map, shortcut[0], shortcut[1])
        else:
            permute_output(perm_map, model1.conv1, model1.bn1)
        
        if k+1 < len(blocks1):
            permute_input(perm_map, blocks1[k+1].conv1)
            shortcut = blocks1[k+1].downsample
            if shortcut:
                permute_input(perm_map, shortcut[0])
        else:
            permute_input(perm_map, model1.fc)
    
    return model1

In [17]:
permuted_model_2 = permute_model_resnet50(resnet50_1, resnet50_2, train_dl, config)
torch.save(permuted_model_2.state_dict(), "/home/rohan/code/sparse-rebasin-minimal/sparse-rebasin/notebooks/artifacts/permuted_model_2.pth")

100%|██████████| 5004/5004 [14:24<00:00,  5.79it/s]
 65%|██████▍   | 3252/5004 [09:09<03:43,  7.82it/s]

In [29]:
def evaluate_merged_models(model0, permuted_model_1, alpha, config):
    model = get_model(config)
    m1, m2 = model0.state_dict(), permuted_model_1.state_dict()
    sd_alpha = {k: (1 - alpha) * m1[k].cuda() + alpha * m2[k].cuda()
                for k in m1.keys()
                if k in m2}
    model.load_state_dict(sd_alpha, strict=False)
    return model

In [31]:
# reset all tracked BN stats against training data
def reset_bn_stats(model, loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    model.train()
    with torch.no_grad(), autocast():
        for images, _ in loader:
            output = model(images.cuda())

In [None]:
model_a = evaluate_merged_models(resnet50_1, permuted_model_2, 0.5, config)

print('(test_acc, test_loss):')
print('(α=0.5): %s\t\t<-- Merged model with neuron alignment', evaluate(model_a,test_dl,config["device"]))
reset_bn_stats(model_a, train_dl)
print('(α=0.5): %s\t\t<-- Merged model with alignment + BN reset', evaluate(model_a,test_dl,config["device"]))

In [None]:
class TrackLayer(nn.Module):
    def __init__(self, layer, one_d=False):
        super().__init__()
        self.layer = layer
        dim = layer.conv3.out_channels
        self.bn = nn.BatchNorm2d(dim)
        
    def get_stats(self):
        return (self.bn.running_mean, self.bn.running_var.sqrt())
        
    def forward(self, x):
        x1 = self.layer(x)
        self.bn(x1)
        return x1

class ResetLayer(nn.Module):
    def __init__(self, layer, one_d=False):
        super().__init__()
        self.layer = layer
        dim = layer.conv3.out_channels
        self.bn = nn.BatchNorm2d(dim)
        
    def set_stats(self, goal_mean, goal_std):
        self.bn.bias.data = goal_mean
        self.bn.weight.data = goal_std
        
    def forward(self, x):
        x1 = self.layer(x)
        return self.bn(x1)

# adds TrackLayer around each block
def make_tracked_net(net):
    net1 = get_model(config)
    net1.load_state_dict(net.state_dict())
    for i in range(4):
        layer = getattr(net1, 'layer%d' % (i+1))
        for j, block in enumerate(layer):
            layer[j] = TrackLayer(block).cuda()
    return net1

# adds ResetLayer around each block
def make_repaired_net(net):
    net1 = get_model(config)
    net1.load_state_dict(net.state_dict())
    for i in range(4):
        layer = getattr(net1, 'layer%d' % (i+1))
        for j, block in enumerate(layer):
            layer[j] = ResetLayer(block).cuda()
    return net1

In [None]:

model0 = evaluate_merged_models(resnet50_1, permuted_model_2, 0, config)
model1 = evaluate_merged_models(resnet50_1, permuted_model_2, 1, config)

## Calculate all neuronal statistics in the endpoint networks
wrap0 = make_tracked_net(model0)
wrap1 = make_tracked_net(model1)
reset_bn_stats(wrap0)
reset_bn_stats(wrap1)

In [None]:
alpha = 0.5
wrap_a = make_repaired_net(model_a)
# Iterate through corresponding triples of (TrackLayer, TrackLayer, ResetLayer)
# around conv layers in (model0, model1, model_a).
for track0, track1, reset_a in zip(wrap0.modules(), wrap1.modules(), wrap_a.modules()): 
    if not isinstance(track0, TrackLayer):
        continue  
    assert (isinstance(track0, TrackLayer)
            and isinstance(track1, TrackLayer)
            and isinstance(reset_a, ResetLayer))

    # get neuronal statistics of original networks
    mu0, std0 = track0.get_stats()
    mu1, std1 = track1.get_stats()
    # set the goal neuronal statistics for the merged network 
    goal_mean = (1 - alpha) * mu0 + alpha * mu1
    goal_std = (1 - alpha) * std0 + alpha * std1
    reset_a.set_stats(goal_mean, goal_std)

# Estimate mean/vars such that when added BNs are set to eval mode,
# neuronal stats will be goal_mean and goal_std.
reset_bn_stats(wrap_a)
print('(α=0.5): %s\t\t<-- Merged models with REPAIR' % evaluate(wrap_a,test_dl,config["device"]))