# clem.pytorch: *Continual Learning using Episodic Memory in PyTorch*

In [None]:
import numpy as np
import torch

from learners import Learner, GEM, AGEM, ER

In [None]:
seed = 42
n_tasks = 5

### Download MNIST

In [None]:
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import subprocess

mnist_path = "data/mnist.npz"

if not os.path.exists('data'):
    os.mkdir('data')

# URL from: https://github.com/fchollet/keras/blob/master/keras/datasets/mnist.py
if not os.path.exists(mnist_path):
    subprocess.call("wget https://s3.amazonaws.com/img-datasets/mnist.npz", shell=True)
    !mv mnist.npz data/

f = np.load('data/mnist.npz')
x_tr = torch.from_numpy(f['x_train'])
y_tr = torch.from_numpy(f['y_train']).long()
x_te = torch.from_numpy(f['x_test'])
y_te = torch.from_numpy(f['y_test']).long()
f.close()

torch.save((x_tr, y_tr), 'data/mnist_train.pt')
torch.save((x_te, y_te), 'data/mnist_test.pt')

### Preprocessing and Train/Test Split

In [None]:
torch.manual_seed(seed)

x_tr, y_tr = torch.load('data/mnist_train.pt') # 60000 samples
x_te, y_te = torch.load('data/mnist_test.pt') # 10000 samples

# reshape and normalize data
x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0
x_te = x_te.float().view(x_te.size(0), -1) / 255.0
y_tr = y_tr.view(-1).long()
y_te = y_te.view(-1).long()

# shuffle datasets
p_tr = torch.randperm(x_tr.size(0))
p_te = torch.randperm(x_te.size(0))

x_tr, y_tr = x_tr[p_tr], y_tr[p_tr]
x_te, y_te = x_te[p_te], y_te[p_te]

### Split MNIST

In [None]:
tr_task_size = 10000
te_task_size = 2000

tasks_tr = []
tasks_te = []

for t in range(n_tasks):
    tasks_tr.append([x_tr[t*tr_task_size:(t+1)*tr_task_size], y_tr[t*tr_task_size:(t+1)*tr_task_size]])
    tasks_te.append([x_te[t*te_task_size:(t+1)*te_task_size], y_te[t*te_task_size:(t+1)*te_task_size]])

torch.save([tasks_tr, tasks_te], 'data/mnist_splitted.pt')
torch.save([[x_tr[:(tr_task_size*n_tasks)], y_tr[:(tr_task_size*n_tasks)]],
            [x_te[:(te_task_size*n_tasks)], y_te[:(te_task_size*n_tasks)]]], 'data/mnist_all.pt')

### Skewed Split: For simulating training on unbalanced datasets

In [None]:
from collections import Counter

# probability for each class in each split
# each row correspond to a split. each column correspond to a class (0-9)
# a cell tells what percentage of data to get from a class, to include in a split

def skewed_split(X, y, class_probs):
    '''
    '''
    count_dict = Counter(y.numpy()) # count_dict[class] = num_of_data_in_class
    indices_per_class = [(y==c).nonzero().squeeze() for c in range(len(count_dict))]
    # generate random indices TO INDEX THE ACTUAL INDICES for each class
    idxs = [torch.randperm(count_dict[i]) for i in range(len(count_dict))]
    for prob_set in class_probs:
        idxs_to_get = []
        for i in range(len(prob_set)):
            end_idx = int(prob_set[i]*count_dict[i])
            idxs_to_get.append(indices_per_class[i][idxs[i][:end_idx]])
            # update indices, we treat the idxs like a stack where we
            # remove indices we have already used
            idxs[i] = idxs[i][end_idx:]
        
        idxs_to_get = torch.cat(idxs_to_get)
        yield X[idxs_to_get], y[idxs_to_get]


