In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "4"

In [3]:
from tqdm import tqdm
import numpy as np
import scipy.optimize

import torch
from torch import nn
import torchvision.transforms as T
import torchvision.models 
from torchvision.models import ResNet50_Weights 
from torch.cuda.amp import autocast

from utils import get_model, yaml, cifar_dataloader, evaluate, check_hooks, calculate_overall_sparsity_from_pth, transfer_sparsity_resnet

In [4]:
with open("./configs_imagenet/config_1_80_imagenet.yaml", "r") as f:
    config = yaml.safe_load(f)

config['save_path']= f"{config['save_path']}width_{config['width_multiplier']}/sparsity_{config['pruning']['sparsity']*100}/"

In [5]:
def resnet50_pretrained():
    model = torchvision.models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V1)
    return model.cuda().eval()
    

In [6]:
def get_blocks(net):
    return nn.Sequential(nn.Sequential(net.conv1, net.bn1, net.relu, net.maxpool),
                         *net.layer1, *net.layer2, *net.layer3, *net.layer4)

In [7]:
train_dl, test_dl = cifar_dataloader(config["batch_size"], config)


Load ImageNet dataset


In [None]:
# images, labels = next(iter(train_dl))
# single_batch = (images, labels)

# print(images.shape)

# batch_size = images.shape[0]
# print(f"Number of data points in the batch: {batch_size}")

### Below is run_corr_matrix for a single batch of data.

In [7]:
def run_corr_matrix_single(net0, net1, batch, device="cuda"):
    """
    Given two networks net0, net1 which each output a feature map of shape NxCxWxH, this will reshape both outputs to (N*W*H)xC 
    and then compute a CxC correlation matrix between the two.
    """
    images, labels = batch  # Unpack the single batch
    with torch.no_grad():
        net0.eval()
        net1.eval()
        
        img_t = images.float().cuda()
        out0 = net0(img_t).double()
        out0 = out0.permute(0, 2, 3, 1).reshape(-1, out0.shape[1])
        out1 = net1(img_t).double()
        out1 = out1.permute(0, 2, 3, 1).reshape(-1, out1.shape[1])

        # save batchwise first+second moments and outer product
        mean0_b = out0.mean(dim=0)
        mean1_b = out1.mean(dim=0)
        sqmean0_b = out0.square().mean(dim=0)
        sqmean1_b = out1.square().mean(dim=0)
        outer_b = (out0.T @ out1) / out0.shape[0]

        # Initialize accumulators
        mean0 = torch.zeros_like(mean0_b)
        mean1 = torch.zeros_like(mean1_b)
        sqmean0 = torch.zeros_like(sqmean0_b)
        sqmean1 = torch.zeros_like(sqmean1_b)
        outer = torch.zeros_like(outer_b)

        # Accumulate statistics
        mean0 += mean0_b
        mean1 += mean1_b
        sqmean0 += sqmean0_b
        sqmean1 += sqmean1_b
        outer += outer_b

    cov = outer - torch.outer(mean0, mean1)
    std0 = (sqmean0 - mean0**2).sqrt()
    std1 = (sqmean1 - mean1**2).sqrt()
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [8]:
def run_corr_matrix(net0, net1, loader, device="cuda"):
    """
    Given two networks net0, net1 which each output a feature map of shape NxCxWxH, this will reshape both outputs to (N*W*H)xC 
    and then compute a CxC correlation matrix between the two.
    """
    n = len(loader)
    with torch.no_grad():
        net0.eval()
        net1.eval()
        for i, (images, _) in enumerate(tqdm(loader)):
            
            img_t = images.float().cuda()
            out0 = net0(img_t).double()
            out0 = out0.permute(0, 2, 3, 1).reshape(-1, out0.shape[1])
            out1 = net1(img_t).double()
            out1 = out1.permute(0, 2, 3, 1).reshape(-1, out1.shape[1])

            # save batchwise first+second moments and outer product
            mean0_b = out0.mean(dim=0)
            mean1_b = out1.mean(dim=0)
            sqmean0_b = out0.square().mean(dim=0)
            sqmean1_b = out1.square().mean(dim=0)
            outer_b = (out0.T @ out1) / out0.shape[0]
            if i == 0:
                mean0 = torch.zeros_like(mean0_b)
                mean1 = torch.zeros_like(mean1_b)
                sqmean0 = torch.zeros_like(sqmean0_b)
                sqmean1 = torch.zeros_like(sqmean1_b)
                outer = torch.zeros_like(outer_b)
            mean0 += mean0_b / n
            mean1 += mean1_b / n
            sqmean0 += sqmean0_b / n
            sqmean1 += sqmean1_b / n
            outer += outer_b / n

    cov = outer - torch.outer(mean0, mean1)
    std0 = (sqmean0 - mean0**2).sqrt()
    std1 = (sqmean1 - mean1**2).sqrt()
    corr = cov / (torch.outer(std0, std1) + 1e-4)
    return corr

