In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random, os
from PIL import Image
from typing import List, Dict, Any, Tuple
import timm
from einops import rearrange
import pandas as pd

# ========= CONFIG =========
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32
NUM_EPOCHS_PER_TASK = 5
LEARNING_RATE = 0.001
LAMBDA_CASSLE = 0.8
NUM_CLASSES_PER_TASK = 20  # 5 tasks × 20 classes
NUM_TOTAL_CLASSES = 100
NUM_ROT_CLASSES = 4
LINEAR_EVAL_EPOCHS = 10
LINEAR_EVAL_BATCH_SIZE = 128

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    os.environ['PYTHONHASHSEED'] = str(seed)
set_seed(42)
torch.cuda.empty_cache()

# ========= DATA TRANSFORMS =========
cnn_transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.7, 1.0), ratio=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
vit_transform = transforms.Compose([
    transforms.Resize(224),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

# ========= DATASETS =========
class RotNetCifar100TaskDataset(Dataset):
    def __init__(self, cifar100_dataset, class_list, base_transform):
        self.data, self.targets = [], []
        for i in range(len(cifar100_dataset)):
            img, label = cifar100_dataset[i]
            if label in class_list:
                if isinstance(img, np.ndarray):
                    img = Image.fromarray(img)
                self.data.append(img)
                self.targets.append(label)
        self.base_transform = base_transform
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        img = self.data[idx]
        rotated_imgs, rotation_labels = [], []
        for angle, rot_label in zip([0, 90, 180, 270], range(4)):
            rotated_img = transforms.functional.rotate(img, angle)
            rotated_img = self.base_transform(rotated_img)
            rotated_imgs.append(rotated_img)
            rotation_labels.append(torch.tensor(rot_label, dtype=torch.long))
        return torch.stack(rotated_imgs), torch.stack(rotation_labels), self.targets[idx]

class MAECifar100TaskDataset(Dataset):
    def __init__(self, cifar100_dataset, class_list, base_transform, mask_ratio=0.75):
        self.data, self.targets = [], []
        self.base_transform = base_transform
        self.mask_ratio = mask_ratio
        for i in range(len(cifar100_dataset)):
            img, label = cifar100_dataset[i]
            if label in class_list:
                if isinstance(img, np.ndarray):
                    img = Image.fromarray(img)
                self.data.append(img)
                self.targets.append(label)
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        img = self.base_transform(self.data[idx])
        B, H, W = img.shape
        patch_size = 16  # ViT-B/16
        num_patches = (H // patch_size) * (W // patch_size)
        mask = torch.randperm(num_patches) < int(self.mask_ratio * num_patches)
        return img, mask, self.targets[idx]

class PairCifar100TaskDataset(Dataset):
    def __init__(self, cifar100_dataset, class_list, base_transform, pair_transform=None):
        self.data, self.targets = [], []
        self.base_transform = base_transform
        self.pair_transform = pair_transform if pair_transform else base_transform
        for i in range(len(cifar100_dataset)):
            img, label = cifar100_dataset[i]
            if label in class_list:
                if isinstance(img, np.ndarray):
                    img = Image.fromarray(img)
                self.data.append(img)
                self.targets.append(label)
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        img = self.data[idx]
        return self.base_transform(img), self.pair_transform(img), self.targets[idx]

class DenoisingCifar100TaskDataset(Dataset):
    def __init__(self, cifar100_dataset, class_list, base_transform, noise_std=0.2):
        self.data, self.targets = [], []
        self.base_transform = base_transform
        self.noise_std = noise_std
        for i in range(len(cifar100_dataset)):
            img, label = cifar100_dataset[i]
            if label in class_list:
                if isinstance(img, np.ndarray):
                    img = Image.fromarray(img)
                self.data.append(img)
                self.targets.append(label)
    def __len__(self): return len(self.data)
    def __getitem__(self, idx):
        img = self.base_transform(self.data[idx])
        noisy_img = img + torch.randn_like(img) * self.noise_std
        noisy_img = torch.clamp(noisy_img, 0., 1.)
        return noisy_img, img, self.targets[idx]

# ========= BACKBONES =========
class CustomCNNBackbone(nn.Module):
    def __init__(self, input_channels=3, base_channels=64, dropout_p=0.1):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(input_channels, base_channels, 3, 1, 1),
            nn.BatchNorm2d(base_channels), nn.ReLU(),
            nn.Conv2d(base_channels, base_channels*2, 3, 1, 1),
            nn.BatchNorm2d(base_channels*2), nn.ReLU(),
            nn.MaxPool2d(2), nn.Dropout(p=dropout_p),
            nn.Conv2d(base_channels*2, base_channels*4, 3, 1, 1),
            nn.BatchNorm2d(base_channels*4), nn.ReLU(),
            nn.Conv2d(base_channels*4, base_channels*4, 3, 1, 1),
            nn.BatchNorm2d(base_channels*4), nn.ReLU(),
            nn.MaxPool2d(2), nn.Dropout(p=dropout_p),
            nn.AdaptiveAvgPool2d(1), nn.Flatten()
        )
        dummy_input = torch.randn(1, input_channels, 32, 32)
        with torch.no_grad():
            features_dim = self.features(dummy_input).shape[1]
        self.features_dim = features_dim
    def forward(self, x): return self.features(x)

def get_backbone(method):
    if method in ["dino", "mae"]:
        vit = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=0)
        vit.features_dim = vit.embed_dim
        return vit
    else:
        return CustomCNNBackbone()

# ========= SSL MODELS =========
class RotNetModel(nn.Module):
    def __init__(self, backbone, num_rot_classes=4):
        super().__init__()
        self.backbone = backbone
        self.features_dim = backbone.features_dim
        self.classifier = nn.Sequential(
            nn.Linear(self.features_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_rot_classes)
        )
    def forward(self, x):
        features = self.backbone(x)
        logits = self.classifier(features)
        return {'logits': logits, 'features': features}
    def calculate_ssl_loss(self, logits, rot_labels):
        return F.cross_entropy(logits, rot_labels.long())

class SimSiamModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.features_dim = backbone.features_dim
        self.projector = nn.Sequential(
            nn.Linear(self.features_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 2048),
            nn.BatchNorm1d(2048)
        )
        self.predictor = nn.Sequential(
            nn.Linear(2048, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 2048)
        )
    def forward(self, x):
        features = self.backbone(x)
        z = self.projector(features)
        p = self.predictor(z)
        return {'features': features, 'projected': z, 'predicted': p}
    def calculate_ssl_loss(self, p1, z2, p2, z1):
        return -(F.cosine_similarity(p1, z2.detach(), dim=-1).mean() +
                 F.cosine_similarity(p2, z1.detach(), dim=-1).mean()) * 0.5

class DinoHead(nn.Module):
    def __init__(self, in_dim, out_dim=2048, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim), nn.GELU()]
            for _ in range(nlayers - 2):
                layers += [nn.Linear(hidden_dim, hidden_dim), nn.GELU()]
            layers += [nn.Linear(hidden_dim, bottleneck_dim)]
            self.mlp = nn.Sequential(*layers)
        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
    def forward(self, x):
        x = self.mlp(x)
        x = F.normalize(x, dim=-1, p=2)
        return self.last_layer(x), x

class DINOModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        embed_dim = backbone.features_dim
        self.head = DinoHead(embed_dim)
        self.teacher = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=0)
        self.teacher.eval()
        for p in self.teacher.parameters(): p.requires_grad = False
    def forward(self, x1, x2):
        y1 = self.backbone(x1)
        y2 = self.backbone(x2)
        out1, _ = self.head(y1)
        out2, _ = self.head(y2)
        with torch.no_grad():
            t1 = self.teacher(x1)
            t2 = self.teacher(x2)
            t1_out, _ = self.head(t1)
            t2_out, _ = self.head(t2)
        return {"student": (out1, out2), "teacher": (t1_out, t2_out)}
    def calculate_ssl_loss(self, student_out, teacher_out):
        loss = 0.
        for s, t in zip(student_out, teacher_out):
            t = t.detach()
            loss += - (F.softmax(t, dim=-1) * F.log_softmax(s, dim=-1)).sum(dim=-1).mean()
        return loss / len(student_out)

class PatchMAEEncoder(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.features_dim = backbone.features_dim
    def forward(self, x, mask):
        features = self.backbone.patch_embed(x)
        features[0][mask] = 0
        pooled = features.mean(dim=1)
        return pooled

class MAEDecoder(nn.Module):
    def __init__(self, embed_dim=768, patch_dim=768):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(embed_dim, embed_dim), nn.GELU(), nn.Linear(embed_dim, patch_dim)
        )
    def forward(self, x): return self.decoder(x)

class MAEModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.encoder = PatchMAEEncoder(backbone)
        self.decoder = MAEDecoder(embed_dim=backbone.features_dim, patch_dim=backbone.features_dim)
        self.features_dim = backbone.features_dim
    def forward(self, x, mask):
        encoded = self.encoder(x, mask)
        recon_patches = self.decoder(encoded)
        return {"recon_patches": recon_patches}
    def calculate_ssl_loss(self, x, recon_patches, mask):
        return F.mse_loss(recon_patches, x.view(x.size(0), -1))

class SdAEModel(nn.Module):
    def __init__(self, backbone):
        super().__init__()
        self.backbone = backbone
        self.features_dim = backbone.features_dim
        self.decoder = nn.Linear(self.features_dim, 3 * 32 * 32)
    def forward(self, x):
        features = self.backbone(x)
        recon = self.decoder(features)
        return {"features": features, "recon": recon}
    def calculate_ssl_loss(self, x, recon):
        return F.mse_loss(recon, x.view(x.size(0), -1))

# ========= SSL MODEL WRAPPER =========
class SSLModelWrapper(nn.Module):
    def __init__(self, method="RotNet", backbone=None, **kwargs):
        super().__init__()
        self.method = method.lower()
        if backbone is None:
            raise ValueError("You must provide a backbone!")
        if self.method == "rotnet":
            self.model = RotNetModel(backbone=backbone, num_rot_classes=kwargs.get('num_rot_classes', 4))
        elif self.method == "simsiam":
            self.model = SimSiamModel(backbone=backbone)
        elif self.method == "dino":
            self.model = DINOModel(backbone=backbone)
        elif self.method == "mae":
            self.model = MAEModel(backbone=backbone)
        elif self.method == "sdae":
            self.model = SdAEModel(backbone=backbone)
        else:
            raise NotImplementedError(f"SSL method '{method}' not implemented.")
        self.ssl_loss_fn = self.model.calculate_ssl_loss
    def forward(self, *args, **kwargs): return self.model(*args, **kwargs)
    def get_representation(self, x): return self.model.backbone(x)
    def calculate_ssl_loss(self, *args, **kwargs): return self.ssl_loss_fn(*args, **kwargs)
    @property
    def backbone(self): return self.model.backbone
    @property
    def features_dim(self): return self.model.features_dim

# ========= CaSSLe Predictor/Trainer =========
class CaSSLePredictor(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, output_dim)
        )
    def forward(self, x): return self.net(x)

