# Gradient Episodic Memory

In [0]:
!pip install quadprog
import numpy as np
import matplotlib.pyplot as plt
import quadprog
from copy import deepcopy 

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

### Define P-MNIST data loader

In [0]:
class PermutedMNISTDataLoader(torchvision.datasets.MNIST):
    
    def __init__(self, source='./drive/My Drive/projects/cont_learn/mnist_data', train = True, shuffle_seed = None):
        super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)
        
        self.train = train
        if self.train:
            self.permuted_train_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.train_data])
        else:
            self.permuted_test_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.test_data])
            
    def __getitem__(self, index):
        
        if self.train:
            input, label = self.permuted_train_data[index], self.train_labels[index]
        else:
            input, label = self.permuted_test_data[index], self.test_labels[index]
        
        return input, label

    def sample(self, size):
        return [img for img in self.permuted_train_data[random.sample(range(len(self), size))]]
    

### Set hyperparameters & get permuted MNIST

In [0]:
batch_size = 64
learning_rate = 1e-3
num_task = 10
num_epochs = 20
sample_size = 100

criterion = nn.CrossEntropyLoss()

def permute_mnist():
    train_loader = {}
    test_loader = {}
    
    for i in range(num_task):
        shuffle_seed = np.arange(28*28)
        np.random.shuffle(shuffle_seed)
        train_loader[i] = torch.utils.data.DataLoader(
            PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed),
                batch_size=batch_size)
        
        test_loader[i] = torch.utils.data.DataLoader(
            PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed),
                batch_size=batch_size)
    
    return train_loader, test_loader

train_loader, test_loader = permute_mnist()

### Define Neural Net

In [0]:
class NeuralNet(nn.Module):
    def __init__(self):
        # Always start with inheriting torch.nn.Module
        # Ancestor class of all Neural Net module
        super(NeuralNet, self).__init__()
        # Linear: linear transformation
        fc1 = nn.Linear(28*28, 100)
        fc2 = nn.Linear(100, 10)
  
        
        self.fc_module = nn.Sequential(
            fc1,
            nn.ReLU(),
            fc2
        )
        
        if torch.cuda.is_available():
            self.fc_module = self.fc_module.cuda()

    def forward(self, x):
        return self.fc_module(x)

### GEM Learning Function

In [0]:
def gem_eval(continuum):
    r = torch.tensor(np.zeros((t, )))
    for k in range(t):
        r_k = 0
        for i, data in enumerate(continuum):
            r_k += accuracy()

        r_k /= len(continuum[k])

    return r

In [0]:
def gem_train(net, optimizer, num_tasks, log_file):
    ep_mem = {}
    R = torch.tensor(np.zeros((t, t)))
    for t in Tasks:
        for x, y in Contimuum_train(t):
            M_t.append(x, y)
            g = gradient(Loss(f(x, t), y))
            for k in t:
                g_k = gradient(Loss(f, M_k))
            g_tilda = Project(g, g_1, g_2, ... , g_k)
            theta -= alpha*g_tilda

        R_t = gem_eval(f, Continuum_test)

    return f, R

In [0]:
class GEMLearning():
    def __init__(self, **kwargs):
        self.net = kwargs['net']
        self.optim = kwargs['optim']
        self.num_tasks = kwargs['num_tasks']
        self.log_file = open(kwargs['logfile_name'], "w")

        self.mem_size = kwargs['mem_size']
        self.num_input = kwargs['num_input']

        # Episodic Memory
        self.ep_mem = torch.FloatTensor(self.num_tasks, self.mem_size, self.num_input)
        self.labels = torch.IntTensor(self.num_tasks, self.mem_size)
        if torch.cuda.is_available():
            self.ep_mem = self.ep_mem.cuda()
            self.labels = self.labels.cuda()

        # Matrix for storing accuracy
        self.R = torch.FloatTensor(self.num_tasks, self.num_tasks)
        if torch.cuda.is_available():
            self.R = self.R.cuda()

        # Matrix for storing gradient for each memory
        self.g_mem = None

    # Do traning procedure one-by-ome
    def train(self, continuum):
        x, t, y = continuum
        self.optim.zero_grad()

        pred = self.net(x)
        loss = criterion(pred, y)
        loss.backward()
        grad = 
        for param in self.net.parameters():
            grad = torch.stack((grad, param.grad))


        G = None

        # Calculate all g_k's
        for i in range(0, t):
            pred_ = self.net(self.ep_mem[i])
            label_ = self.labels[i]
            loss_ = criterion(pred_, label_)

            loss_ = torch.sum(loss_) / self.mem_size
            if G is None:
                G = loss_
            else:
                G = torch.stack((G, loss_))

        g_tilda = self.projection(loss, G)
        # copy g_tilda to parameter's gradient
        pass

        # Update weights
        self.optim.step()
        
    def eval(self, continuum):
        pass

    def projection(self):
        pass

    def grad_mem_store(self):
        if self.g_mem is None:
            
            pass
        else:
            pass

### Continual Learnig with GEM

In [0]:
net = NeuralNet()
optim = optim.SGD(net.parameters(), lr=learning_rate)
logfile_name = "logfile_training_gem_%d_%d_%d_%d_%d.txt" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)

gem = GEMLearning(net, optim, num_tasks, )