# Gradient Episodic Memory

In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torchvision
import quadprog

### Define P-MNIST data loader

In [2]:
class PermutedMNISTDataLoader(torchvision.datasets.MNIST):
    def __init__(self, source='data/mnist_data', train = True, shuffle_seed = None):
        super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)
        
        self.train = train
        self.num_data = 0
        
        if self.train:
            self.permuted_train_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.train_data])
            self.num_data = self.permuted_train_data.shape[0]
            
        else:
            self.permuted_test_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.test_data])
            self.num_data = self.permuted_test_data.shape[0]
            
            
    def __getitem__(self, index):
        
        if self.train:
            input, label = self.permuted_train_data[index], self.train_labels[index]
        else:
            input, label = self.permuted_test_data[index], self.test_labels[index]
        
        return input, label

    
    def getNumData(self):
        return self.num_data

### Set hyperparameters & get permuted MNIST

In [6]:
batch_size = 64
learning_rate = 1e-2
num_task = 10
criterion = torch.nn.CrossEntropyLoss()
cuda_available = False
if torch.cuda.is_available():
    cuda_available = True

In [3]:
def permute_mnist():
    train_loader = {}
    test_loader = {}
    
    train_data_num = 0
    test_data_num = 0
    
    for i in range(num_task):
        shuffle_seed = np.arange(28*28)
        np.random.shuffle(shuffle_seed)
        
        train_PMNIST_DataLoader = PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed)
        test_PMNIST_DataLoader = PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed)
        
        train_data_num += train_PMNIST_DataLoader.getNumData()
        test_data_num += test_PMNIST_DataLoader.getNumData()
        
        train_loader[i] = torch.utils.data.DataLoader(
                train_PMNIST_DataLoader,
                batch_size=batch_size)
        
        test_loader[i] = torch.utils.data.DataLoader(
                test_PMNIST_DataLoader,
                batch_size=batch_size)
    
    return train_loader, test_loader, int(train_data_num/num_task), int(test_data_num/num_task)

train_loader, test_loader, train_data_num, test_data_num = permute_mnist()

0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/train-images-idx3-ubyte.gz


9920512it [00:02, 3314130.95it/s]                            


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/train-images-idx3-ubyte.gz


0it [00:00, ?it/s]

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/train-labels-idx1-ubyte.gz


32768it [00:00, 48141.24it/s]                           
0it [00:00, ?it/s]

Extracting data/mnist_data/PermutedMNISTDataLoader/raw/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/t10k-images-idx3-ubyte.gz


1654784it [00:01, 901746.75it/s]                            
0it [00:00, ?it/s]

Extracting data/mnist_data/PermutedMNISTDataLoader/raw/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/mnist_data/PermutedMNISTDataLoader/raw/t10k-labels-idx1-ubyte.gz


8192it [00:00, 18292.85it/s]            


Extracting data/mnist_data/PermutedMNISTDataLoader/raw/t10k-labels-idx1-ubyte.gz
Processing...
Done!




### Define Neural Net

In [4]:
class NeuralNet(torch.nn.Module):
    def __init__(self):
        # Always start with inheriting torch.nn.Module
        # Ancestor class of all Neural Net module
        super(NeuralNet, self).__init__()
        # Linear: linear transformation
        fc1 = torch.nn.Linear(28*28, 100)
        fc2 = torch.nn.Linear(100, 100)
        fc3 = torch.nn.Linear(100, 100)
  
        
        self.fc_module = torch.nn.Sequential(
            fc1,
            torch.nn.ReLU(),
            fc2,
            torch.nn.ReLU(),
            fc3
        )
        
        if torch.cuda.is_available():
            self.fc_module = self.fc_module.cuda()

    def forward(self, x):
        return self.fc_module(x)

### Gradient Episodic Memry

