<a href="https://colab.research.google.com/github/hursung1/grad_project/blob/master/MNIST_Continual_Learning_Final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
import torchvision.transforms as transforms

In [0]:
class PermutedMNISTDataLoader(torchvision.datasets.MNIST):
    
    def __init__(self, source='./mnist_data', train = True, shuffle_seed = None):
        super(PermutedMNISTDataLoader, self).__init__(source, train, download=True)
        
        self.train = train
        if self.train:
            self.permuted_train_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.train_data])
        else:
            self.permuted_test_data = torch.stack(
                [img.type(dtype=torch.float32).view(-1)[shuffle_seed] / 255.0
                    for img in self.test_data])
            
    def __getitem__(self, index):
        
        if self.train:
            input, label = self.permuted_train_data[index], self.train_labels[index]
        else:
            input, label = self.permuted_test_data[index], self.test_labels[index]
        
        return input, label

    def sample(self, size):
        return [img for img in self.permuted_train_data[random.sample(range(len(self), size))]]
    
    '''
    def __len__(self):
        if self.train:
            return self.train_data.size()
        else:
            return self.test_data.size()
    '''

In [11]:
batch_size = 64
num_tasks = 10

def permute_mnist():
    train_loader = {}
    test_loader = {}
    
    for i in range(num_tasks):
        shuffle_seed = np.arange(28*28)
        np.random.shuffle(shuffle_seed)
        train_loader[i] = torch.utils.data.DataLoader(
            PermutedMNISTDataLoader(train=True, shuffle_seed=shuffle_seed),
                batch_size=batch_size)
        
        test_loader[i] = torch.utils.data.DataLoader(
            PermutedMNISTDataLoader(train=False, shuffle_seed=shuffle_seed),
                batch_size=batch_size)
    
    return train_loader, test_loader

train_loader, test_loader = permute_mnist()



In [4]:
'''
batch_size = 64

#DataLoader: read batsh_size number of data from dataset 
train_loader = torch.utils.data.DataLoader(train_set, 
                                           batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_set,
                                        batch_size=batch_size, shuffle=False)
'''                                        

'\nbatch_size = 64\n\n#DataLoader: read batsh_size number of data from dataset \ntrain_loader = torch.utils.data.DataLoader(train_set, \n                                           batch_size=batch_size, shuffle=True)\ntest_loader = torch.utils.data.DataLoader(test_set,\n                                        batch_size=batch_size, shuffle=False)\n'

### Define Neural Net

In [0]:
class NeuralNet(nn.Module):
    def __init__(self):
        # Always start with inheriting torch.nn.Module
        # Ancestor class of all Neural Net module
        super(NeuralNet, self).__init__()
        # Linear: linear transformation
        fc1 = nn.Linear(28*28, 400)
        fc2 = nn.Linear(400, 400)
        fc3 = nn.Linear(400, 10)
  
        
        self.fc_module = nn.Sequential(
            fc1,
            nn.ReLU(),
            fc2,
            nn.ReLU(),
            fc3
        )
        
        if torch.cuda.is_available():
            self.fc_module = self.fc_module.cuda()

    def forward(self, x):
        return self.fc_module(x)

### Get Fisher Matrix

In [0]:
def fisher(net, data_loader, task):
    fisher_mat = []
    #start_time = time.time()       
    for i in range(task):
        #data = train_loader[task].dataset.__getitem__(task - i)[0]
        for n, data in enumerate(data_loader[i]):
            data_ = data[0]
            break

        #print(data.size())
        if torch.cuda.is_available():
            data_ = data_.cuda()

        params = {n : p for n, p in net.named_parameters() if p.requires_grad}
        fisher_mat_per_task = {}
        for n, p in deepcopy(params).items():
            p.data.zero_()
            fisher_mat_per_task[n] = p.data

        net.eval()
        for data in data_:
            net.zero_grad()
            output = net(data).view(1, -1)
            pred = output.max(1)[1].view(-1)
            loss = F.nll_loss(F.log_softmax(output, dim=1), pred)
            loss.backward()

            for n, p in net.named_parameters():
                fisher_mat_per_task[n].data += p.grad.data ** 2 / len(data_)

        fisher_mat.append({n : p for n, p in fisher_mat_per_task.items()})
        #print("Time: %.3f" %(time.time() - start_time))
    return fisher_mat

