In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
from copy import deepcopy
from torch.nn import functional as Func
from tqdm import tqdm
from torch.nn.utils import parameters_to_vector, vector_to_parameters
import time

from nngeometry.object import PMatDiag, PMatBlockDiag, PMatKFAC, PMatEKFAC, PMatDense, PMatQuasiDiag, PVector
from nngeometry.metrics import FIM


In [2]:
def seed_everything(seed: int):
    import random, os
    import numpy as np
    import torch
    
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    
seed_everything(42)

# 1. Continual learning

In [3]:
import random


class DataBuffer:
    def __init__(self, max_size):
        self.buffer = []
        self.max_size = max_size
        self.num_seen_examples = 0

    def update_buffer(self, new_data):
        """Update the buffer with new data using reservoir sampling."""
        for x, y in new_data:  # Assuming new_data is a list of tuples (x, y)
            self.num_seen_examples += 1
            if len(self.buffer) < self.max_size:
                self.buffer.append((x, y))
            else:
                # Vitter's reservoir sampling algorithm
                replace_index = random.randint(0, self.num_seen_examples - 1)
                if replace_index < self.max_size:
                    self.buffer[replace_index] = (x, y)

    def sample(self, n):
        """Sample n items from the buffer without replacement and return as tensors."""
        if len(self.buffer) < n:
            # If buffer size is less than n, return all items in the buffer
            x_all, y_all = zip(*self.buffer) if self.buffer else ([], [])
        else:
            sampled_data = random.sample(self.buffer, n)
            x_all, y_all = zip(*sampled_data)

        # Convert lists to tensors
        X_sampled = torch.stack(x_all) if x_all else torch.Tensor()  # Create an empty tensor if no data
        Y_sampled = torch.stack(y_all) if y_all else torch.Tensor()
        return X_sampled, Y_sampled





We implement a function to train or evaluate a model on a task. This function needs the task set, the model to train, the optimizer, the list of ewc class (one per past tasks) that will be used to regularized the gradient, the importance of the regularization and a boolean flag to know if the task should be used to train or just for evaluation.

In [4]:
def process_task(task_set, model, optimizer, Replay, ewc_list, importance, FIM_Representation, train, buffer):

    b_size = 256
    task_loader = DataLoader(task_set, batch_size=b_size, shuffle=True)
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    model.train() if train else model.eval()

    epoch_loss = 0
    correct = 0
    

    for x_current, y_current, t in task_loader:
    
      x_current, y_current = x_current.to(device), y_current.to(device).long()
      x_, y_ = x_current, y_current

      #extract from buffer the same amount of data as the task, concatenate with the task data, create a new temporary dataloader
      
      if Replay and train:
        x_buffer, y_buffer = buffer.sample(x_current.size(0))
        x_buffer, y_buffer = x_buffer.to(device), y_buffer.to(device)
        x_ = torch.cat((x_current, x_buffer), 0)
        y_ = torch.cat((y_current, y_buffer), 0).long()
        
      

      temp_dataset = torch.utils.data.TensorDataset(x_, y_)
      temp_loader = DataLoader(temp_dataset, batch_size=b_size, shuffle=True)

      #Compute the Fisher matrix
      if FIM_Representation is not None and Replay and train:
        #Copmute the FIsher only of the buffer data
        buffer_loader = DataLoader(torch.utils.data.TensorDataset(x_buffer, y_buffer), batch_size=b_size, shuffle=True)
        F_buffer = FIM(model=model,
                    loader=buffer_loader,
                    representation=FIM_Representation,
                    n_output=30,
                    variant = 'classif_logits',
                    device=device)
        
        new_loader = DataLoader(torch.utils.data.TensorDataset(x_current, y_current), batch_size=b_size, shuffle=True)
        F_current = FIM(model=model,
                    loader=new_loader,
                    representation=FIM_Representation,
                    n_output=30,
                    variant = 'classif_logits',
                    device=device)
        
        F = FIM(model=model, 
                    loader=temp_loader,
                    representation=FIM_Representation,
                    n_output=30,
                    variant = 'classif_logits',
                    device=device)
        
        print(F, F_buffer+F_current)
        


      for x_, y_ in temp_loader:      
  
        # we compute the loss without regularization
        predictions = model(x_)
        loss = Func.cross_entropy(predictions, y_)
        epoch_loss += loss.cpu().item()
        correct += (predictions.max(dim=1)[1] == y_).sum().item()
        
        
        
        if train:
          optimizer.zero_grad()
          
          if FIM_Representation is not None and Replay==False:
            for n, ewc in ewc_list.items():
              regule = ewc.penalty(model)
              loss += importance * regule

          loss.backward()

          if FIM_Representation is not None and Replay:            
            original_grad_vec = PVector.from_model_grad(model)
            regularized_grad = F.solve(original_grad_vec, regul=1e-2) 
            regularized_grad.to_model_grad(model)   

          optimizer.step()

          if Replay:
            buffer.update_buffer(list(zip(x_current.cpu().detach(), y_current.cpu().detach())))
        
    accuracy=100*correct / len(task_loader.dataset)
    loss_mean=epoch_loss / (len(task_loader.dataset))

    return loss_mean, accuracy

