# Experiment 1


## Setup

In [1]:
!pip install --upgrade --no-cache-dir gdown &> /dev/null
!pip install --upgrade --no-cache-dir POT &> /dev/null
!pip uninstall torch -y &> /dev/null
!pip install torch==1.13.1 &> /dev/null

In [2]:
import os
import ot
import time
import numpy as np
import torch
import torch.nn.functional as F
import torch.nn as nn
import torchvision 

  Referenced from: <0F72FEF0-4DF1-3E8A-90BA-513122A1950F> /Users/gennarodanieleacciaro/PycharmProjects/tesi_experiments/venv/lib/python3.9/site-packages/torchvision/image.so
  warn(


In [None]:
print("Pytorch version is ", torch.__version__)

Pytorch version is  1.13.1+cu117


In [22]:
# Constants
NUMPY_SEED = 100
np.random.seed(NUMPY_SEED)

GPU_USED = -1 # -1 = CPU

In [4]:
# Args
batch_size_train = 64
batch_size_test = 1000
activation_histograms = True
act_num_samples = 200
activation_mode = "raw"
activation_seed = 21
disable_bias = True
personal_class_idx = 9
prelu_acts = True
unbalanced = False
importance = None
exact = False #OT Exact mode
ensemble_step = 0.5
eval_aligned = False
softmax_temperature = 1
past_correction = True
skip_last_layer = False
reg = 0.01
correction = True
proper_marginals = False
width_ratio = 1

ground_metric = 'euclidean'
ground_metric_normalize = 'none'
ground_metric_eff = False
not_squared = True
geom_ensemble_type = "acts"
normalize_wts = False
clip_gm = False

ground_metric_params = {
    "ground_metric": ground_metric,
    "reg": reg,
    "ground_metric_normalize": ground_metric_normalize,
    "ground_metric_eff": ground_metric_eff,
    "not_squared": not_squared,
    "geom_ensemble_type": geom_ensemble_type,
    "normalize_wts":normalize_wts,
    "clip_gm":clip_gm
}

In [5]:
# Download and unzip MNIST models
mnist_models_fileid = "1SJTxBpi2Ln3XukcJLJNFIv8S2ix_M2sp"
!gdown $mnist_models_fileid 
!rm -rf mnist_models
!unzip mnist_models.zip

Downloading...
From: https://drive.google.com/uc?id=1SJTxBpi2Ln3XukcJLJNFIv8S2ix_M2sp
To: /Users/gennarodanieleacciaro/PycharmProjects/tesi_experiments/mnist_models.zip
100%|██████████████████████████████████████| 3.07M/3.07M [00:01<00:00, 2.79MB/s]
Archive:  mnist_models.zip
   creating: mnist_models/
   creating: mnist_models/mnsit/
   creating: mnist_models/model_0/
   creating: mnist_models/model_1/
  inflating: mnist_models/model_0/final.checkpoint  
  inflating: mnist_models/model_1/final.checkpoint  


# Functions from original code

In [6]:
class MlpNet(nn.Module):
    def __init__(self, width_ratio=-1):
        super(MlpNet, self).__init__()
        input_dim = 784 # [mnist] 28 x 28 x 1
        if width_ratio != -1:
            self.width_ratio = width_ratio
        else:
            self.width_ratio = 1

        self.fc1 = nn.Linear(input_dim, int(400/self.width_ratio), bias=not True)
        self.fc2 = nn.Linear(int(400/self.width_ratio), int(200/self.width_ratio), bias=not True)
        self.fc3 = nn.Linear(int(200/self.width_ratio), int(100/self.width_ratio), bias=not True)
        self.fc4 = nn.Linear(int(100/self.width_ratio), 10, bias=not True)
        self.enable_dropout = False

    def forward(self, x, disable_logits=False):
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        if self.enable_dropout:
            x = F.dropout(x, training=self.training)
        x = F.relu(self.fc2(x))
        if self.enable_dropout:
            x = F.dropout(x, training=self.training)
        x = F.relu(self.fc3(x))
        if self.enable_dropout:
            x = F.dropout(x, training=self.training)
        x = self.fc4(x)

        if disable_logits:
            return x
        else:
            return F.log_softmax(x)

In [7]:
def get_pretrained_model(path, data_separated=False):
    model = MlpNet()

    
    if GPU_USED != -1:
        state = torch.load(
            path, map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cuda:' + str(GPU_USED))
            ),)
    else:
        state = torch.load(
            path, map_location=(
                lambda s, _: torch.serialization.default_restore_location(s, 'cpu')
            ),)

    model_state_dict = state['model_state_dict']

    if 'test_accuracy' not in state:
        state['test_accuracy'] = -1

    if 'epoch' not in state:
        state['epoch'] = -1

    if not data_separated:
        print("Loading model at path {} which had accuracy {} and at epoch {}".format(path, state['test_accuracy'],
                                                                                  state['epoch']))
    else:
        print("Loading model at path {} which had local accuracy {} and overall accuracy {} for choice {} at epoch {}".format(path,
            state['local_test_accuracy'], state['test_accuracy'], state['choice'], state['epoch']))

    model.load_state_dict(model_state_dict)

    if GPU_USED != -1:
        model = model.cuda(GPU_USED)

    if not data_separated:
        return model, state['test_accuracy']
    else:
        return model, state['test_accuracy'], state['local_test_accuracy']