class CaSSleTrainer:
    def __init__(self, base_ssl_model, ca_predictor_hidden_dim, learning_rate, lambda_cassle, device='cuda'):
        self.base_ssl_model = base_ssl_model.to(device)
        self.lambda_cassle = lambda_cassle
        self.device = device
        self.ca_predictor_hidden_dim = ca_predictor_hidden_dim
        self.learning_rate = learning_rate
        self.f_frozen_teacher = None
        self.g_current = None
        self.optimizer = None
    def set_previous_frozen_encoder(self, encoder_state):
        backbone_class, state_dict = encoder_state
        self.f_frozen_teacher = backbone_class()
        self.f_frozen_teacher.load_state_dict(state_dict)
        self.f_frozen_teacher.to(self.device)
        for param in self.f_frozen_teacher.parameters(): param.requires_grad = False
    def train_task(self, data_loader, epochs, ssl_method):
        first_batch = next(iter(data_loader))
        with torch.no_grad():
            if ssl_method == "rotnet":
                rotated_imgs, rotation_labels, _ = first_batch
                imgs_flat = rotated_imgs.view(-1, *rotated_imgs.shape[2:]).to(self.device)
                ssl_output = self.base_ssl_model(imgs_flat)
                features_for_distill = ssl_output['features']
            elif ssl_method == "simsiam":
                img1, img2, _ = first_batch
                img1 = img1.to(self.device)
                out1 = self.base_ssl_model(img1)
                features_for_distill = out1['features']
            elif ssl_method == "dino":
                img1, img2, _ = first_batch
                img1 = img1.to(self.device)
                dino_out = self.base_ssl_model(img1, img1)
                features_for_distill = dino_out["student"][0]
            elif ssl_method == "mae":
                imgs, masks, _ = first_batch
                imgs = imgs.to(self.device)
                mae_out = self.base_ssl_model(imgs, masks)
                features_for_distill = mae_out['recon_patches'].view(imgs.size(0), -1)
            elif ssl_method == "sdae":
                noisy_imgs, clean_imgs, _ = first_batch
                noisy_imgs = noisy_imgs.to(self.device)
                sdae_out = self.base_ssl_model(noisy_imgs)
                features_for_distill = sdae_out['features']
            else: raise NotImplementedError
            feature_dim = features_for_distill.shape[1]
        self.g_current = CaSSLePredictor(feature_dim, self.ca_predictor_hidden_dim, feature_dim).to(self.device)
        self.optimizer = torch.optim.AdamW(
            list(self.base_ssl_model.parameters()) + list(self.g_current.parameters()),
            lr=self.learning_rate, weight_decay=0.01
        )
        self.base_ssl_model.train(); self.g_current.train()
        if self.f_frozen_teacher: self.f_frozen_teacher.eval()
        best_loss, patience, patience_counter, min_delta = float('inf'), 10, 0, 0.001
        self.scheduler = torch.optim.lr_scheduler.StepLR(self.optimizer, step_size=100, gamma=0.1)
        for epoch in range(epochs):
            total_ssl_loss, total_cassle_loss, total_loss = 0, 0, 0
            for batch in data_loader:
                self.optimizer.zero_grad()
                if ssl_method == "rotnet":
                    rotated_imgs, rotation_labels, _ = batch
                    imgs_flat = rotated_imgs.view(-1, *rotated_imgs.shape[2:]).to(self.device)
                    labels_flat = rotation_labels.view(-1).to(self.device)
                    ssl_output = self.base_ssl_model(imgs_flat)
                    loss_ssl = self.base_ssl_model.calculate_ssl_loss(ssl_output['logits'], labels_flat)
                    features_for_distill = ssl_output['features']
                elif ssl_method == "simsiam":
                    img1, img2, _ = batch
                    img1, img2 = img1.to(self.device), img2.to(self.device)
                    out1 = self.base_ssl_model(img1)
                    out2 = self.base_ssl_model(img2)
                    loss_ssl = self.base_ssl_model.calculate_ssl_loss(
                        out1['predicted'], out2['projected'],
                        out2['predicted'], out1['projected']
                    )
                    features_for_distill = out1['features']
                elif ssl_method == "dino":
                    img1, img2, _ = batch
                    img1, img2 = img1.to(self.device), img2.to(self.device)
                    dino_out = self.base_ssl_model(img1, img2)
                    loss_ssl = self.base_ssl_model.calculate_ssl_loss(
                        dino_out["student"], dino_out["teacher"]
                    )
                    features_for_distill = dino_out["student"][0]
                elif ssl_method == "mae":
                    imgs, masks, _ = batch
                    imgs, masks = imgs.to(self.device), masks.to(self.device)
                    mae_out = self.base_ssl_model(imgs, masks)
                    loss_ssl = self.base_ssl_model.calculate_ssl_loss(imgs, mae_out['recon_patches'], masks)
                    features_for_distill = mae_out['recon_patches'].view(imgs.size(0), -1)
                elif ssl_method == "sdae":
                    noisy_imgs, clean_imgs, _ = batch
                    noisy_imgs = noisy_imgs.to(self.device)
                    clean_imgs = clean_imgs.to(self.device)
                    sdae_out = self.base_ssl_model(noisy_imgs)
                    loss_ssl = self.base_ssl_model.calculate_ssl_loss(clean_imgs, sdae_out['recon'])
                    features_for_distill = sdae_out['features']
                else: raise NotImplementedError
                loss_cassle = torch.tensor(0.0).to(self.device)
                if self.f_frozen_teacher:
                    with torch.no_grad():
                        if ssl_method == "rotnet": teacher_input = imgs_flat
                        elif ssl_method in ["simsiam", "dino"]: teacher_input = img1
                        elif ssl_method == "mae": teacher_input = imgs
                        elif ssl_method == "sdae": teacher_input = noisy_imgs
                        else: raise NotImplementedError
                        frozen_features = self.f_frozen_teacher(teacher_input)
                        frozen_features = frozen_features.view(teacher_input.size(0), -1)
                    student_features = features_for_distill
                    student_pred = self.g_current(student_features)
                    loss_cassle = 1 - F.cosine_similarity(student_pred, frozen_features, dim=-1).mean()
                loss = loss_ssl + self.lambda_cassle * loss_cassle
                loss.backward()
                self.optimizer.step()
                total_ssl_loss += loss_ssl.item()
                total_cassle_loss += loss_cassle.item()
                total_loss += loss.item()
            avg_loss = total_loss / len(data_loader)
            self.scheduler.step()
            if avg_loss < best_loss - min_delta: best_loss, patience_counter = avg_loss, 0
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print("Early stopping triggered."); break
            print(f"Epoch {epoch+1}/{epochs} - SSL Loss: {total_ssl_loss / len(data_loader):.4f}, "
                  f"CaSSle Loss: {total_cassle_loss / len(data_loader):.4f}, "
                  f"Total Loss: {total_loss / len(data_loader):.4f}")
        return (type(self.base_ssl_model.backbone), self.base_ssl_model.backbone.state_dict())

