# Gradient Episodic Memory

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision

import pyfiles.GEM as GEM

### Define P-MNIST data loader

In [2]:
class PermutedMNISTDataLoader(torchvision.datasets.MNIST):
    def __init__(self, source='data/mnist_data', train = True, shuffle_seed = None):
        super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)
        
        self.train = train
        self.num_data = 0
        
        if self.train:
            self.permuted_train_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.train_data])
            self.num_data = self.permuted_train_data.shape[0]
            
        else:
            self.permuted_test_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.test_data])
            self.num_data = self.permuted_test_data.shape[0]
            
            
    def __getitem__(self, index):
        
        if self.train:
            input, label = self.permuted_train_data[index], self.train_labels[index]
        else:
            input, label = self.permuted_test_data[index], self.test_labels[index]
        
        return input, label

    
    def getNumData(self):
        return self.num_data

### Set hyperparameters & get permuted MNIST

In [6]:
batch_size = 64
learning_rate = 1e-2
num_task = 10
criterion = torch.nn.CrossEntropyLoss()
cuda_available = False
if torch.cuda.is_available():
    cuda_available = True

In [3]:
def permute_mnist():
    train_loader = {}
    test_loader = {}
    
    train_data_num = 0
    test_data_num = 0
    
    for i in range(num_task):
        shuffle_seed = np.arange(28*28)
        np.random.shuffle(shuffle_seed)
        
        train_PMNIST_DataLoader = PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed)
        test_PMNIST_DataLoader = PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed)
        
        train_data_num += train_PMNIST_DataLoader.getNumData()
        test_data_num += test_PMNIST_DataLoader.getNumData()
        
        train_loader[i] = torch.utils.data.DataLoader(
                train_PMNIST_DataLoader,
                batch_size=batch_size)
        
        test_loader[i] = torch.utils.data.DataLoader(
                test_PMNIST_DataLoader,
                batch_size=batch_size)
    
    return train_loader, test_loader, int(train_data_num/num_task), int(test_data_num/num_task)

train_loader, test_loader, train_data_num, test_data_num = permute_mnist()

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/train-images-idx3-ubyte.gz


9920512it [00:02, 3314130.95it/s]                            


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 48141.24it/s]                           
0it [00:00, ?it/s]

Extracting data/mnist_data/PermutedMNISTDataLoader/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/t10k-images-idx3-ubyte.gz


1654784it [00:01, 901746.75it/s]                            
0it [00:00, ?it/s]

Extracting data/mnist_data/PermutedMNISTDataLoader/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 18292.85it/s]            


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!




### Define Neural Net

In [4]:
class NeuralNet(torch.nn.Module):
    def __init__(self):
        # Always start with inheriting torch.nn.Module
        # Ancestor class of all Neural Net module
        super(NeuralNet, self).__init__()
        # Linear: linear transformation
        fc1 = torch.nn.Linear(28*28, 100)
        fc2 = torch.nn.Linear(100, 100)
        fc3 = torch.nn.Linear(100, 100)
  
        
        self.fc_module = torch.nn.Sequential(
            fc1,
            torch.nn.ReLU(),
            fc2,
            torch.nn.ReLU(),
            fc3
        )
        
        if torch.cuda.is_available():
            self.fc_module = self.fc_module.cuda()

    def forward(self, x):
        return self.fc_module(x)

### Continual Learnig with GEM

In [None]:
net = NeuralNet()
optim = torch.optim.SGD(net.parameters(), lr=learning_rate)
memsize_list = [100, 300, 1000, 3000, 10000]
#logfile_name = "logfile_training_gem_%d_%d_%d_%d_%d.txt" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)

for mem_size in memsize_list:
    gem = GEM.GEMLearning(net = net,
                          tasks = num_task,
                          optim = optim,
                          criterion = criterion,
                          mem_size = mem_size,
                          #num_input = ,
                          traindata_len = train_data_num,
                          testdata_len = test_data_num,
                          batch_size = batch_size)
    
    for i in range(num_task):
        gem.train(train_loader[i], i)
        
        for j in range(i+1):
            gem.eval(test_loader[j], j)
            
        print(gem.R)

[78400, 100, 10000, 100, 10000, 100]
SGD (
Parameter Group 0
    dampening: 0
    lr: 0.01
    momentum: 0
    nesterov: False
    weight_decay: 0
)
CrossEntropyLoss()
Memory size:  100
[1	100] AVG. loss: 0.461

[1	200] AVG. loss: 0.225

[1	300] AVG. loss: 0.146

[1	400] AVG. loss: 0.102

[1	500] AVG. loss: 0.073

[1	600] AVG. loss: 0.051

[1	700] AVG. loss: 0.036

[1	800] AVG. loss: 0.026

[1	900] AVG. loss: 0.019

tensor([[82.4700,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,