In [None]:


# ======================================================================
# == Imports ===========================================================
# ======================================================================
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, Subset, Dataset
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
# Optional: from umap import UMAP # If you install umap-learn
import numpy as np
from tqdm.notebook import tqdm # Use notebook version for better Colab/Jupyter integration
import random
import os
import time

# ======================================================================
# == Configuration =====================================================
# ======================================================================
config = {
    # --- General Settings ---
    "seed": 42,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "num_workers": 2,         # Dataloader workers (can be 0 in Colab if issues arise)
    "save_dir": "./simclr_cifar10_demo_output", # Directory to save models and plots

    # --- Data Settings ---
    "dataset": "CIFAR10",
    "image_size": 32,         # CIFAR-10 image size
    "num_classes": 10,        # Number of classes in CIFAR-10
    "cifar_mean": (0.4914, 0.4822, 0.4465),
    "cifar_std": (0.2023, 0.1994, 0.2010),

    # --- SSL Pre-training (SimCLR) Settings ---
    "ssl_model_name": "resnet18", # Base encoder for SimCLR
    "projection_dim": 128,    # Dimension of the SimCLR projection head output
    "ssl_epochs": 20,         
    "ssl_batch_size": 128,    # Lower if GPU memory is limited (e.g., 64)
    "ssl_learning_rate": 3e-4,
    "ssl_weight_decay": 1e-6,
    "temperature": 0.1,       # Temperature for NT-Xent loss
    "n_views": 2,             # Number of augmentations per image for contrastive loss
    "use_subset_ssl": 5000,   # Use a subset (e.g., 5000) for faster SSL demo, None for full dataset

    # --- Downstream Linear Probing Settings ---
    "linear_epochs": 50,      # Epochs for training the linear classifier
    "linear_batch_size": 256,
    "linear_learning_rate": 1e-3, # Often higher LR for linear head is fine
    "linear_weight_decay": 0.0, # Often no weight decay for linear probing head

    # --- Visualization Settings ---
    "vis_num_augmentations": 6, # Number of image pairs to show for augmentation viz
    "vis_tsne_subset_size": 1000, # Number of points for t-SNE plot (keep manageable)
}

# --- Create Save Directory ---
os.makedirs(config["save_dir"], exist_ok=True)

# --- Set Seed for Reproducibility ---
random.seed(config["seed"])
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])
if config["device"].type == "cuda":
    torch.cuda.manual_seed_all(config["seed"])
    # Optional: Force deterministic algorithms (can slow down training, use for debugging)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = False # Disabling benchmark can also aid reproducibility but slow down

print(f"Using device: {config['device']}")
print(f"Configuration:\n{config}")

# ======================================================================
# == Helper Functions & Classes ========================================
# ======================================================================

# --- 1. Data Augmentation for Contrastive Learning ---
class ContrastiveTransformations:
    """Applies a base transformation multiple times to get different views."""
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views

    def __call__(self, x):
        # Apply the base_transforms n_views times to the input image x
        return [self.base_transforms(x) for _ in range(self.n_views)]

def get_cifar10_ssl_transforms(img_size=32, mean=config["cifar_mean"], std=config["cifar_std"]):
    """Defines the strong augmentations for SimCLR pre-training on CIFAR-10."""
    # Color jitter parameters are often crucial for contrastive learning effectiveness
    s = 0.5 # Strength factor for color jitter
    color_jitter = transforms.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)

    transform = transforms.Compose([
        # Randomly crop a region and resize it to the desired image size
        transforms.RandomResizedCrop(size=img_size, scale=(0.2, 1.0)),
        # Randomly flip the image horizontally
        transforms.RandomHorizontalFlip(p=0.5),
        # Apply color jitter transformations randomly with a probability
        transforms.RandomApply([color_jitter], p=0.8),
        # Convert image to grayscale randomly with a probability
        transforms.RandomGrayscale(p=0.2),
        # Optional: Gaussian Blur (can be added, but less common/effective on small CIFAR images)
        # transforms.GaussianBlur(kernel_size=int(0.1 * img_size) | 1, sigma=(0.1, 2.0)), # Kernel size must be odd
        # Convert PIL image to PyTorch tensor
        transforms.ToTensor(),
        # Normalize the tensor image with mean and standard deviation
        transforms.Normalize(mean, std)
    ])
    return transform

