In [13]:
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from functools import reduce
from matplotlib import colors
from matplotlib.ticker import MaxNLocator

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import Sampler

# plotting params
%matplotlib inline
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 12
plt.rcParams['figure.figsize'] = (13.0, 6.0)
sns.set_style("white")

data_dir = './data/'
plot_dir = './imgs/'
dump_dir = './dump/'

## Setup

In [2]:
# ensuring reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = False

In [3]:
GPU = False

device = torch.device("cuda" if GPU else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if GPU else {}

## Data Loader

We need to create a special dataloader for the experiment with shuffling. This is necessary because we need to keep track of each sample and shuffling loses that information.

To solve this, we can:

- create permutations of a list of numbers from 0 to 59,999 (the number of images in MNIST)
- create a sampler class that takes a list and interates over it sequentially
- at each epoch, create a dataloader with a sampler that gets fed the precomputed permutations

In [4]:
class LinearSampler(Sampler):
    def __init__(self, idx):
        self.idx = idx

    def __iter__(self):
        return iter(self.idx)

    def __len__(self):
        return len(self.idx)

In [5]:
def get_data_loader(data_dir, batch_size, permutation=None, num_workers=3, pin_memory=False):
    normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    transform = transforms.Compose([transforms.ToTensor(), normalize])
    dataset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)
    
    sampler = None
    if permutation is not None:
        sampler = LinearSampler(permutation)

    loader = torch.utils.data.DataLoader(
        dataset, batch_size=batch_size,
        shuffle=False, num_workers=num_workers,
        pin_memory=pin_memory, sampler=sampler
    )

    return loader

## Model

In [6]:
class SmallConv(nn.Module):
    def __init__(self):
        super(SmallConv, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        out = F.relu(F.max_pool2d(self.conv1(x), 2))
        out = F.relu(F.max_pool2d(self.conv2(out), 2))
        out = out.view(-1, 320)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return F.log_softmax(out, dim=1)

## Utility Functions

In [8]:
def accuracy(predicted, ground_truth):
    predicted = torch.max(predicted, 1)[1]
    total = len(ground_truth)
    correct = (predicted == ground_truth).sum().double()
    acc = 100 * (correct / total)
    return acc.item()

def loss_vs_gradnorm(list_stats):
    flattened = [val for sublist in list_stats for val in sublist]
    sorted_idx = sorted(range(len(flattened)), key=lambda k: flattened[k][1][0])
    losses = [flattened[idx][1][1].item() for idx in sorted_idx]
    return losses

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    
    epoch_stats = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        # forward pass
        output = model(data)
        acc = accuracy(output, target)
        
        # compute batch loss and gradient norm
        losses = F.nll_loss(output, target, reduction='none')
        indices = [batch_idx*len(data) + i for i in range(len(data))]
        
        batch_stats = []
        for i, l in zip(indices, losses):
            batch_stats.append([i, l])
        epoch_stats.append(batch_stats)
            
        # take average loss
        loss = losses.mean()
        
        # backwards pass
        loss.backward()
        optimizer.step()
        
        if batch_idx % 25 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.2f}%'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item(), acc))

    return epoch_stats

## Without Shuffling

In [12]:
num_epochs = 5
learning_rate = 1e-3
mom = 0.99
batch_size = 64

In [None]:
torch.manual_seed(SEED)

model = SmallConv().to(device)

# relu init
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in')

# define optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=mom)
train_loader = get_data_loader(data_dir, batch_size, None, **kwargs)

stats_no_shuffling = []
for epoch in range(1,  num_epochs+1):
    stats_no_shuffling.append(train(model, device, train_loader, optimizer, epoch))
pickle.dump(stats_no_shuffling, open("./no_shuffling.pkl", "wb"))

## With Shuffling

In [11]:
num_epochs = 5
learning_rate = 1e-3
mom = 0.99
batch_size = 64

In [10]:
# create permutations
permutations = []
permutations.append(list(np.arange(60000)))

x = list(np.arange(60000))
np.random.seed(SEED)

for _ in range(num_epochs-1):
    np.random.shuffle(x)
    permutations.append(x.copy())

In [None]:
torch.manual_seed(SEED)

model = SmallConv().to(device)