In [8]:
def get_dataloader(unit_batch = False, no_randomness=False):
    if unit_batch:
        bsz = (1, 1)
    else:
        bsz = (batch_size_train, batch_size_test)

    if no_randomness:
        enable_shuffle = False
    else:
        enable_shuffle = True
        
    train_loader = torch.utils.data.DataLoader(
          torchvision.datasets.MNIST('./files/', train=True, download=True,
                                     transform=torchvision.transforms.Compose([
                                       torchvision.transforms.ToTensor(),
                                       torchvision.transforms.Normalize(
                                           # only 1 channel
                                           (0.1307,), (0.3081,))
                                     ])),
          batch_size=bsz[0], shuffle=enable_shuffle
        )


    test_loader = torch.utils.data.DataLoader(
          torchvision.datasets.MNIST('./files/', train=False, download=True,
                 transform=torchvision.transforms.Compose([
                   torchvision.transforms.ToTensor(),
                   torchvision.transforms.Normalize(
                       (0.1307,), (0.3081,))
                 ])),
          batch_size=bsz[1], shuffle=enable_shuffle
        )

    return train_loader, test_loader


In [9]:
def compute_activations_across_models_v1(args, models, train_loader, num_samples, mode='mean',
                                         dump_activations=False, dump_path=None):
    torch.manual_seed(activation_seed)

    # hook that computes the mean activations across data samples
    def get_activation(activation, name):
        def hook(model, input, output):
            if name not in activation:
                activation[name] = []

            activation[name].append(output.detach())

        return hook

    # Prepare all the models
    activations = {}
    forward_hooks = []

    assert disable_bias
    # handle below for bias later on!
    # print("list of model named params ", list(models[0].named_parameters()))
    param_names = [tupl[0].replace('.weight', '') for tupl in models[0].named_parameters()]
    for idx, model in enumerate(models):

        # Initialize the activation dictionary for each model
        activations[idx] = {}
        layer_hooks = []
        # Set forward hooks for all layers inside a model
        for name, layer in model.named_modules():
            if name == '':
                print("excluded")
                continue
            layer_hooks.append(layer.register_forward_hook(get_activation(activations[idx], name)))
            print("set forward hook for layer named: ", name)

        forward_hooks.append(layer_hooks)
        # Set the model in train mode
        model.train()

    # Run the same data samples ('num_samples' many) across all the models
    num_samples_processed = 0
    num_personal_idx = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        if num_samples_processed == num_samples:
            break

        if GPU_USED != -1:
            data = data.cuda(GPU_USED)

        if int(target.item()) == personal_class_idx:
            num_personal_idx += 1

        for idx, model in enumerate(models):
            model(data)

        num_samples_processed += 1

    print("num_personal_idx ", num_personal_idx)

    relu = torch.nn.ReLU()
    maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)
    avgpool = torch.nn.AvgPool2d(kernel_size=1, stride=1)

    # Combine the activations generated across the number of samples to form importance scores
    # The importance calculated is based on the 'mode' flag: which is either of 'mean', 'std', 'meanstd'

    model_cfg = None
    for idx in range(len(models)):
        cfg_idx = 0
        for lnum, layer in enumerate(activations[idx]):
            print('***********')
            activations[idx][layer] = torch.stack(activations[idx][layer])
            print("min of act: {}, max: {}, mean: {}".format(torch.min(activations[idx][layer]), torch.max(activations[idx][layer]), torch.mean(activations[idx][layer])))
            if not prelu_acts and not lnum == (len(activations[idx])-1):
                # print("activation was ", activations[idx][layer])
                print("applying relu ---------------")
                activations[idx][layer] = relu(activations[idx][layer])
                print("after RELU: min of act: {}, max: {}, mean: {}".format(torch.min(activations[idx][layer]),
                                                                 torch.max(activations[idx][layer]),
                                                                 torch.mean(activations[idx][layer])))
            if mode == 'mean':
                activations[idx][layer] = activations[idx][layer].mean(dim=0)
            elif mode == 'std':
                activations[idx][layer] = activations[idx][layer].std(dim=0)
            elif mode == 'meanstd':
                activations[idx][layer] = activations[idx][layer].mean(dim=0) * activations[idx][layer].std(dim=0)

    # Remove the hooks (as this was intefering with prediction ensembling)
    for idx in range(len(forward_hooks)):
        for hook in forward_hooks[idx]:
            hook.remove()

    return activations

