In [50]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import copy
from typing import Any, Dict, List, Tuple
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, Dataset
import os
import random
import numpy as np
from PIL import Image

# Configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
INPUT_SIZE = 32  # CIFAR-100 image size
BATCH_SIZE = 32  # 4 rotations × 32 = 128 processed images
NUM_EPOCHS_PER_TASK = 200
LEARNING_RATE = 0.001
LAMBDA_CASSLE = 0.8 # Weight for CaSSle loss
NUM_CLASSES_PER_TASK = 10
NUM_TOTAL_CLASSES = 100  # Total CIFAR-100 classes
NUM_ROT_CLASSES = 4  # 0°, 90°, 180°, 270°
LINEAR_EVAL_EPOCHS = 10
LINEAR_EVAL_BATCH_SIZE = 128

torch.cuda.empty_cache()

#Custom Dataset for RotNet
class RotNetCifar100TaskDataset(Dataset):
    def __init__(self, cifar100_dataset, class_list: List[int], base_transform):
        self.data = []
        self.targets = []

        # Filter CIFAR-100 data based on the provided class_list
        for i in range(len(cifar100_dataset)):
            img, label = cifar100_dataset[i]
            if label in class_list:
                if isinstance(img, np.ndarray):
                    img = Image.fromarray(img)
                self.data.append(img)
                self.targets.append(label)

        self.base_transform = base_transform # Applied before rotation

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        original_label = self.targets[idx] # Original label, not used for SSL

        # 4 rotated versions and their labels (0, 1, 2, 3)
        # We need to perform rotation on the Tensor after base_transform
        rotated_imgs = []
        rotation_labels = []


        #Rotations
        for angle, rot_label in zip([0, 90, 180, 270], range(4)):
          rotated_img = transforms.functional.rotate(img, angle)
          rotated_img = self.base_transform(rotated_img)
          rotated_imgs.append(rotated_img)
          rotation_labels.append(torch.tensor(rot_label, dtype=torch.long))

        # Stack all 4 rotated images and labels for batching
        return torch.stack(rotated_imgs), torch.stack(rotation_labels), original_label


class CustomBackbone(nn.Module):
    def __init__(self, input_channels: int = 3, base_channels: int = 64, dropout_p: float = 0.1): # Added dropout_p parameter
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, base_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(),
            
            nn.Conv2d(base_channels, base_channels * 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels * 2),
            nn.ReLU(),
            
            nn.MaxPool2d(2),
            nn.Dropout(p=dropout_p),
            
            nn.Conv2d(base_channels * 2, base_channels * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels * 4),
            nn.ReLU(),
            
            nn.Conv2d(base_channels * 4, base_channels * 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(base_channels * 4),
            nn.ReLU(),
            
            nn.MaxPool2d(2),
            nn.Dropout(p=dropout_p),
            
            nn.Flatten()
        )

        # Dummy pass to infer flattened feature size
        dummy_input = torch.randn(1, input_channels, INPUT_SIZE, INPUT_SIZE)
        with torch.no_grad():
            features_dim = self.features(dummy_input).shape[1]
        self.features_dim = features_dim


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.features(x)

# --- RotNet Model using Custom Backbone (no changes needed here) ---
class RotNetModel(nn.Module):
    def __init__(self, num_rot_classes: int = 4, dropout_p: float = 0.1, input_channels: int = 3, base_channels: int = 64):
        super().__init__()

        self.backbone = CustomBackbone(input_channels, base_channels, dropout_p=dropout_p) # Pass dropout_p
        self.features_dim = self.backbone.features_dim

        self.classifier = nn.Sequential(
            nn.Linear(self.features_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_rot_classes)
        )

    def forward(self, x: torch.Tensor) -> Dict[str, Any]:
        features = self.backbone(x)
        features_flat = features.view(features.size(0), -1)
        logits = self.classifier(features_flat)
        return {
            'logits': logits,
            'features': features_flat
        }

    def calculate_ssl_loss(self, logits: torch.Tensor, rot_labels: torch.Tensor) -> torch.Tensor:
        # Use F.cross_entropy and ensure labels are long type
        return F.cross_entropy(logits, rot_labels.long())

