CaSSLe : prepare the base ssl model then divide the data set into tasks and then train the model on each task while saving encoder (frozen) after each task
during training Train the current model (student) using two losses

training loop for task t(After learning task t-1):
  

*   save a copy of the models backbone and prediction/projection head (ft-1 ,gt-1)
*   alongside the cassle predictor ht-1 from previous task


*   your main backbone is the same from previous task

*  the cassle predictor ht is reinitialized

*     Apply two different augmentations(e.g.,v1,v2) to each image, just like standard SSL.

*     Take the representation from one of the augmented views of the current imagePass it through the current CaSSLe predictor network ht (p1)


*     Pass the same representation through the frozen previous task's network and its predictor (q1)


*     compute CaSSLe Distillation Loss typically a cross-entropy or MSE loss
      between q1 adn p1


*     compute the ssl loss for current task using z1 and z2
      combined loss Ltotal

*    The entire current mode (ft,ht,gt) is updated using backpropagation on
     Ltotal


*The main representation learning components (f and g) are carried over.
The CaSSLe predictor from the previous task is saved and frozen as a teacher.
A new CaSSLe predictor is re-initialized for the current task and trained.*

  
  
  











#**MAE**

In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from transformers import ViTMAEForPreTraining
from transformers import ViTMAEConfig
from transformers import AutoImageProcessor
from typing import List, Dict, Any
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.utils.data import ConcatDataset
from PIL import Image
import random
import os
import gc

mae finetune on class 1
freeze mae1 and train mae1 on class 2 using ssl loss from mae1 with two augmentations from class 2 and ssl loss from frozen mae1 and g(z) of one aumentatio
then save the trained mae1 as mae2
etc..
so we use the output of the decoder for the ssl loss and use the output of the encoder for the distillation loss

In [2]:

# Custom Dataset for Two Augmentations and Task Filtering
class CustomCifar100TaskDataset(Dataset):
    def __init__(self, cifar100_dataset, class_list: List[int], transform_v1, transform_v2, processor):
        self.data = []
        self.targets = []
        self.processor = processor # HuggingFace processor for pixel_values

        # Filter CIFAR-100 data based on the provided class_list
        for i in range(len(cifar100_dataset)):
            img, label = cifar100_dataset[i]
            if label in class_list:
                # Ensure img is a PIL Image if it's a numpy array, as torchvision transforms expect it
                if isinstance(img, np.ndarray):
                    img = Image.fromarray(img)
                self.data.append(img)
                self.targets.append(label)

        self.transform_v1 = transform_v1
        self.transform_v2 = transform_v2

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        label = self.targets[idx] # Label not used by CaSSLe's SSL, but kept for consistency

        # Apply two distinct augmentations to the same image
        img_v1 = self.transform_v1(img)
        img_v2 = self.transform_v2(img)

        # HuggingFace processor converts PIL Image to tensor, normalizes, and resizes if needed
        inputs_v1 = self.processor(images=img_v1, return_tensors="pt")['pixel_values'].squeeze(0)
        inputs_v2 = self.processor(images=img_v2, return_tensors="pt")['pixel_values'].squeeze(0)

        return inputs_v1, inputs_v2, label