# ========= Utility: Linear Eval, Random Accuracy =========
def evaluate_model(feature_extractor, all_seen_classes, cifar100_train_full, cifar100_test_full, base_transform, batch_size=128, linear_eval_epochs=10, device=DEVICE):
    feature_extractor.eval()
    for param in feature_extractor.parameters(): param.requires_grad = False
    class LinearEvalDataset(Dataset):
        def __init__(self, original_dataset, class_list, transform):
            self.data, self.targets = [], []
            self.transform = transform
            for i in range(len(original_dataset)):
                img, label = original_dataset[i]
                if label in class_list:
                    if isinstance(img, np.ndarray):
                        img = Image.fromarray(img)
                    self.data.append(img)
                    self.targets.append(label)
        def __len__(self): return len(self.data)
        def __getitem__(self, idx):
            img = self.data[idx]
            label = self.targets[idx]
            img = self.transform(img)
            return img, label
    train_linear_dataset = LinearEvalDataset(cifar100_train_full, all_seen_classes, base_transform)
    train_linear_loader = DataLoader(train_linear_dataset, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True)
    dummy_input = torch.randn(1, 3, train_linear_loader.dataset[0][0].shape[1], train_linear_loader.dataset[0][0].shape[2]).to(device)
    with torch.no_grad():
        features_dim = feature_extractor(dummy_input).view(dummy_input.size(0), -1).shape[1]
    num_output_classes = len(all_seen_classes)
    linear_classifier = nn.Linear(features_dim, num_output_classes).to(device)
    label_to_contiguous_map = {label: i for i, label in enumerate(sorted(all_seen_classes))}
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(linear_classifier.parameters(), lr=0.001)
    linear_classifier.train()
    for epoch in range(linear_eval_epochs):
        for img_batch, label_batch in train_linear_loader:
            img_batch = img_batch.to(device)
            label_batch = torch.tensor([label_to_contiguous_map[l.item()] for l in label_batch]).to(device)
            optimizer.zero_grad()
            with torch.no_grad():
                features = feature_extractor(img_batch).view(img_batch.size(0), -1)
            outputs = linear_classifier(features)
            loss = criterion(outputs, label_batch)
            loss.backward()
            optimizer.step()
    linear_classifier.eval()
    total_correct, total_samples = 0, 0
    test_linear_dataset = LinearEvalDataset(cifar100_test_full, all_seen_classes, base_transform)
    test_linear_loader = DataLoader(test_linear_dataset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    with torch.no_grad():
        for img_batch, label_batch in test_linear_loader:
            img_batch = img_batch.to(device)
            label_batch = torch.tensor([label_to_contiguous_map[l.item()] for l in label_batch]).to(device)
            features = feature_extractor(img_batch).view(img_batch.size(0), -1)
            outputs = linear_classifier(features)
            _, predicted = torch.max(outputs.data, 1)
            total_samples += label_batch.size(0)
            total_correct += (predicted == label_batch).sum().item()
    accuracy = 100 * total_correct / total_samples
    for param in feature_extractor.parameters(): param.requires_grad = True
    feature_extractor.train()
    return accuracy

def get_random_accuracy(num_classes_in_task, cifar100_train_full, cifar100_test_full, base_transform, target_class_list, batch_size=128, linear_eval_epochs=10, device=DEVICE):
    random_backbone = CustomCNNBackbone()
    random_backbone.to(device)
    return evaluate_model(random_backbone, target_class_list, cifar100_train_full, cifar100_test_full, base_transform, batch_size, linear_eval_epochs, device)

# ========== DATA LOADING ==========
cifar100_train_full = datasets.CIFAR100(root='./data', train=True, download=True)
cifar100_test_full = datasets.CIFAR100(root='./data', train=False, download=True)
all_classes_shuffled = list(range(NUM_TOTAL_CLASSES))
random.shuffle(all_classes_shuffled)
task_class_splits = [all_classes_shuffled[i:i + NUM_CLASSES_PER_TASK] for i in range(0, NUM_TOTAL_CLASSES, NUM_CLASSES_PER_TASK)]

def prepare_task_datasets(model_type):
    if model_type in ["dino", "mae"]:
        transform = vit_transform
    else:
        transform = cnn_transform
    tasks = []
    for class_list in task_class_splits:
        if model_type == "rotnet":
            tasks.append(RotNetCifar100TaskDataset(cifar100_train_full, class_list, transform))
        elif model_type == "simsiam" or model_type == "dino":
            tasks.append(PairCifar100TaskDataset(cifar100_train_full, class_list, transform))
        elif model_type == "mae":
            tasks.append(MAECifar100TaskDataset(cifar100_train_full, class_list, transform))
        elif model_type == "sdae":
            tasks.append(DenoisingCifar100TaskDataset(cifar100_train_full, class_list, transform))
        else:
            raise NotImplementedError
    return tasks

# ========== MAIN TRAIN LOOP ==========
ssl_methods = ["rotnet", "simsiam", "sdae"]
results = {}
for ssl_method in ssl_methods:
    print("\n" + "="*60)
    print(f"Starting continual learning for SSL method: {ssl_method.upper()}")
    print("="*60)
    backbone = get_backbone(ssl_method).to(DEVICE)
    base_ssl_model_instance = SSLModelWrapper(method=ssl_method, backbone=backbone, num_rot_classes=NUM_ROT_CLASSES)
    prev_encoder_state = None
    all_task_accuracies = []
    random_accuracies_Ri = {}
    task_datasets = prepare_task_datasets(ssl_method)
    for task_id, current_task_dataset in enumerate(task_datasets):
        print(f"\n===== Training Task {task_id + 1}/{len(task_datasets)} =====")
        current_task_loader = DataLoader(current_task_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
        trainer = CaSSleTrainer(base_ssl_model=base_ssl_model_instance,
                                ca_predictor_hidden_dim=256,
                                learning_rate=LEARNING_RATE,
                                lambda_cassle=LAMBDA_CASSLE,
                                device=DEVICE)
        if prev_encoder_state:
            trainer.set_previous_frozen_encoder(prev_encoder_state)
        prev_encoder_state = trainer.train_task(current_task_loader, NUM_EPOCHS_PER_TASK, ssl_method)
        print(f"\n--- Evaluating after Task {task_id + 1} ---")
        current_seen_classes = sorted(set().union(*task_class_splits[:task_id + 1]))
        accuracies_after_this_task = []
        for eval_task_idx in range(task_id + 1):
            eval_task_classes = task_class_splits[eval_task_idx]
            acc_jk = evaluate_model(
                base_ssl_model_instance.backbone,
                eval_task_classes,
                cifar100_train_full,
                cifar100_test_full,
                vit_transform if ssl_method in ["dino", "mae"] else cnn_transform,
                LINEAR_EVAL_BATCH_SIZE,
                LINEAR_EVAL_EPOCHS,
                DEVICE
            )
            accuracies_after_this_task.append(acc_jk)
            print(f"  Eval Task {eval_task_idx+1}: {acc_jk:.2f}%")
            if eval_task_idx not in random_accuracies_Ri:
                random_accuracies_Ri[eval_task_idx] = get_random_accuracy(
                    NUM_CLASSES_PER_TASK,
                    cifar100_train_full,
                    cifar100_test_full,
                    vit_transform if ssl_method in ["dino", "mae"] else cnn_transform,
                    eval_task_classes,
                    LINEAR_EVAL_BATCH_SIZE,
                    LINEAR_EVAL_EPOCHS,
                    DEVICE
                )
        all_task_accuracies.append(accuracies_after_this_task)
    T = len(task_datasets)
    final_accuracies_row = all_task_accuracies[T-1]
    avg_accuracy = sum(final_accuracies_row) / T
    print(f"\nFinal Average Accuracy (A) for {ssl_method.upper()}: {avg_accuracy:.2f}%")
    forgetting = 0
    if T > 1:
        for i in range(T - 1):
            max_acc = max(all_task_accuracies[t][i] for t in range(T) if i < len(all_task_accuracies[t]))
            final_acc = all_task_accuracies[T-1][i]
            forgetting += (max_acc - final_acc)
        forgetting /= (T - 1)
    print(f"Final Forgetting (F) for {ssl_method.upper()}: {forgetting:.2f}%")
    backward_transfer = 0
    count = 0
    if T > 1:
        for new_task in range(1, T):
            for old_task in range(new_task):
                if old_task < len(all_task_accuracies[new_task - 1]) and old_task < len(all_task_accuracies[new_task]):
                    acc_before = all_task_accuracies[new_task - 1][old_task]
                    acc_after = all_task_accuracies[new_task][old_task]
                    backward_transfer += (acc_after - acc_before)
                    count += 1
        backward_transfer /= count if count > 0 else 1
    else:
        backward_transfer = 0
    print(f"Final Backward Transfer (BT) for {ssl_method.upper()}: {backward_transfer:.2f}%")
    results[ssl_method] = {
        "average_accuracy": avg_accuracy,
        "forgetting": forgetting,
        "backward_transfer": backward_transfer
    }

print("\n==== ALL SSL MODEL RESULTS ====")
for model, res in results.items():
    print(f"{model.upper()}: Accuracy={res['average_accuracy']:.2f}% | Forgetting={res['forgetting']:.2f}% | Backward Transfer={res['backward_transfer']:.2f}%")
pd.DataFrame(results).T.to_csv("all_ssl_results.csv")


100%|██████████| 169M/169M [00:02<00:00, 78.1MB/s] 



Starting continual learning for SSL method: ROTNET

===== Training Task 1/5 =====
Epoch 1/5 - SSL Loss: 1.2226, CaSSle Loss: 0.0000, Total Loss: 1.2226
Epoch 2/5 - SSL Loss: 1.1364, CaSSle Loss: 0.0000, Total Loss: 1.1364
Epoch 3/5 - SSL Loss: 1.0862, CaSSle Loss: 0.0000, Total Loss: 1.0862
Epoch 4/5 - SSL Loss: 1.0456, CaSSle Loss: 0.0000, Total Loss: 1.0456
Epoch 5/5 - SSL Loss: 1.0084, CaSSle Loss: 0.0000, Total Loss: 1.0084

--- Evaluating after Task 1 ---
  Eval Task 1: 36.10%

===== Training Task 2/5 =====
Epoch 1/5 - SSL Loss: 0.8888, CaSSle Loss: 0.0475, Total Loss: 0.9268
Epoch 2/5 - SSL Loss: 0.8304, CaSSle Loss: 0.0211, Total Loss: 0.8473
Epoch 3/5 - SSL Loss: 0.7953, CaSSle Loss: 0.0197, Total Loss: 0.8111
Epoch 4/5 - SSL Loss: 0.7618, CaSSle Loss: 0.0192, Total Loss: 0.7771
Epoch 5/5 - SSL Loss: 0.7446, CaSSle Loss: 0.0182, Total Loss: 0.7592

--- Evaluating after Task 2 ---
  Eval Task 1: 39.50%
  Eval Task 2: 42.95%

===== Training Task 3/5 =====
Epoch 1/5 - SSL Loss: 1