In [51]:
class CaSSLePredictor(nn.Module):
    def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
        super().__init__()
        # input_dim and output_dim will both be the backbone's feature dimension
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [None]:
class CaSSleTrainer:
    def __init__(self, base_ssl_model: RotNetModel,
                 ca_predictor_hidden_dim: int,
                 learning_rate, lambda_cassle, device: str = 'cuda'):

        self.base_ssl_model = base_ssl_model.to(device) # This is f_t + rotation_head
        self.lambda_cassle = lambda_cassle
        self.device = device

        # Input and output dimensions for CaSSLe Predictor are the backbone's feature dimension
        predictor_input_output_dim = self.base_ssl_model.features_dim

        # Initialize the current CaSSLe predictor
        self.g_current = CaSSLePredictor(
            predictor_input_output_dim,
            ca_predictor_hidden_dim,
            predictor_input_output_dim
        ).to(device)

        # Optimizer for ALL trainable parameters: current RotNet model (f_t + head) AND predictor g
        self.optimizer = torch.optim.AdamW(
            list(self.base_ssl_model.parameters()) + list(self.g_current.parameters()),
            lr=learning_rate,
            weight_decay=0.01
        )

        


        
        # This will hold the frozen previous encoder (f_{t-1}^{frozen})
        self.f_frozen_teacher = None

    def set_previous_frozen_encoder(self, encoder_state_dict: Dict[str, Any]):
        """
        Loads the state of the previous encoder (f_{t-1}) and freezes it.
        This becomes the 'teacher' encoder for distillation.
        """
        # A new backbone for the frozen teacher
        self.f_frozen_teacher = CustomBackbone(input_channels=3, base_channels=64, dropout_p=0.1).to(self.device)
        self.f_frozen_teacher.load_state_dict(encoder_state_dict)


        # Freeze the parameters
        for param in self.f_frozen_teacher.parameters():
            param.requires_grad = False

        print(f"Frozen encoder (f_t-1) loaded and parameters frozen: {all(not p.requires_grad for p in self.f_frozen_teacher.parameters())}")


    def train_task(self, data_loader: torch.utils.data.DataLoader, epochs: int):
        self.base_ssl_model.train() # f_t + rotation_head is trainable
        self.g_current.train() # g is trainable

        # Set f_frozen_teacher to eval mode to disable dropout/batchnorm updates for teacher
        if self.f_frozen_teacher:
            self.f_frozen_teacher.eval()

        print(f"Distilling from frozen teacher encoder (f_t-1): {self.f_frozen_teacher is not None}")

        best_loss = float('inf')
        patience = 10
        patience_counter = 0
        min_delta = 0.001

        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.1)
        
        for epoch in range(epochs):
            total_ssl_loss = 0
            total_cassle_loss = 0
            total_loss = 0

            for batch_idx, (rotated_imgs, rotation_labels, _) in enumerate(data_loader): # _ for original_label
                self.optimizer.zero_grad()

                # rotated_imgs: (batch_size, 4, C, H, W)
                # rotation_labels: (batch_size, 4)

                # Flatten the batch and rotation dimensions for model input
                imgs_flat = rotated_imgs.view(-1, *rotated_imgs.shape[2:]).to(self.device)
                labels_flat = rotation_labels.view(-1).to(self.device)

                #  Forward Pass through the current trainable RotNet model (f_t + head)
                ssl_output = self.base_ssl_model(imgs_flat)

                # Calculate Base Self-Supervised Loss  (Cross-Entropy)
                loss_ssl = self.base_ssl_model.calculate_ssl_loss(ssl_output['logits'], labels_flat)

                # Calculate CaSSle Distillation Loss (L_D)
                loss_cassle = torch.tensor(0.0).to(self.device) # Initialize to 0 for the first task

                if self.f_frozen_teacher:
                    # Get features from the *frozen previous encoder* (f_{t-1}^{frozen})
                    with torch.no_grad():
                        features_from_frozen = self.f_frozen_teacher(imgs_flat)

                    # Student predictions from current trainable 'g'
                    # g takes features from current f_t
                    student_pred = self.g_current(ssl_output['features'])
                    #student_pred = F.normalize(student_pred, dim=-1)  # Normalize student predictions

                    # Teacher targets (from frozen f_t-1).
                    teacher_target = features_from_frozen

                    # Compute distillation loss using MSE
                    loss_cassle = 1-F.cosine_similarity(student_pred, teacher_target,dim=-1).mean()

                # --- Total Loss and Optimization ---
                loss = loss_ssl + self.lambda_cassle * loss_cassle

                loss.backward()
                self.optimizer.step()
                

                total_ssl_loss += loss_ssl.item()
                total_cassle_loss += loss_cassle.item()
                total_loss += loss.item()
                
                

            #Early stopping
            avg_loss = total_loss / len(data_loader)
            self.scheduler.step()
            if avg_loss < best_loss - min_delta:
                best_loss = avg_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print("Early stopping triggered.")
                    break

            print(f"Epoch {epoch+1}/{epochs} - SSL Loss: {total_ssl_loss / len(data_loader):.4f}, "
                  f"CaSSle Loss: {total_cassle_loss / len(data_loader):.4f}, "
                  f"Total Loss: {total_loss / len(data_loader):.4f}")

        # State_dict of the current RotNet model's backbone (f_t).
        return self.base_ssl_model.backbone.state_dict()