In [10]:
def get_model_activations(args, models, config=None, layer_name=None, selective=False, personal_dataset = None):
    train_loader, test_loader = get_dataloader(unit_batch=True)
    activations = compute_activations_across_models_v1(args, models,
                                                          train_loader,
                                                          act_num_samples,
                                                          mode=activation_mode)
    return activations
                

In [11]:
import torch

class GroundMetric:
    """
        Ground Metric object for Wasserstein computations:

    """
    def isnan(self,x):
      return x != x

    def __init__(self, params, not_squared = False):
        self.params = params
        self.ground_metric_type = params["ground_metric"]
        self.ground_metric_normalize = params["ground_metric_normalize"]
        self.reg = params["reg"]
        self.squared = False
        self.mem_eff = params["ground_metric_eff"]

    def _clip(self, ground_metric_matrix):
        if self.params["debug"]:
            print("before clipping", ground_metric_matrix.data)

        percent_clipped = (float((ground_metric_matrix >= self.reg * self.params.clip_max).long().sum().data) \
                           / ground_metric_matrix.numel()) * 100
        print("percent_clipped is (assumes clip_min = 0) ", percent_clipped)
        #setattr(self.params, 'percent_clipped', percent_clipped)
        # will keep the M' = M/reg in range clip_min and clip_max
        ground_metric_matrix.clamp_(min=self.reg * self.params.clip_min,
                                             max=self.reg * self.params.clip_max)
        if self.params["debug"]:
            print("after clipping", ground_metric_matrix.data)
        return ground_metric_matrix

    def _normalize(self, ground_metric_matrix):

        if self.ground_metric_normalize == "log":
            ground_metric_matrix = torch.log1p(ground_metric_matrix)
        elif self.ground_metric_normalize == "max":
            print("Normalizing by max of ground metric and which is ", ground_metric_matrix.max())
            ground_metric_matrix = ground_metric_matrix / ground_metric_matrix.max()
        elif self.ground_metric_normalize == "median":
            print("Normalizing by median of ground metric and which is ", ground_metric_matrix.median())
            ground_metric_matrix = ground_metric_matrix / ground_metric_matrix.median()
        elif self.ground_metric_normalize == "mean":
            print("Normalizing by mean of ground metric and which is ", ground_metric_matrix.mean())
            ground_metric_matrix = ground_metric_matrix / ground_metric_matrix.mean()
        elif self.ground_metric_normalize == "none":
            return ground_metric_matrix
        else:
            raise NotImplementedError

        return ground_metric_matrix

    def _sanity_check(self, ground_metric_matrix):
        assert not (ground_metric_matrix < 0).any()
        assert not (self.isnan(ground_metric_matrix).any())

    def _cost_matrix_xy(self, x, y, p=2, squared = True):
        # TODO: Use this to guarantee reproducibility of previous results and then move onto better way
        "Returns the matrix of $|x_i-y_j|^p$."
        x_col = x.unsqueeze(1)
        y_lin = y.unsqueeze(0)
        c = torch.sum((torch.abs(x_col - y_lin)) ** p, 2)
        if not squared:
            print("dont leave off the squaring of the ground metric")
            c = c ** (1/2)
        return c


    def _pairwise_distances(self, x, y=None, squared=True):
        '''
        Source: https://discuss.pytorch.org/t/efficient-distance-matrix-computation/9065/2
        Input: x is a Nxd matrix
               y is an optional Mxd matirx
        Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
                if y is not given then use 'y=x'.
        i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
        '''
        x_norm = (x ** 2).sum(1).view(-1, 1)
        if y is not None:
            y_t = torch.transpose(y, 0, 1)
            y_norm = (y ** 2).sum(1).view(1, -1)
        else:
            y_t = torch.transpose(x, 0, 1)
            y_norm = x_norm.view(1, -1)

        dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
        # Ensure diagonal is zero if x=y
        dist = torch.clamp(dist, min=0.0)

        if self.params.activation_histograms and self.params.dist_normalize:
            dist = dist/self.params.act_num_samples
            print("Divide squared distances by the num samples")

        if not squared:
            print("dont leave off the squaring of the ground metric")
            dist = dist ** (1/2)

        return dist

    def _get_euclidean(self, coordinates, other_coordinates=None):
        # TODO: Replace by torch.pdist (which is said to be much more memory efficient)

        if other_coordinates is None:
            matrix = torch.norm(
                coordinates.view(coordinates.shape[0], 1, coordinates.shape[1]) \
                - coordinates, p=2, dim=2
            )
        else:
            if self.mem_eff:
                matrix = self._pairwise_distances(coordinates, other_coordinates, squared=self.squared)
            else:
                matrix = self._cost_matrix_xy(coordinates, other_coordinates, squared = self.squared)

        return matrix

    def _normed_vecs(self, vecs, eps=1e-9):
        norms = torch.norm(vecs, dim=-1, keepdim=True)
        print("stats of vecs are: mean {}, min {}, max {}, std {}".format(
            norms.mean(), norms.min(), norms.max(), norms.std()
        ))
        return vecs / (norms + eps)

    def _get_cosine(self, coordinates, other_coordinates=None):
        if other_coordinates is None:
            matrix = coordinates / torch.norm(coordinates, dim=1, keepdim=True)
            matrix = 1 - matrix @ matrix.t()
        else:
            matrix = 1 - torch.div(
                coordinates @ other_coordinates.t(),
                torch.norm(coordinates, dim=1).view(-1, 1) @ torch.norm(other_coordinates, dim=1).view(1, -1)
            )
        return matrix.clamp_(min=0)

    def _get_angular(self, coordinates, other_coordinates=None):
        pass

    def get_metric(self, coordinates, other_coordinates=None):
        get_metric_map = {
            'euclidean': self._get_euclidean,
            'cosine': self._get_cosine,
            'angular': self._get_angular,
        }
        return get_metric_map[self.ground_metric_type](coordinates, other_coordinates)

    def process(self, coordinates, other_coordinates=None):
        print('Processing the coordinates to form ground_metric')
        if self.params["geom_ensemble_type"] == 'wts' and self.params["normalize_wts"]:
            print("In weight mode: normalizing weights to unit norm")
            coordinates = self._normed_vecs(coordinates)
            if other_coordinates is not None:
                other_coordinates = self._normed_vecs(other_coordinates)

        ground_metric_matrix = self.get_metric(coordinates, other_coordinates)

        self._sanity_check(ground_metric_matrix)

        ground_metric_matrix = self._normalize(ground_metric_matrix)

        self._sanity_check(ground_metric_matrix)

        if self.params["clip_gm"]:
            ground_metric_matrix = self._clip(ground_metric_matrix)

        self._sanity_check(ground_metric_matrix)

        return ground_metric_matrix