# --- 2. NT-Xent Loss ---
class NTXentLoss(nn.Module):
    """
    Normalized Temperature-scaled Cross Entropy Loss.
    Calculates the loss for contrastive learning based on cosine similarity.
    """
    def __init__(self, temperature=0.5, batch_size=None, n_views=2, device='cpu'):
        super(NTXentLoss, self).__init__()
        if batch_size is None:
            raise ValueError("batch_size must be specified for NTXentLoss")
        self.temperature = temperature
        self.batch_size = batch_size # Expected batch size
        self.n_views = n_views
        self.device = device
        # Use sum reduction for CrossEntropyLoss, will normalize manually later
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        # Cosine similarity function to compare feature vectors
        self.similarity_f = nn.CosineSimilarity(dim=2) # Compare along feature dimension

    def _get_correlated_mask(self, current_batch_size):
        """
        Generates masks to identify positive pairs (different views of the same image)
        and exclude self-comparisons within the similarity matrix.
        """
        # Diagonal mask to identify self-comparisons (feature vector compared with itself)
        diag = torch.eye(self.n_views * current_batch_size, device=self.device, dtype=torch.bool)

        # Block diagonal mask: assumes views of the same image are adjacent in the batch
        # e.g., [img1_view1, img1_view2, img2_view1, img2_view2, ...]
        mask = torch.block_diag(*[torch.ones((self.n_views, self.n_views), device=self.device, dtype=torch.bool)
                                 for _ in range(current_batch_size)])

        # Mask out the diagonal (self-comparisons) but keep other elements within the block (positive pairs)
        mask_without_self = mask & (~diag)
        return mask_without_self, diag

    def forward(self, features):
        """
        Calculates the NT-Xent loss.
        Args:
            features (torch.Tensor): Concatenated projected features from all views.
                                      Shape: (n_views * batch_size, projection_dim)
        Returns:
            torch.Tensor: The calculated scalar loss value.
        """
        # Check if the input feature dimension matches the expected format
        if features.shape[0] % self.n_views != 0:
             raise ValueError(f"Feature dimension {features.shape[0]} is not divisible by n_views {self.n_views}")

        # Determine the actual batch size in this step (might be smaller for the last batch)
        current_batch_size = features.shape[0] // self.n_views
        if current_batch_size == 0:
             print("Warning: NTXentLoss received batch size 0. Returning 0 loss.")
             return torch.tensor(0.0, device=self.device, requires_grad=True) # Return loss=0 if batch is empty

        # --- Calculate Cosine Similarity Matrix ---
        # Compare each feature vector with all other feature vectors in the batch
        # features.unsqueeze(1): (N*B, 1, D)
        # features.unsqueeze(0): (1, N*B, D)
        # sim_matrix: (N*B, N*B) where N=n_views, B=current_batch_size
        sim_matrix = self.similarity_f(features.unsqueeze(1), features.unsqueeze(0))

        # --- Identify Positive and Negative Pairs ---
        # Get masks to differentiate pairs
        positive_pair_mask, self_comp_mask = self._get_correlated_mask(current_batch_size)

        # Select similarities corresponding to positive pairs (views of the same image, excluding self)
        # Shape: (N*B, N-1), for N=2 -> (2*B, 1)
        positives = sim_matrix[positive_pair_mask].view(self.n_views * current_batch_size, -1)

        # Select similarities corresponding to negative pairs (views of different images)
        # Exclude self-comparisons and positive pairs
        negatives_mask = ~positive_pair_mask & ~self_comp_mask
        # Shape: (N*B, N*B - N), for N=2 -> (2*B, 2*B - 2)
        negatives = sim_matrix[negatives_mask].view(self.n_views * current_batch_size, -1)

        # --- Construct Logits and Labels for CrossEntropyLoss ---
        # Concatenate positive similarity score(s) with negative similarity scores for each sample.
        # The positive score should be the target for the classification task.
        # Shape: (N*B, 1 + (N*B - N)) -> (N*B, N*B - N + 1), for N=2 -> (2*B, 2*B - 1)
        logits = torch.cat([positives, negatives], dim=1)

        # Labels: The positive pair is always at index 0 for each row in the logits matrix
        # Shape: (N*B,)
        labels = torch.zeros(self.n_views * current_batch_size, dtype=torch.long, device=self.device)

        # --- Apply Temperature Scaling and Calculate Loss ---
        # Scale logits by temperature T. Lower T -> sharper distribution, higher T -> softer distribution.
        logits = logits / self.temperature

        # Calculate CrossEntropyLoss. Treats it as a classification problem where the goal is
        # to correctly identify the positive pair (at index 0) among all other pairs.
        loss = self.criterion(logits, labels)

        # Normalize loss by the total number of samples in the batch (N*B)
        loss = loss / (self.n_views * current_batch_size)
        return loss