In [53]:

#Evaluation Function
def evaluate_model(feature_extractor: torch.nn.Module,
                   all_seen_classes: List[int],
                   cifar100_train_full: datasets.CIFAR100,
                   cifar100_test_full: datasets.CIFAR100,
                   base_transform, # Use the same base transform as RotNet
                   batch_size: int = 128,
                   linear_eval_epochs: int = 10,
                   device: torch.device = torch.device("cuda")):

    # freeze feature extractor
    feature_extractor.eval()
    for param in feature_extractor.parameters():
        param.requires_grad = False

    # Data for Linear Classifier Training
    class LinearEvalDataset(Dataset):
        def __init__(self, original_dataset, class_list, transform):
            self.data = []
            self.targets = []
            self.transform = transform

            for i in range(len(original_dataset)):
                img, label = original_dataset[i]
                if label in class_list:
                    if isinstance(img, np.ndarray):
                        img = Image.fromarray(img)
                    self.data.append(img)
                    self.targets.append(label)

        def __len__(self):
            return len(self.data)

        def __getitem__(self, idx):
            img = self.data[idx]
            label = self.targets[idx]
            img = self.transform(img) # Apply base_transform for feature extraction
            return img, label

    train_linear_dataset = LinearEvalDataset(cifar100_train_full, all_seen_classes, base_transform)
    train_linear_loader = DataLoader(train_linear_dataset, batch_size=batch_size, shuffle=True,
                                     num_workers=os.cpu_count() // 2 if os.cpu_count() else 0, pin_memory=True)


    # Feature extractor output dim
    dummy_input = torch.randn(1, 3, INPUT_SIZE, INPUT_SIZE).to(device)
    with torch.no_grad():
        features_dim = feature_extractor(dummy_input).view(dummy_input.size(0), -1).shape[1]

    num_output_classes = len(all_seen_classes)
    linear_classifier = nn.Linear(features_dim, num_output_classes).to(device)

    # Map CIFAR-100 original labels to a contiguous range for the classifier's output
    label_to_contiguous_map = {label: i for i, label in enumerate(sorted(all_seen_classes))}

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(linear_classifier.parameters(), lr=0.001)

    # 3. Train Linear Classifier
    linear_classifier.train()
    for epoch in range(linear_eval_epochs):
        for img_batch, label_batch in train_linear_loader:
            img_batch = img_batch.to(device)
            label_batch = torch.tensor([label_to_contiguous_map[l.item()] for l in label_batch]).to(device)

            optimizer.zero_grad()
            with torch.no_grad(): # Ensure feature extractor remains frozen
                features = feature_extractor(img_batch).view(img_batch.size(0), -1)

            outputs = linear_classifier(features)
            loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()

    # 4. Evaluate Linear Classifier on Test Data
    linear_classifier.eval()
    total_correct = 0
    total_samples = 0

    test_linear_dataset = LinearEvalDataset(cifar100_test_full, all_seen_classes, base_transform)
    test_linear_loader = DataLoader(test_linear_dataset, batch_size=batch_size, shuffle=False,
                                   num_workers=os.cpu_count() // 2 if os.cpu_count() else 0, pin_memory=True)

    with torch.no_grad():
        for img_batch, label_batch in test_linear_loader:
            img_batch = img_batch.to(device)
            label_batch = torch.tensor([label_to_contiguous_map[l.item()] for l in label_batch]).to(device)

            features = feature_extractor(img_batch).view(img_batch.size(0), -1)
            outputs = linear_classifier(features)

            _, predicted = torch.max(outputs.data, 1)
            total_samples += label_batch.size(0)
            total_correct += (predicted == label_batch).sum().item()

    accuracy = 100 * total_correct / total_samples
    print(f"Linear evaluation accuracy on {len(all_seen_classes)} classes: {accuracy:.2f}%")

    # Restore feature extractor to training mode if it's going to be used for more training
    for param in feature_extractor.parameters():
        param.requires_grad = True # Re-enable gradients for next task's training
    feature_extractor.train() # Set backbone back to train mode

    return accuracy

# Baseline for Forward Transfer (R_i) a randomly initialized linear classifier
def get_random_accuracy(num_classes_in_task: int,
                        cifar100_train_full: datasets.CIFAR100,
                        cifar100_test_full: datasets.CIFAR100,
                        base_transform,
                        target_class_list: List[int],
                        batch_size: int = 128,
                        linear_eval_epochs: int = 10,
                        device: torch.device = torch.device("cpu")):

    print(f"Calculating R_i for task with classes: {target_class_list}")
    # Create a randomly initialized ResNet18 backbone
    random_resnet18_backbone = models.resnet18(weights=None) # No pretrained weights
    random_resnet18_backbone.fc = nn.Identity() # Remove the default classification head
    random_resnet18_backbone.to(device)


    # Then evaluate using the same function
    random_acc = evaluate_model(random_resnet18_backbone,
                                 target_class_list,
                                 cifar100_train_full,
                                 cifar100_test_full,
                                 base_transform,
                                 batch_size,
                                 linear_eval_epochs,
                                 device)
    print(f"Random network accuracy on task: {random_acc:.2f}%")
    return random_acc



def set_seed(seed: int = 42):
    import random, os
    import numpy as np
    import torch
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
    print(f"Random seed set to {seed}")

set_seed(42)


# Transforms
base_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.7, 1.0), ratio=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),

])

