In [None]:
# implement https://arxiv.org/pdf/1705.08690.pdf on avalanche framework on permuted MNIST
# https://aahaanmaini.medium.com/mimicking-human-continual-learning-in-a-neural-network-c15e1ae11d70
#continual learning

## create GAN

In [None]:
%load_ext autoreload
%autoreload 2
from helper_func import *

In [None]:
#https://github.com/znxlwm/pytorch-MNIST-CelebA-GAN-DCGAN/blob/master/pytorch_MNIST_GAN.py
import matplotlib.pyplot as plt
from torch import nn
#import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable

# G(z)
class generator(nn.Module):
    # initializers
    def __init__(self, input_size=32, n_class = 10):
        super(generator, self).__init__()
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 1024)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.tanh(self.fc4(x))

        return x

class discriminator(nn.Module):
    # initializers
    def __init__(self, input_size=32, n_class=10):
        super(discriminator, self).__init__()
        self.fc1 = nn.Linear(input_size, 1024)
        self.fc2 = nn.Linear(self.fc1.out_features, 512)
        self.fc3 = nn.Linear(self.fc2.out_features, 256)
        self.fc4 = nn.Linear(self.fc3.out_features, n_class)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.fc1(input), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc2(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.leaky_relu(self.fc3(x), 0.2)
        x = F.dropout(x, 0.3)
        x = F.sigmoid(self.fc4(x))

        return x


## Neural Network Architecture

In [None]:
from torch.optim import SGD, Adam
from torch.nn import CrossEntropyLoss

#MNIST neural network with 2 hidden layers of 400 neurons each
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(28*28, 400)
        self.fc2 = nn.Linear(400, 400)
        self.fc3 = nn.Linear(400, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


## naive

In [None]:
model_naive = Net()
optimizer_naive = SGD(model_naive.parameters(), lr=0.001, momentum=0.9)
criterion_naive = CrossEntropyLoss()

naive_accuracies = []
for experience in range(5):
    train_dataset,test_dataset = get_datasets(experience, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=True)
    for epoch in range(0, 3):
        model_naive.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            #wrap in variable
            data, target = Variable(data), Variable(target)
            optimizer_naive.zero_grad()
            output = model_naive(data)
            loss = criterion_naive(output, target)
            loss.backward()
            optimizer_naive.step()
    
    model_naive.eval()
    accuracy = test_model(model_naive, test_loader)
    print('Experience: {} Accuracy: {:.2f}'.format(experience, accuracy))
    naive_accuracies.append(accuracy)

print(naive_accuracies)



### scholar with GAN

In [None]:
%reload_ext autoreload
%autoreload 2
from helper_func import *
from torch.utils.data import ConcatDataset

def get_new_generator_and_discriminator():
    lr = 0.0002
    G = generator(input_size=100, n_class=28*28)
    D = discriminator(input_size=28*28, n_class=1)

    G_optimizer = optim.Adam(G.parameters(), lr=lr)
    D_optimizer = optim.Adam(D.parameters(), lr=lr)

    BCE_loss = nn.BCELoss()
    return G, D, G_optimizer, D_optimizer, BCE_loss

# training parameters
batch_size = 128
#fixed_z_ = torch.randn((batch_size, 100))


#train CL
cl_accuracies = []
for experience in range(5):
    train_dataset, test_dataset = get_datasets(experience, transform=transform)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)
    if experience != 0:
        #create new data with GAN
        new_dataset = generate_data(G, model_cl, n_images = 60000*2/10, batch_size = batch_size) # half of the images are generated half are new
        #concatenate with old data
        X = torch.cat((train_dataset.data, new_dataset.data), 0)
        y = torch.cat((train_dataset.targets, new_dataset.targets), 0)
        #create new dataset
        concat_dataset = DS_from_tensors(X, y)
        train_loader = torch.utils.data.DataLoader(concat_dataset, batch_size=batch_size, shuffle=True)

    print('training GAN on experience {}'.format(experience))
    G, D, G_optimizer, D_optimizer, BCE_loss = get_new_generator_and_discriminator()
    G, D, D_losses, G_losses = train_GAN(G, D, train_loader, G_optimizer, D_optimizer, BCE_loss, train_epoch = 200)
    images = generate_images(G, n_images=25).numpy()
    plt.imshow(images.transpose((1, 2, 0)))
    plt.show()

    #CL train
    model_cl = Net()
    cl_optimizer = SGD(model_cl.parameters(), lr=0.001, momentum=0.9)
    cl_criterion = CrossEntropyLoss()
    for epoch in range(3):
        model_cl.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            #wrap data in Variable
            data, target = Variable(data), Variable(target)
            cl_optimizer.zero_grad()
            output = model_cl(data)
            loss = cl_criterion(output, target)
            loss.backward()
            cl_optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))
    model_cl.eval()
    accuracy = test_model(model_cl, test_loader)
    cl_accuracies.append(accuracy)

print('cl accuracies: {}'.format(cl_accuracies))

## plots

In [None]:

plt.plot(naive_accuracies, label='naive')
plt.plot(cl_accuracies, label='CL')

plt.xlabel('Task')

plt.ylim(0, 1.1)
plt.show()