# --- 3. Model Definition (Encoder + Projector) ---
class SimCLRModel(nn.Module):
    """
    The SimCLR model combines a base encoder network (e.g., ResNet)
    with an MLP projection head.
    """
    def __init__(self, base_encoder, projection_dim):
        super().__init__()
        self.encoder = base_encoder
        # Get the output feature dimension from the pre-defined attribute in the encoder
        self.n_features = base_encoder.n_features

        # Define the MLP projection head (as used in SimCLR paper)
        # Input: Encoder features -> Hidden Layer -> Output: Projection
        self.projector = nn.Sequential(
            nn.Linear(self.n_features, self.n_features, bias=False), # Hidden layer
            nn.ReLU(),                                             # Non-linearity
            nn.Linear(self.n_features, projection_dim, bias=False), # Output layer
        )

    def forward(self, x):
        # Pass input through the base encoder to get feature representations
        features = self.encoder(x)
        # Pass features through the projection head
        projections = self.projector(features)
        # Return both raw features (for downstream tasks) and projections (for SSL loss)
        return features, projections

def get_resnet_encoder(name="resnet18", use_pretrained=False):
    """
    Loads a specified ResNet model (e.g., 'resnet18', 'resnet50').
    Optionally modifies it for small input images like CIFAR-10 if not using ImageNet pretraining.
    Removes the final classification layer and stores the feature dimension.
    """
    print(f"Loading {name} {'with' if use_pretrained else 'without'} pretrained weights.")

    # Load the specified ResNet model
    if name == "resnet18":
        weights = models.ResNet18_Weights.DEFAULT if use_pretrained else None
        encoder = models.resnet18(weights=weights)
    elif name == "resnet50":
         weights = models.ResNet50_Weights.DEFAULT if use_pretrained else None
         encoder = models.resnet50(weights=weights)
    else:
        raise ValueError(f"Unsupported ResNet name: {name}. Choose 'resnet18' or 'resnet50'.")

    # Store the original feature dimension (before removing the fc layer)
    encoder.n_features = encoder.fc.in_features

    # Modify the first convolutional layer and remove max pooling for small images (like CIFAR-10)
    if not use_pretrained and config['image_size'] <= 32:
        print("Modifying ResNet first layer stride/kernel and removing maxpool for small image input.")
        # Change conv1 from kernel_size=7, stride=2 to kernel_size=3, stride=1, padding=1
        encoder.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        # Replace the initial max pooling layer with an identity function (effectively removing it)
        encoder.maxpool = nn.Identity()

    # Remove the final fully connected layer (classification layer)
    encoder.fc = nn.Identity()
    return encoder

# --- 4. Visualisation Functions ---
def visualize_augmentations(dataset, n_samples=5, save_path=None):
    """
    Displays pairs of augmented images generated by the contrastive transformation pipeline.
    """
    
    base_dataset = dataset.dataset if isinstance(dataset, Subset) else dataset

    # Temporarily remove transforms from the base dataset to retrieve original PIL images
    original_transform = base_dataset.transform
    base_dataset.transform = None # Get PIL image for applying contrastive transforms manually

    plt.figure(figsize=(n_samples * 3, config["n_views"] * 3 + 1)) # Adjust figure size
    plt.suptitle("Sample Contrastive Augmentations", fontsize=16)

    # Select random samples from the dataset (or subset)
    indices = random.sample(range(len(dataset)), n_samples)

    # Get the contrastive transformation function (it expects PIL images)
    contrastive_transform_func = ContrastiveTransformations(
        get_cifar10_ssl_transforms(img_size=config['image_size']), # Use configured img size
        n_views=config["n_views"]
    )

    for i, idx in enumerate(indices):
        # Get the actual index in the base dataset if using Subset
        actual_idx = dataset.indices[idx] if isinstance(dataset, Subset) else idx
        # Retrieve the original PIL image and its label (label ignored here)
        pil_img, _ = base_dataset[actual_idx]

        # Apply the contrastive transformations manually to the PIL image
        img_views_tensors = contrastive_transform_func(pil_img)

        # Plot each generated view
        for j, img_tensor in enumerate(img_views_tensors):
            ax = plt.subplot(config["n_views"], n_samples, j * n_samples + i + 1)

            # Inverse normalize the tensor image for display
            mean = torch.tensor(config["cifar_mean"]).view(3, 1, 1)
            std = torch.tensor(config["cifar_std"]).view(3, 1, 1)
            img_display = img_tensor * std + mean
            # Permute dimensions from (C, H, W) to (H, W, C) for matplotlib
            img_display = img_display.permute(1, 2, 0).numpy()
            # Clip values to [0, 1] range for valid image display
            img_display = np.clip(img_display, 0, 1)

            plt.imshow(img_display)
            plt.title(f"Idx {idx} - View {j+1}")
            plt.axis("off") # Hide axes

    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to prevent title overlap
    if save_path:
        plt.savefig(save_path)
        print(f"Augmentation visualization saved to {save_path}")
    plt.show() # Display the plot

    # Restore the original transform to the base dataset
    base_dataset.transform = original_transform