# Load Dataset
cifar100_train_full = datasets.CIFAR100(root='./data', train=True, download=True)
cifar100_test_full = datasets.CIFAR100(root='./data', train=False, download=True)

# Task Split
all_classes_shuffled = list(range(NUM_TOTAL_CLASSES))
random.shuffle(all_classes_shuffled)

task_class_splits = [all_classes_shuffled[i:i + NUM_CLASSES_PER_TASK] for i in range(0, NUM_TOTAL_CLASSES, NUM_CLASSES_PER_TASK)]

task_datasets = []
for i, class_list in enumerate(task_class_splits):
    print(f"Task {i+1} includes classes: {class_list}")
    task_dataset = RotNetCifar100TaskDataset(cifar100_train_full, class_list, base_transform)
    task_datasets.append(task_dataset)

# Init Model
resnet18_backbone = models.resnet18(weights=None)
resnet18_backbone.fc = nn.Identity()

base_ssl_model_instance = RotNetModel( num_rot_classes=NUM_ROT_CLASSES).to(DEVICE)
prev_encoder_state_dict = None

# Training Loop
all_task_accuracies = []
random_accuracies_Ri = {}

for task_id, current_task_dataset in enumerate(task_datasets):
    print(f"\n===== Training Task {task_id + 1}/{len(task_datasets)} =====")

    current_task_loader = DataLoader(
        current_task_dataset, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=os.cpu_count() // 2 if os.cpu_count() else 0, pin_memory=True
    )

    trainer = CaSSleTrainer(
        base_ssl_model=base_ssl_model_instance,
        ca_predictor_hidden_dim=256,
        learning_rate=LEARNING_RATE,
        lambda_cassle=LAMBDA_CASSLE,
        device=DEVICE
    )

    #Save previouse encoder
    if prev_encoder_state_dict:
        trainer.set_previous_frozen_encoder(prev_encoder_state_dict)
    #Train encoder on current task and save its state_dict
    prev_encoder_state_dict = trainer.train_task(current_task_loader, NUM_EPOCHS_PER_TASK)

    print(f"\n--- Evaluating after Task {task_id + 1} ---")
    current_seen_classes = sorted(set().union(*task_class_splits[:task_id + 1]))
    accuracies_after_this_task = []

    #Evaluate model for each task seen so far
    for eval_task_idx in range(task_id + 1):
        eval_task_classes = task_class_splits[eval_task_idx]
        print(f"  Evaluating on classes from Task {eval_task_idx+1}: {eval_task_classes}")

        acc_jk = evaluate_model(
            base_ssl_model_instance.backbone,
            eval_task_classes,
            cifar100_train_full,
            cifar100_test_full,
            base_transform,
            LINEAR_EVAL_BATCH_SIZE,
            LINEAR_EVAL_EPOCHS,
            DEVICE
        )
        accuracies_after_this_task.append(acc_jk)

        #calculate random baseline accuracy fo reach task
        if eval_task_idx not in random_accuracies_Ri:
            random_accuracies_Ri[eval_task_idx] = get_random_accuracy(
                NUM_CLASSES_PER_TASK,
                cifar100_train_full,
                cifar100_test_full,
                base_transform,
                eval_task_classes,
                LINEAR_EVAL_BATCH_SIZE,
                LINEAR_EVAL_EPOCHS,
                DEVICE
            )

    all_task_accuracies.append(accuracies_after_this_task)