In [None]:
def gen_prob_dist(dom_prob):
    ''' Function for generating a skewed probability distribution
    for each task. This outputs a 5x10 list matrix where each row
    correspond to a task, and each column correspond to a class.
    Each value represents the percentage of samples of a class
    that will be assigned to a task. Concretely, a value of 0.6 at 
    (row 3, column 2), index starting at 0, means 60% of MNIST training
    data labelled as '2' will be assigned to Task 4.
    
    Each distribution has 2 dominant classes i.e. classes
    with the largest probability, whose probabilities are dictated by
    the input variable `dom_prob`. For instance, if dom_prob=0.9, row 1
    will have 90% of samples from classes '2' & '3'. The remaining 10%
    shall then be distributed equally to other tasks i.e. 2.5% for rows 0,2-4.
    Following this way of distribution, the full row 1, in this example
    shall be: [0.025, 0.025, 0.9, 0.9, 0.025, 0.025, 0.025, 0.025, 0.025, 0.025]
    '''
    min_prob =  (1.0 - dom_prob) / 4.0
    prob_dist = [[min_prob] * 10 for t in range(n_tasks)]
    for t in range(n_tasks):
        prob_dist[t][t*2] = dom_prob
        prob_dist[t][(t*2)+1] = dom_prob
      
    return prob_dist

In [None]:
# check generated probabilities
sample_probs = gen_prob_dist(0.6)
for task_probs in sample_probs:
    print(task_probs)

In [None]:
# check generated splits
for new_x, new_y in skewed_split(x_te, y_te, sample_probs):
    print(dict(Counter(new_y.numpy())))

### ML Model

In [None]:
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, drop_prob, output_size):
        super(Classifier, self).__init__()
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(drop_prob)

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out = self.dropout(self.relu(self.fc1(x)))
        logits = self.fc2(out)
        
        return logits

In [None]:
# MNIST
input_size = 784
output_size = 10

# Hyper-parameters
hidden_size = 256
drop_prob = 0.8
num_epochs = 5
learning_rate = 0.001
batch_size = 128

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Non-continual Baseline

In [None]:
from tqdm import tqdm
import torch.nn.functional as f
from torch.utils.data import TensorDataset, DataLoader

np.random.seed(42)
torch.manual_seed(42)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(42)

train_data = TensorDataset(x_tr, y_tr)
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_data = TensorDataset(x_te, y_te)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)

model = Classifier(input_size, hidden_size, drop_prob, output_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for ep in tqdm(range(num_epochs)):
    for inputs, labels in train_loader:
        if device.type == 'cuda':
            inputs, labels = inputs.cuda(), labels.cuda()
    
        optimizer.zero_grad()

        out = model(inputs.float())
        loss = criterion(out, labels.long())
        loss.backward()

        optimizer.step()

model.eval()
val_loss = 0
corrects = 0
total = 0
for inputs, labels in test_loader:
    if device.type == 'cuda':
        inputs, labels = inputs.cuda(), labels.cuda()

    out = model(inputs.float())
    preds = torch.argmax(f.softmax(out, dim=-1), dim=-1).cpu().numpy()

    tmp_val_loss = criterion(out, labels.long())
    val_loss += tmp_val_loss.item()

    corrects += sum(preds == labels.cpu().numpy())
    total += len(preds)

print("Loss: {:.6f}, Acc: {:.6f}".format(val_loss/len(test_loader), (corrects/total)*100))

## Skewed Splits

### Continual Baseline

In [None]:
import time
from tqdm import tqdm
import torch.nn.functional as f
from torch.utils.data import TensorDataset, DataLoader

# constant validation data across tasks
test_data = TensorDataset(x_te, y_te)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)

