<a href="https://colab.research.google.com/github/dineshRaja29/SpeechArchitectures_Hat-Swap-Architecture/blob/main/002_hat_swap_architecture_Catastrophic_Forgetting_remedy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# <font color = 'green'><b>GOAL:</b></font>

* Build a Hat Swap network for different datasets and train model in such a way that we can avoid "Catastrophic
Forgetting"

<font color = 'green'><b>Hat Swap Network:</b></font> Network with an output layer for each task (or dataset), but shared hidden layers.

<font color = 'green'><b>Catastrophic Forgetting: </b></font> When we sequentially train a shared model on multiple tasks, each new training phase overwrites the learned parameters, especially in the shared layers. This leads to the model performing well on the last-trained language, but forgetting previous ones.

* Reference:
    * https://www.inf.ed.ac.uk/teaching/courses/asr/2019-20/asr14-multiling.pdf
    * https://www.ibm.com/think/topics/catastrophic-forgetting

# <font color = 'green'><b>DATASET</b></font>

* Considering the CIFAR10 dataset as our base dataset
* From CIFAR10, created three task's dataset
    * with one: All images with label as 1 considered as 1 and images with label as 4 and 7 are considered as 0
    * with two: All images with label as 2 considered as 1 and images with label as 5 and 6 are considered as 0
    * with three: All images with label as 3 considered as 1 and images with label as 8 and 9 are considered as 0
* Merge these three task's dataset with an additional column which store the task identification.

# <font color = 'green'><b>DATASET PREPARATION</b></font>

In [None]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import classification_report
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoFeatureExtractor, Dinov2Model
from torch.optim.lr_scheduler import StepLR
# note: PIL stands for pillow; to install type "pip3 install pillow"
from PIL import Image
from datetime import datetime
from torchvision.transforms import Compose, Resize, RandomResizedCrop, RandomHorizontalFlip, ColorJitter, ToTensor, Normalize
import pandas as pd
import numpy as np
import os, gc

In [None]:
# a simple transformation
transform = transforms.ToTensor()
to_pil = ToPILImage()
# download data
train_set = CIFAR10(root='./data', train=True, download=True, transform=transform)
test_set = CIFAR10(root='./data', train=False, download=True, transform=transform)
# intermediate directories to save data
save_root = '/content/drive/MyDrive/cifar10_binary'
os.makedirs(save_root, exist_ok = True)

100%|██████████| 170M/170M [00:04<00:00, 34.5MB/s] 


In [None]:
def label_adjustment(dataset, pos, neg):
    results = []
    for img, label in dataset:
        if label == pos:
            results.append([img, 1])
        if label in neg:
             results.append([img, 0])
    return results

def save_images_and_make_csv(data, split_name):
    dir = os.path.join(save_root, split_name)
    os.makedirs(dir, exist_ok = True)
    rows = []
    for idx, (img_tensor, label) in enumerate(data):
        img_path = os.path.join(dir, f'{idx}.png')
        to_pil(img_tensor).save(img_path)
        rows.append([img_path, label])

    df = pd.DataFrame(rows, columns = ["MD5HASH", "LABEL"])
    df.to_csv(os.path.join(save_root, f"{split_name}.csv"), index = False)
    print(f"{split_name}.csv saved with {len(rows)} entries.")

In [None]:
labels = {
             'with_one': [1, [4,7]],
             'with_two': [2, [5,6]],
             'with_three': [3, [8,9]]
            }
def generate_multiple_csv_files():

    for k, v in labels.items():
        train_data = label_adjustment(train_set, v[0], v[1])
        test_data  = label_adjustment(test_set, v[0], v[1])
        # Save both splits
        save_images_and_make_csv(train_data, f"{k}_train")
        save_images_and_make_csv(test_data, f"{k}_test")

generate_multiple_csv_files()

with_one_train.csv saved with 15000 entries.
with_one_test.csv saved with 3000 entries.
with_two_train.csv saved with 15000 entries.
with_two_test.csv saved with 3000 entries.
with_three_train.csv saved with 15000 entries.
with_three_test.csv saved with 3000 entries.


