# Deep Generative Replay for Continual Learning on MNIST
### Assignment for the **Deep Generative View on Continual Learning** course.

Student Names: __Andrea Giuseppe Di Francesco__ and __Farooq Ahmad Wani__

*The following code is inspired from the notebooks viewed during the course's lessons, and the openly available [notebook on GAN for MNIST](https://github.com/lyeoni/pytorch-mnist-GAN/blob/master/pytorch-mnist-GAN.ipynb)*.

In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision
import numpy as np
from torch.utils.data import Subset, TensorDataset, DataLoader
from random import shuffle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import random
import numpy as np
import os
from copy import deepcopy
import matplotlib.pyplot as plt
import psutil

if not torch.cuda.is_available():
    raise SystemError("GPU device not found, selection Runtime -> Change runtime type")

### Arguments

In [None]:

args = {
    'lr': 1e-2,                   
    'lr_generator': 0.0002,            
    'lr_discriminator': 0.00002,     
    'bs': 128,                   
    'epochs': 20,                 
    'num_tasks': 5,         
    'epochs_gan': 20,            
    'dataset': "MNIST",          
    'num_classes': 10,          
    'in_size': 28,               
    'n_channels': 1,             
    'hidden_size': 50,         
    'g_input_dim': 100,           
    'g_output_dim': 784,         
    'ratio': 0.7
}


### Data

In [None]:

def get_dataset(dataroot, dataset):
    if dataset == 'MNIST':
        mean, std = (0.1307), (0.3081)
    elif dataset == 'CIFAR10':
        mean, std = (0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)

    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=mean, std=std)])

    train_dataset = torchvision.datasets.__dict__[dataset](
        root=dataroot,
        train=True,
        download=True,
        transform=transform
    )

    val_dataset = torchvision.datasets.__dict__[dataset](
        root=dataroot,
        train=False,
        download=True,
        transform=transform
    )

    return train_dataset, val_dataset


def split_dataset(dataset, tasks_split):
    split_dataset = {}
    for e, current_classes in tasks_split.items():
        task_indices = np.isin(np.array(dataset.targets), current_classes)
        split_dataset[e] = Subset(dataset, np.where(task_indices)[0])
    return split_dataset

def set_random_seed(seed):
    random.seed(seed)

    np.random.seed(seed)

    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"





### Metrics & plotting

In [None]:
def dict2array(acc):
    num_tasks = len(acc)
    first_task = list(acc.keys())[0]
    sequence_length = len(acc[first_task]) if isinstance(acc[first_task], list) else num_tasks
    acc_array = np.zeros((num_tasks, sequence_length))
    for task, val in acc.items():
        acc_array[int(task), :] = val
    return acc_array


def plot_accuracy_matrix(array):
    num_tasks = array.shape[1]
    array = np.round(array, 2)
    fig, ax = plt.subplots()
    ax.imshow(array, vmin=np.min(array), vmax=np.max(array))
    for i in range(len(array)):
        for j in range(array.shape[1]):
            ax.text(j,i, array[i,j], va='center', ha='center', c='w', fontsize=15)
    ax.set_yticks(np.arange(num_tasks))
    ax.set_ylabel('Number of tasks')
    ax.set_xticks(np.arange(num_tasks))
    ax.set_xlabel('Tasks finished')
    ax.set_title(f"ACC: {np.mean(array[:, -1]):.3f} -- std {np.std(np.mean(array[:, -1])):.3f}")
    plt.show()


def plot_acc_over_time(array):
    fig, ax = plt.subplots()
    for e, acc in enumerate(array):
        ax.plot(acc, label=e)
    plt.legend()
    plt.show()


def compute_average_accuracy(array):
    num_tasks = len(array)
    avg_acc = np.sum(array[:, -1], axis=0)/num_tasks
    return avg_acc


def compute_backward_transfer(array):
    num_tasks = len(array)
    diag = np.diag(array)[:-1] # Note, we do not compute backward transfer for the last task!
    end_acc = array[:-1, -1]
    bwt = np.sum(end_acc - diag)/(num_tasks - 1)
    return bwt


