# Rehearsal Mechanisms

## Introduction

The rehearsal mechanism is described by Robbins (1995) as a mechanism to mitigate the effects of catastrophic forgetting. They even went as far as postulating that this mechanism is what the human brain used to learn, through replaying old data along with new data during dreaming in sleep. In this notebook, we will explore the effects of rehearsal and pseudorehearsals on a model that is trained to do MNIST digit classification.

In [1]:
# As always, we start with package imports
import torch
import torch.nn as nn
import numpy as np
import torchvision
from torchvision.transforms import ToTensor, Compose, Pad
from tqdm.notebook import tqdm

# From my pytorch-patterns package
from utils import train_epoch, calculate_error, validate

## Customize Dataset

Unlike classification tasks, we can't just use the MNIST dataset as is. To investigate catastrophic forgetting, we need to holdout the samples of 1 class of labels. To do this, we need to build a new PyTorch dataset.

In [2]:
class CFMNIST(torchvision.datasets.MNIST):
    
    def __init__(
        self, root, 
        train=True, 
        download=False, 
        transform=None, 
        target_transform=None,
        holdout_label=9,
        holdout=True,
    ):
        super().__init__(
            root, 
            train=train, 
            download=download, 
            transform=transform, 
            target_transform=target_transform
        )
        
        imgs, labels = self._load_data()
        
        if holdout:
            mask = labels != holdout_label
        else:
            mask = labels == holdout_label
            
        self.data = imgs[mask]
        self.targets = labels[mask]

In [3]:
transforms = Compose([
    ToTensor(),
    Pad(2)
])
trainset = CFMNIST(".", download=True, transform=transforms)
validationset = CFMNIST(".", download=True, train=False, transform=transforms)
print(trainset)
print(validationset)

Dataset CFMNIST
    Number of datapoints: 54051
    Root location: .
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Pad(padding=2, fill=0, padding_mode=constant)
           )
Dataset CFMNIST
    Number of datapoints: 8991
    Root location: .
    Split: Test
    StandardTransform
Transform: Compose(
               ToTensor()
               Pad(padding=2, fill=0, padding_mode=constant)
           )


## Training LeNet5

Next, we pre-train a LeNet5 while holding out the label 9 to get a model that can classify MNIST 0 - 8. We store the best model based.

In [4]:
class LeNet(nn.Module):
    
    def __init__(self):
        super(LeNet, self).__init__()
        self.convnet = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
            nn.ReLU(inplace=True),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
            nn.ReLU(inplace=True),
        )
        
        self.fc1 = nn.Linear(in_features=120, out_features=84)
        self.fc2 = nn.Linear(in_features=84, out_features=10)
        
    def forward(self, img):
        X = self.convnet(img)
        X = torch.flatten(X, start_dim=1)
        X = self.fc1(X)
        X = nn.functional.relu(X)
        X = self.fc2(X)
        return X

In [5]:
# Setup training
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = LeNet().to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

train_loader = torch.utils.data.DataLoader(
                    trainset, 
                    batch_size=64, 
                    shuffle=True,
                    num_workers=2
                )

val_loader = torch.utils.data.DataLoader(
                    validationset, 
                    batch_size=64, 
                    shuffle=True,
                    num_workers=2
                )

In [6]:
# Epoch Training
num_epochs = 30

best_verror = 1_000_000