T = len(task_datasets)

# Average Accuracy
final_accuracies_row = all_task_accuracies[T-1]
avg_accuracy = sum(final_accuracies_row) / T
print(f"\nFinal Average Accuracy (A): {avg_accuracy:.2f}%")

# Forgetting
forgetting = 0
if T > 1:
    for i in range(T - 1):
        max_acc = max(all_task_accuracies[t][i] for t in range(T) if i < len(all_task_accuracies[t]))
        final_acc = all_task_accuracies[T-1][i]
        forgetting += (max_acc - final_acc)
    forgetting /= (T - 1)
print(f"Final Forgetting (F): {forgetting:.2f}%")

# Backward Transfer (BT)
backward_transfer = 0
count = 0

if T > 1:
    for new_task in range(1, T):
        for old_task in range(new_task):
            if old_task < len(all_task_accuracies[new_task - 1]) and old_task < len(all_task_accuracies[new_task]):
                acc_before = all_task_accuracies[new_task - 1][old_task]
                acc_after = all_task_accuracies[new_task][old_task]
                backward_transfer += (acc_after - acc_before)
                count += 1
            else:
                print(f"Skipping BT for old_task {old_task+1}, new_task {new_task+1}: missing data")

    backward_transfer /= count if count > 0 else 1
else:
    backward_transfer = 0

print(f"Final Backward Transfer (BT): {backward_transfer:.2f}%")

Random seed set to 42
Task 1 includes classes: [42, 41, 91, 9, 65, 50, 1, 70, 15, 78]
Task 2 includes classes: [73, 10, 55, 56, 72, 45, 48, 92, 76, 37]
Task 3 includes classes: [30, 21, 32, 96, 80, 49, 83, 26, 87, 33]
Task 4 includes classes: [8, 47, 59, 63, 74, 44, 98, 52, 85, 12]
Task 5 includes classes: [36, 23, 39, 40, 18, 66, 61, 60, 7, 34]
Task 6 includes classes: [99, 46, 2, 51, 16, 38, 58, 68, 22, 62]
Task 7 includes classes: [24, 5, 6, 67, 82, 19, 79, 43, 90, 20]
Task 8 includes classes: [0, 95, 57, 93, 53, 89, 25, 71, 84, 77]
Task 9 includes classes: [64, 29, 27, 88, 97, 4, 54, 75, 11, 69]
Task 10 includes classes: [86, 13, 17, 28, 31, 35, 94, 3, 14, 81]

===== Training Task 1/10 =====
Distilling from frozen teacher encoder (f_t-1): False
Epoch 1/200 - SSL Loss: 1.6184, CaSSle Loss: 0.0000, Total Loss: 1.6184
Epoch 2/200 - SSL Loss: 1.2912, CaSSle Loss: 0.0000, Total Loss: 1.2912
Epoch 3/200 - SSL Loss: 1.2624, CaSSle Loss: 0.0000, Total Loss: 1.2624
Epoch 4/200 - SSL Loss: 1