############################################################################                
    '''
    params = {n : p for n, p in net.named_parameters() if p.requires_grad}
    fisher_mat = {}
    for n, p in deepcopy(params).items():
        p.data.zero_()
        fisher_mat[n] = p.data

    net.eval()
    for data in input:
        net.zero_grad()
        output = net(data).view(1, -1)
        pred = output.max(1)[1].view(-1)
        loss = F.nll_loss(F.log_softmax(output, dim=1), pred)
        loss.backward()

        for n, p in net.named_parameters():
            fisher_mat[n].data += p.grad.data ** 2 / len(input)

    fisher_mat = {n : p for n, p in fisher_mat.items()}
    #print("Time: %.3f" %(time.time() - start_time))
    return fisher_mat
    '''

### Learning Function

In [0]:
def Continual_Learning(net, optimizer, num_tasks, reg_coef, learn_mode = 0):
    if learn_mode > 2 or learn_mode < 0:
        print("Learn mode Error\nplain: 0\tpenalty with L2 distance: 1\tpenalty with ewc: 2")
        return False

    criterion = nn.CrossEntropyLoss()
    num_epochs = 20
    sample_size = 100
    acc = {}
    params_per_tasks = []

    print("Task\tEpoch")
    for task in range(num_tasks):
        running_loss = 0.0
        
        # Get Fisher Matrix
        if len(params_per_tasks) != 0 and learn_mode == 2:
            fisher_mat = fisher(net, train_loader, task)
            '''
            for j in range(task):
                for i, data in enumerate(train_loader[task]):
                    inputs = data[0]
                    if torch.cuda.is_available():
                        inputs = inputs.cuda()
                    #print(inputs.size())
                    fisher_mat.append(fisher(net, inputs))
                    break
            '''
        # Train for each task
        for epoch in range(num_epochs):
            for i, data in enumerate(train_loader[task]):
                inputs, labels = data

                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                # gradient initiallize
                optimizer.zero_grad()

                # Compute forward-propagation
                outputs = net(inputs)

                # Compute Loss
                loss = criterion(outputs, labels)

                # Compute Loss & L2 distance
                if learn_mode != 0:
                    reg = 0
                    ind = 0
                    for params_past in params_per_tasks:
                        for n, p in net.named_parameters():
                            if torch.cuda.is_available():
                                params_past[n] = params_past[n].cuda()
                                
                            penalty = (params_past[n] - p)**2
                            # EWC: multiply fisher matrix
                            if learn_mode == 2:
                                penalty = fisher_mat[ind][n] * penalty

                            reg += torch.sum(penalty)
                        ind += 1
                    loss = loss + (reg_coef / 2) * reg

                #Do Back-propagation
                loss.backward()
                #Weight update
                optimizer.step()

                #cumulate loss
                running_loss += loss.data.item()

            if epoch % 5 == 4:
                print('[%d\t%d] AVG. loss: %.3f' % (task+1, epoch + 1, running_loss/(i*5)))
                running_loss = 0.0

        # Save parameters to use at next iteration: used to cal. penalty term      
        if learn_mode != 0:
            tp = {n : p for n, p in net.named_parameters() if p.requires_grad}
            params = {}
            for n, p in deepcopy(tp).items():
                params[n] = p.data
            params_per_tasks.append(params)

            
        # Test for each task after learning a task.
        each_task_acc = []
        for j in range(task+1):
            total = 0
            correct = 0
            for i, data in enumerate(test_loader[j]):
                inputs, labels = data
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                # forward propagation
                outputs = net(inputs)

                # torch.max: returns maximum value of a tensor
                _, predicted = torch.max(outputs.data, dim=1)
                total += labels.size(0)
                
                # Estimate accuracy of model
                correct += (predicted == labels).sum()
                                
            each_task_acc.append(100 * correct.cpu().numpy() / total)
            
        print(each_task_acc)
        each_task_acc = np.asarray(each_task_acc)
        print(each_task_acc)
        print('Average accuracy after training task %d: %d %%' % (task+1, np.mean(each_task_acc)))
        acc[task] = np.mean(each_task_acc)
            # For each input data, print the accuracy of the model
            #print('Accuracy of the network on the test images %d for task %d after training task %d: %d %%' 
            #                % (i, j, task+1, 100 * correct / total))
            
    return acc

