In [1]:
import time
import pickle
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

from functools import reduce
from matplotlib import colors
from matplotlib.ticker import MaxNLocator

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import Sampler, WeightedRandomSampler

# plotting params
%matplotlib inline
plt.rcParams['font.size'] = 10
plt.rcParams['axes.labelsize'] = 10
plt.rcParams['axes.titlesize'] = 10
plt.rcParams['xtick.labelsize'] = 8
plt.rcParams['ytick.labelsize'] = 8
plt.rcParams['legend.fontsize'] = 10
plt.rcParams['figure.titlesize'] = 12
plt.rcParams['figure.figsize'] = (13.0, 6.0)
sns.set_style("white")

data_dir = './data/'
plot_dir = './imgs/'
dump_dir = './dump/'

In [2]:
# ensuring reproducibility
SEED = 42
torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = False

In [3]:
GPU = False

device = torch.device("cuda" if GPU else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if GPU else {}

## Data Loader

In [4]:
class LinearSampler(Sampler):
    def __init__(self, idx):
        self.idx = idx

    def __iter__(self):
        return iter(self.idx)

    def __len__(self):
        return len(self.idx)

In [5]:
def get_data_loader(data_dir, batch_size, permutation=None, num_workers=3, pin_memory=False):
    normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    transform = transforms.Compose([transforms.ToTensor(), normalize])
    dataset = MNIST(root=data_dir, train=True, download=True, transform=transform)
    
    sampler = None
    if permutation is not None:
        sampler = LinearSampler(permutation)

    loader = DataLoader(
        dataset, batch_size=batch_size,
        shuffle=False, num_workers=num_workers,
        pin_memory=pin_memory, sampler=sampler
    )

    return loader

def get_weighted_loader(data_dir, batch_size, weights, num_workers=3, pin_memory=False):
    normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    transform = transforms.Compose([transforms.ToTensor(), normalize])
    dataset = MNIST(root=data_dir, train=True, download=True, transform=transform)
    
    sampler = WeightedRandomSampler(weights, len(weights), True)

    loader = DataLoader(
        dataset, batch_size=batch_size,
        shuffle=False, num_workers=num_workers,
        pin_memory=pin_memory, sampler=sampler
    )

    return loader

def get_test_loader(data_dir, batch_size, num_workers=3, pin_memory=False):
    normalize = transforms.Normalize(mean=(0.1307,), std=(0.3081,))
    transform = transforms.Compose([transforms.ToTensor(), normalize])
    dataset = MNIST(root=data_dir, train=False, download=True, transform=transform)
    loader = DataLoader(
        dataset, batch_size=batch_size,
        shuffle=False, num_workers=num_workers,
        pin_memory=pin_memory,
    )
    return loader

## Model

In [6]:
class SmallConv(nn.Module):
    def __init__(self):
        super(SmallConv, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        out = F.relu(F.max_pool2d(self.conv1(x), 2))
        out = F.relu(F.max_pool2d(self.conv2(out), 2))
        out = out.view(-1, 320)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return F.log_softmax(out, dim=1)

## Utility Functions

In [7]:
def accuracy(predicted, ground_truth):
    predicted = torch.max(predicted, 1)[1]
    total = len(ground_truth)
    correct = (predicted == ground_truth).sum().double()
    acc = 100 * (correct / total)
    return acc.item()

def train_transient(model, device, train_loader, optimizer, epoch, track=False):
    model.train()
    epoch_stats = []
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        acc = accuracy(output, target)
        losses = F.nll_loss(output, target, reduction='none')
        if track:
            indices = [batch_idx*train_loader.batch_size + i for i in range(len(data))]
            batch_stats = []
            for i, l in zip(indices, losses):
                batch_stats.append([i, l.item()])
            epoch_stats.append(batch_stats)
        loss = losses.mean()
        loss.backward()
        optimizer.step()
        if batch_idx % 25 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.2f}%'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item(), acc))
    if track:
        return epoch_stats
    return None

def train_steady_state(model, device, train_loader, optimizer, epoch):
    model.train()
    
    seen = []
    for batch_idx, (data, target) in enumerate(train_loader):
        # forward pass and compute losses
        output = model(data)
        losses = F.nll_loss(output, target, reduction='none')

        # compute importance weights
        probas = (losses / losses.sum())
        idxs = np.random.choice(len(data), len(data), p=probas.cpu().data.numpy())
        seen.extend(list(idxs+batch_idx*train_loader.batch_size))

        idxs = torch.from_numpy(idxs).long()
        new_pdf = probas[idxs]
        old_pdf = 1. / len(data)
        weight = old_pdf / new_pdf

        # resample
        data_r = data.detach()[idxs]
        target_r = target.detach()[idxs]

        # forward pass
        output_r = model(data_r)
        acc = accuracy(output_r, target_r)
        loss = F.nll_loss(output_r, target_r, reduction='none')

        # reweight losses
        loss = (loss * weight).mean()

        # backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % 25 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.2f}%'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item(), acc))

    return seen

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [8]:
num_epochs_transient = 2
num_epochs_steady = 3
learning_rate = 1e-3
mom = 0.99
batch_size = 64
normalize = False
perc_to_remove = 10

In [9]:
torch.manual_seed(SEED)

# instantiate convnet
model = SmallConv().to(device)

# relu init
for m in model.modules():
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_in')

# define optimizer
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=mom)

# instantiate loaders
train_loader = get_data_loader(data_dir, batch_size, None, **kwargs)
test_loader = get_test_loader(data_dir, 128, **kwargs)

In [10]:
# transient training
losses = None
for epoch in range(1, num_epochs_transient+1):
    if epoch == 1:
        losses = train_transient(model, device, train_loader, optimizer, epoch, track=True)
    else:
        train_transient(model, device, train_loader, optimizer, epoch)
    test(model, device, test_loader)


Test set: Average loss: 0.0946, Accuracy: 9692/10000 (97%)


Test set: Average loss: 0.0628, Accuracy: 9803/10000 (98%)



In [11]:
for epoch in range(num_epochs_transient, num_epochs_steady+1):
    losses = [v for sublist in losses for v in sublist]
    sorted_loss_idx = sorted(range(len(losses)), key=lambda k: losses[k][1], reverse=True)
    sorted_loss_idx = sorted_loss_idx[:-int((perc_to_remove / 100) * len(sorted_loss_idx))]
    sorted_loss_idx.sort()
    weights = [losses[idx][1] for idx in sorted_loss_idx]
    if normalize:
        max_w = max(weights)
        weights = [w / max_w for w in weights]
    train_loader = get_weighted_loader(data_dir, batch_size, weights, **kwargs)
    print("\t[*] Effective Size: {:,}".format(len(train_loader.sampler)))
    losses = train_transient(model, device, train_loader, optimizer, epoch, track=True)
    test(model, device, test_loader)

	[*] Effective Size: 54,000

Test set: Average loss: 0.0597, Accuracy: 9814/10000 (98%)

	[*] Effective Size: 48,600

Test set: Average loss: 0.0654, Accuracy: 9822/10000 (98%)