In [12]:
def get_histogram(idx, cardinality, layer_name, activations=None, return_numpy = True, float64=False):
    if activations is None:
        # returns a uniform measure
        if not unbalanced:
            print("returns a uniform measure of cardinality: ", cardinality)
            return np.ones(cardinality)/cardinality
        else:
            return np.ones(cardinality)
    else:
        # return softmax over the activations raised to a temperature
        # layer_name is like 'fc1.weight', while activations only contains 'fc1'
        print(activations[idx].keys())
        unnormalized_weights = activations[idx][layer_name.split('.')[0]]
        print("For layer {},  shape of unnormalized weights is ".format(layer_name), unnormalized_weights.shape)
        unnormalized_weights = unnormalized_weights.squeeze()
        assert unnormalized_weights.shape[0] == cardinality

        if return_numpy:
            if float64:
                return torch.softmax(unnormalized_weights / softmax_temperature, dim=0).data.cpu().numpy().astype(
                    np.float64)
            else:
                return torch.softmax(unnormalized_weights / softmax_temperature, dim=0).data.cpu().numpy()
        else:
            return torch.softmax(unnormalized_weights / softmax_temperature, dim=0)

In [13]:
def get_wassersteinized_layers_modularized(args, networks, activations=None, eps=1e-7, test_loader=None):
    '''
    Two neural networks that have to be averaged in geometric manner (i.e. layerwise).
    The 1st network is aligned with respect to the other via wasserstein distance.
    Also this assumes that all the layers are either fully connected or convolutional *(with no bias)*

    :param networks: list of networks
    :param activations: If not None, use it to build the activation histograms.
    Otherwise assumes uniform distribution over neurons in a layer.
    :return: list of layer weights 'wassersteinized'
    '''

    avg_aligned_layers = []
    T_var = None
    previous_layer_shape = None
    ground_metric_object = GroundMetric(ground_metric_params)

    if eval_aligned:
        model0_aligned_layers = []

    if GPU_USED==-1:
        device = torch.device('cpu')
    else:
        device = torch.device('cuda:{}'.format(GPU_USED))


    num_layers = len(list(zip(networks[0].parameters(), networks[1].parameters())))
    for idx, ((layer0_name, fc_layer0_weight), (layer1_name, fc_layer1_weight)) in \
            enumerate(zip(networks[0].named_parameters(), networks[1].named_parameters())):

        assert fc_layer0_weight.shape == fc_layer1_weight.shape
        previous_layer_shape = fc_layer1_weight.shape

        mu_cardinality = fc_layer0_weight.shape[0]
        nu_cardinality = fc_layer1_weight.shape[0]

        layer_shape = fc_layer0_weight.shape
        if len(layer_shape) > 2:
            is_conv = True
            # For convolutional layers, it is (#out_channels, #in_channels, height, width)
            fc_layer0_weight_data = fc_layer0_weight.data.view(fc_layer0_weight.shape[0], fc_layer0_weight.shape[1], -1)
            fc_layer1_weight_data = fc_layer1_weight.data.view(fc_layer1_weight.shape[0], fc_layer1_weight.shape[1], -1)
        else:
            is_conv = False
            fc_layer0_weight_data = fc_layer0_weight.data
            fc_layer1_weight_data = fc_layer1_weight.data

        if idx == 0:
            if is_conv:
                M = ground_metric_object.process(fc_layer0_weight_data.view(fc_layer0_weight_data.shape[0], -1),
                                fc_layer1_weight_data.view(fc_layer1_weight_data.shape[0], -1))
            else:
                M = ground_metric_object.process(fc_layer0_weight_data, fc_layer1_weight_data)
                
            aligned_wt = fc_layer0_weight_data
        else:

            print("shape of layer: model 0", fc_layer0_weight_data.shape)
            print("shape of layer: model 1", fc_layer1_weight_data.shape)
            print("shape of previous transport map", T_var.shape)

            if is_conv:
                T_var_conv = T_var.unsqueeze(0).repeat(fc_layer0_weight_data.shape[2], 1, 1)
                aligned_wt = torch.bmm(fc_layer0_weight_data.permute(2, 0, 1), T_var_conv).permute(1, 2, 0)

                M = ground_metric_object.process(
                    aligned_wt.contiguous().view(aligned_wt.shape[0], -1),
                    fc_layer1_weight_data.view(fc_layer1_weight_data.shape[0], -1)
                )
            else:
                if fc_layer0_weight.data.shape[1] != T_var.shape[0]:
                    # Handles the switch from convolutional layers to fc layers
                    fc_layer0_unflattened = fc_layer0_weight.data.view(fc_layer0_weight.shape[0], T_var.shape[0], -1).permute(2, 0, 1)
                    aligned_wt = torch.bmm(
                        fc_layer0_unflattened,
                        T_var.unsqueeze(0).repeat(fc_layer0_unflattened.shape[0], 1, 1)
                    ).permute(1, 2, 0)
                    aligned_wt = aligned_wt.contiguous().view(aligned_wt.shape[0], -1)
                else:
                    # print("layer data (aligned) is ", aligned_wt, fc_layer1_weight_data)
                    aligned_wt = torch.matmul(fc_layer0_weight.data, T_var)

                M = ground_metric_object.process(aligned_wt, fc_layer1_weight)
               
            if skip_last_layer and idx == (num_layers - 1):
                print("Simple averaging of last layer weights. NO transport map needs to be computed")
                if ensemble_step != 0.5:
                    avg_aligned_layers.append((1 - ensemble_step) * aligned_wt +
                                          ensemble_step * fc_layer1_weight)
                else:
                    avg_aligned_layers.append((aligned_wt + fc_layer1_weight)/2)
                return avg_aligned_layers

        if importance is None or (idx == num_layers -1):
            mu = get_histogram(0, mu_cardinality, layer0_name)
            nu = get_histogram(1, nu_cardinality, layer1_name)
        else:
            raise Exception("Deleted code")

        cpuM = M.data.cpu().numpy()
        if exact:
            T = ot.emd(mu, nu, cpuM)
        else:
            T = ot.bregman.sinkhorn(mu, nu, cpuM, reg=reg)

        if GPU_USED!=-1:
            T_var = torch.from_numpy(T).cuda(GPU_USED).float()
        else:
            T_var = torch.from_numpy(T).float()


        if correction:
            if not proper_marginals:
                # think of it as m x 1, scaling weights for m linear combinations of points in X
                if GPU_USED != -1:
                    marginals = torch.ones(T_var.shape[0]).cuda(GPU_USED) / T_var.shape[0]
                else:
                    marginals = torch.ones(T_var.shape[0]) / T_var.shape[0]
                marginals = torch.diag(1.0/(marginals + eps))  # take inverse
                T_var = torch.matmul(T_var, marginals)
            else:
                # marginals_alpha = T_var @ torch.ones(T_var.shape[1], dtype=T_var.dtype).to(device)
                marginals_beta = T_var.t() @ torch.ones(T_var.shape[0], dtype=T_var.dtype).to(device)

                marginals = (1 / (marginals_beta + eps))
                print("shape of inverse marginals beta is ", marginals_beta.shape)
                print("inverse marginals beta is ", marginals_beta)

                T_var = T_var * marginals
                # i.e., how a neuron of 2nd model is constituted by the neurons of 1st model
                # this should all be ones, and number equal to number of neurons in 2nd model
                print(T_var.sum(dim=0))
                # assert (T_var.sum(dim=0) == torch.ones(T_var.shape[1], dtype=T_var.dtype).to(device)).all()

        #if args.debug:
        #    if idx == (num_layers - 1):
        #        print("there goes the last transport map: \n ", T_var)
        #    else:
        #        print("there goes the transport map at layer {}: \n ".format(idx), T_var)
        #    print("Ratio of trace to the matrix sum: ", torch.trace(T_var) / torch.sum(T_var))

        #print("Ratio of trace to the matrix sum: ", torch.trace(T_var) / torch.sum(T_var))
        #print("Here, trace is {} and matrix sum is {} ".format(torch.trace(T_var), torch.sum(T_var)))
        #setattr(args, 'trace_sum_ratio_{}'.format(layer0_name), (torch.trace(T_var) / torch.sum(T_var)).item())

        if past_correction:
            print("this is past correction for weight mode")
            print("Shape of aligned wt is ", aligned_wt.shape)
            print("Shape of fc_layer0_weight_data is ", fc_layer0_weight_data.shape)
            t_fc0_model = torch.matmul(T_var.t(), aligned_wt.contiguous().view(aligned_wt.shape[0], -1))
        else:
            t_fc0_model = torch.matmul(T_var.t(), fc_layer0_weight_data.view(fc_layer0_weight_data.shape[0], -1))

        # Average the weights of aligned first layers
        if ensemble_step != 0.5:
            geometric_fc = ((1-ensemble_step) * t_fc0_model +
                            ensemble_step * fc_layer1_weight_data.view(fc_layer1_weight_data.shape[0], -1))
        else:
            geometric_fc = (t_fc0_model + fc_layer1_weight_data.view(fc_layer1_weight_data.shape[0], -1))/2
        if is_conv and layer_shape != geometric_fc.shape:
            geometric_fc = geometric_fc.view(layer_shape)
            
        avg_aligned_layers.append(geometric_fc)

        # get the performance of the model 0 aligned with respect to the model 1
        if eval_aligned:
            raise Exception("Deleted code.")

    return avg_aligned_layers