In [None]:
languages_dataset = []
for data in labels.keys():
    file = f"/content/drive/MyDrive/cifar10_binary/{data}_train.csv"
    frame = pd.read_csv(file)
    frame['LANGUAGE'] = data
    languages_dataset.append(frame)

# Concatenate all language-specific DataFrames into one
full_dataset = pd.concat(languages_dataset, ignore_index=True)
full_dataset.to_csv('/content/drive/MyDrive/cifar10_binary/all_languages_training_data.csv', index = False)

In [None]:
!ls /content/drive/MyDrive/cifar10_binary/

all_languages_training_data.csv  with_three_test       with_two_test.csv
with_one_test			 with_three_test.csv   with_two_train
with_one_test.csv		 with_three_train      with_two_train.csv
with_one_train			 with_three_train.csv
with_one_train.csv		 with_two_test


In [None]:
! wc -l /content/drive/MyDrive/cifar10_binary/all_languages_training_data.csv

45001 /content/drive/MyDrive/cifar10_binary/all_languages_training_data.csv


In [None]:
! cat /content/drive/MyDrive/cifar10_binary/all_languages_training_data.csv | head -4

MD5HASH,LABEL
/content/drive/MyDrive/cifar10_binary/with_two_train/0.png,0
/content/drive/MyDrive/cifar10_binary/with_two_train/1.png,1
/content/drive/MyDrive/cifar10_binary/with_two_train/2.png,1
cat: write error: Broken pipe


# <font color = 'green'><b>SETTING ENVIRONMENT FOR HARDWARE ACCELERATOR</b></font>

In [None]:
# Check for CUDA and MPS availability, set the device accordingly
if torch.backends.mps.is_available():
    device = torch.device("mps")
    # setting environment variables, need to run training in MacOS
    os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
    os.environ['PYDEVD_DISABLE_FILE_VALIDATION'] = '1'
    print("Using MPS as the device.")
else:
    if torch.cuda.is_available():
	# the syntax 'cuda:3' used to point a specific GPU from the cluster of GPUs
	# 'cuda' points to first GPU from the cluster of GPUs
        device = torch.device("cuda")
        print("Using CUDA as the device.")
    else:
        device = torch.device("cpu")
        print("Using CPU as the device.")

Using CUDA as the device.


# <font color = 'green'><b>CONFIGURATION VARIABLES AND UTILITY FUNCTIONS</b></font>

In [None]:
BATCH_SIZE                            = 128 #256
WORKERS                               = 4
PIN_MEMORY                            = True
MIXING                                = True
MODEL_NAME                            = "facebook/dinov2-base"
RESULTS                               = 'results'
EPOCHS                                = 5
BEST_MODEL                            = None
PRETRAINING                           = False
LEARNING_RATE                         = 1e-4
L2_PENALTY                            = 1e-4
GAMMA                                 = 0.1
STEPSIZE                              = 3
SAVE_CHECKPOINTS                      = True
MIN_LOSS                              = float('inf')
MODEL_SAVED                           = f'{RESULTS}/bestmodel.pth'
THRESHOLD                             = 0.5
OUTPUT_DIM                            = 1
HEADS                                 = ['with_one', 'with_two', 'with_three']

In [None]:
def calculate_classification_accuracy(loader, model, head):
    model.eval()  # Set the model in evaluation mode
    LABELS = []
    PREDICTIONS = []

    with torch.no_grad():
        for images, labels in loader:
            # Move to device and cast to float32
            images, labels = images.to(device), labels.to(device).float()
            features = model(images).squeeze()
            probabilities = model.classify(features, head)
            # Predictions based on the threshold
            prediction = torch.where(probabilities > THRESHOLD, 1.0, 0.0)
            LABELS.extend(labels.tolist())
            PREDICTIONS.extend(prediction.tolist())
    return classification_report(LABELS, PREDICTIONS)


def cleaning_memory():
    # Explicitly free up GPU memory
    if torch.backends.mps.is_available():
        torch.backends.mps.is_macos13_or_newer.cache_clear()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    # Run garbage collector to free up CPU memory
    gc.collect()

