In [23]:
import numpy as np
import torch

from learners import Learner, GEM, AGEM

In [2]:
seed = 42
n_tasks = 5

### Download MNIST

In [3]:
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import subprocess

mnist_path = "data/mnist.npz"

# URL from: https://github.com/fchollet/keras/blob/master/keras/datasets/mnist.py
if not os.path.exists(mnist_path):
    subprocess.call("wget https://s3.amazonaws.com/img-datasets/mnist.npz", shell=True)
    !mv mnist.npz data/

f = np.load('data/mnist.npz')
x_tr = torch.from_numpy(f['x_train'])
y_tr = torch.from_numpy(f['y_train']).long()
x_te = torch.from_numpy(f['x_test'])
y_te = torch.from_numpy(f['y_test']).long()
f.close()

torch.save((x_tr, y_tr), 'data/mnist_train.pt')
torch.save((x_te, y_te), 'data/mnist_test.pt')

### Preprocessing and Train/Test Split

In [4]:
torch.manual_seed(seed)

x_tr, y_tr = torch.load('data/mnist_train.pt') # 60000 samples
x_te, y_te = torch.load('data/mnist_test.pt') # 10000 samples

# reshape and normalize data
x_tr = x_tr.float().view(x_tr.size(0), -1) / 255.0
x_te = x_te.float().view(x_te.size(0), -1) / 255.0
y_tr = y_tr.view(-1).long()
y_te = y_te.view(-1).long()

# shuffle datasets
p_tr = torch.randperm(x_tr.size(0))
p_te = torch.randperm(x_te.size(0))

x_tr, y_tr = x_tr[p_tr], y_tr[p_tr]
x_te, y_te = x_te[p_te], y_te[p_te]

### Split MNIST

In [None]:
tr_task_size = 10000
te_task_size = 2000

tasks_tr = []
tasks_te = []

for t in range(n_tasks):
    tasks_tr.append([x_tr[t*tr_task_size:(t+1)*tr_task_size], y_tr[t*tr_task_size:(t+1)*tr_task_size]])
    tasks_te.append([x_te[t*te_task_size:(t+1)*te_task_size], y_te[t*te_task_size:(t+1)*te_task_size]])

torch.save([tasks_tr, tasks_te], 'data/mnist_splitted.pt')
torch.save([[x_tr[:(tr_task_size*n_tasks)], y_tr[:(tr_task_size*n_tasks)]],
            [x_te[:(te_task_size*n_tasks)], y_te[:(te_task_size*n_tasks)]]], 'data/mnist_all.pt')

### Skewed Split: For simulating training on unbalanced datasets

In [28]:
from collections import Counter

# probability for each class in each split
# each row correspond to a split. each column correspond to a class (0-9)
# a cell tells what percentage of data to get from a class, to include in a split
class_probs = [
    [0.6, 0.6, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
    [0.1, 0.1, 0.6, 0.6, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 0.6, 0.6, 0.1, 0.1, 0.1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.6, 0.6, 0.1, 0.1],
    [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.6, 0.6],
]

def skewed_split(X, y, class_probs):
    '''
    '''
    count_dict = Counter(y.numpy()) # count_dict[class] = num_of_data_in_class
    indices_per_class = [(y==c).nonzero().squeeze() for c in range(len(count_dict))]
    # generate random indices TO INDEX THE ACTUAL INDICES for each class
    idxs = [torch.randperm(count_dict[i]) for i in range(len(count_dict))]
    for prob_set in class_probs:
        idxs_to_get = []
        for i in range(len(prob_set)):
            end_idx = int(prob_set[i]*count_dict[i])
            idxs_to_get.append(indices_per_class[i][idxs[i][:end_idx]])
            # update indices, we treat the idxs like a stack where we
            # remove indices we have already used
            idxs[i] = idxs[i][end_idx:]
        
        idxs_to_get = torch.cat(idxs_to_get)
        yield X[idxs_to_get], y[idxs_to_get]


# print(Counter(y_te.numpy()))        
# for new_x, new_y in skewed_split(x_te, y_te, class_probs):
#     print(Counter(new_y.numpy()))

Counter({1: 1135, 2: 1032, 7: 1028, 3: 1010, 9: 1009, 4: 982, 0: 980, 8: 974, 6: 958, 5: 892})
Counter({1: 681, 0: 588, 2: 103, 7: 102, 3: 101, 9: 100, 4: 98, 8: 97, 6: 95, 5: 89})
Counter({2: 619, 3: 606, 1: 113, 7: 102, 9: 100, 0: 98, 4: 98, 8: 97, 6: 95, 5: 89})
Counter({4: 589, 5: 535, 1: 113, 2: 103, 7: 102, 3: 101, 9: 100, 0: 98, 8: 97, 6: 95})
Counter({7: 616, 6: 574, 1: 113, 2: 103, 3: 101, 9: 100, 0: 98, 4: 98, 8: 97, 5: 89})
Counter({9: 605, 8: 584, 1: 113, 2: 103, 7: 102, 3: 101, 0: 98, 4: 98, 6: 95, 5: 89})


### ML Model

In [29]:
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, input_size, hidden_size, drop_prob, output_size):
        super(Classifier, self).__init__()
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(drop_prob)

        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
        
    def forward(self, x):
        out = self.dropout(self.relu(self.fc1(x)))
        logits = self.fc2(out)
        
        return logits