In [9]:
def get_layer_perm1(corr_mtx):
    corr_mtx_a = corr_mtx.cpu().numpy()
    corr_mtx_a = np.nan_to_num(corr_mtx_a)
    row_ind, col_ind = scipy.optimize.linear_sum_assignment(corr_mtx_a, maximize=True)
    assert (row_ind == np.arange(len(corr_mtx_a))).all()
    perm_map = torch.tensor(col_ind).long()
    return perm_map

# returns the channel-permutation to make layer1's activations most closely
# match layer0's. --> so this is permuting model1  --> model0 (i.e. π(net1))
def get_layer_perm(net0, net1, loader):
    corr_mtx = run_corr_matrix(net0, net1, loader)
    return get_layer_perm1(corr_mtx)

In [10]:
def permute_output(perm_map, conv, bn=None):
    pre_weights = [conv.weight]
    if bn is not None:
        pre_weights.extend([bn.weight, bn.bias, bn.running_mean, bn.running_var])
        
    if hasattr(conv, 'weight_orig'):
        pre_weights.append(conv.weight_orig)
    if hasattr(conv, 'weight_mask'):
        pre_weights.append(conv.weight_mask)
        
    for w in pre_weights:
        w.data = w[perm_map]

def permute_input(perm_map, layer):
    w = layer.weight
    w.data = w[:, perm_map]
    
    if hasattr(layer, 'weight_orig'):
        layer.weight_orig.data = layer.weight_orig[:, perm_map]
    if hasattr(layer, 'weight_mask'):
        layer.weight_mask.data = layer.weight_mask[:, perm_map]

In [11]:
# I am just intializing a pretrained ResNet50, so that I can quickly obtain a one-shot mask. And then feed
# the model_A_sparse into the permute model to see if it also permutes model_A_sparse.

resnet50_pretrained = resnet50_pretrained()
blocks = get_blocks(resnet50_pretrained)
evaluate(resnet50_pretrained, test_dl, config["device"])

(76.146, 0.9644105411609825)

In [None]:
import torch.nn.utils.prune as prune
import copy
from pruning import prune_model

results = {}
device = torch.device(config['device'] if torch.cuda.is_available() else 'cpu')
model_A_sparse = copy.deepcopy(resnet50_pretrained)
check_hooks(model_A_sparse) ## error since dense

prune_model(
    config=config,
    model=model_A_sparse,
    target_sparsity=config['pruning']['sparsity'],
    optimizer_config=config['optimizer'],
    prune_epochs=config['pruning']['prune_epochs'],
    initial_lr=config['pruning']['prune_lr'],
    batch_size=config['batch_size'],
    device=device,  
    initial_prune_perc=0.80,
    train_epochs_per_prune=config['pruning']['train_epochs_per_prune'],
)

original_model_A_sparse = copy.deepcopy(model_A_sparse)