# relu init
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in')

# define optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=mom)

stats_with_shuffling = []
for epoch in range(1,  num_epochs+1):
    train_loader = get_data_loader(data_dir, batch_size, permutations[epoch-1], **kwargs)
    stats_with_shuffling.append(train(model, device, train_loader, optimizer, epoch))
pickle.dump(stats_with_shuffling, open("./with_shuffling.pkl", "wb"))

## Plot Densities

In [None]:
stats = pickle.load(open(dump_dir + "no_shuffling_mnist.p", "rb"))
# permutations = pickle.load(open(dump_dir + "permutations_mnist.p", "rb"))
permutations = [np.arange(60000)]*len(stats)

In [None]:
# grab losses and grads over all epochs
grads = []
losses = []

for s in stats:
    g, l = losses_grads(s)
    grads.append(g)
    losses.append(l)
    
grads = [item for sublist in grads for item in sublist]
losses = [item for sublist in losses for item in sublist]

## Sorted Indices Quantile Distribution

Track the evolution of sorted indices across time. Use 10 quantiles for example, bin indices into quantiles, and examine mean and variance of quantile for each sample. This will help me get an idea of how stable the importance of a sample is. If it is important in the initial epochs, does this hold true in future epochs or is this random?

In [None]:
# remap the indices based on the `permutations` list
fixed = []
for i in range(len(stats)):
    s = stats[i]
    flattened = [val for sublist in s for val in sublist]
    for j in range(len(flattened)):
        flattened[j][0] = permutations[i][j]
    fixed.append(flattened)

In [None]:
# resort in increasing index order
for i in range(len(fixed)):
    fixed[i] = sorted(fixed[i], key=lambda x: x[0])

In [None]:
def percentage_split(seq, percentages):
    cdf = np.cumsum(percentages)
    assert np.allclose(cdf[-1], 1.0)
    stops = list(map(int, cdf * len(seq)))
    return [seq[a:b] for a, b in zip([0]+stops, stops)]

def idx_evolution(all_epochs):
    percentile_splits = []
    for epoch in all_epochs:
        # sort in descending order based on loss
        sorted_loss_idx = sorted(range(len(epoch)), key=lambda k: epoch[k][1][0,1], reverse=True)
        splits = percentage_split(sorted_loss_idx, [0.1]*10)
        percentile_splits.append(splits)
    return percentile_splits

In [None]:
# get percentile splits for all 5 epochs
percentile_splits = idx_evolution(fixed)
num_quantiles = len(percentile_splits[0])

percent_matches = []
# for each quantile
for i in range(num_quantiles):
    percentile_all = []
    # decide over how many epochs to compare
    for j in range(1, len(percentile_splits)):
        percentile_all.append(percentile_splits[j][i])
    matching = reduce(np.intersect1d, percentile_all)
    percent = 100 * len(matching) / len(percentile_all[0])
    percent_matches.append(percent)
    
for perc in percent_matches:
    print("{0:.2f}%".format(perc), end=", ")

In [None]:
# sns.reset_orig()
fig, ax = plt.subplots(figsize=(3,4))
ax.bar(range(1, len(percent_matches)+1), percent_matches, width=0.9, color='r')
ax.set_xlabel('Quantile')
ax.set_ylabel('Percent Match Across Epochs')
ax.set_title('All Epochs (No Shuffling)')
ax.yaxis.set_major_locator(MaxNLocator(integer=True))
plt.tight_layout()
plt.savefig(plot_dir + "percent_match_no_shuffling_all.png", format="png", dpi=300)

In [None]:
mnist_size = 60000
cifar_size = 50000

permutations = []
permutations.append(list(np.arange(mnist_size)))
x = list(np.arange(mnist_size))
np.random.seed(seed)
for i in range(4):
    np.random.shuffle(x)
    permutations.append(x.copy())

In [None]:
torch.manual_seed(0)

stats = []
for epoch in range(1,  num_epochs+1):
    train_loader = get_train_loader_random(
        data_dir, dataset, batch_size, permutations[epoch-1],
        **kwargs
    )
    stats.append(train(epoch, train_loader))

In [None]:
# pickle.dump(permutations, open(dump_dir + "permutations_mnist.p", "wb"))