In [2]:
import numpy as np
import torch

from learners import Learner, GEM, AGEM

In [5]:
seed = 42
n_tasks = 5

### Download MNIST

In [4]:
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import subprocess

mnist_path = "data/mnist.npz"

# URL from: https://github.com/fchollet/keras/blob/master/keras/datasets/mnist.py
if not os.path.exists(mnist_path):
    subprocess.call("wget https://s3.amazonaws.com/img-datasets/mnist.npz", shell=True)
    !mv mnist.npz data/

f = np.load('data/mnist.npz')
x_tr = torch.from_numpy(f['x_train'])
y_tr = torch.from_numpy(f['y_train']).long()
x_te = torch.from_numpy(f['x_test'])
y_te = torch.from_numpy(f['y_test']).long()
f.close()

torch.save((x_tr, y_tr), 'data/mnist_train.pt')
torch.save((x_te, y_te), 'data/mnist_test.pt')

### Preprocessing and Train/Test Split

In [None]:
torch.manual_seed(seed)

x_tr, y_tr = torch.load('data/mnist_train.pt') # 60000 samples
x_te, y_te = torch.load('data/mnist_test.pt') # 10000 samples

# reshape and normalize data
x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0
x_te = x_te.float().view(x_te.size(0), -1) / 255.0
y_tr = y_tr.view(-1).long()
y_te = y_te.view(-1).long()

# shuffle datasets
p_tr = torch.randperm(x_tr.size(0))
p_te = torch.randperm(x_te.size(0))

x_tr, y_tr = x_tr[p_tr], y_tr[p_tr]
x_te, y_te = x_te[p_te], y_te[p_te]

### Split MNIST

In [None]:
tr_task_size = 10000
te_task_size = 2000

tasks_tr = []
tasks_te = []

for t in range(n_tasks):
    tasks_tr.append([x_tr[t*tr_task_size:(t+1)*tr_task_size], y_tr[t*tr_task_size:(t+1)*tr_task_size]])
    tasks_te.append([x_te[t*te_task_size:(t+1)*te_task_size], y_te[t*te_task_size:(t+1)*te_task_size]])

torch.save([tasks_tr, tasks_te], 'data/mnist_splitted.pt')
torch.save([[x_tr[:(tr_task_size*n_tasks)], y_tr[:(tr_task_size*n_tasks)]],
            [x_te[:(te_task_size*n_tasks)], y_te[:(te_task_size*n_tasks)]]], 'data/mnist_all.pt')

### Skewed Split: For simulating training on unbalanced datasets

In [None]:
from collections import Counter

# probability for each class in each split
# each row correspond to a split. each column correspond to a class (0-9)
# a cell tells what percentage of data to get from a class, to include in a split
class_probs = [
    [0.6, 0.6, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
    [0.1, 0.1, 0.6, 0.6, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 0.6, 0.6, 0.1, 0.1, 0.1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.6, 0.6, 0.1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.6, 0.6],
]

def skewed_split(X, y, class_probs):
    '''
    '''
    count_dict = Counter(y.numpy())
    
    idxs = [torch.randperm(count_dict[i]) for i in range(len(count_dict))]
    for prob_set in class_probs:
        idxs_to_get = []
        for i in len(prob_set):
            end_idx = int(prob_set[i]*count_dict[i])
            idxs_to_get.append(idxs[i][:end_idx])
            idxs[i] = idxs[i][end_idx:] # update indices
        
        idxs_to_get = torch.cat(idxs_to_get)
        # implement get from data here

### ML Model

In [6]:
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, drop_prob, output_size):
        super(Classifier, self).__init__()
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(drop_prob)

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out = self.dropout(self.relu(self.fc1(x)))
        logits = self.fc2(out)
        
        return logits

### Non-continual Baseline

### Continual Baseline