In [1]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import models, datasets, transforms
from scipy.stats import wasserstein_distance

# Representation shift

Example code on how to calculate Representation Shift using Wasserstein distance. 

In this example, we are using a model pre-trained on Imagenet, and calculate the representation shift between in-distribution data (train and validation data respectivly from the TinyImageNet dataset) and between out-of-distribution data (TinyImageNet vs CIFAR10). 

Requirements: download Tiny ImageNet data, found here: https://www.kaggle.com/mikewallace250/tiny-imagenet-challenge

## Prepare data and model

In [2]:
path_to_data = 'data/TinyImageNet'
path_to_store_downloaded_data = 'data'

Using model pretrained on Imagenet, removing the final fully connected layer to get output from penultimate layer.

In [3]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
    
    def forward(self, x):
        return x

model = models.resnet18(pretrained=True)
setattr(model, 'fc', Identity())

In [4]:
# Download test data from open datasets.
trans = transforms.Compose([
    transforms.Resize(128),
    transforms.ToTensor()
])

tiny_imagenet_data = datasets.ImageFolder(root=f'{path_to_data}/train', transform=trans)
tiny_imagenet_data_val = datasets.ImageFolder(root=f'{path_to_data}/val', transform=trans)
cifar_data = datasets.CIFAR10(root=path_to_store_downloaded_data, train=False, download=True, transform=trans)

reference_dataloader = torch.utils.data.DataLoader(tiny_imagenet_data, batch_size=1)
indist_dataloader = torch.utils.data.DataLoader(tiny_imagenet_data_val, batch_size=1)
outdist_dataloader = torch.utils.data.DataLoader(cifar_data, batch_size=1)

Files already downloaded and verified


## Calculate Representation Shift

In [9]:
def extract_activations(model, dataloader, max_samples=1000):
    """
    Iterate though each (subset of) dataset and store activations.
    
    Parameters:
        model (torch.model): the model to evaluate, output representation of input image with size D
        dataloader (torch.utils.data.DataLoader): Dataloader, with batch size 1
        max_samples (int): number of samples to evaluate, N

    Returns:
        activations (numpy.array): Array of size NxD
    """
    model.eval()

    activations = []
    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx >= max_samples:
                break
            print(f'\r{idx}/{min(len(dataloader), max_samples)}', end="")
            out = model(batch[0])
            activations.extend(out.numpy())
    
    return np.asarray(activations)

def representation_shift(act_ref, act_test):
    """
    Calculate representation shift using Wasserstein distance
    
    Parameters:
        act_ref (numpy.array): Array of size NxD
        act_test (numpy.array): Array of size NxD

    Returns:
        representation_shift (float): Mean Wasserstein distance over all channels (D) 
    """
    wass_dist = [wasserstein_distance(act_ref[:, channel], act_test[:, channel]) for channel in range(act_ref.shape[1])]
    return np.asarray(wass_dist).mean()

In [6]:
# Get activations for a subset of each dataset
activations_ref = extract_activations(model, reference_dataloader)
activations_indist = extract_activations(model, indist_dataloader)
activations_outdist = extract_activations(model, outdist_dataloader)


999/1000

In [11]:
wass_dist_indist = representation_shift(activations_ref, activations_indist)
wass_dist_outdist =  representation_shift(activations_ref, activations_outdist)

print('Representation shift, in-distribution:', wass_dist_indist)
print('Representation shift, out-of-distribution:', wass_dist_outdist)

Representation shift, in-distribution: 0.1595269632828995
Representation shift, out-of-distribution: 0.2402791791188115