We create the neural network that will be trained on the sequence of tasks. 28*28  images as input and 30 outputs.

In [5]:
class Base_Net(nn.Module):
    def __init__(self, num_classes=10):
        super(Base_Net, self).__init__()
        self.num_classes = num_classes

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(1, 5, kernel_size=5)
        self.conv2 = nn.Conv2d(5, 5, kernel_size=5)
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.fc1 = nn.Linear(80, 20)
        self.fc2 = nn.Linear(20, self.num_classes, bias=False)

    def forward(self, x):
        x = x.view(-1, 1, 28, 28)

        x = self.relu(self.maxpool2(self.conv1(x)))
        x = self.relu(self.maxpool2(self.conv2(x)))
        x = x.view(-1, 80)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

We import from continuum library, the datasets and scenario necessary for the experiences.

In [6]:
# we import MNISTFellowship=[MNIST,FashionMNIST,KMNIST]
from continuum.datasets import MNISTFellowship
# ClassIncremental => new tasks = new classes
from continuum import ClassIncremental

  from .autonotebook import tqdm as notebook_tqdm


We create a function to ennumerate tasks and manage the continual training. 

In [7]:
def continual_process(epochs, importance=1000., FIM_Representation=None, Replay=False, num_tasks=3, verbose=False, buffer_size=1000):

  
    
    # training scenario
    scenario_tr = ClassIncremental(MNISTFellowship('/data/e.urettini/DATA/MNISTFellowship', train=True), increment=10)

    # testing scenario
    scenario_te = ClassIncremental(MNISTFellowship('/data/e.urettini/DATA/MNISTFellowship', train=False), increment=10)

    #initialize the buffer
    buffer = DataBuffer(buffer_size)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = Base_Net(num_classes=30).to(device)
    optimizer = optim.SGD(params=model.parameters(), lr=0.002)
    #optimizer = optim.Adam(params=model.parameters())
    
    loss_te, acc_te, ewc, F_matrix = {}, {}, {}, {}

    for ind_task in range(num_tasks):
        loss_te[ind_task], acc_te[ind_task] = [], [], 
        


    # enumerate scenario tasks
    for task_id, dataset_tr in enumerate(scenario_tr):
        for ep in tqdm(range(epochs), desc=f"Task {task_id} Epochs"):
            # train on task
            loss, acc = process_task(dataset_tr, model, optimizer, Replay,
                         ewc, importance, FIM_Representation, train=True, buffer=buffer)
            if verbose: 
              print(f"Training Accuracy Task {task_id} : {acc} % ")
            for sub_task_id, dataset_te in enumerate(scenario_te):
              loss, acc = process_task(dataset_te, model, optimizer, Replay,
                            ewc, importance, FIM_Representation, train=False, buffer=buffer)
              
              if verbose: 
                print(f"Validation Accuracy Task {sub_task_id} : {acc} % ")
              acc_te[sub_task_id].append(acc)
              loss_te[sub_task_id].append(loss)
        # compute the fisher matrix to protect weights learned on this task
        if (FIM_Representation is not None) and task_id < scenario_tr.nb_tasks-1 and Replay==False: 
            
            ewc[task_id] = EWC(model, dataset_tr, FIM_Representation)
            F_matrix[task_id] = ewc[task_id].return_fisher().get_dense_tensor().cpu()

    return loss_te, acc_te, F_matrix