In [30]:
learning_rate = 1e-3
num_task = 10

reg_coef_list = []
learning_acc = {}

# Learning mode set
#   if 0: plain SGD
#   if 1: penalty term - L2 distance
#   if 2: penalty term - EWC
#   else: Error 
learn_mode = 2
if learn_mode == 0:
    print("Learn Mode: Plain SGD")
    reg_coef_list.append(1)
elif learn_mode == 1:
    print("Learn Mode: Penalty term with L2 distance")
    for i in range(1, 11):
        reg_coef_list.append(i * 0.001)
elif learn_mode == 2:
    print("Learn Mode: Penalty term with EWC")
    for i in range(1, 11):
        reg_coef_list.append(i * 0.001)
else:
    print("Wrong value")

for reg_coef in reg_coef_list:
    net = NeuralNet()
    if torch.cuda.is_available():
        print("Use GPU")
        net.cuda()
        
    optimizer = optim.SGD(net.parameters(), lr=learning_rate)    
    print("Penalty term coefficient: ", reg_coef)
    learning_acc[reg_coef] = Continual_Learning(net, optimizer, num_task, reg_coef, learn_mode=learn_mode)

if(learning_acc):
    x = []
    y = []
    for key, value in learning_acc.items():
        x.append(key)
        y.append(value)
    plt.plot(x, y)
    plt.show()
else:
    print('헛짓거리함')

Learn Mode: Penalty term with EWC
Use GPU
Penalty term coefficient:  0.001
Task	Epoch




[1	5] AVG. loss: 2.124
[1	10] AVG. loss: 1.111
[1	15] AVG. loss: 0.578
[1	20] AVG. loss: 0.445
[89.11]
[89.11]
Average accuracy after training task 1: 89 %




[2	5] AVG. loss: 0.804
[2	10] AVG. loss: 0.437
[2	15] AVG. loss: 0.375
[2	20] AVG. loss: 0.342
[84.37, 90.98]
[84.37 90.98]
Average accuracy after training task 2: 87 %


KeyboardInterrupt: ignored

In [0]:
learn_mode = 2 
if learn_mode == 0:
    print("Learn Mode: Plain SGD")
    reg_coef_list.append(1)
elif learn_mode == 1:
    print("Learn Mode: Penalty term with L2 distance")
    for i in range(1, 11):
        reg_coef_list.append(i * 0.001)
elif learn_mode == 2:
    print("Learn Mode: Penalty term with EWC")
    for i in range(1, 11):
        reg_coef_list.append(i * 0.001)
else:
    print("Wrong value")

for reg_coef in reg_coef_list:
    net = NeuralNet()
    if torch.cuda.is_available():
        print("Use GPU")
        net.cuda()
        
    optimizer = optim.SGD(net.parameters(), lr=learning_rate)    
    print("Penalty term coefficient: ", reg_coef)
    learning_acc[reg_coef] = Continual_Learning(net, optimizer, num_task, reg_coef, learn_mode=learn_mode)


In [0]:
acc = learning_acc[1]
print(acc)
x = []
y = []
for key, value in acc.items():
    x.append(key)
    y.append(value)
plt.plot(x, y)
plt.show()