In [None]:
def test_continual_learner(learner_class, class_probs, use_memory=False):
    ''' Tester for continual learners
    '''
    np.random.seed(42)
    torch.manual_seed(42)
    if device.type == 'cuda':
        torch.cuda.manual_seed_all(42)

    task_perm_final_accs = []

    # continual learning is performed n_tasks(5) times
    # for more reliable results
    for t in range(n_tasks):
        
        # initialize models
        model = Classifier(input_size, hidden_size, drop_prob, output_size).to(device)
        criterion = nn.CrossEntropyLoss()
        if use_memory:
            learner = learner_class(model, criterion, device=device,
                                    memory_capacity=memory_capacity, memory_sample_sz=memory_sample_size)
        else:
            learner = learner_class(model, criterion, device=device)
        
        # task loop
        np.random.shuffle(class_probs)
        for T_x, T_y in skewed_split(x_tr, y_tr, class_probs):
            train_data = TensorDataset(T_x, T_y)
            train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

            learner.prepare(optimizer=torch.optim.Adam, lr=learning_rate)

            model.train()
            for ep in tqdm(range(num_epochs)):
                for inputs, labels in train_loader:
                    if device.type == 'cuda':
                        inputs, labels = inputs.cuda(), labels.cuda()

                    learner.run(inputs, labels)

            if use_memory:
                # remember a subset
                learner.remember(train_data, min_save_sz=task_memory_size)

            model.eval()
            val_loss = 0
            corrects = 0
            total = 0
            for inputs, labels in test_loader:
                if device.type == 'cuda':
                    inputs, labels = inputs.cuda(), labels.cuda()

                out = model(inputs.float())
                preds = torch.argmax(f.softmax(out, dim=-1), dim=-1).cpu().numpy()

                tmp_val_loss = criterion(out, labels.long())
                val_loss += tmp_val_loss.item()

                corrects += sum(preds == labels.cpu().numpy())
                total += len(preds)

            print("Loss: {:.6f}, Acc: {:.6f}".format(val_loss/len(test_loader), (corrects/total)*100))

        task_perm_final_accs.append((corrects/total)*100) # save final accuracy in current task permutation
        
    print("Final Accs: ", task_perm_final_accs, " Average Final Acc: ", np.array(task_perm_final_accs).mean())

In [None]:
# use base class of learners. this trains model
# continually without use of the implemented continual learning methods

class_probs = gen_prob_dist(dom_prob=0.9)
st = time.time()
test_continual_learner(Learner, class_probs, use_memory=False)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

### GEM

In [None]:
memory_capacity = 10240
task_memory_size = 2048
memory_sample_size = 64

In [None]:
st = time.time()
test_continual_learner(GEM, class_probs, use_memory=True)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

### A-GEM

In [None]:
st = time.time()
test_continual_learner(AGEM, class_probs, use_memory=True)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

### Experience Replay

In [None]:
# we halve the memory sampling size the adjust the batch size
# so that the number of samples to use for the actual weight update
# will still be consistent with the other learning methods
memory_sample_size = int(memory_sample_size/2)
batch_size = int(batch_size - memory_sample_size)
print(memory_sample_size, batch_size)

In [None]:
st = time.time()
test_continual_learner(ER, class_probs, use_memory=True)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

In [None]:
# revert to original values
batch_size = int(batch_size + memory_sample_size)
memory_sample_size = int(memory_sample_size*2)
print(memory_sample_size, batch_size)

## Class Splits

### Continual Baseline

In [None]:
class_probs = gen_prob_dist(dom_prob=1.0)
st = time.time()
test_continual_learner(Learner, class_probs, use_memory=False)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

### GEM

In [None]:
st = time.time()
test_continual_learner(GEM, class_probs, use_memory=True)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

### A-GEM

In [None]:
st = time.time()
test_continual_learner(AGEM, class_probs, use_memory=True)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

### Experience Replay

In [None]:
# we halve the memory sampling size the adjust the batch size
# so that the number of samples to use for the actual weight update
# will still be consistent with the other learning methods
memory_sample_size = int(memory_sample_size/2)
batch_size = int(batch_size - memory_sample_size)
print(memory_sample_size, batch_size)

In [None]:
st = time.time()
test_continual_learner(ER, class_probs, use_memory=True)
print("Elapsed: %.6f s" % ((time.time() - st)/n_tasks)) # divide to get average

In [None]:
# revert to original values
batch_size = int(batch_size + memory_sample_size)
memory_sample_size = int(memory_sample_size*2)
print(memory_sample_size, batch_size)