def compute_forward_transfer(array, b):
    num_tasks = len(array)
    sub_diag = np.diag(array, k=-1) # Note, we do not compute forward transfer for the first task!
    fwt = np.sum(sub_diag - b[1:])/(num_tasks - 1)
    return fwt

def print_memory_usage():
    # RAM usage
    ram_usage = psutil.virtual_memory().used / (1024 ** 3)  # Convert bytes to GB
    print(f"RAM Usage: {ram_usage:.2f} GB")
    
    # GPU usage
    if torch.cuda.is_available():
        gpu_memory_allocated = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert bytes to GB
        gpu_memory_reserved = torch.cuda.memory_reserved() / (1024 ** 3)    # Convert bytes to GB
        print(f"GPU Memory Allocated: {gpu_memory_allocated:.2f} GB")
        print(f"GPU Memory Reserved: {gpu_memory_reserved:.2f} GB")
    else:
        print("CUDA is not available.")

        
def get_n_params(model):
    pp=0
    for p in list(model.parameters()):
        nn=1
        for s in list(p.size()):
            nn = nn*s
        pp += nn
    return pp


In [None]:



def plot_reconstructed_images(tensor, labels):

    
    if len(tensor.shape) == 2 and tensor.shape[1] ==  args['in_size']**2:
     
        images = tensor.reshape(-1, 1, args['in_size'], args['in_size']).detach().cpu().numpy()
    elif len(tensor.shape) == 4 and tensor.shape[1:] == (1, args['in_size'], args['in_size']):
       
        images = tensor.detach().cpu().numpy()
    else:
        raise ValueError("Unexpected image shape")

    num_images = len(labels)

    fig, axes = plt.subplots(1, num_images, figsize=(num_images * 2, 2))

    for i in range(num_images):
        if num_images == 1:
            ax = axes
        else:
            ax = axes[i]
        ax.imshow(images[i, 0], cmap='gray')
        ax.set_title(f"Label: {labels[i]}")
        ax.axis('off')
    plt.show()


## Class incremental model

In [None]:

class Agent:
    def __init__(self, args, train_datasets, val_datasets):
        self.args = args
        self.dev = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.current_scholar = Scholar(self.args).to(self.dev)
        self.discriminator = Discriminator(img_shape=(args['in_size'], args['in_size'], args['n_channels'])).to(self.dev)
        self.criterion = torch.nn.CrossEntropyLoss()
        self.adversarial_loss = torch.nn.BCELoss()
        self.reset_acc()
        self.train_datasets = train_datasets
        self.val_datasets = val_datasets
        self.r = self.args['ratio']

        print(f"Total number of parameters MLP: {get_n_params(self.current_scholar.solver)}")
        tot_params = get_n_params(self.current_scholar.solver) + get_n_params(self.current_scholar.generator) + get_n_params(self.discriminator)
        print(f"Total number of parameters: {tot_params}")


    def reset_acc(self):
        self.acc = {key: [] for key in self.args['task_names']}
        self.acc_end = {key: [] for key in self.args['task_names']}



    def train_solver(self, X, y):

        # print(f"Training Solver in task n°{task_n}.......\n")

        output = self.current_scholar.solver(X)
        loss = self.criterion(output, y)

        return loss, output

    def use_generator(self, num_samples, y, gen_images = False):
      z = torch.randn(num_samples, self.args['g_input_dim']).to(self.dev)

      gen_input = self.old_scholar.generator(z)

      output = self.current_scholar.solver(gen_input) # We train the model on old samples.

      if gen_images:

        unique_labels, indices = torch.unique(y, return_inverse=True)
        sample_indices = [indices.tolist().index(i) for i in range(len(unique_labels))]
        tensor_to_plot = gen_input[sample_indices]
        if gen_input[sample_indices].shape[1] == self.args['in_size']:
            tensor_to_plot = torch.permute(gen_input[sample_indices], (0, 3, 1, 2))
        

        plot_reconstructed_images(tensor_to_plot, list(unique_labels))

      y_ = F.softmax(self.old_scholar.solver(gen_input), dim = -1)
      y_ = torch.argmax(y_, dim = -1)
      
      
      loss = self.criterion(output, y_)
      return loss
      
    def train_generator(self, task, task_loader):

        optimizer_G = torch.optim.Adam(self.current_scholar.generator.parameters(), lr=self.args['lr_generator'])
        optimizer_D = torch.optim.Adam(self.discriminator.parameters(), lr=self.args['lr_discriminator'])

        for epoch in range(self.args['epochs_gan']):
            for i, (imgs, _) in enumerate(task_loader):
                batch_size = imgs.size(0)

                valid = torch.ones(batch_size, 1, requires_grad=False)
                fake = torch.zeros(batch_size, 1, requires_grad=False)
                if torch.cuda.is_available():
                    valid, fake = valid.cuda(), fake.cuda()

                real_imgs = imgs
                if torch.cuda.is_available():
                    real_imgs = real_imgs.cuda()

                optimizer_G.zero_grad()
                z = torch.randn(batch_size, self.args['g_input_dim'])
                if torch.cuda.is_available():
                    z = z.cuda()
                gen_imgs = self.current_scholar.generator(z)
                g_loss = self.adversarial_loss(self.discriminator(gen_imgs), valid)
                g_loss.backward()
                optimizer_G.step()

                optimizer_D.zero_grad()
                real_loss = self.adversarial_loss(self.discriminator(real_imgs), valid)
                fake_loss = self.adversarial_loss(self.discriminator(gen_imgs.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2
                d_loss.backward()
                optimizer_D.step()

            print(f"Task {task} | Epoch {epoch} | D loss: {d_loss.item()} | G loss: {g_loss.item()}")
        



    def train(self):
        for i, (task, data) in enumerate(self.train_datasets.items()):
            print(task)
            set_random_seed(42)
            print(f"Solving task n°{i+1}")
            self.current_scholar = Scholar(self.args)
            if torch.cuda.is_available():
                self.current_scholar.solver.cuda()
                self.current_scholar.generator.cuda()

            optimizer = torch.optim.Adam(self.current_scholar.solver.parameters(), lr=self.args['lr'])

            loader = torch.utils.data.DataLoader(data, batch_size=self.args['bs'], shuffle=True)

            gen_images = True
            self.train_generator(i, loader)
            for epoch in range(self.args['epochs']):

              total = 0
              correct = 0
              loss_current_data = 0
              loss_past_data = 0
              max_num_samples = self.r * len(data)
              n_samples = 0
              for e1, (X, y) in enumerate(loader):
                  if torch.cuda.is_available():
                      X, y = X.cuda(), y.cuda()
                  n_samples += X.shape[0]
                  
                  inter_loss, output = self.train_solver(X, y)
                  loss_current_data += inter_loss

                  correct += torch.sum(torch.topk(output, axis=1, k=1)[1].squeeze(1) == y)
                  total += len(X)
                  acc_train = correct/total

                  if n_samples < max_num_samples and i != 0:
                    
                    loss_past_data += self.use_generator(n_samples, y, gen_images = gen_images) # Loss from generated samples of previous tasks.
                    if gen_images:
                      gen_images = False

                  


              print(f"Epoch {epoch}: Loss current data{loss_current_data/(e1+1):.3f} Loss past data {loss_past_data/(e1+1):.3f} Acc: {acc_train:.3f}")

              

              tot_loss = self.r * loss_current_data + (1-self.r) * loss_past_data
              optimizer.zero_grad()
              tot_loss.backward()
              optimizer.step()

            self.validate(end_of_epoch=True)

            self.old_scholar = deepcopy(self.current_scholar)

            del self.current_scholar

            for param in self.old_scholar.solver.parameters():
              param.requires_grad = False

            for param in self.old_scholar.generator.parameters():
              param.requires_grad = False
            
            self.old_scholar.eval()

            # ...
            if torch.cuda.is_available():
                torch.cuda.empty_cache()


    @torch.no_grad()
    def validate(self, end_of_epoch=False):
        self.current_scholar.solver.eval()
        for task, data in self.val_datasets.items():
            loader = torch.utils.data.DataLoader(data, batch_size=args['bs'], shuffle=True)
            correct, total = 0, 0
            for e, (X, y) in enumerate(loader):
                if torch.cuda.is_available():
                    X, y = X.cuda(), y.cuda()
                output = self.current_scholar.solver(X)
                correct += torch.sum(torch.topk(output, axis=1, k=1)[1].squeeze(1) == y).item()
                total += len(X)
            self.acc[task].append(correct/total)
            if end_of_epoch:
                self.acc_end[task].append(correct/total)
        self.current_scholar.solver.train()
        self.current_scholar.generator.train()

    
class Solver(torch.nn.Module):
    def __init__(self, args):
        super().__init__()
        hidden_size = args['hidden_size']
        self.fc1 = torch.nn.Linear(args['in_size']**2 * args['n_channels'], hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc3 = torch.nn.Linear(hidden_size, hidden_size)
        self.fc4 = torch.nn.Linear(hidden_size, args['num_classes'])

    def forward(self, input):
        x = input.flatten(start_dim=1)
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        x = self.fc4(x)
        return x
    

    
class Generator(torch.nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape

        self.model = torch.nn.Sequential(
            torch.nn.Linear(latent_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 256),
            torch.nn.BatchNorm1d(256, 0.8),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 512),
            torch.nn.BatchNorm1d(512, 0.8),
            torch.nn.ReLU(),
            torch.nn.Linear(512, int(np.prod(img_shape))),
            torch.nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *self.img_shape)
        return img

class Discriminator(torch.nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = torch.nn.Sequential(
            torch.nn.Linear(int(np.prod(img_shape)), 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 1),
            torch.nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity

    
class Scholar(nn.Module):
  def __init__(self, args):
    super(Scholar, self).__init__()
    self.solver = Solver(args)
    self.generator = Generator(latent_dim = args['g_input_dim'], img_shape=(args['in_size'], args['in_size'], args['n_channels']))


  def forward(self, x, mode = 'solver'):

    if mode == 'solver':
      out = self.solver(x)


    return out




In [None]:
import subprocess

def get_gpu_usage():
    try:
        # Run the nvidia-smi command
        result = subprocess.check_output(['nvidia-smi', '--query-gpu=utilization.gpu,memory.total,memory.used,memory.free', '--format=csv,nounits,noheader'])
        result = result.decode('utf-8').strip()
        
        # Split the output into lines and parse each line
        gpu_usages = []
        for line in result.split('\n'):
            gpu_utilization, total_memory, used_memory, free_memory = map(int, line.split(', '))
            gpu_usages.append({
                'gpu_utilization': gpu_utilization,
                'total_memory': total_memory,
                'used_memory': used_memory,
                'free_memory': free_memory
            })
        
        # Print or return the parsed information
        for i, usage in enumerate(gpu_usages):
            print(f"GPU {i}:")
            print(f"  Utilization: {usage['gpu_utilization']}%")
            print(f"  Total Memory: {usage['total_memory']} MB")
            print(f"  Used Memory: {usage['used_memory']} MB")
            print(f"  Free Memory: {usage['free_memory']} MB")
        
    except subprocess.CalledProcessError as e:
        print(f"Failed to run nvidia-smi: {e}")
    except Exception as e:
        print(f"Error: {e}")



In [None]:
classes = list(range(args['num_classes']))
set_random_seed(42)
shuffle(classes)
class_split = {str(i): classes[i*2: (i+1)*2] for i in range(args['num_tasks'])}
args['task_names'] = list(class_split.keys())

In [None]:
train, test = get_dataset(dataroot='../data/', dataset=args['dataset'])
train_tasks = split_dataset(train, class_split)
val_tasks = split_dataset(test, class_split)
agent = Agent(args, train_tasks, val_tasks)

agent.validate()
random_model_acc = [i[0] for i in agent.acc.values()]
agent.reset_acc()
agent.train()

acc_at_end_arr = dict2array(agent.acc_end)
plot_accuracy_matrix(acc_at_end_arr)

acc_arr = dict2array(agent.acc)
plot_acc_over_time(acc_arr)

print(f"The average accuracy at the end of sequence is: {compute_average_accuracy(acc_at_end_arr):.3f}")
print(f"BWT:'{compute_backward_transfer(acc_at_end_arr):.3f}'")
print(f"FWT:'{compute_forward_transfer(acc_at_end_arr, random_model_acc):.3f}'")
get_gpu_usage()