In [14]:
def test(args, network, test_loader, log_dict, debug=False, return_loss=False, is_local=False):
    network.eval()
    test_loss = 0
    correct = 0
    if is_local:
        print("\n--------- Testing in local mode ---------")
    else:
        print("\n--------- Testing in global mode ---------")

    for data, target in test_loader:
        if GPU_USED!=-1:
            data = data.cuda(GPU_USED)
            target = target.cuda(GPU_USED)

        output = network(data)
        if debug:
            print("output is ", output)

        # mnist models return log_softmax outputs, while cifar ones return raw values!    
        test_loss += F.nll_loss(output, target, size_average=False).item()

        pred = output.data.max(1, keepdim=True)[1]
        correct += pred.eq(target.data.view_as(pred)).sum()

    print("size of test_loader dataset: ", len(test_loader.dataset))
    test_loss /= len(test_loader.dataset)
    if is_local:
        string_info = 'local_test'
    else:
        string_info = 'test'
    log_dict['{}_losses'.format(string_info)].append(test_loss)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

    ans = (float(correct) * 100.0) / len(test_loader.dataset)

    if not return_loss:
        return ans
    else:
        return ans, test_loss

In [15]:
def get_network_from_param_list(args, param_list, test_loader):

    print("using independent method")
    new_network = MlpNet()
    if GPU_USED != -1:
        new_network = new_network.cuda(GPU_USED)

    # check the test performance of the network before
    log_dict = {}
    log_dict['test_losses'] = []
    test(args, new_network, test_loader, log_dict)

    # set the weights of the new network
    print("len of model parameters and avg aligned layers is ", len(list(new_network.parameters())), len(param_list))
    assert len(list(new_network.parameters())) == len(param_list)

    layer_idx = 0
    model_state_dict = new_network.state_dict()

    print("len of model_state_dict is ", len(model_state_dict.items()))
    print("len of param_list is ", len(param_list))

    for key, value in model_state_dict.items():
        model_state_dict[key] = param_list[layer_idx]
        layer_idx += 1

    new_network.load_state_dict(model_state_dict)

    # check the test performance of the network after
    log_dict = {}
    log_dict['test_losses'] = []
    acc = test(args, new_network, test_loader, log_dict)

    return acc, new_network