# <font color = 'green'><b>MODEL ARCHITECTURE</b></font>
> Taking DINO as feature extractor which we will fine-tune on our dataset.

In [None]:
class MultiHeadNetwork(nn.Module):
    # Reference: https://www.inf.ed.ac.uk/teaching/courses/asr/2019-20/asr14-multiling.pdf
    def __init__(self, heads):
        super(MultiHeadNetwork, self).__init__()

        # Load pre-trained processor and backbone model
        # taking processor for necessary substitions, if needed in later stages
        self.processor = AutoImageProcessor.from_pretrained(MODEL_NAME)
        self.backbone_model = Dinov2Model.from_pretrained(MODEL_NAME)
        self.pretrained_backbone_model_last_dim = self.backbone_model.layernorm.normalized_shape[0]

        # Create a dictionary to hold the heads
        self.heads = nn.ModuleDict()

        for head in heads:
            # Create a classification head for each entry in heads
            self.heads[head] = nn.Sequential(
                nn.Linear(self.pretrained_backbone_model_last_dim, OUTPUT_DIM, bias=True),
                nn.Sigmoid()
            )

        # Initialize weights
        self.initialize_weights()

    def initialize_weights(self):
        torch.manual_seed(444)
        for head_name, head in self.heads.items():
            for layer in head:
                if isinstance(layer, nn.Linear):
                    nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu')
                    nn.init.zeros_(layer.bias)
                    print(f"kaiming_uniform_ Initialization: {layer.__class__.__name__}")

    def forward(self, x):
        """
        Forward only through the shared backbone.
        Language-specific heads are applied outside in training loop.
        Acts as a Feature Extractor
        """
        features = self.backbone_model(x).last_hidden_state[:, 0]
        return features

    def classify(self, features, lang):
        """
        Forward through a specific language head.
        """
        return self.heads[lang](features)

In [None]:

model = MultiHeadNetwork(HEADS)

model.to(device)

preprocessor_config.json:   0%|          | 0.00/436 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


config.json:   0%|          | 0.00/548 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

kaiming_uniform_ Initialization: Linear
kaiming_uniform_ Initialization: Linear
kaiming_uniform_ Initialization: Linear