@torch.no_grad() # Disable gradient calculations for inference
def visualize_embeddings(encoder_model, dataloader, device, num_samples, title="", save_path=None):
    """
    Generates a t-SNE visualization of the features extracted by the encoder.
    Points are colored by their true class labels to assess clustering quality.
    """
    encoder_model.eval() # Set the encoder model to evaluation mode
    all_features = []
    all_labels = []

    print(f"Generating features for {title} ({num_samples} samples)...")

    # Create a temporary subset loader to get exactly num_samples (or close to it)
    # This avoids processing the entire test set if only a subset is needed for visualization.
    try:
        subset_indices = random.sample(range(len(dataloader.dataset)), num_samples)
    except ValueError:
        print(f"Warning: Requested num_samples ({num_samples}) > dataset size ({len(dataloader.dataset)}). Using full dataset for t-SNE.")
        subset_indices = list(range(len(dataloader.dataset)))
        num_samples = len(dataloader.dataset) # Adjust num_samples if needed

    subset_loader = DataLoader(
        Subset(dataloader.dataset, subset_indices),
        batch_size=dataloader.batch_size, # Use original loader's batch size for efficiency
        shuffle=False, # No need to shuffle for inference
        num_workers=config["num_workers"]
    )

    # Extract features and labels for the selected samples
    for images, labels in tqdm(subset_loader, desc="Extracting features"):
        images = images.to(device)
        # Pass images through the encoder to get feature representations
        features = encoder_model(images)
        all_features.append(features.cpu().numpy()) # Store features on CPU as numpy arrays
        all_labels.append(labels.cpu().numpy())     # Store labels on CPU as numpy arrays

    # Concatenate features and labels from all batches
    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    # Ensure we have exactly num_samples (in case batching gave slightly more/less)
    # This step might not be strictly necessary with the Subset approach but adds robustness.
    all_features = all_features[:num_samples]
    all_labels = all_labels[:num_samples]


    print(f"Running t-SNE dimensionality reduction on {all_features.shape[0]} features...")
    start_tsne = time.time()
    # Initialize t-SNE. 'auto' learning rate and 'pca' initialization are often good defaults.
    tsne = TSNE(n_components=2, # Reduce to 2 dimensions for plotting
                  perplexity=30, # Typical value, controls local neighborhood size
                  learning_rate='auto',
                  init='pca',      # PCA initialization is faster and often more stable
                  n_iter=1000,     # Number of optimization iterations
                  random_state=config["seed"]) # For reproducibility
    embeddings_2d = tsne.fit_transform(all_features) # Perform dimensionality reduction
    end_tsne = time.time()
    print(f"t-SNE finished in {end_tsne - start_tsne:.2f} seconds.")


    # --- Plotting the t-SNE Results ---
    print("Plotting t-SNE visualization...")
    plt.figure(figsize=(12, 10))
    # Create a scatter plot of the 2D embeddings
    scatter = plt.scatter(
        embeddings_2d[:, 0], # x-coordinates
        embeddings_2d[:, 1], # y-coordinates
        c=all_labels,        # Color points by their true class labels
        cmap='tab10',        # Use a colormap suitable for categorical data (10 colors)
        s=10,                # Size of the points
        alpha=0.8            # Transparency of the points
    )
    plt.title(f"{title} Visualization of Encoder Features (Colored by Class)")
    plt.xlabel("t-SNE Component 1")
    plt.ylabel("t-SNE Component 2")

    # Create a legend for the classes
    # Access class names correctly whether using Subset or full Dataset
    if isinstance(dataloader.dataset, Subset):
         classes = dataloader.dataset.dataset.classes
    else:
         classes = dataloader.dataset.classes

    if len(classes) <= 15: # Only show legend if number of classes is manageable
         # Generate legend handles and labels from the scatter plot
         legend_elements = scatter.legend_elements(num=len(classes))[0]
         plt.legend(handles=legend_elements, labels=classes, title="Classes", loc='best')
    else:
         plt.colorbar(scatter, label='Class Label') # Use a colorbar if too many classes for legend

    plt.grid(True, linestyle='--', alpha=0.6) # Add a grid for better readability

    if save_path:
        plt.savefig(save_path)
        print(f"t-SNE plot saved to {save_path}")
    plt.show() # Display the plot

def plot_loss_curve(losses, title, save_path=None):
    """Plots a simple loss curve over epochs."""
    plt.figure(figsize=(10, 5))
    plt.plot(range(1, len(losses) + 1), losses, marker='o', linestyle='-')
    plt.title(title)
    plt.xlabel("Epoch")
    plt.ylabel("Average Loss per Sample")
    plt.grid(True)
    plt.tight_layout() # Adjust layout
    if save_path:
        plt.savefig(save_path)
        print(f"Loss curve saved to {save_path}")
    plt.show() # Display the plot