# Main

## Custom Model A


In [16]:
#Read the model
model_A = MlpNet()
model_A.load_state_dict(torch.load("model_A.pth"))

<All keys matched successfully>

## Custom Model B

In [18]:
model_B = MlpNet()
model_B.load_state_dict(torch.load("model_B.pth"))

<All keys matched successfully>

In [43]:
models = [model_A, model_B]
args = None

## Merge using Wasserstein distance on weights

In [44]:
train_loader, test_loader = get_dataloader()

In [45]:
activations = get_model_activations(None, models, config=None)

excluded
set forward hook for layer named:  fc1
set forward hook for layer named:  fc2
set forward hook for layer named:  fc3
set forward hook for layer named:  fc4
excluded
set forward hook for layer named:  fc1
set forward hook for layer named:  fc2
set forward hook for layer named:  fc3
set forward hook for layer named:  fc4


  return F.log_softmax(x)


num_personal_idx  25
***********
min of act: -2.6315741539001465, max: 2.6386382579803467, mean: -0.0048573692329227924
***********
min of act: -0.9510814547538757, max: 0.8703941106796265, mean: 0.002457394264638424
***********
min of act: -0.3619929254055023, max: 0.38994652032852173, mean: -0.0010728973429650068
***********
min of act: -0.09826940298080444, max: 0.1475382149219513, mean: 0.002620210638269782
***********
min of act: -3.4334421157836914, max: 7.080609321594238, mean: 0.5755436420440674
***********
min of act: -2.987900733947754, max: 10.621912956237793, mean: 0.8549277186393738
***********
min of act: -4.2195143699646, max: 8.599244117736816, mean: 1.068608045578003
***********
min of act: -10.92115592956543, max: 13.338738441467285, mean: 0.1271066814661026