results["A_sparse"] = evaluate(
    model_A_sparse, test_dl, config["device"]
)
print(f"A_sparse: {results['A_sparse'][0]:.2f}%, {results['A_sparse'][1]:.4f}")

In [14]:
torch.save(model_A_sparse, "/home/rohan/code/sparse-rebasin-minimal/sparse-rebasin/notebooks/artifacts/model_A_sparse.pth")

In [None]:
resnet50_1 = get_model(config)
path1 = "/scratch/rohan/test/imagenet_testing/width_1/sparsity_80.0/Model_A_Dense_sparsity_0.8_seed_11_epoch_5"
resnet50_2 = get_model(config)
path2 = "/scratch/rohan/test/imagenet_testing/width_1/sparsity_80.0/Model_B_Dense_sparsity_0.8_seed_11_epoch_5"

resnet50_1.load_state_dict(torch.load(path1))
resnet50_2.load_state_dict(torch.load(path2))

blocks0 = get_blocks(resnet50_1)
blocks1 = get_blocks(resnet50_2)
evaluate(resnet50_1,test_dl,config["device"]), evaluate(resnet50_2,test_dl,config["device"])

In [14]:
# Restrict the permutations such that the network is not functionally changed.
# In particular, the same permutation must be applied to every conv output in a residual stream.
def get_permk(k):
    if k == 0:
        return 0
    elif k > 0 and k <= 3:
        return 3
    elif k > 3 and k <= 7:
        return 7
    elif k > 7 and k <= 13:
        return 13
    elif k > 13 and k <= 16:
        return 16
    else:
        raise Exception()

In [15]:
def permute_model_resnet50(model0, model1, modelA_sparse, loader, config):
    last_kk = None
    blocks0 = get_blocks(model0)
    blocks1 = get_blocks(model1)
    blocks_sparse = get_blocks(modelA_sparse)
    
    for k in range(1, len(blocks0)):
        block0 = blocks0[k]
        block1 = blocks1[k]
        block_sparse = blocks_sparse[k]
        subnet0 = nn.Sequential(blocks0[:k],
                                block0.conv1, block0.bn1, block0.relu)
        subnet1 = nn.Sequential(blocks1[:k],
                                block1.conv1, block1.bn1, block1.relu)
        # perm_map = get_layer_perm(subnet0, subnet1, train_dl) --> original if you want to permute model1 to model0
        perm_map = get_layer_perm(subnet1, subnet0, loader)
        permute_output(perm_map, block0.conv1, block0.bn1)
        permute_input(perm_map, block0.conv2)
        permute_output(perm_map, block_sparse.conv1,block_sparse.bn1)
        permute_input(perm_map, block_sparse.conv2)
        
        subnet0 = nn.Sequential(blocks0[:k],
                                block0.conv1, block0.bn1, block0.relu,
                                block0.conv2, block0.bn2, block0.relu)
        subnet1 = nn.Sequential(blocks1[:k],
                                block1.conv1, block1.bn1, block1.relu,
                                block1.conv2, block1.bn2, block1.relu)
        perm_map = get_layer_perm(subnet1, subnet0, loader)
        permute_output(perm_map, block1.conv2, block1.bn2)
        permute_input(perm_map, block1.conv3)
        permute_output(perm_map, block_sparse.conv2, block_sparse.bn2)
        permute_input(perm_map, block_sparse.conv3)
    
    for k in range(len(blocks0)):
        kk = get_permk(k)
        if kk != last_kk:
            perm_map = get_layer_perm(blocks1[:kk+1], blocks0[:kk+1], loader)
            last_kk = kk
        
        if k > 0:
            permute_output(perm_map, blocks0[k].conv3, blocks0[k].bn3)
            shortcut = blocks0[k].downsample
            if shortcut:
                permute_output(perm_map, shortcut[0], shortcut[1])
            permute_output(perm_map, blocks_sparse[k].conv3, blocks_sparse[k].bn3)
            shortcut = blocks_sparse[k].downsample
            if shortcut:
                permute_output(perm_map, shortcut[0], shortcut[1])
                
        else:
            permute_output(perm_map, model0.conv1, model0.bn1)
            permute_output(perm_map, modelA_sparse.conv1, modelA_sparse.bn1) 
        
        if k+1 < len(blocks0):
            permute_input(perm_map, blocks0[k+1].conv1)
            shortcut = blocks0[k+1].downsample
            if shortcut:
                permute_input(perm_map, shortcut[0])
            permute_input(perm_map, blocks_sparse[k+1].conv1)
            shortcut = blocks_sparse[k+1].downsample
            if shortcut:
                permute_input(perm_map, shortcut[0])
        else:
            permute_input(perm_map, model0.fc)
            permute_input(perm_map, modelA_sparse.fc)
            
    
    return model0, modelA_sparse