# ======================================================================
# == Main Execution Block ==============================================
# ======================================================================
if __name__ == "__main__":

    # --- Define file paths for saving models and plots ---
    ssl_encoder_save_path = os.path.join(config["save_dir"], "ssl_encoder_final.pth")
    ssl_full_model_save_path = os.path.join(config["save_dir"], "ssl_full_model_final.pth")
    linear_classifier_save_path = os.path.join(config["save_dir"], "linear_classifier_final.pth")
    augment_vis_save_path = os.path.join(config["save_dir"], "augmentations_visualization.png")
    ssl_loss_save_path = os.path.join(config["save_dir"], "ssl_training_loss.png")
    tsne_save_path = os.path.join(config["save_dir"], "tsne_visualization_final.png")

    # ======================================================
    # == 1. SELF-SUPERVISED PRE-TRAINING (SimCLR) ========
    # ======================================================
    print("\n" + "="*70)
    print(" STEP 1: Self-Supervised Pre-training (SimCLR) ".center(70, "="))
    print("="*70 + "\n")

    start_ssl_time = time.time()

    # --- SSL Data Loading ---
    print("Loading CIFAR-10 dataset for SSL pre-training...")
    # Define the contrastive transformations pipeline
    ssl_transforms_func = get_cifar10_ssl_transforms(img_size=config["image_size"])
    contrastive_transforms_wrapper = ContrastiveTransformations(ssl_transforms_func, n_views=config["n_views"])

    # Load the base training dataset *without* transforms first to allow subsetting and visualization
    base_train_dataset = datasets.CIFAR10(
        root='./data', train=True, download=True, transform=None # Get PIL images initially
    )

    # Handle optional subset usage for faster demo runs
    if config["use_subset_ssl"] and config["use_subset_ssl"] < len(base_train_dataset):
        print(f"Using a subset of {config['use_subset_ssl']} images for SSL training.")
        subset_indices = list(range(len(base_train_dataset)))
        random.shuffle(subset_indices) # Shuffle indices before selecting subset
        # Create a Subset object pointing to the base dataset with selected indices
        ssl_train_dataset_for_loader = Subset(base_train_dataset, subset_indices[:config['use_subset_ssl']])
        # IMPORTANT: Apply the contrastive transform *to the underlying base dataset*
        # This way, the Subset accesses already transformed data when indexed.
        # We modify the transform attribute of the dataset wrapped by the Subset.
        ssl_train_dataset_for_loader.dataset.transform = contrastive_transforms_wrapper

        # Dataset for visualization (needs PIL access, so use the subset of the non-transformed data)
        vis_dataset_ssl = Subset(base_train_dataset, subset_indices[:config['use_subset_ssl']])
    else:
        print("Using the full CIFAR-10 training set for SSL.")
        # Apply the contrastive transform directly to the dataset if using the full set
        base_train_dataset.transform = contrastive_transforms_wrapper
        ssl_train_dataset_for_loader = base_train_dataset

        # Use the non-transformed base dataset for visualization base
        vis_dataset_ssl = datasets.CIFAR10(root='./data', train=True, download=False, transform=None)


    # Create the DataLoader for SSL training
    ssl_train_loader = DataLoader(
        ssl_train_dataset_for_loader,
        batch_size=config["ssl_batch_size"],
        shuffle=True, # Shuffle data each epoch
        num_workers=config["num_workers"],
        pin_memory=True, # Speeds up data transfer to GPU
        drop_last=True   # Drop the last incomplete batch, important for NTXentLoss assumptions
    )

    # --- Visualize Augmentations ---
    print("\nVisualizing sample SSL augmentations...")
    visualize_augmentations(vis_dataset_ssl, config["vis_num_augmentations"], augment_vis_save_path)

    # --- Initialize SSL Model, Loss, Optimizer ---
    print("\nInitializing SimCLR model...")
    # Load the base encoder network (e.g., ResNet18) without pre-trained weights
    base_encoder = get_resnet_encoder(name=config["ssl_model_name"], use_pretrained=False)
    # Create the full SimCLR model (encoder + projection head)
    ssl_model = SimCLRModel(base_encoder, config["projection_dim"]).to(config["device"])

    # Initialize the NT-Xent loss function
    ssl_criterion = NTXentLoss(
        temperature=config["temperature"],
        # Provide the expected batch size (loss function might adjust internally for last batch if drop_last=False)
        batch_size=config["ssl_batch_size"],
        n_views=config["n_views"],
        device=config["device"]
    )

    # Initialize the optimizer (AdamW is common for transformers and often works well here too)
    ssl_optimizer = optim.AdamW(
        ssl_model.parameters(), # Optimize all parameters in the SimCLR model (encoder + projector)
        lr=config["ssl_learning_rate"],
        weight_decay=config["ssl_weight_decay"]
    )
    # Initialize a learning rate scheduler (Cosine Annealing is common for SSL)
    ssl_scheduler = optim.lr_scheduler.CosineAnnealingLR(
        ssl_optimizer,
        T_max=len(ssl_train_loader) * config["ssl_epochs"], # Total number of training steps
        eta_min=0 # Minimum learning rate
    )

    # --- SSL Training Loop ---
    print(f"\nStarting SimCLR pre-training for {config['ssl_epochs']} epochs...")
    ssl_train_losses = []
    for epoch in range(config["ssl_epochs"]):
        ssl_model.train() # Set model to training mode
        epoch_loss = 0.0
        num_samples = 0

        # Use tqdm for progress bar
        progress_bar = tqdm(ssl_train_loader, desc=f"SSL Epoch {epoch+1}/{config['ssl_epochs']}", leave=True)

        for batch_idx, (images, _) in enumerate(progress_bar): # Labels (_) are ignored in SSL
            # 'images' is a list of tensors [view1_batch, view2_batch, ...] from ContrastiveTransformations
            # Concatenate the views along the batch dimension:
            # e.g., [B, C, H, W], [B, C, H, W] -> [2*B, C, H, W]
            images_cat = torch.cat(images, dim=0).to(config["device"])
            current_batch_size = images[0].size(0) # Batch size of one view
            num_samples += current_batch_size

            ssl_optimizer.zero_grad() # Reset gradients

            # Forward pass: Get features and projections from the SimCLR model
            _, projections = ssl_model(images_cat)

            # Calculate the contrastive loss using the projections
            loss = ssl_criterion(projections)

            # Check for NaN loss (can happen with unstable training/large LRs)
            if torch.isnan(loss):
                 print(f"\nWarning: NaN loss detected at epoch {epoch+1}, batch {batch_idx}. Skipping update.")
                 # Consider stopping training or reducing LR if this happens frequently
                 continue # Skip backward pass and optimizer step

            # Backward pass: Compute gradients
            loss.backward()
            # Optimizer step: Update model parameters
            ssl_optimizer.step()
            # Scheduler step: Update learning rate (after optimizer step)
            ssl_scheduler.step()

            # Accumulate loss for the epoch. Multiply by batch size since loss is averaged per sample.
            epoch_loss += loss.item() * current_batch_size
            # Update progress bar description with current loss and learning rate
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", lr=f"{ssl_optimizer.param_groups[0]['lr']:.1e}")

        # Calculate average loss for the epoch (per sample)
        avg_epoch_loss = epoch_loss / num_samples if num_samples > 0 else 0
        ssl_train_losses.append(avg_epoch_loss)
        print(f"SSL Epoch {epoch+1}/{config['ssl_epochs']} - Average Loss: {avg_epoch_loss:.5f}")


    end_ssl_time = time.time()
    print(f"\nSSL pre-training finished in {(end_ssl_time - start_ssl_time)/60:.2f} minutes.")

    # --- Save Final SSL Model Weights ---
    print("Saving final SSL model weights...")
    # Save only the ENCODER weights - this is typically what's used for downstream tasks
    torch.save(ssl_model.encoder.state_dict(), ssl_encoder_save_path)
    print(f" --> Final Encoder weights saved to: {ssl_encoder_save_path}")
    # Save the full SimCLR model (encoder + projector) as well, might be useful
    torch.save(ssl_model.state_dict(), ssl_full_model_save_path)
    print(f" --> Final Full SSL model saved to: {ssl_full_model_save_path}")

    # --- Plot SSL Loss Curve ---
    plot_loss_curve(ssl_train_losses, "SimCLR Pre-training Loss per Sample", ssl_loss_save_path)


    # ======================================================
    # == 2. DOWNSTREAM TASK: LINEAR PROBING ==============
    # ======================================================
    print("\n" + "="*70)
    print(" STEP 2: Downstream Task - Linear Probing Evaluation ".center(70, "="))
    print("="*70 + "\n")

    start_linear_time = time.time()

    # Linear Probing Data Loading 
    normalize = transforms.Normalize(config["cifar_mean"], config["cifar_std"])
    linear_train_transform = transforms.Compose([
        transforms.RandomResizedCrop(config["image_size"], scale=(0.8, 1.0)), # Less aggressive crop
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    # For testing, only normalization is needed
    linear_test_transform = transforms.Compose([
        transforms.Resize(config["image_size"]), # Ensure consistent size
        transforms.ToTensor(),
        normalize,
    ])

    print("Loading CIFAR-10 dataset for linear probing (with labels)...")
    # Load the training set with standard augmentations
    linear_train_dataset = datasets.CIFAR10(
        root='./data', train=True, download=False, # Already downloaded
        transform=linear_train_transform
    )
    # Load the test set with only test-time transformations
    linear_test_dataset = datasets.CIFAR10(
        root='./data', train=False, download=False,
        transform=linear_test_transform
    )

    # Create DataLoaders for the linear probing phase
    linear_train_loader = DataLoader(
        linear_train_dataset,
        batch_size=config["linear_batch_size"],
        shuffle=True, # Shuffle training data
        num_workers=config["num_workers"],
        pin_memory=True
    )
    linear_test_loader = DataLoader(
        linear_test_dataset,
        batch_size=config["linear_batch_size"],
        shuffle=False, # No shuffling for test set
        num_workers=config["num_workers"],
        pin_memory=True
    )

    # --- Model for Linear Probing ---
    print("\nPreparing model for linear probing...")
    # 1. Load the base encoder structure (must match the architecture saved during SSL)
    encoder = get_resnet_encoder(name=config["ssl_model_name"], use_pretrained=False)

    # 2. Load the saved SSL-trained weights into the encoder
    print(f"Loading pre-trained encoder weights from: {ssl_encoder_save_path}")
    encoder.load_state_dict(torch.load(ssl_encoder_save_path, map_location=config["device"]))
    encoder = encoder.to(config["device"])

    # 3. Freeze the encoder parameters
    # We only want to train the linear classifier head, not the pre-trained encoder.
    print("Freezing encoder parameters...")
    for param in encoder.parameters():
        param.requires_grad = False

    # 4. Create a new linear classifier head
    num_features = encoder.n_features # Get feature dimension from the loaded encoder
    linear_classifier = nn.Linear(num_features, config["num_classes"]).to(config["device"])
    print(f"Created linear classifier head ({num_features} features -> {config['num_classes']} classes).")

    # --- Optimizer and Loss for Linear Probing ---
    linear_optimizer = optim.AdamW(
        linear_classifier.parameters(), # <-- Only pass the classifier's parameters
        lr=config["linear_learning_rate"],
        weight_decay=config["linear_weight_decay"]
    )
    # Standard Cross-Entropy Loss for classification task
    linear_criterion = nn.CrossEntropyLoss().to(config["device"])

    # --- Linear Probing Training Loop ---
    print(f"\nStarting linear classifier training for {config['linear_epochs']} epochs...")
    linear_train_accuracies = []
    best_test_acc = 0.0

    for epoch in range(config["linear_epochs"]):
        encoder.eval()
        linear_classifier.train()

        epoch_loss = 0.0
        correct = 0
        total = 0
        progress_bar = tqdm(linear_train_loader, desc=f"Linear Epoch {epoch+1}/{config['linear_epochs']}", leave=False)

        for images, labels in progress_bar:
            images, labels = images.to(config["device"]), labels.to(config["device"])
            linear_optimizer.zero_grad()
            with torch.no_grad():
                features = encoder(images)
            outputs = linear_classifier(features)

            loss = linear_criterion(outputs, labels)
            loss.backward()
            linear_optimizer.step()
            epoch_loss += loss.item() * images.size(0) # Accumulate loss weighted by batch size
            _, predicted = torch.max(outputs.data, 1)  # Get the index of the max log-probability
            total += labels.size(0)
            correct += (predicted == labels).sum().item() # Count correct predictions
            batch_acc = 100 * (predicted == labels).sum().item() / labels.size(0)
            progress_bar.set_postfix(loss=f"{loss.item():.4f}", acc=f"{batch_acc:.2f}%")

        # --- End of Epoch ---
        avg_epoch_loss = epoch_loss / total
        epoch_acc = 100 * correct / total
        linear_train_accuracies.append(epoch_acc)
        print(f"Linear Epoch {epoch+1}/{config['linear_epochs']} -> Avg Loss: {avg_epoch_loss:.4f}, Train Accuracy: {epoch_acc:.2f}%")

        # --- Optional: Evaluate on test set periodically during training ---
        if (epoch + 1) % 5 == 0 or epoch == config["linear_epochs"] - 1: # Eval every 5 epochs and at the end
             encoder.eval()
             linear_classifier.eval()
             test_correct = 0
             test_total = 0
             with torch.no_grad():
                 for test_images, test_labels in linear_test_loader:
                     test_images, test_labels = test_images.to(config["device"]), test_labels.to(config["device"])
                     features = encoder(test_images)
                     outputs = linear_classifier(features)
                     _, predicted = torch.max(outputs.data, 1)
                     test_total += test_labels.size(0)
                     test_correct += (predicted == test_labels).sum().item()
             current_test_acc = 100 * test_correct / test_total
             print(f"  -> Test Accuracy after Epoch {epoch+1}: {current_test_acc:.2f}%")
             if current_test_acc > best_test_acc:
                 best_test_acc = current_test_acc
                 # Save the best performing linear classifier head
                 torch.save(linear_classifier.state_dict(), linear_classifier_save_path)
                 print(f"   -> New best test accuracy! Saved linear classifier head to: {linear_classifier_save_path}")


    end_linear_time = time.time()
    print(f"\nLinear classifier training finished in {(end_linear_time - start_linear_time)/60:.2f} minutes.")

    # --- Final Evaluation on Test Set (using the best saved head) ---
    print("\nEvaluating final linear classifier performance on the test set...")
    # Load the best performing linear head
    if os.path.exists(linear_classifier_save_path):
        print(f"Loading best linear classifier head from: {linear_classifier_save_path}")
        linear_classifier.load_state_dict(torch.load(linear_classifier_save_path, map_location=config["device"]))
    else:
        print("Warning: Best linear classifier head not found. Using the head from the final epoch.")

    encoder.eval()          # Ensure encoder is in eval mode
    linear_classifier.eval() # Ensure classifier is in eval mode

    final_test_correct = 0
    final_test_total = 0
    with torch.no_grad(): # No gradients needed for final evaluation
        for images, labels in tqdm(linear_test_loader, desc="Final Testing"):
            images, labels = images.to(config["device"]), labels.to(config["device"])
            features = encoder(images)           # Extract features
            outputs = linear_classifier(features) # Classify features
            _, predicted = torch.max(outputs.data, 1)
            final_test_total += labels.size(0)
            final_test_correct += (predicted == labels).sum().item()

    final_test_accuracy = 100 * final_test_correct / final_test_total
    print("\n" + "*"*70)
    print(f"| Final Downstream Linear Probing Test Accuracy: {final_test_accuracy:.2f}% |".center(70))
    print("*"*70 + "\n")


    # ======================================================
    # == 3. VISUALIZATION of Learned Features (t-SNE) =====
    # ======================================================
    print("\n" + "="*70)
    print(" STEP 3: Visualize Learned SSL Features using t-SNE ".center(70, "="))
    print("="*70 + "\n")

    # Use the test data loader (with simple transforms) for visualization
    visualize_embeddings(
        encoder, # Pass the SSL-pretrained, frozen encoder
        linear_test_loader, # Use the test loader (contains labels for coloring)
        config["device"],
        config["vis_tsne_subset_size"], # Number of points to visualize
        title="t-SNE of SSL Encoder Features",
        save_path=tsne_save_path
    )

    print("\n" + "#"*70)
    print(" Execution Finished Successfully! ".center(70, "#"))
    print(f"Models and plots saved in directory: {config['save_dir']}".center(70))
    print("#"*70 + "\n")

########## TO DO ##############

1. UNDERSTAND THE CODE

2. Parameter Tweaking (Easy):
Task: Rerun the entire notebook with the following individual changes (resetting between each run) and observe the impact on final linear probing accuracy and training time. 
Change linear_epochs to 20/30/40. Is 50 epochs necessary for the linear head?
Adjust the temperature in NTXentLoss (e.g., to 0.05 and 0.5). How does the SSL loss curve change? Does it affect final accuracy?
Modify the ssl_batch_size (e.g., to 64 and potentially 256 if memory allows). How does this impact SSL loss and potentially accuracy?

3. Deeper Modifications & Analysis
Use a Different Backbone:
Task: Modify the get_resnet_encoder function and relevant config settings to use resnet50 instead of resnet18.


4. Exploration & Conceptual Understanding
Use a Different Dataset:
Task: Adapt the code to run on a different, but similarly sized, dataset like CIFAR100 or SVHN.

########## TO DO ##############
1. Parameter Tweaking (Easy):
Task: Rerun the entire notebook with the following individual changes (resetting between each run) and observe the impact on final linear probing accuracy and training time.
Change linear_epochs to 20 or 50. Is 50 epochs necessary for the linear head?
Adjust the temperature in NTXentLoss (e.g., to 0.05 and 0.5). How does the SSL loss curve change? Does it affect final accuracy?
Modify the ssl_batch_size (e.g., to 64 and potentially 256 if memory allows). How does this impact SSL loss and potentially accuracy?

2. Deeper Modifications & Analysis
Use a Different Backbone:
Task: Modify the get_resnet_encoder function and relevant config settings to use resnet50 instead of resnet18.


3. Exploration & Conceptual Understanding
Use a Different Dataset:
Task: Adapt the code to run on a different, but similarly sized, dataset like CIFAR100 or SVHN.