MultiHeadNetwork(
  (backbone_model): Dinov2Model(
    (embeddings): Dinov2Embeddings(
      (patch_embeddings): Dinov2PatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(14, 14), stride=(14, 14))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): Dinov2Encoder(
      (layer): ModuleList(
        (0-11): 12 x Dinov2Layer(
          (norm1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
          (attention): Dinov2Attention(
            (attention): Dinov2SelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
            )
            (output): Dinov2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (layer_scale1): Dinov2LayerSca

# <font color = 'green'><b>DATASET LOADING</b></font>

In [None]:
class MD5HASHDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.images = self.dataframe['MD5HASH'].values
        self.labels = self.dataframe['LABEL'].values
        try:
            self.languages = self.dataframe['LANGUAGE'].values
        except:
            pass
        self.processor = model.processor
        self.mean = self.processor.image_mean
        self.std = self.processor.image_std
        self.interpolation = self.processor.resample

        self.train_transform = Compose([
            Resize(size = (32, 32)),
            #RandomResizedCrop(size = (224, 224),
            #                  scale = (0.08, 1.0),
            #                  ratio = (0.75, 1.3333),
            #                  interpolation = self.interpolation),
            #RandomHorizontalFlip(p = 0.5),
            #ColorJitter(brightness = (0.6, 1.4),
            #            contrast = (0.6, 1.4),
            #            saturation = (0.6, 1.4)),
            ToTensor(),
            Normalize(mean = self.mean, std = self.std),
        ])


    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Load the image from the file path
        image_path = self.images[idx]
        image = self.train_transform(Image.open(image_path).convert('RGB'))
        # Get the label
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        try:
            language = self.languages[idx]
        except:
            return image, label
        return image, label, language


In [None]:
def create_training_loader(data_csv, upsampling = False):
    # Load data
    training_data = pd.read_csv(data_csv)
    print('::: DATA DETAILS :::')
    print('- Number of Samples:', training_data.shape[0])
    # Create dataset and dataloader
    md5hash_dataset = MD5HASHDataset(training_data)
    if upsampling:
        # References:
        # https://pytorch.org/docs/stable/data.html
        # https://towardsdatascience.com/demystifying-pytorchs-weightedrandomsampler-by-example-a68aceccb452
        from torch.utils.data import WeightedRandomSampler
        print('- LANGUAGE DISTRIBUTION: \n',training_data['LANGUAGE'].value_counts())
        classes_count = dict(training_data['LANGUAGE'].value_counts())
        sample_weights = [ 1 / classes_count[i] for i in training_data.LANGUAGE.values]
        sampler = WeightedRandomSampler(weights = sample_weights,
                                        num_samples = len(training_data),
                                        replacement = True)
        data_loader = DataLoader(md5hash_dataset,
                                 batch_size = BATCH_SIZE,
                                 num_workers = WORKERS,
                                 pin_memory = PIN_MEMORY,
                                 shuffle = False,
                                 sampler = sampler)
    else:
        data_loader = DataLoader(md5hash_dataset,
                                 batch_size = BATCH_SIZE,
                                 num_workers = WORKERS,
                                 pin_memory = PIN_MEMORY,
                                 shuffle = MIXING)

    # Clean memory, :)
    del training_data

    return data_loader



# <font color = 'green'><b>LOSS FUNCTION</b></font>

In [None]:
# Define the loss function: BCE
criterion = nn.BCELoss()

In [None]:
# directory creation
os.makedirs(RESULTS, exist_ok = True)
if SAVE_CHECKPOINTS:
    CHECKPOINTDIR = f'{RESULTS}/checkpoints'
    os.makedirs(CHECKPOINTDIR, exist_ok = True)

# <font color = 'green'><b>OPTIMIZER</b></font>

In [None]:
# Idea borrowed from Research paper titled as "Improving Generalization Performance by Switching from Adam to SGD"
if PRETRAINING:
    optimizer = torch.optim.SGD(model.parameters(), lr = LEARNING_RATE, momentum = 0.9, weight_decay = L2_PENALTY)
else:
    optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE, weight_decay = L2_PENALTY)

# Define a learning rate scheduler
scheduler = StepLR(optimizer, step_size = STEPSIZE, gamma = GAMMA)  # Adjust step_size and gamma as needed


In [None]:
file = f'/content/drive/MyDrive/cifar10_binary/all_languages_training_data.csv'
data_loader = create_training_loader(file, upsampling = True)

::: DATA DETAILS :::
- Number of Samples: 45000
- LANGUAGE DISTRIBUTION: 
 LANGUAGE
with_one      15000
with_two      15000
with_three    15000
Name: count, dtype: int64


# <font color = 'green'><b>MODEL TRAINING</b></font>

In [None]:
MIN_LOSS = float('inf')
# TRAINING LOOP
for epoch in range(EPOCHS):
    print('-'*70)
    # Define the total number of batches in the loader
    total_loss = 0.0

    # setting model stage to training
    model.train()

    for batch_idx, (images, labels, languages) in enumerate(data_loader):
        # shifting on hardware accelator
        images, labels = images.to(device), labels.to(device)
        # Forward pass
        optimizer.zero_grad()  # Moved this line here to avoid accumulating gradients

        with torch.set_grad_enabled(True):
            # Forward pass
            features = model(images).squeeze()  # Squeeze to remove extra dimensions

            unique_langs = list(set(languages))
            batch_loss = 0.0

            for lang in unique_langs:
                # Get indices for current language
                lang_indices = [i for i, l in enumerate(languages) if l == lang]
                if not lang_indices:
                    continue

                # Gather corresponding features and labels
                lang_feats = features[lang_indices]
                lang_labels = labels[lang_indices].unsqueeze(1)

                # Forward pass through corresponding head
                outputs = model.classify(lang_feats, lang)

                # Handle edge cases if output shape is scalar
                if outputs.dim() == 0:
                    outputs = outputs.unsqueeze(0)

                lang_loss = criterion(outputs, lang_labels)

                batch_loss += lang_loss

            # Backprop
            batch_loss.backward()
            optimizer.step()

        total_loss += batch_loss.item()
        cleaning_memory() # cleaning memory

    print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {total_loss / (batch_idx + 1)}")
    # Update the learning rate
    scheduler.step()


    if SAVE_CHECKPOINTS:
        timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
        checkpointmodel = '{}/epoch_{}_{}.pth'.format(CHECKPOINTDIR, epoch + 1, timestamp)
        print('Saving checkpoint: ', checkpointmodel)
        torch.save(model.state_dict(), checkpointmodel)

    # Check if this epoch had the minimum loss
    if total_loss < MIN_LOSS:
        MIN_LOSS = total_loss
        best_model = model.state_dict()
        # Save the best model
        if best_model is not None:
            print('Saving Best Model: ', MODEL_SAVED)
            torch.save(best_model, MODEL_SAVED)

################################

----------------------------------------------------------------------
Epoch 1/5, Loss: 1.3424734544347634
Saving checkpoint:  results/checkpoints/epoch_1_20250729120016.pth
Saving Best Model:  results/bestmodel.pth
----------------------------------------------------------------------
Epoch 2/5, Loss: 0.623031364584511
Saving checkpoint:  results/checkpoints/epoch_2_20250729120257.pth
Saving Best Model:  results/bestmodel.pth
----------------------------------------------------------------------
Epoch 3/5, Loss: 0.42985334171151574
Saving checkpoint:  results/checkpoints/epoch_3_20250729120537.pth
Saving Best Model:  results/bestmodel.pth
----------------------------------------------------------------------
Epoch 4/5, Loss: 0.17698980413313786
Saving checkpoint:  results/checkpoints/epoch_4_20250729120818.pth
Saving Best Model:  results/bestmodel.pth
----------------------------------------------------------------------
Epoch 5/5, Loss: 0.09020945761717898
Saving checkpoint:  results

# <font color = 'green'><b>MODEL EVALUATION</b></font>

In [None]:
# loading best model
model.load_state_dict(torch.load(MODEL_SAVED, weights_only = True))

<All keys matched successfully>

In [None]:
for head in HEADS:
    print(f"** Peformance of Head {head.upper()}**")
    for set_name in ['train', 'test']:
        file = f'/content/drive/MyDrive/cifar10_binary/{head}_{set_name}.csv'
        print(f"Dataset: {set_name}")
        data_loader = create_training_loader(file,
                                             upsampling = False)
        # without upsampling, used to report exact performance on the training data and testing data
        print(calculate_classification_accuracy(data_loader, model, head))
        del data_loader
        cleaning_memory() # cleaning memory

** Peformance of Head WITH_ONE**
Dataset: train
::: DATA DETAILS :::
- Number of Samples: 15000
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00     10000
         1.0       1.00      1.00      1.00      5000

    accuracy                           1.00     15000
   macro avg       1.00      1.00      1.00     15000
weighted avg       1.00      1.00      1.00     15000

Dataset: test
::: DATA DETAILS :::
- Number of Samples: 3000
              precision    recall  f1-score   support

         0.0       0.98      0.99      0.98      2000
         1.0       0.97      0.96      0.97      1000

    accuracy                           0.98      3000
   macro avg       0.98      0.98      0.98      3000
weighted avg       0.98      0.98      0.98      3000

** Peformance of Head WITH_TWO**
Dataset: train
::: DATA DETAILS :::
- Number of Samples: 15000
              precision    recall  f1-score   support

         0.0       0.99      0.99     

# <font color = 'green'><b>OBSERVATION</b></font>

> * At each iteration, data from all languages is fed into the network.
> * The derivatives for the feature extractor are computed by aggregating the derivatives from all output layers.
> * Performance evaluation shows that the model performs well across all tasks.