In [None]:
num_batches = 50

batches = []
for i, (images, labels) in enumerate(train_dl):
    if i >= num_batches:
        break
    batches.append((images, labels))

batch_size = batches[0][0].shape[0]

total_data_points = num_batches * batch_size

print(f"Batch size: {batch_size}")
print(f"Total data points: {total_data_points}")

permuted_model_1, permuted_model_A_sparse = permute_model_resnet50(resnet50_1, resnet50_2, model_A_sparse, batches, config)


In [26]:
# images, labels = next(iter(train_dl))
# single_batch = (images, labels)
# permuted_model_1, permuted_model_A_sparse = permute_model_resnet50(resnet50_1, resnet50_2, model_A_sparse, single_batch, config)
torch.save(permuted_model_1, "/home/rohan/code/sparse-rebasin-minimal/sparse-rebasin/notebooks/artifacts/permuted_model_1_multiple_batch.pth")
torch.save(permuted_model_A_sparse, "/home/rohan/code/sparse-rebasin-minimal/sparse-rebasin/notebooks/artifacts/permuted_model_A_sparse_multiple_batch.pth")

In [17]:
evaluate(permuted_model_1, test_dl, config["device"])

(44.604, 2.4261582411673603)

In [18]:
evaluate(permuted_model_A_sparse, test_dl, config["device"])

(72.566, 1.1008016059593277)

In [19]:
def evaluate_merged_models(model0, permuted_model_1, alpha, config):
    model = get_model(config)
    m1, m2 = model0.state_dict(), permuted_model_1.state_dict()
    sd_alpha = {k: (1 - alpha) * m1[k].cuda() + alpha * m2[k].cuda()
                for k in m1.keys()
                if k in m2}
    model.load_state_dict(sd_alpha, strict=False)
    return model

In [20]:
# reset all tracked BN stats against training data
def reset_bn_stats(model, loader):
    # resetting stats to baseline first as below is necessary for stability
    for m in model.modules():
        if type(m) == nn.BatchNorm2d:
            m.momentum = None # use simple average
            m.reset_running_stats()
    model.train()
    with torch.no_grad(), autocast():
        for images, _ in loader:
            output = model(images.cuda())

In [14]:
# permuted_model_2 = get_model(config)
# permuted_model_2.load_state_dict(torch.load("/home/rohan/code/sparse-rebasin-minimal/sparse-rebasin/notebooks/artifacts/permuted_model_2.pth"))

<All keys matched successfully>

In [None]:
model_a = evaluate_merged_models(permuted_model_1, resnet50_2, 0.5, config)

print('(test_acc, test_loss):')
print('(α=0.5): %s\t\t<-- Merged model with neuron alignment', evaluate(model_a,test_dl,config["device"]))
reset_bn_stats(model_a, train_dl)
print('(α=0.5): %s\t\t<-- Merged model with alignment + BN reset', evaluate(model_a,test_dl,config["device"]))

In [22]:
class TrackLayer(nn.Module):
    def __init__(self, layer, one_d=False):
        super().__init__()
        self.layer = layer
        dim = layer.conv3.out_channels
        self.bn = nn.BatchNorm2d(dim)
        
    def get_stats(self):
        return (self.bn.running_mean, self.bn.running_var.sqrt())
        
    def forward(self, x):
        x1 = self.layer(x)
        self.bn(x1)
        return x1