In [33]:
# Hyper-parameters
input_size = 784
output_size = 10

hidden_size = 256
drop_prob = 0.8
num_epochs = 5
learning_rate = 0.001
batch_size = 128

In [34]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Non-continual Baseline

In [47]:
from tqdm import tqdm
import torch.nn.functional as f
from torch.utils.data import TensorDataset, DataLoader

np.random.seed(42)
torch.manual_seed(42)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(42)

train_data = TensorDataset(x_tr, y_tr)
train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)
test_data = TensorDataset(x_te, y_te)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)

model = Classifier(input_size, hidden_size, drop_prob, output_size).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model.train()
for ep in tqdm(range(num_epochs)):
    for inputs, labels in train_loader:
        if device.type == 'cuda':
            inputs, labels = inputs.cuda(), labels.cuda()
    
        optimizer.zero_grad()

        out = model(inputs.float())
        loss = criterion(out, labels.long())
        loss.backward()

        optimizer.step()

model.eval()
val_loss = 0
corrects = 0
total = 0
for inputs, labels in test_loader:
    if device.type == 'cuda':
        inputs, labels = inputs.cuda(), labels.cuda()

    out = model(inputs.float())
    preds = torch.argmax(f.softmax(out, dim=-1), dim=-1).cpu().numpy()

    tmp_val_loss = criterion(out, labels.long())
    val_loss += tmp_val_loss.item()

    corrects += sum(preds == labels.cpu().numpy())
    total += len(preds)

print("Loss: {:.6f}, Acc: {:.6f}".format(val_loss/len(test_loader), (corrects/total)*100))

100%|██████████| 5/5 [00:07<00:00,  1.44s/it]


Loss: 0.132289, Acc: 96.040000


### Continual Baseline

In [53]:
from tqdm import tqdm
import torch.nn.functional as f
from torch.utils.data import TensorDataset, DataLoader

np.random.seed(42)
torch.manual_seed(42)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(42)

all_trn_f1_mean = np.array([])
all_val_f1_mean = np.array([])

task_perm_final_accs = []

# constant validation data across tasks
test_data = TensorDataset(x_te, y_te)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)

class_probs = [
    [0.025] * 10,
    [0.025] * 10,
    [0.025] * 10,
    [0.025] * 10,
    [0.025] * 10,
]

a = 0.9
class_probs[0][0], class_probs[0][1] = a, a
class_probs[1][2], class_probs[1][3] = a, a
class_probs[2][4], class_probs[2][5] = a, a
class_probs[3][6], class_probs[3][7] = a, a
class_probs[4][8], class_probs[4][9] = a, a

for t in range(n_tasks):
    
    # initialize models
    model = Classifier(input_size, hidden_size, drop_prob, output_size).to(device)
    criterion = nn.CrossEntropyLoss()
    learner = Learner(model, criterion, device=device)
    
    # task loop
    np.random.shuffle(class_probs)
    for T_x, T_y in skewed_split(x_tr, y_tr, class_probs):
        train_data = TensorDataset(T_x, T_y)
        train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

        learner.prepare(optimizer=torch.optim.Adam, lr=learning_rate)

        model.train()
        for ep in tqdm(range(num_epochs)):
            for inputs, labels in train_loader:
                if device.type == 'cuda':
                    inputs, labels = inputs.cuda(), labels.cuda()

                learner.run(inputs, labels)


        model.eval()
        val_loss = 0
        corrects = 0
        total = 0
        for inputs, labels in test_loader:
            if device.type == 'cuda':
                inputs, labels = inputs.cuda(), labels.cuda()

            out = model(inputs.float())
            preds = torch.argmax(f.softmax(out, dim=-1), dim=-1).cpu().numpy()

            tmp_val_loss = criterion(out, labels.long())
            val_loss += tmp_val_loss.item()

            corrects += sum(preds == labels.cpu().numpy())
            total += len(preds)

        print("Loss: {:.6f}, Acc: {:.6f}".format(val_loss/len(test_loader), (corrects/total)*100))

    task_perm_final_accs.append((corrects/total)*100) # save final accuracy in current task permutation
    