for epoch in tqdm(range(num_epochs)):
    
    # Inner training loop
    avg_loss = train_epoch(
        model, train_loader,
        optimizer=optimizer,
        criterion=criterion,
        device=device
    )
    
    # Inner validation loop
    avg_vloss, avg_verror = validate(
        model, val_loader, criterion=criterion, device=device)
    
    if avg_verror < best_verror:
        best_verror = avg_verror
        torch.save(model.state_dict(), "best-model.pth")
        
    tqdm.write(
        f"Train loss: {avg_loss:.3f}, \
        Validation loss: {avg_vloss:.3f}, \
        Validation error: {avg_verror:.3f}")

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 2.264,         Validation loss: 2.241,         Validation error: 0.874


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 2.228,         Validation loss: 2.216,         Validation error: 0.873


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 2.193,         Validation loss: 2.109,         Validation error: 0.586


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.878,         Validation loss: 0.359,         Validation error: 0.107


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.322,         Validation loss: 0.251,         Validation error: 0.073


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.241,         Validation loss: 0.192,         Validation error: 0.055


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.193,         Validation loss: 0.154,         Validation error: 0.044


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.164,         Validation loss: 0.134,         Validation error: 0.039


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.143,         Validation loss: 0.125,         Validation error: 0.037


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.127,         Validation loss: 0.138,         Validation error: 0.045


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.115,         Validation loss: 0.107,         Validation error: 0.030


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.104,         Validation loss: 0.092,         Validation error: 0.028


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.096,         Validation loss: 0.087,         Validation error: 0.025


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.089,         Validation loss: 0.085,         Validation error: 0.026


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.082,         Validation loss: 0.077,         Validation error: 0.024


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.077,         Validation loss: 0.080,         Validation error: 0.024


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.072,         Validation loss: 0.069,         Validation error: 0.020


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.069,         Validation loss: 0.068,         Validation error: 0.021


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.065,         Validation loss: 0.075,         Validation error: 0.023


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.061,         Validation loss: 0.062,         Validation error: 0.019


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.058,         Validation loss: 0.064,         Validation error: 0.021


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.056,         Validation loss: 0.062,         Validation error: 0.019


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.054,         Validation loss: 0.060,         Validation error: 0.019


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.051,         Validation loss: 0.056,         Validation error: 0.017


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.049,         Validation loss: 0.053,         Validation error: 0.016


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.046,         Validation loss: 0.052,         Validation error: 0.017


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.045,         Validation loss: 0.052,         Validation error: 0.014


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.043,         Validation loss: 0.058,         Validation error: 0.018


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.042,         Validation loss: 0.054,         Validation error: 0.015


  0%|          | 0/845 [00:00<?, ?it/s]

Train loss: 0.040,         Validation loss: 0.046,         Validation error: 0.014


## Pseudo-Rehearsals

In pseudorehearsals, the new intervening data will be batched with randomly generated data. The target labels of the randomly generated data are generated by the model in the forward pass. In pseudo-rehearsal mechanism, we then backpropagate this batch as though it is a full dataset. The hope is that the randomly generated data and model generated labels will capture the function representation that was best describing the original training set. This mechanism has the advantage over rehearsal methods when the previous trained data are not available.

In [7]:
newtrainset = CFMNIST(".", download=True, transform=transforms, holdout=False)
newtrainset

Dataset CFMNIST
    Number of datapoints: 5949
    Root location: .
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Pad(padding=2, fill=0, padding_mode=constant)
           )

In [8]:
newtrainloader = torch.utils.data.DataLoader(
                    newtrainset, 
                    batch_size=4, 
                    shuffle=True,
                    num_workers=1
                )

In [52]:
# Rehearsal is like a epoch training loop.
# We only batch the data differently
def pseudo_rehearse(
        model, trainloader, valloader,
        optimizer=None,
        criterion=torch.nn.CrossEntropyLoss(),
        device=torch.device("cpu"),
        factor=3
    ):
    
    model = model.to(device)

    running_vloss = []
    running_verror = []
    data_size = len(trainloader)

    if not optimizer:
        # Default optimizer if one is not provided
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    
    for i, data in tqdm(enumerate(trainloader), total=data_size):
        # Load the new data in minibatch
        imgs, labels = data
        n, c, H, W = imgs.size()
        
        # Generate random tensors for pseudo rehearsals
        # Note, in future it might be better to train on logits
        # instead of labels. Also need to solve why all labels are 8
        rand_imgs = torch.rand(n * factor, c, H, W)
        rand_logits = model(rand_imgs.to(device))
        rand_labels = torch.argmax(torch.softmax(rand_logits, dim=1), dim=1)
        imgs = torch.cat((imgs, rand_imgs), dim=0)
        labels = torch.cat((labels, rand_labels.cpu()))
        
        # Shuffle the data and labels
        indices = torch.randperm(len(labels))
        imgs = imgs[indices]
        labels = labels[indices]
        
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(imgs.to(device))
        
        # Compute loss and backpropagate error gradients
        loss = criterion(outputs, labels.to(device))
        loss.backward()
        
        # Gradient descent
        optimizer.step()
        
        # Validate model performance degradation
        vloss, verror = validate(
            model, valloader, criterion=criterion, device=device)
        
        running_vloss.append(vloss)
        running_verror.append(verror)
        
    return running_vloss, running_verror

In [None]:
# Pseudo Rehearsal
pseudo_rehearse(model, newtrainloader, val_loader, criterion=criterion, device=device)