class ResetLayer(nn.Module):
    def __init__(self, layer, one_d=False):
        super().__init__()
        self.layer = layer
        dim = layer.conv3.out_channels
        self.bn = nn.BatchNorm2d(dim)
        
    def set_stats(self, goal_mean, goal_std):
        self.bn.bias.data = goal_mean
        self.bn.weight.data = goal_std
        
    def forward(self, x):
        x1 = self.layer(x)
        return self.bn(x1)

# adds TrackLayer around each block
def make_tracked_net(net):
    net1 = get_model(config)
    net1.load_state_dict(net.state_dict())
    for i in range(4):
        layer = getattr(net1, 'layer%d' % (i+1))
        for j, block in enumerate(layer):
            layer[j] = TrackLayer(block).cuda()
    return net1

# adds ResetLayer around each block
def make_repaired_net(net):
    net1 = get_model(config)
    net1.load_state_dict(net.state_dict())
    for i in range(4):
        layer = getattr(net1, 'layer%d' % (i+1))
        for j, block in enumerate(layer):
            layer[j] = ResetLayer(block).cuda()
    return net1

In [23]:
def repair_merged_model(permuted_model_0, model_1, model_a, train_dl, config, alpha=0.5):
    # Evaluate merged models
    model0 = evaluate_merged_models(permuted_model_0, model_1, 0, config)
    model1 = evaluate_merged_models(permuted_model_0, model_1, 1, config)

    # Calculate all neuronal statistics in the endpoint networks
    wrap0 = make_tracked_net(model0)
    wrap1 = make_tracked_net(model1)
    reset_bn_stats(wrap0, train_dl)
    reset_bn_stats(wrap1, train_dl)

    wrap_a = make_repaired_net(model_a)
    # Iterate through corresponding triples of (TrackLayer, TrackLayer, ResetLayer)
    # around conv layers in (model0, model1, model_a).
    for track0, track1, reset_a in zip(wrap0.modules(), wrap1.modules(), wrap_a.modules()): 
        if not isinstance(track0, TrackLayer):
            continue  
        assert (isinstance(track0, TrackLayer)
                and isinstance(track1, TrackLayer)
                and isinstance(reset_a, ResetLayer))

        # Get neuronal statistics of original networks
        mu0, std0 = track0.get_stats()
        mu1, std1 = track1.get_stats()
        # Set the goal neuronal statistics for the merged network 
        goal_mean = (1 - alpha) * mu0 + alpha * mu1
        goal_std = (1 - alpha) * std0 + alpha * std1
        reset_a.set_stats(goal_mean, goal_std)

    # Estimate mean/vars such that when added BNs are set to eval mode,
    # neuronal stats will be goal_mean and goal_std.
    reset_bn_stats(wrap_a, train_dl)
    
    return wrap_a

In [24]:
repaired_model_a = repair_merged_model(permuted_model_1, resnet50_2, model_a, train_dl, config, 0.5)

In [None]:
results["Merged Model with REPAIR"] = evaluate(repaired_model_a, test_dl, config["device"])
print(f"Merged Model with REPAIR: {results['Merged Model with REPAIR'][0]:.2f}%, {results['Merged Model with REPAIR'][1]:.4f}")

In [None]:
## Testing transfer sparsity

# resnet50_pretrained_test = resnet50_pretrained()
# evaluate(resnet50_pretrained_test, test_dl, config["device"])
print("Sparsity of the model before transfer sparsity: ",calculate_overall_sparsity_from_pth(resnet50_pretrained))

transfer_sparsity_resnet(permuted_model_A_sparse, resnet50_pretrained)
print("Sparsity of the init after transfer sparsity: ",calculate_overall_sparsity_from_pth(resnet50_pretrained))