In [None]:
class GEMLearning(torch.nn.Module):
    def __init__(self, **kwargs):
        super(GEMLearning, self).__init__()
        self.net = kwargs['net']
        self.tasks = kwargs['tasks']
        self.optim = kwargs['optim']
        self.criterion = kwargs['criterion']
        self.mem_size = kwargs['mem_size']
        #self.num_input = kwargs['num_input']
        self.traindata_len = kwargs['traindata_len']
        self.testdata_len = kwargs['testdata_len']
        self.batch_size = kwargs['batch_size']
        
        # Initiallize Episodic Memory
        self.ep_mem = torch.FloatTensor(self.tasks, self.mem_size, 28*28)
        self.ep_labels = torch.LongTensor(self.tasks, self.mem_size)
        if cuda_available:
            self.ep_mem = self.ep_mem.cuda()
            self.ep_labels = self.ep_labels.cuda()

        # Save each parameters' number of elements(numels)
        self.grad_numels = []
        for params in self.parameters():
            self.grad_numels.append(params.data.numel())
        print(self.grad_numels)
        # Make matrix for gradient w.r.t. past tasks
        self.G = torch.zeros((sum(self.grad_numels), self.tasks))
        if cuda_available:
            self.G = self.G.cuda()

        # Make matrix for accuracy w.r.t. past tasks
        self.R = torch.zeros((self.tasks, self.tasks))
        if cuda_available:
            self.R = self.R.cuda()

        #msg = "Optimizer: {}\nCriterion: {}\nEpisodic Memory Size: {}\n"%(self.optim, self.criterion, self.mem_size)
        print(self.optim)
        print(self.criterion)
        print("Memory size: ", self.mem_size)
        #self.log_file.write(msg)
        
    def train(self, data_loader, task):
        self.cur_task = task
        running_loss = 0.0
        input_stack = torch.zeros((self.traindata_len, 28*28))
        label_stack = torch.zeros((self.traindata_len))
        if cuda_available:
            input_stack = input_stack.cuda()
            label_stack = label_stack.cuda()
        
        
        for i, data in enumerate(data_loader):
            #print(data)
            x, y = data
            #input_stack = np.vstack((input_stack, x))
            #label_stack = np.hstack((label_stack, y))
            input_stack[i*self.batch_size: (i+1)*self.batch_size] = x.clone()
            label_stack[i*self.batch_size: (i+1)*self.batch_size] = y.clone()
            if cuda_available:
                x = x.cuda()
                y = y.cuda()
                
            
            if self.cur_task > 0:
                #pdb.set_trace()
                # Compute gradient w.r.t. past tasks by using episodic memory
                for k in range(0, self.cur_task):
                    self.zero_grad()
                    pred_ = self.net(self.ep_mem[k])
                    pred_[:, : k * 10] = -10e10
                    pred_[:, (k+1) * 10:] = -10e10
                    
                    pred_ = pred_[:, k*10: (k+1)*10]
                    
                    label_ = self.ep_labels[k]
                    loss_ = self.criterion(pred_, label_)
                    loss_.backward()
        
                    # Copy parameters into Matrix "G"
                    j = 0
                    for params in self.parameters():
                        if params is not None:
                            if j == 0:
                                stpt = 0
                            else:
                                stpt = sum(self.grad_numels[:j])
            
                            endpt = sum(self.grad_numels[:j+1])
                            self.G[stpt:endpt, k].data.copy_(params.grad.view(self.G[stpt:endpt, k].data.size()))
                            j += 1
                    
            self.zero_grad()
            
            # Compute gradient w.r.t. current continuum
            pred = self.net(x)#[:, self.cur_task * 10 : (self.cur_task + 1) * 10]
            pred[:, : self.cur_task * 10] = -10e10
            pred[:, (self.cur_task+1) * 10:] = -10e10
                    
            pred = pred[:, self.cur_task*10: (self.cur_task+1)*10]
            loss = self.criterion(pred, y)
            loss.backward()
            running_loss += loss.detach().item()
            
            if i % 100 == 99:
                msg = '[%d\t%d] AVG. loss: %.3f\n'% (task+1, i+1, running_loss/(i*5))
                print(msg)
                #self.log_file.write(msg)
                running_loss = 0.0
            
            if self.cur_task > 0:
                grad = []
                for param in self.net.parameters():
                    grad.append(param.detach().view(-1))
                
                grad = torch.cat(grad)
                self.G[:, self.cur_task].data.copy_(grad)
                
                # Solve Quadratic Problem 
                prod = torch.mm(self.G[:, self.cur_task].unsqueeze(0), self.G[:, :self.cur_task])
                if(prod < 0).sum() != 0: # There are some violations: do projection
                    mem_grad_np = self.G[:, :self.cur_task+1].cpu().t().double().numpy()
                    curtask_grad_np = self.G[:, self.cur_task].unsqueeze(1).cpu().contiguous().view(-1).double().numpy()
                    
                    t = mem_grad_np.shape[0]
                    P = np.dot(mem_grad_np, np.transpose(mem_grad_np))
                    P = 0.5 * (P + np.transpose(P)) + np.eye(t) * 1e-3#eps
                    q = np.dot(mem_grad_np, curtask_grad_np) * (-1)
                    G = np.eye(t)
                    h = np.zeros(t) + 0.5 #margin: hyperparameter 
                    v = quadprog.solve_qp(P, q, G, h)[0]
                    x = np.dot(v, mem_grad_np) + curtask_grad_np
                    #print(torch.Tensor(x).shape)
                    #print(self.G[:, self.cur_task].shape)
                    self.G[:, self.cur_task].copy_(torch.Tensor(x))
                    #grad = torch.Tensor(x).view(-1).detach().clone()
    
                    # Copy gradients into params
                    j = 0
                    for params in self.parameters():
                        if params is not None:
                            if j == 0:
                                stpt = 0
                            else:
                                stpt = sum(self.grad_numels[:j])
        
                            endpt = sum(self.grad_numels[:j+1])
                            
                            copy_object_grad = self.G[stpt:endpt, self.cur_task].contiguous().view(params.grad.data.size())
                            params.grad.data.copy_(copy_object_grad)
                            j += 1
            
            self.optim.step()
            
        perm = torch.randperm(self.traindata_len)
        perm = perm[:self.mem_size]
        self.ep_mem[self.cur_task] = input_stack[perm].detach().clone().float()
        self.ep_labels[self.cur_task] = label_stack[perm].detach().clone()

    def eval(self, data_loader, task):
        total = 0
        correct = 0
        self.net.eval()
        for i, data in enumerate(data_loader):
            x, y = data
            if cuda_available:
                x = x.cuda()
                y = y.cuda()
                
            output = self.net(x)[:, task * 10: (task+1) * 10]
            _, predicted = torch.max(output, dim=1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    
            self.R[self.cur_task][task] = 100 * correct / total

### Continual Learnig with GEM

In [None]:
net = NeuralNet()
optim = torch.optim.SGD(net.parameters(), lr=learning_rate)
memsize_list = [100, 300, 1000, 3000, 10000]
#logfile_name = "logfile_training_gem_%d_%d_%d_%d_%d.txt" % (dt.year, dt.month, dt.day, dt.hour, dt.minute)

for mem_size in memsize_list:
    gem = GEMLearning(net = net,
                          tasks = num_task,
                          optim = optim,
                          criterion = criterion,
                          mem_size = mem_size,
                          #num_input = ,
                          traindata_len = train_data_num,
                          testdata_len = test_data_num,
                          batch_size = batch_size)
    
    for i in range(num_task):
        gem.train(train_loader[i], i)
        
        for j in range(i+1):
            gem.eval(test_loader[j], j)
            
        print(gem.R)

[78400, 100, 10000, 100, 10000, 100]
SGD (
Parameter Group 0
    dampening: 0
    lr: 0.01
    momentum: 0
    nesterov: False
    weight_decay: 0
)
CrossEntropyLoss()
Memory size:  100
[1	100] AVG. loss: 0.461

[1	200] AVG. loss: 0.225

[1	300] AVG. loss: 0.146

[1	400] AVG. loss: 0.102

[1	500] AVG. loss: 0.073

[1	600] AVG. loss: 0.051

[1	700] AVG. loss: 0.036

[1	800] AVG. loss: 0.026

[1	900] AVG. loss: 0.019

tensor([[82.4700,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,