# 2. EWC diag and KFAC using NNGeometry

We install NNGeometry from its github repository:

For this example, we will need to import:
 - `FIM` is a helper that specifies that the metric that we are going to use is the Fisher Information Matrix. 
 - `PVector` is the vector of all parameters $\mathbf{w} = \left\{ W_1, b_1, \cdots \right\}$. Instead of having to loop through all parameters tensors, it conveniently offers an interface to do typical operations such as addition or multiplication with a scalar.
 - `PMatKFAC` and `PMatDiag` are matrix representations of the FIM.

In short:
 - `FIM` is *what* we want to compute
 - `PMatKFAC` and `PMatDiag` is *how* we are going to represent it

In [8]:
from nngeometry.metrics import FIM
from nngeometry.object import PMatKFAC, PMatDiag, PVector

We now define our class `EWC` that will be called at the end of each task in order to store the current parameters (here `self.v0`) and compute the FIM at current parameter values (here `self.Fisher`).

**Note that even if internally the mechanics are very different, switching from diagonal EWC to KFAC EWC is as simple as passing a different argument `Representation` to `FIM`.**

In [9]:
class EWC(object):
    def __init__(self, model: nn.Module, train_set, Representation):
        self.model = model
        self.train_set = train_set
        self.Fisher, self.v0 = self.compute_fisher(self.model, Representation)

    def compute_fisher(self, model, Representation):
        fisher_set = deepcopy(self.train_set)
        fisher_loader = DataLoader(fisher_set, batch_size=50, shuffle=False)
        F_diag = FIM(model=model,
                     loader=fisher_loader,
                     representation=Representation,
                     n_output=30,
                     variant='classif_logits',
                     device='cuda')

        v0 = PVector.from_model(model).clone().detach()

        return F_diag, v0

    def penalty(self, model: nn.Module):
        v = PVector.from_model(model)
        regularization_loss = self.Fisher.vTMv(v - self.v0)

        return regularization_loss
    
    def return_fisher(self):
        return self.Fisher

In [10]:
from nngeometry.layercollection import LayerCollection
net = Base_Net(num_classes=30)

lc = LayerCollection.from_model(net)
l_to_m, m_to_l = lc.get_layerid_module_maps(net)
for layer_id, layer in lc.layers.items():
            mod = l_to_m[layer_id]
            lay = m_to_l[mod]
            print(mod, layer_id)
            print('-----------------')



Conv2d(1, 5, kernel_size=(5, 5), stride=(1, 1)) conv1.Conv2d(1, 5, kernel_size=(5, 5), stride=(1, 1))
-----------------
Conv2d(5, 5, kernel_size=(5, 5), stride=(1, 1)) conv2.Conv2d(5, 5, kernel_size=(5, 5), stride=(1, 1))
-----------------
Linear(in_features=80, out_features=20, bias=True) fc1.Linear(in_features=80, out_features=20, bias=True)
-----------------
Linear(in_features=20, out_features=30, bias=False) fc2.Linear(in_features=20, out_features=30, bias=False)
-----------------


# 4. Learning

Now that we are all set, we start our continual learning sequence.

First, we do not use any regularization technique. This will serve as a benchmark.

In [11]:
importance = 10
num_tasks = 3
test_label = False
epochs = 5

loss, acc, _ = continual_process(epochs,
                              importance=importance, 
                              FIM_Representation=None,
                              num_tasks=num_tasks,
                              )

Task 0 Epochs: 100%|██████████| 5/5 [00:31<00:00,  6.39s/it]
Task 1 Epochs:   0%|          | 0/5 [00:02<?, ?it/s]


KeyboardInterrupt: 

In [None]:
loss_replay, acc_replay, _ = continual_process(epochs,
                              importance=importance, 
                              FIM_Representation=None,
                              num_tasks=num_tasks,
                              Replay=True
                              )