In [46]:
def merge_models_wts(args, models, train_loader, test_loader, activations):
    avg_aligned_layers = get_wassersteinized_layers_modularized(args, models, activations, test_loader=test_loader)
    return get_network_from_param_list(args, avg_aligned_layers, test_loader)

In [47]:
geometric_acc, geometric_model = merge_models_wts(args, models, train_loader, test_loader, activations)

Processing the coordinates to form ground_metric
dont leave off the squaring of the ground metric
returns a uniform measure of cardinality:  400
returns a uniform measure of cardinality:  400
this is past correction for weight mode
Shape of aligned wt is  torch.Size([400, 784])
Shape of fc_layer0_weight_data is  torch.Size([400, 784])
shape of layer: model 0 torch.Size([200, 400])
shape of layer: model 1 torch.Size([200, 400])
shape of previous transport map torch.Size([400, 400])
Processing the coordinates to form ground_metric
dont leave off the squaring of the ground metric
returns a uniform measure of cardinality:  200
returns a uniform measure of cardinality:  200
this is past correction for weight mode
Shape of aligned wt is  torch.Size([200, 400])
Shape of fc_layer0_weight_data is  torch.Size([200, 400])
shape of layer: model 0 torch.Size([100, 200])
shape of layer: model 1 torch.Size([100, 200])
shape of previous transport map torch.Size([200, 200])
Processing the coordinates t

  return F.log_softmax(x)


