# Gradient Episodic Memory

In [0]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision

import pyfiles.GEMLearning as GEMLearning

### Define P-MNIST data loader

In [0]:
class PermutedMNISTDataLoader(torchvision.datasets.MNIST):
    def __init__(self, source='data/mnist_data', train = True, shuffle_seed = None):
        super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)
        
        self.train = train
        self.num_data = 0
        
        if self.train:
            self.permuted_train_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.train_data])
            self.num_data = self.permuted_train_data.shape[0]
            
        else:
            self.permuted_test_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.test_data])
            self.num_data = self.permuted_test_data.shape[0]
            
            
    def __getitem__(self, index):
        
        if self.train:
            input, label = self.permuted_train_data[index], self.train_labels[index]
        else:
            input, label = self.permuted_test_data[index], self.test_labels[index]
        
        return input, label

    
    def getNumData(self):
        return self.num_data

### Set hyperparameters & get permuted MNIST

In [0]:
batch_size = 64
learning_rate = 1e-3
num_task = 10
criterion = torch.nn.CrossEntropyLoss()
cuda_available = False
if torch.cuda.is_available():
    cuda_available = True

In [0]:
def permute_mnist():
    train_loader = {}
    test_loader = {}
    
    train_data_num = 0
    test_data_num = 0
    
    for i in range(num_task):
        shuffle_seed = np.arange(28*28)
        np.random.shuffle(shuffle_seed)
        
        train_PMNIST_DataLoader = PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed)
        test_PMNIST_DataLoader = PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed)
        
        train_data_num += train_PMNIST_DataLoader.getNumData()
        test_data_num += test_PMNIST_DataLoader.getNumData()
        
        train_loader[i] = torch.utils.data.DataLoader(
                train_PMNIST_DataLoader,
                batch_size=batch_size)
        
        test_loader[i] = torch.utils.data.DataLoader(
                test_PMNIST_DataLoader,
                batch_size=batch_size)
    
    return train_loader, test_loader, int(train_data_num/num_task), int(test_data_num/num_task)

train_loader, test_loader, train_data_num, test_data_num = permute_mnist()

### Define Neural Net

In [0]:
class NeuralNet(torch.nn.Module):
    def __init__(self):
        # Always start with inheriting torch.nn.Module
        # Ancestor class of all Neural Net module
        super(NeuralNet, self).__init__()
        # Linear: linear transformation
        fc1 = torch.nn.Linear(28*28, 100)
        fc2 = torch.nn.Linear(100, 100)
        fc3 = torch.nn.Linear(100, 100)
  
        
        self.fc_module = torch.nn.Sequential(
            fc1,
            torch.nn.ReLU(),
            fc2,
            torch.nn.ReLU(),
            fc3
        )
        
        if torch.cuda.is_available():
            self.fc_module = self.fc_module.cuda()

    def forward(self, x):
        return self.fc_module(x)

### Continual Learnig with GEM

In [0]:
memsize_list = [100, 300, 1000, 3000]

for mem_size in memsize_list:
    net = NeuralNet()
    optim = torch.optim.SGD(net.parameters(), lr=0.1)
    gem = GEMLearning(net = net,
                          tasks = num_task,
                          optim = optim,
                          criterion = criterion,
                          mem_size = mem_size,
                          traindata_len = train_data_num,
                          testdata_len = test_data_num,
                          batch_size = batch_size,
                      margin = 0.5,
                      eps = 0.001)
    
    for i in range(num_task):
        gem.train(train_loader[i], i)
        
        for j in range(i+1):
            gem.eval(test_loader[j], j)
            
        print(gem.R)