Task 0 Epochs:   0%|          | 0/5 [00:04<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# Training with Diagonal representation
loss_EWC_diag, acc_EWC_diag, F_matrix_diag = continual_process(epochs,
                                      importance=importance,
                                      FIM_Representation=PMatDiag,
                                      num_tasks=num_tasks,
                                      )

In [None]:
# Training with Diagonal representation
loss_EWC_KFAC, acc_EWC_KFAC, F_matrix_KFAC = continual_process(epochs,
                                      importance=importance,
                                      FIM_Representation=PMatKFAC,
                                      num_tasks=num_tasks,
                                      )

In [None]:
# Training with Diagonal representation
loss_EWC_dense, acc_EWC_dense, F_matrix_dense = continual_process(epochs,
                                      importance=importance,
                                      FIM_Representation=PMatDense,
                                      num_tasks=num_tasks,
                                      )

In [None]:
# Training with Diagonal representation

loss_ours_KFAC, acc_ours_KFAC, _ = continual_process(epochs,
                                      importance=importance,
                                      FIM_Representation=PMatKFAC,
                                      num_tasks=num_tasks,
                                      buffer_size=1000,
                                      Replay=True
                                      )


Task 0 Epochs: 100%|██████████| 5/5 [04:50<00:00, 58.19s/it]
Task 1 Epochs: 100%|██████████| 5/5 [04:37<00:00, 55.42s/it]
Task 2 Epochs: 100%|██████████| 5/5 [04:36<00:00, 55.23s/it]


In [None]:
# Training with Diagonal representation

loss_ours_dense, acc_ours_dense, _ = continual_process(epochs,
                                      importance=importance,
                                      FIM_Representation=PMatDense,
                                      num_tasks=num_tasks,
                                      buffer_size=1000,
                                      Replay=True)

In [None]:
# Training with Diagonal representation

loss_ours_diag, acc_ours_diag, _ = continual_process(epochs,
                                      importance=importance,
                                      FIM_Representation=PMatDiag,
                                      num_tasks=num_tasks,
                                      buffer_size=1000,
                                      Replay=True)

In [None]:
#Create a folder if not exists to save the results
import os
if not os.path.exists('results'):
    os.makedirs('results')

import pickle

#Save the results
with open('results/results.pkl', 'wb') as f:
    pickle.dump([loss, acc, loss_replay, acc_replay, loss_EWC_diag, acc_EWC_diag, F_matrix_diag, loss_EWC_KFAC, acc_EWC_KFAC, F_matrix_KFAC, loss_EWC_dense, acc_EWC_dense, F_matrix_dense, loss_ours_KFAC, acc_ours_KFAC, loss_ours_dense, acc_ours_dense, loss_ours_diag, acc_ours_diag], f)


# 5. Results

In [None]:
import matplotlib.pyplot as plt
import pickle

#Load the results
with open('results/results.pkl', 'rb') as f:
    loss, acc, loss_replay, acc_replay, loss_EWC_diag, acc_EWC_diag, F_matrix_diag, loss_EWC_KFAC, acc_EWC_KFAC, F_matrix_KFAC, loss_EWC_dense, acc_EWC_dense, F_matrix_dense, loss_ours_KFAC, acc_ours_KFAC, loss_ours_dense, acc_ours_dense, loss_ours_diag, acc_ours_diag = pickle.load(f)

    

We plot here the evolution of the test loss of all training strategies. The test set evaluate the algorithm's performance on all the tasks. We can see the EWC with KFAC representation has the lowest loss on the test set and that this loss continue to decrease on the last task.

In [None]:
plt.plot(np.mean([loss[0], loss[1], loss[2]], axis=0), label='no regul', linestyle='dotted', color='black')
plt.plot(np.mean([loss_replay[0], loss_replay[1], loss_replay[2]], axis=0), label='Replay', linestyle='dotted', color='gray')
plt.plot(np.mean([loss_EWC_diag[0], loss_EWC_diag[1], loss_EWC_diag[2]], axis=0), label='EWC diag', color='C0')
plt.plot(np.mean([loss_EWC_KFAC[0], loss_EWC_KFAC[1], loss_EWC_KFAC[2]], axis=0), label='EWC KFAC', color='C1')
plt.plot(np.mean([loss_EWC_dense[0], loss_EWC_dense[1], loss_EWC_dense[2]], axis=0), label='EWC dense', color='C2')
#plt.plot(np.mean([loss_ours_diag[0], loss_ours_diag[1], loss_ours_diag[2]], axis=0), label='Ours diag', linestyle='dashed', color='C0')
plt.plot(np.mean([loss_ours_KFAC[0], loss_ours_KFAC[1], loss_ours_KFAC[2]], axis=0), label='Ours KFAC', linestyle='dashed', color='C1')
plt.plot(np.mean([loss_ours_dense[0], loss_ours_dense[1], loss_ours_dense[2]], axis=0), label='Ours dense', linestyle='dashed', color='C2')


plt.legend()
plt.ylabel('loss averaged on 3 tasks')
plt.show()

We plot now, the model test accuracy with all the training strategies. As for the loss, we can see the KFAC representation perform the best. Moreover, it is the only one to improve through time.

In [None]:
plt.plot(np.mean([acc[0], acc[1], acc[2]], axis=0), label='no regul', linestyle='dotted', color='black')
plt.plot(np.mean([acc_replay[0], acc_replay[1], acc_replay[2]], axis=0), label='Replay', linestyle='dotted', color='gray')
plt.plot(np.mean([acc_EWC_diag[0], acc_EWC_diag[1], acc_EWC_diag[2]], axis=0), label='EWC diag', color='C0')
plt.plot(np.mean([acc_EWC_KFAC[0], acc_EWC_KFAC[1], acc_EWC_KFAC[2]], axis=0), label='EWC KFAC', color='C1')
plt.plot(np.mean([acc_EWC_dense[0], acc_EWC_dense[1], acc_EWC_dense[2]], axis=0), label='EWC dense', color='C2')
plt.plot(np.mean([acc_ours_diag[0], acc_ours_diag[1], acc_ours_diag[2]], axis=0), label='Ours diag', linestyle='dashed', color='C0')
plt.plot(np.mean([acc_ours_KFAC[0], acc_ours_KFAC[1], acc_ours_KFAC[2]], axis=0), label='Ours KFAC', linestyle='dashed', color='C1')
plt.plot(np.mean([acc_ours_dense[0], acc_ours_dense[1], acc_ours_dense[2]], axis=0), label='Ours dense', linestyle='dashed', color='C2')

plt.legend()
plt.ylabel('accuracy averaged on 3 tasks')
plt.show()

In [None]:
#Plot of the loss for each task in the epochs. We use same line style for each task and same color for each regularization method
color = ['red', 'blue', 'green']
for i in range(3):    
    plt.plot(acc_replay[i], label=f'task {i}', linestyle='dotted', color=color[i])
    plt.plot(acc_ours_KFAC[i], label=f'task {i} Ours', linestyle='--', color=color[i])
plt.legend()
plt.ylabel('accuracy for each task')
plt.show()

In [None]:
#Plot the fisher matrix (F[0].get_dense_tensor().cpu())of the first task for each regularization method using matplotlib
#Use a grayscale colormap and high contrast for the visualization.
#Highlight the non-zero values
import matplotlib.pyplot as plt
from matplotlib.colors import Normalize

plt.imshow(F_matrix_diag[0], cmap='gray', norm=Normalize(vmin=0, vmax=0.1))
plt.show()

plt.imshow(F_matrix_KFAC[0], cmap='gray', norm=Normalize(vmin=0, vmax=0.1))
plt.show()

plt.imshow(F_matrix_dense[0], cmap='gray', norm=Normalize(vmin=0, vmax=0.1))
plt.show()

plt.imshow(F_matrix_diag[1], cmap='gray', norm=Normalize(vmin=0, vmax=0.1))
plt.show()

plt.imshow(F_matrix_KFAC[1], cmap='gray', norm=Normalize(vmin=0, vmax=0.1))
plt.show()

plt.imshow(F_matrix_dense[1], cmap='gray', norm=Normalize(vmin=0, vmax=0.1))
plt.show()