size of test_loader dataset:  10000

Test set: Avg. loss: 2.3056, Accuracy: 686/10000 (7%)

len of model parameters and avg aligned layers is  4 4
len of model_state_dict is  4
len of param_list is  4

--------- Testing in global mode ---------
size of test_loader dataset:  10000

Test set: Avg. loss: 1.9742, Accuracy: 8097/10000 (81%)



## Vanilla Ensembling


In [48]:
def get_avg_parameters(networks, weights=None):
    avg_pars = []
    for par_group in zip(*[net.parameters() for net in networks]):
        if weights is not None:
            weighted_par_group = [par * weights[i] for i, par in enumerate(par_group)]
            avg_par = torch.sum(torch.stack(weighted_par_group), dim=0)
        else:
            avg_par = torch.mean(torch.stack(par_group), dim=0)
        avg_pars.append(avg_par)
    return avg_pars

In [49]:
def naive_ensembling(args, networks, test_loader):
    # simply average the weights in networks
    
    if width_ratio != 1:
        print("Unfortunately naive ensembling can't work if models are not of same shape!")
        return -1, None
    weights = [(1-ensemble_step), ensemble_step]
    avg_pars = get_avg_parameters(networks, weights)
    ensemble_network = MlpNet()
    # put on GPU
    if GPU_USED!=-1:
        ensemble_network = ensemble_network.cuda(GPU_USED)

    # check the test performance of the method before
    log_dict = {}
    log_dict['test_losses'] = []
    test(args, ensemble_network, test_loader, log_dict)

    # set the weights of the ensembled network
    for idx, (name, param) in enumerate(ensemble_network.state_dict().items()):
        ensemble_network.state_dict()[name].copy_(avg_pars[idx].data)

    # check the test performance of the method after ensembling
    log_dict = {}
    log_dict['test_losses'] = []
    
    return test(args, ensemble_network, test_loader, log_dict), ensemble_network

In [50]:
vanilla_acc, vanilla_ensemble_model = naive_ensembling(args, models, test_loader)


--------- Testing in global mode ---------


  return F.log_softmax(x)


size of test_loader dataset:  10000

Test set: Avg. loss: 2.3044, Accuracy: 810/10000 (8%)


--------- Testing in global mode ---------
size of test_loader dataset:  10000

Test set: Avg. loss: 2.0564, Accuracy: 5212/10000 (52%)



# Results

In [51]:
geometric_acc

80.97

In [52]:
vanilla_acc

52.12