#CaSSLe predictor g(z)
class CaSSLePredictor(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super(CaSSLePredictor, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.LayerNorm(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

class MAECaSSLeModel(nn.Module):
    def __init__(self, mae_backbone: ViTMAEForPreTraining):
        super().__init__()
        self.mae_backbone = mae_backbone
        # The output dimension of the MAE encoder (hidden_size in config)
        self.features_for_h_dim = mae_backbone.config.hidden_size
        #Freezing the decoder
        for param in self.mae_backbone.decoder.parameters():
            param.requires_grad = False


    def forward(self, img_v1: torch.Tensor, img_v2: torch.Tensor) -> Dict[str, Any]:

        # Ensure images have a batch dimension for the MAE model
        # Using output_hidden_states=True to get encoder outputs
        outputs_v1 = self.mae_backbone(pixel_values=img_v1, return_dict=True, output_hidden_states=True)
        outputs_v2 = self.mae_backbone(pixel_values=img_v2, return_dict=True, output_hidden_states=True)

        # For MAE, the last_hidden_state is (batch_size, sequence_length, hidden_size).
        # We average over the sequence length (patches + CLS token) for a single feature vector per image.
        features_for_h_v1 = outputs_v1.hidden_states[-1].mean(dim=1)
        features_for_h_v2 = outputs_v2.hidden_states[-1].mean(dim=1)

        # MAE models often compute the reconstruction loss internally.
        ssl_loss_components_v1 = {'internal_mae_loss': outputs_v1.loss}
        ssl_loss_components_v2 = {'internal_mae_loss': outputs_v2.loss}

        return {
            'features_for_h_v1': features_for_h_v1,
            'features_for_h_v2': features_for_h_v2,
            'ssl_loss_components_v1': ssl_loss_components_v1,
            'ssl_loss_components_v2': ssl_loss_components_v2
        }

    def calculate_ssl_loss(self, ssl_loss_components: Dict[str, Any]) -> torch.Tensor:

        #Extracts and returns the MAE's internal reconstruction loss.
        return ssl_loss_components['internal_mae_loss']

    def get_learnable_params(self) -> List[dict]:
        #Returns the parameters of the entire MAE backbone
        return [{"params": self.mae_backbone.parameters()}]


#training process for each task, handling the current MAE model,
#the frozen previous MAE model, the g predictor, and the two loss terms (L_SSL and L_D).
class CaSSLeTrainer:
    def __init__(self, base_ssl_model: MAECaSSLeModel,
                 ca_predictor_hidden_dim: int,
                 learning_rate: float = 1e-4, lambda_cassle: float = 0.1, device: str = 'cpu'):

        self.base_ssl_model = base_ssl_model.to(device) # This is f_t
        self.lambda_cassle = lambda_cassle
        self.device = device

        # Input and output dimensions for CaSSLe Predictor (g) are the MAE encoder's feature dimension
        predictor_g_input_output_dim = self.base_ssl_model.features_for_h_dim

        # Initialize the current CaSSLe predictor (g from the paper)
        # It's newly initialized for each task
        self.g_current = CaSSLePredictor(
            predictor_g_input_output_dim,
            ca_predictor_hidden_dim,
            predictor_g_input_output_dim
        ).to(device)

        g_lr=0.001

        # Optimizer for ALL trainable parameters: current MAE model (f_t) AND predictor g
        params = self.base_ssl_model.get_learnable_params()  # This is a list of dicts
        params.append({"params": self.g_current.parameters(), "lr": g_lr})  # Add g_current params

        self.optimizer = torch.optim.AdamW(params, lr=learning_rate)


        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=6, gamma=0.1)

        # For L_D = L_SSL(g(z), z̄), the paper implies reusing the SSL loss.
        # However, for distilling features (z from f_t and z̄ from f_t-1), MSE is a common and effective choice.
        # Given MAE's SSL is reconstruction, directly applying it to feature vectors (g(z) and z̄) is not straightforward.
        # We'll use MSE for feature distillation as it's a standard practice for this type of knowledge transfer.


        # This will hold the frozen previous encoder (f_{t-1}^{frozen})
        self.f_frozen_teacher = None

    def set_previous_frozen_encoder(self, encoder_state_dict: Dict[str, Any]):

        # Re-instantiate the MAE model to load the full state dictionary
        self.f_frozen_teacher = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base').to(self.device)
        self.f_frozen_teacher.load_state_dict(encoder_state_dict)

        # set requires_grad to False for all parameters of the teacher model
        for param in self.f_frozen_teacher.parameters():
            param.requires_grad = False

        print(f"Frozen encoder (f_t-1) loaded and parameters frozen: {all(not p.requires_grad for p in self.f_frozen_teacher.parameters())}")


    def train_task(self, data_loader: torch.utils.data.DataLoader, epochs: int):

        self.base_ssl_model.train()
        self.g_current.train()
        scaler = torch.amp.GradScaler('cuda')  # Initialize GradScaler for AMP
        # Set f_frozen_teacher to eval mode to disable updates for teacher
        if self.f_frozen_teacher:
            self.f_frozen_teacher.eval()

        print(f"Distilling from frozen teacher encoder (f_t-1): {self.f_frozen_teacher is not None}")

        best_loss = float('inf')
        patience = 3
        patience_counter = 0
        min_delta = 1e-3

        for epoch in range(epochs):
            total_ssl_loss = 0
            total_cassle_loss = 0
            total_loss = 0

            for batch_idx, (img_v1, img_v2, _) in enumerate(data_loader): # _ for labels
                self.optimizer.zero_grad()
                img_v1 = img_v1.to(self.device)
                img_v2 = img_v2.to(self.device)

                # Mixed precision
                with torch.amp.autocast('cuda'):
                  # Forward Pass through the current trainable MAE model (f_t)
                  # This returns features (z_A, z_B) and components for L_SSL
                  ssl_output = self.base_ssl_model(img_v1, img_v2)

                  # Extract features for CaSSLe predictor and APPLY L2 NORMALIZATION HERE
                  features_for_h_v1_norm = F.normalize(ssl_output['features_for_h_v1'], dim=-1)
                  features_for_h_v2_norm = F.normalize(ssl_output['features_for_h_v2'], dim=-1)

                  # Calculate Base Self-Supervised Loss (L_SSL - MAE Reconstruction Loss)
                  # Sum MAE reconstruction loss for both views
                  loss_ssl = self.base_ssl_model.calculate_ssl_loss(ssl_output['ssl_loss_components_v1']) + \
                            self.base_ssl_model.calculate_ssl_loss(ssl_output['ssl_loss_components_v2'])

                  # Calculate CaSSLe Distillation Loss (L_D)
                  loss_cassle = torch.tensor(0.0).to(self.device) # Initialize to 0 for the first task

                  if self.f_frozen_teacher:
                      # Get features from the frozen previous encoder

                      with torch.no_grad():
                        # Get encoder outputs only (bypass the decoder)
                        outputs_v1 = self.f_frozen_teacher.vit(img_v1, output_hidden_states=True)
                        outputs_v1_features= outputs_v1.hidden_states[-1].mean(dim=1)
                        features_from_frozen_v1 = F.normalize(outputs_v1_features, dim=-1)

                        outputs_v2 = self.f_frozen_teacher.vit(img_v2, output_hidden_states=True)
                        outputs_v2_features= outputs_v2.hidden_states[-1].mean(dim=1)
                        features_from_frozen_v2 = F.normalize(outputs_v2_features, dim=-1)

                      # Student predictions from current trainable g
                      # g takes features from current f_t
                      student_pred_v1 = self.g_current(features_for_h_v1_norm)
                      student_pred_v2 = self.g_current(features_for_h_v2_norm)

                      # Teacher targets (from frozen f_t-1)
                      teacher_target_v1 = features_from_frozen_v1
                      teacher_target_v2 = features_from_frozen_v2

                      # Compute distillation loss using MSE
                      loss_cassle_v1 = 1 - F.cosine_similarity(student_pred_v1, teacher_target_v1, dim=-1).mean()
                      loss_cassle_v2 =  1 - F.cosine_similarity(student_pred_v2, teacher_target_v2, dim=-1).mean()
                      loss_cassle = (loss_cassle_v1 + loss_cassle_v2 )* self.lambda_cassle



                  # Total Loss and backpropagation
                  loss = loss_ssl +  loss_cassle

                scaler.scale(loss).backward()
                scaler.step(self.optimizer)
                scaler.update()

                total_ssl_loss += loss_ssl.item()
                total_cassle_loss += loss_cassle.item()
                total_loss += loss.item()

            #Early stopping
            avg_loss = total_loss / len(data_loader)
            if avg_loss < best_loss - min_delta:
                best_loss = avg_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print("Early stopping triggered.")
                    break

            print(f"Epoch {epoch+1}/{epochs} - SSL Loss: {total_ssl_loss / len(data_loader):.4f}, "
                  f"CaSSLe Loss: {total_cassle_loss / len(data_loader):.4f}, "
                  f"Total Loss: {total_loss / len(data_loader):.4f}")

            torch.cuda.empty_cache()
            self.scheduler.step()

        # After training the task, return the state_dict of the current MAE model (f_t)
        return self.base_ssl_model.mae_backbone.state_dict()


In [None]:

# Evaluation Function
def evaluate_model(feature_extractor: torch.nn.Module,
                   all_seen_classes: List[int],
                   cifar100_train_full: datasets.CIFAR100, # Full CIFAR100 train dataset
                   cifar100_test_full: datasets.CIFAR100,  # Full CIFAR100 test dataset
                   processor, # HuggingFace processor
                   batch_size: int = 128,
                   linear_eval_epochs: int = 5,
                   device: torch.device = torch.device("cuda")):

    # Put feature extractor is in eval mode and frozen
    feature_extractor.eval()
    for param in feature_extractor.parameters():
        param.requires_grad = False

    # Prepare Data for Linear Classifier Training
    # Filter training data to include only classes seen so far
    train_linear_dataset = CustomCifar100TaskDataset(
        cifar100_train_full,
        all_seen_classes,
        None,
        None,
        processor
    )


    # For linear evaluation, we extract features from the original images, usually with a standard transform.
    linear_eval_transform = transforms.Compose([
        transforms.Resize(INPUT_SIZE),
        transforms.CenterCrop(INPUT_SIZE),
        # HuggingFace processor handles ToTensor and Normalize, so we don't need it here.
    ])
    # wrap the dataset for linear evaluation
    class LinearEvalDataset(Dataset):
        def __init__(self, original_dataset, class_list, transform, processor):
            self.data = []
            self.targets = []
            self.transform = transform
            self.processor = processor

            for i in range(len(original_dataset)):
                img, label = original_dataset[i]
                if label in class_list:
                    if isinstance(img, np.ndarray):
                        img = Image.fromarray(img)
                    self.data.append(img)
                    self.targets.append(label)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            img = self.data[idx]
            label = self.targets[idx]

            img = self.transform(img)
            inputs = self.processor(images=img, return_tensors="pt")['pixel_values'].squeeze(0)
            return inputs, label

    train_linear_dataset = LinearEvalDataset(cifar100_train_full, all_seen_classes, linear_eval_transform, processor)
    train_linear_loader = DataLoader(train_linear_dataset, batch_size=batch_size, shuffle=True,
                                     num_workers=os.cpu_count() // 2 if os.cpu_count() else 0, pin_memory=False)

    # Linear Classifier
    num_output_classes = len(all_seen_classes)
    linear_classifier = nn.Linear(feature_extractor.config.hidden_size, num_output_classes).to(device)

    # Change the original 100 labels to a continuous range starting from 0, only for the current classes
    label_to_contiguous_map = {label: i for i, label in enumerate(sorted(all_seen_classes))}

    #Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(linear_classifier.parameters(), lr=0.001)
    final_loss = 0.0
    num_batches = 0
    # Train Linear Classifier
    linear_classifier.train()
    for epoch in range(linear_eval_epochs):
        for img_batch, label_batch in train_linear_loader:
            img_batch = img_batch.to(device)
            # Map original labels to contiuous range
            label_batch = torch.tensor([label_to_contiguous_map[l.item()] for l in label_batch]).to(device)

            optimizer.zero_grad()
            with torch.no_grad(): # Ensure feature extractor remains frozen
              with torch.amp.autocast('cuda'): #Mixed precision
                features = feature_extractor(pixel_values=img_batch, return_dict=True, output_hidden_states=True).hidden_states[-1].mean(dim=1)

            with torch.amp.autocast('cuda'):
              outputs = linear_classifier(features)
              loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()
            final_loss += loss.item()
            num_batches += 1
        avg_loss = final_loss / num_batches
        print(f"Epoch {epoch+1}/{linear_eval_epochs}, Average Loss: {avg_loss:.4f}")


    # Evaluate Linear Classifier on Test Data
    linear_classifier.eval()
    total_correct = 0
    total_samples = 0

    # Prepare test data for classes seen so far
    test_linear_dataset = LinearEvalDataset(cifar100_test_full, all_seen_classes, linear_eval_transform, processor)
    test_linear_loader = DataLoader(test_linear_dataset, batch_size=batch_size, shuffle=False,
                                   num_workers=os.cpu_count() // 2 if os.cpu_count() else 0, pin_memory=False)

    with torch.no_grad():
        for img_batch, label_batch in test_linear_loader:
            img_batch = img_batch.to(device)
            label_batch = torch.tensor([label_to_contiguous_map[l.item()] for l in label_batch]).to(device)

            with torch.amp.autocast('cuda'):
              features = feature_extractor(pixel_values=img_batch, return_dict=True, output_hidden_states=True).hidden_states[-1].mean(dim=1)
              outputs = linear_classifier(features)

            _, predicted = torch.max(outputs.data, 1)
            total_samples += label_batch.size(0)
            total_correct += (predicted == label_batch).sum().item()

    accuracy = 100 * total_correct / total_samples
    print(f"Linear evaluation accuracy on {len(all_seen_classes)} classes: {accuracy:.2f}%")


    for param in feature_extractor.parameters():
        param.requires_grad = True # Re-enable gradients for next task's training
    feature_extractor.train()

    return accuracy

# Baseline for Forward Transfer (R_i)
def get_random_accuracy(num_classes_in_task: int,
                        cifar100_train_full: datasets.CIFAR100,
                        cifar100_test_full: datasets.CIFAR100,
                        processor,
                        target_class_list: List[int],
                        batch_size: int = 128,
                        linear_eval_epochs: int = 5,
                        device: torch.device = torch.device("cpu")):

    # Here, we'll train a linear classifier on top of a randomly initialized MAE encoder
    print(f"Calculating R_i for task with classes: {target_class_list}")

    # Randomly initialize its weights
    config = ViTMAEConfig()
    random_mae_backbone = ViTMAEForPreTraining(config).to(device) # Random model
    random_mae_backbone.eval()
    for param in random_mae_backbone.parameters():
        param.requires_grad = False

    # Re-use the evaluate_model function with the random backbone
    random_acc = evaluate_model(random_mae_backbone,
                                 target_class_list, # Only target classes for this task
                                 cifar100_train_full,
                                 cifar100_test_full,
                                 processor,
                                 batch_size,
                                 linear_eval_epochs,
                                 device)
    print(f"Random network accuracy on task: {random_acc:.2f}%")
    return random_acc

In [4]:
# --- Configuration ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_SIZE = 224 # MAE's default input size for ViT-base
BATCH_SIZE = 64
NUM_EPOCHS_PER_TASK = 15
LEARNING_RATE = 0.0001
LAMBDA_CASSLE = 2 # Weight for the CaSSLe distillation loss
NUM_CLASSES_PER_TASK = 10
NUM_TOTAL_CLASSES = 100
LINEAR_EVAL_EPOCHS = 5 # Number of epochs to train the linear classifier
LINEAR_EVAL_BATCH_SIZE = 128
# Initial Setup
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed(42)

torch.cuda.empty_cache()
gc.collect()

# Load pre-trained MAE-base model and processor
mae_backbone = ViTMAEForPreTraining.from_pretrained('facebook/vit-mae-base')
mae_backbone.to(DEVICE)
processor = AutoImageProcessor.from_pretrained('facebook/vit-mae-base')

# Define the two augmentation pipelines
transform_v1 = transforms.Compose([
    transforms.RandomResizedCrop(INPUT_SIZE, scale=(0.2, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
])

transform_v2 = transforms.Compose([
    transforms.RandomResizedCrop(INPUT_SIZE, scale=(0.2, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, 5.)), # Kernel size must be odd
])

# Load CIFAR-100 training dataset
cifar100_train_full = datasets.CIFAR100(root='./data', train=True, download=True)
cifar100_test_full = datasets.CIFAR100(root='./data', train=False, download=True)
# Prepare Task Data
# Divide CIFAR-100 into 10 tasks of 10 classes each
all_classes = list(range(NUM_TOTAL_CLASSES))
random.shuffle(all_classes) # Shuffle classes for a fairer split

task_class_splits = []
for i in range(0, NUM_TOTAL_CLASSES, NUM_CLASSES_PER_TASK):
    task_class_splits.append(all_classes[i:i + NUM_CLASSES_PER_TASK])

task_datasets = []
for i, class_list in enumerate(task_class_splits):
    print(f"Task {i+1} includes classes: {class_list}")
    task_dataset = CustomCifar100TaskDataset(
        cifar100_train_full,
        class_list,
        transform_v1,
        transform_v2,
        processor
    )
    task_datasets.append(task_dataset)

# Initialize CaSSLe Components
# base_ssl_model_instance is the current model (f_t) that gets updated across tasks
# It starts with the loaded pretrained MAE.
base_ssl_model_instance = MAECaSSLeModel(mae_backbone=mae_backbone)

# This will store the state_dict of the f_t-1 (previous encoder)
prev_encoder_state_dict = None

print("\n--- Starting CaSSLe Continual Training ---")
print(f"Running on device: {DEVICE}")

# Metrics storage A_j_k will be a list of lists (j tasks observed, k task evaluated on)
# A_j_k[j_idx][k_idx]
all_task_accuracies = [] # Stores accuracies A_{j,k}


random_accuracies_Ri = {} # {task_idx: R_i_value}

# Continual Training Loop
for task_id, current_task_dataset in enumerate(task_datasets):
    print(f"\n===== Training Task {task_id + 1}/{len(task_datasets)} =====")

    current_task_loader = DataLoader(current_task_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                      num_workers=os.cpu_count() // 2 if os.cpu_count() else 0, pin_memory=False)


    # Pass the SAME base_ssl_model_instance, as this is the model we continue updating.
    trainer = CaSSLeTrainer(
        base_ssl_model=base_ssl_model_instance,
        ca_predictor_hidden_dim=256, # Example hidden dim for 'g'
        learning_rate=LEARNING_RATE,
        lambda_cassle=LAMBDA_CASSLE,
        device=DEVICE
    )

    # If it's not the first task, set up the frozen teacher (f_t-1)
    if prev_encoder_state_dict:
        trainer.set_previous_frozen_encoder(prev_encoder_state_dict)

    # Train the current task and get the state of the trained f_t.
    # This will be used as the f_t-1 for the next task.
    prev_encoder_state_dict = trainer.train_task(current_task_loader, NUM_EPOCHS_PER_TASK)

    print(f"\n--- Evaluating after Task {task_id + 1} ---")

    # Determine all classes seen up to this task
    current_seen_classes = []
    for i in range(task_id + 1):
        current_seen_classes.extend(task_class_splits[i])
    current_seen_classes = sorted(list(set(current_seen_classes))) # Ensure unique and sorted

    accuracies_after_this_task = [] # Stores accuracies

    for eval_task_idx in range(task_id + 1):
        # Classes for the specific evaluation task (k)
        eval_task_classes = task_class_splits[eval_task_idx]
        print(f"  Evaluating on classes from Task {eval_task_idx+1}: {eval_task_classes}")

        # Perform linear evaluation on the classes specific to eval_task_idx
        acc_jk = evaluate_model(
            base_ssl_model_instance.mae_backbone, # The current encoder (f_t)
            eval_task_classes, # Evaluate on only classes of task k
            cifar100_train_full,
            cifar100_test_full,
            processor,
            LINEAR_EVAL_BATCH_SIZE,
            LINEAR_EVAL_EPOCHS,
            DEVICE
        )
        accuracies_after_this_task.append(acc_jk)

            # Calculate R_i for Forward Transfer only once per task
        if eval_task_idx not in random_accuracies_Ri:
              random_accuracies_Ri[eval_task_idx] = get_random_accuracy(
                NUM_CLASSES_PER_TASK,
                cifar100_train_full,
                cifar100_test_full,
                processor,
                eval_task_classes, # R_i is for task i
                LINEAR_EVAL_BATCH_SIZE,
                LINEAR_EVAL_EPOCHS,
                DEVICE
            )

    all_task_accuracies.append(accuracies_after_this_task)

    # Print current state (optional)
    print(f"\nCurrent A_j,k matrix after Task {task_id + 1}:")
    for j_idx, row in enumerate(all_task_accuracies):
        print(f"After Task {j_idx+1}: {row}")


print("\n CaSSLe Continual Training Process Completed")

# Final Metric Calculation
T = len(task_datasets)

# Average Accuracy (A)
final_accuracies_row = all_task_accuracies[T-1] if T > 0 else []
avg_accuracy = sum(final_accuracies_row) / T if T > 0 else 0
print(f"\nFinal Average Accuracy (A): {avg_accuracy:.2f}%")

# Forgetting (F)
forgetting = 0
if T > 1:
    for i in range(T - 1): # For each previous task i (0-indexed)
        max_acc_on_task_i = 0
        for t_idx in range(T): # Max over all times t when task i was evaluated
            if t_idx < len(all_task_accuracies) and i < len(all_task_accuracies[t_idx]):
                max_acc_on_task_i = max(max_acc_on_task_i, all_task_accuracies[t_idx][i])

        final_acc_on_task_i = all_task_accuracies[T-1][i] # A_T,i
        forgetting += (max_acc_on_task_i - final_acc_on_task_i)
    forgetting /= (T - 1)
print(f"Final Forgetting (F): {forgetting:.2f}%")

#Forward Transfer
forward_transfer = 0
if T > 1:
    count = 0
    for i in range(2, T+1):
        row_idx = i - 2
        col_idx = i - 1
        if row_idx < len(all_task_accuracies) and col_idx < len(all_task_accuracies[row_idx]):
            acc_i_minus_1_on_i = all_task_accuracies[row_idx][col_idx]
            r_i_value = random_accuracies_Ri[col_idx]
            forward_transfer += (acc_i_minus_1_on_i - r_i_value)
            count += 1
        else:
            # Skip if data missing for this pair
            continue
    if count > 0:
        forward_transfer /= count
    else:
        forward_transfer = 0
print(f"Final Forward Transfer (FT): {forward_transfer:.2f}%")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/676 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/448M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/217 [00:00<?, ?B/s]

Fast image processor class <class 'transformers.models.vit.image_processing_vit_fast.ViTImageProcessorFast'> is available for this model. Using slow image processor class. To use the fast image processor class set `use_fast=True`.
100%|██████████| 169M/169M [00:05<00:00, 31.1MB/s]


Task 1 includes classes: [42, 41, 91, 9, 65, 50, 1, 70, 15, 78]
Task 2 includes classes: [73, 10, 55, 56, 72, 45, 48, 92, 76, 37]
Task 3 includes classes: [30, 21, 32, 96, 80, 49, 83, 26, 87, 33]
Task 4 includes classes: [8, 47, 59, 63, 74, 44, 98, 52, 85, 12]
Task 5 includes classes: [36, 23, 39, 40, 18, 66, 61, 60, 7, 34]
Task 6 includes classes: [99, 46, 2, 51, 16, 38, 58, 68, 22, 62]
Task 7 includes classes: [24, 5, 6, 67, 82, 19, 79, 43, 90, 20]
Task 8 includes classes: [0, 95, 57, 93, 53, 89, 25, 71, 84, 77]
Task 9 includes classes: [64, 29, 27, 88, 97, 4, 54, 75, 11, 69]
Task 10 includes classes: [86, 13, 17, 28, 31, 35, 94, 3, 14, 81]

--- Starting CaSSLe Continual Training ---
Running on device: cuda

===== Training Task 1/10 =====
Distilling from frozen teacher encoder (f_t-1): False
Epoch 1/15 - SSL Loss: 0.0908, CaSSLe Loss: 0.0000, Total Loss: 0.0908
Epoch 2/15 - SSL Loss: 0.0846, CaSSLe Loss: 0.0000, Total Loss: 0.0846
Epoch 3/15 - SSL Loss: 0.0834, CaSSLe Loss: 0.0000, T

In [5]:
# Backward Transfer (BT)
backward_transfer = 0
count = 0

if T > 1:
    for new_task in range(1, T):  # starting from second task
        for old_task in range(new_task):  # all previously learned tasks
            # Accuracy on old_task before learning new_task
            acc_before = all_task_accuracies[new_task - 1][old_task] if (new_task - 1) < len(all_task_accuracies) and old_task < len(all_task_accuracies[new_task - 1]) else None

            # Accuracy on old_task after learning new_task
            acc_after = all_task_accuracies[new_task][old_task] if new_task < len(all_task_accuracies) and old_task < len(all_task_accuracies[new_task]) else None

            if acc_before is not None and acc_after is not None:
                backward_transfer += (acc_after - acc_before)
                count += 1
            else:
                print(f"Skipping BT for old_task {old_task+1}, new_task {new_task+1}: missing data")

backward_transfer = backward_transfer / count if count > 0 else 0
print(f"Final Backward Transfer (BT): {backward_transfer:.2f}%")

Final Backward Transfer (BT): -0.40%


In [6]:
# Approximate Forward Transfer (Lazy FT)
forward_transfer = 0
count = 0

if T > 1:
    for k in range(1, T):  # Task k (1 to T-1)
        # Use the accuracy when task k is first trained
        if k < len(all_task_accuracies) and k < len(all_task_accuracies[k]):
            first_acc = all_task_accuracies[k][k]  # Model learns task k
            random_baseline = random_accuracies_Ri[k]

            print(f"Task {k+1}: first_acc = {first_acc:.2f}, random = {random_baseline:.2f}")

            forward_transfer += (first_acc - random_baseline)
            count += 1
        else:
            print(f"Skipping Task {k+1}: missing all_task_accuracies[{k}][{k}]")

forward_transfer = forward_transfer / count if count > 0 else 0
print(f"Approx. Forward Transfer (FT): {forward_transfer:.2f}%")

Task 2: first_acc = 86.90, random = 37.90
Task 3: first_acc = 89.70, random = 41.50
Task 4: first_acc = 81.00, random = 36.90
Task 5: first_acc = 90.80, random = 37.60
Task 6: first_acc = 90.90, random = 40.80
Task 7: first_acc = 87.70, random = 44.40
Task 8: first_acc = 85.10, random = 45.80
Task 9: first_acc = 81.30, random = 30.90
Task 10: first_acc = 85.70, random = 35.80
Approx. Forward Transfer (FT): 47.50%