print("Final Accs: ", task_perm_final_accs, " Average Final Acc: ", np.array(task_perm_final_accs).mean())

100%|██████████| 5/5 [00:01<00:00,  3.20it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.930062, Acc: 70.450000


100%|██████████| 5/5 [00:01<00:00,  3.50it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.686618, Acc: 75.830000


100%|██████████| 5/5 [00:01<00:00,  3.19it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.531379, Acc: 81.700000


100%|██████████| 5/5 [00:01<00:00,  3.14it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.394364, Acc: 88.470000


100%|██████████| 5/5 [00:01<00:00,  3.31it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.402153, Acc: 88.340000


100%|██████████| 5/5 [00:01<00:00,  3.30it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.710794, Acc: 80.110000


100%|██████████| 5/5 [00:01<00:00,  3.17it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.687649, Acc: 76.850000


100%|██████████| 5/5 [00:01<00:00,  3.83it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.491099, Acc: 83.690000


100%|██████████| 5/5 [00:01<00:00,  3.31it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.486626, Acc: 84.980000


100%|██████████| 5/5 [00:01<00:00,  3.17it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.400435, Acc: 88.060000


100%|██████████| 5/5 [00:01<00:00,  3.56it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.981371, Acc: 64.910000


100%|██████████| 5/5 [00:01<00:00,  3.31it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.501614, Acc: 85.140000


100%|██████████| 5/5 [00:01<00:00,  3.76it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.507369, Acc: 84.950000


100%|██████████| 5/5 [00:01<00:00,  3.77it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.442327, Acc: 87.240000


100%|██████████| 5/5 [00:01<00:00,  4.02it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.498352, Acc: 83.200000


100%|██████████| 5/5 [00:01<00:00,  3.83it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.085404, Acc: 60.140000


100%|██████████| 5/5 [00:01<00:00,  3.63it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.494407, Acc: 85.280000


100%|██████████| 5/5 [00:01<00:00,  3.38it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.565742, Acc: 82.800000


100%|██████████| 5/5 [00:01<00:00,  3.72it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.442644, Acc: 87.000000


100%|██████████| 5/5 [00:01<00:00,  4.00it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.493379, Acc: 83.740000


100%|██████████| 5/5 [00:01<00:00,  3.84it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.086469, Acc: 60.030000


100%|██████████| 5/5 [00:01<00:00,  3.79it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.639120, Acc: 80.340000


100%|██████████| 5/5 [00:01<00:00,  3.75it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.472534, Acc: 86.230000


100%|██████████| 5/5 [00:01<00:00,  3.60it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.373616, Acc: 88.880000


100%|██████████| 5/5 [00:01<00:00,  3.45it/s]

Loss: 0.505951, Acc: 83.920000
Final Accs:  [88.34, 88.06, 83.2, 83.74000000000001, 83.91999999999999]  Average Final Acc:  85.452





### A-GEM

In [45]:
memory_capacity = 10240
task_memory_size = 2048
memory_sample_size = 64

In [52]:
from tqdm import tqdm
import torch.nn.functional as f
from torch.utils.data import TensorDataset, DataLoader

np.random.seed(42)
torch.manual_seed(42)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(42)

all_trn_f1_mean = np.array([])
all_val_f1_mean = np.array([])

task_perm_final_accs = []

# constant validation data across tasks
test_data = TensorDataset(x_te, y_te)
test_loader = DataLoader(test_data, shuffle=True, batch_size=batch_size)

class_probs = [
    [0.025] * 10,
    [0.025] * 10,
    [0.025] * 10,
    [0.025] * 10,
    [0.025] * 10,
]

a = 0.9
class_probs[0][0], class_probs[0][1] = a, a
class_probs[1][2], class_probs[1][3] = a, a
class_probs[2][4], class_probs[2][5] = a, a
class_probs[3][6], class_probs[3][7] = a, a
class_probs[4][8], class_probs[4][9] = a, a

for t in range(n_tasks):
    
    # initialize models
    model = Classifier(input_size, hidden_size, drop_prob, output_size).to(device)
    criterion = nn.CrossEntropyLoss()
    learner = AGEM(model, criterion, device=device,
                   memory_capacity=memory_capacity, memory_sample_sz=memory_sample_size)
    
    # task loop
    np.random.shuffle(class_probs)
    for T_x, T_y in skewed_split(x_tr, y_tr, class_probs):
        
        train_data = TensorDataset(T_x, T_y)
        train_loader = DataLoader(train_data, shuffle=True, batch_size=batch_size)

        learner.prepare(optimizer=torch.optim.Adam, lr=learning_rate)

        model.train()
        for ep in tqdm(range(num_epochs)):
            for inputs, labels in train_loader:
                if device.type == 'cuda':
                    inputs, labels = inputs.cuda(), labels.cuda()

                learner.run(inputs, labels)

        # remember a subset
        learner.remember(train_data, min_save_sz=task_memory_size)
                
        model.eval()
        val_loss = 0
        corrects = 0
        total = 0
        for inputs, labels in test_loader:
            if device.type == 'cuda':
                inputs, labels = inputs.cuda(), labels.cuda()

            out = model(inputs.float())
            preds = torch.argmax(f.softmax(out, dim=-1), dim=-1).cpu().numpy()

            tmp_val_loss = criterion(out, labels.long())
            val_loss += tmp_val_loss.item()

            corrects += sum(preds == labels.cpu().numpy())
            total += len(preds)

        print("Loss: {:.6f}, Acc: {:.6f}".format(val_loss/len(test_loader), (corrects/total)*100))

    task_perm_final_accs.append((corrects/total)*100) # save final accuracy in current task permutation
    
print("Final Accs: ", task_perm_final_accs, " Average Final Acc: ", np.array(task_perm_final_accs).mean())

100%|██████████| 5/5 [00:01<00:00,  3.77it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.933185, Acc: 70.450000


100%|██████████| 5/5 [00:02<00:00,  1.73it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.646325, Acc: 77.370000


100%|██████████| 5/5 [00:02<00:00,  1.92it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.481755, Acc: 84.030000


100%|██████████| 5/5 [00:02<00:00,  1.82it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.349098, Acc: 89.740000


100%|██████████| 5/5 [00:02<00:00,  1.87it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.374982, Acc: 89.140000


100%|██████████| 5/5 [00:01<00:00,  3.60it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.674914, Acc: 81.220000


100%|██████████| 5/5 [00:02<00:00,  1.95it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.708715, Acc: 75.180000


100%|██████████| 5/5 [00:02<00:00,  2.09it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.535044, Acc: 81.950000


100%|██████████| 5/5 [00:02<00:00,  1.86it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.431891, Acc: 86.990000


100%|██████████| 5/5 [00:03<00:00,  1.55it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.381842, Acc: 88.580000


100%|██████████| 5/5 [00:01<00:00,  3.21it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.064345, Acc: 61.600000


100%|██████████| 5/5 [00:02<00:00,  1.80it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.490051, Acc: 86.140000


100%|██████████| 5/5 [00:03<00:00,  1.57it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.535656, Acc: 83.430000


100%|██████████| 5/5 [00:03<00:00,  1.58it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.486655, Acc: 85.720000


100%|██████████| 5/5 [00:03<00:00,  1.74it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.473829, Acc: 83.890000


100%|██████████| 5/5 [00:01<00:00,  3.78it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.085968, Acc: 61.050000


100%|██████████| 5/5 [00:02<00:00,  1.91it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.494040, Acc: 85.460000


100%|██████████| 5/5 [00:02<00:00,  1.76it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.508179, Acc: 84.320000


100%|██████████| 5/5 [00:02<00:00,  1.91it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.457037, Acc: 86.970000


100%|██████████| 5/5 [00:02<00:00,  2.03it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.478863, Acc: 83.730000


100%|██████████| 5/5 [00:01<00:00,  3.30it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 1.035301, Acc: 64.340000


100%|██████████| 5/5 [00:02<00:00,  1.97it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.620146, Acc: 80.830000


100%|██████████| 5/5 [00:02<00:00,  1.88it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.493666, Acc: 85.640000


100%|██████████| 5/5 [00:02<00:00,  1.80it/s]
  0%|          | 0/5 [00:00<?, ?it/s]

Loss: 0.379954, Acc: 88.770000


100%|██████████| 5/5 [00:02<00:00,  1.90it/s]

Loss: 0.516470, Acc: 82.750000
Final Accs:  [89.14, 88.58, 83.89, 83.73, 82.75]  Average Final Acc:  85.61800000000001



