# Global Convergence and Geometry of Contrastive Learning through Temperature Annealing

## Mounting Google Drive 

In [None]:
from google.colab import drive
drive.mount('/content/drive')

## Main Code

### Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as T
import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import os
import json
import time
import copy

%matplotlib inline

### Helper Functions (e.g., normalize, cosine_sim), Data Augmentations, Dataset

In [None]:
def normalize(v):
    # Add small epsilon for numerical stability if norms can be zero
    return v / (torch.norm(v, dim=-1, keepdim=True) + 1e-9)

def cosine_sim(a, b):
    # a: (N, D), b: (M, D) -> (N, M)
    return torch.matmul(a, b.T)

def get_cifar10_contrastive_transforms():
    # SimCLR-style augmentations for CIFAR-10 (32x32 images)
    s = 1.0 # Strength of color jitter
    color_jitter = T.ColorJitter(0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s)
    transform = T.Compose([
        # RandomResizedCrop might be aggressive for 32x32, use RandomCrop + padding.
        T.RandomCrop(32, padding=4),
        T.RandomHorizontalFlip(),
        T.RandomApply([color_jitter], p=0.8),
        T.RandomGrayscale(p=0.2),
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # CIFAR-10 stats
    ])
    return transform

class ContrastiveCIFAR10Dataset(Dataset):
    # Automatically downloads if not present in root
    def __init__(self, root='./data', train=True, transform=None):
        self.cifar10 = torchvision.datasets.CIFAR10(root=root, train=train, download=True)
        self.transform = transform

    def __len__(self):
        return len(self.cifar10)

    def __getitem__(self, idx):
        img, _ = self.cifar10[idx] # Ignore label for contrastive pre-training
        if self.transform:
            view1 = self.transform(img)
            view2 = self.transform(img)
            return view1, view2
        else:
            # Should always have transform for contrastive learning
            return img, img

# Dataset and Transform for Linear Probe
def get_cifar10_linear_probe_transform():
     # Only basic normalization needed for evaluation
     return T.Compose([
        T.ToTensor(),
        T.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])


### Model Definition (ResNet + Projection Head) + InfoNCE Loss

In [None]:
class ContrastiveResNet(nn.Module):
    # Using ResNet18 as backbone for CIFAR-10
    def __init__(self, projection_dim=128, backbone_name='resnet18'):
        super().__init__()
        if backbone_name == 'resnet18':
            #! Can load pre-trained on ImageNet if needed.
            # weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
            # backbone = torchvision.models.resnet18(weights=weights)
            backbone = torchvision.models.resnet18(weights=None) # Train from scratch
            backbone_dim = backbone.fc.in_features
            backbone.fc = nn.Identity() # Remove final classifier layer
        #TODO Maybe add other backbones like resnet34, resnet50?
        else:
            raise ValueError("Unsupported backbone")

        self.backbone = backbone
        # MLP Projection Head (SimCLR style: non-linear projection improves representations)
        self.projection_head = nn.Sequential(
            nn.Linear(backbone_dim, backbone_dim),
            nn.ReLU(),
            nn.Linear(backbone_dim, projection_dim)
        )

    def forward_backbone(self, x):
        # Get features from the backbone ONLY (for linear probe)
        return self.backbone(x)

    def forward(self, x):
        # Full forward pass for contrastive training
        features = self.backbone(x)
        projection = self.projection_head(features)
        return normalize(projection) # Normalize final output for contrastive loss

# InfoNCE Loss for Views
def infoNCE_views(view1_embs, view2_embs, beta, device='cpu'):
    """Calculates InfoNCE loss for two views using logsumexp for stability."""
    batch_size = view1_embs.shape[0]
    embeddings = torch.cat([view1_embs, view2_embs], dim=0) # Shape (2B, D)
    embeddings = normalize(embeddings) # Ensure normalized

    # Calculate similarity matrix (cosine similarity)
    sim_matrix = cosine_sim(embeddings, embeddings) # (2B, 2B)

    # --- Calculate Loss using LogSumExp ---
    # Loss for anchor i matching positive j: -log[ exp(sim_ij*beta) / sum_{k!=i} exp(sim_ik*beta) ]
    # = - ( sim_ij*beta - log[sum_{k!=i} exp(sim_ik*beta)] )
    # = log[sum_{k!=i} exp(sim_ik*beta)] - sim_ij*beta

    # Scale similarities by beta
    scaled_sim_matrix = sim_matrix * beta

    # Mask to exclude self-similarity
    diag_mask = torch.eye(2 * batch_size, dtype=torch.bool, device=device)
    logits_mask = ~diag_mask

    # Calculate logsumexp for the denominator term for each row (anchor)
    # Need log(sum_{k!=i} exp(beta * s_ik)) = logsumexp(beta * s_ik for k!=i)

    # Apply mask before logsumexp to exclude diagonal
    # Create matrix where diagonal is -inf (or a very small number) so exp(diag) is zero
    large_neg = -torch.finfo(scaled_sim_matrix.dtype).max # A very large negative number
    masked_scaled_sim = scaled_sim_matrix.masked_fill(diag_mask, large_neg)

    # Now compute logsumexp row-wise
    log_denominators = torch.logsumexp(masked_scaled_sim, dim=1) # Shape (2B,)

    # Get the scaled positive similarities (numerator terms before log)
    scaled_sim_i_iB = torch.diag(scaled_sim_matrix, batch_size) # beta * s_{i, i+B}
    scaled_sim_iB_i = torch.diag(scaled_sim_matrix, -batch_size) # beta * s_{i+B, i}

    # Calculate loss for pairs (i, i+B) using denominator for anchor i
    loss_i = log_denominators[:batch_size] - scaled_sim_i_iB

    # Calculate loss for pairs (i+B, i) using denominator for anchor i+B
    loss_iB = log_denominators[batch_size:] - scaled_sim_iB_i

    # Total loss is the mean over all 2*B pairs
    loss = (loss_i.sum() + loss_iB.sum()) / (2 * batch_size)

    return loss

### Annealing Schedules

In [None]:
def get_beta(t, total_epochs, schedule_type, beta_low, beta_high, gamma, c_factor): # Added sqrt_c_factor default
    """ Get beta for current epoch t (0-indexed) """
    epoch_idx = t # Use 0-based index t consistently
    total_steps = total_epochs # T

    # Ensure beta_low/high are valid
    beta_low = max(beta_low, 1e-9) # Ensure positive for ratios/logs if needed
    beta_high = max(beta_high, beta_low) # Ensure high >= low

    # Handle edge case of zero epochs
    if total_steps <= 0: return beta_low

    # --- Fixed Schedules ---
    if schedule_type == 'fixed_low':
        return beta_low
    elif schedule_type == 'fixed_high':
        return beta_high

    # --- Annealing Schedules ---
    if schedule_type == 'log':
        if total_epochs == 1: return beta_low # Avoid log(2) if only 1 epoch

        log_range = np.log(total_steps + 1)
        if log_range <= 1e-9: return beta_low # Avoid division by zero/small number if T=0
        c = (beta_high - beta_low) / log_range
        # Calculate base beta for current epoch
        current_beta = beta_low + c * np.log(epoch_idx + 2) # Use log(t+2)
        # Apply scaling factor to the change from beta_low
        current_beta = beta_low + (current_beta - beta_low) * c_factor
        return np.clip(current_beta, beta_low, beta_high)

    elif schedule_type == 'linear_tau':
        # Linear tau decay from tau_start = 1/beta_low to tau_end = 1/beta_high
        tau_start = 1.0 / beta_low
        tau_end = 1.0 / beta_high if beta_high > 1e-9 else float('inf')
        # Handle infinite tau edge cases safely
        if tau_end == float('inf') and tau_start == float('inf'): return beta_low
        if tau_end == float('inf'): tau_end = tau_start + 1.0
        if tau_start == float('inf'): return beta_low # Cannot start from infinite tau if end is finite
        # Interpolate tau using 1-based epoch progress (t+1)/T
        progress = (epoch_idx + 1) / total_steps
        tau_t = tau_start + (tau_end - tau_start) * progress
        return 1.0 / max(tau_t, 1e-9) # Return beta, ensure tau_t doesn't hit zero

    elif schedule_type == 'linear_beta':
        # 1‑based progress
        progress = (epoch_idx + 1) / total_steps
        # vanilla interpolation
        base_beta = beta_low + (beta_high - beta_low) * progress
        # scale the change from beta_low by c_factor
        scaled  = beta_low + (base_beta - beta_low) * c_factor
        return float(np.clip(scaled, beta_low, beta_high))

    elif schedule_type == 'sqrt_beta': # Interpolating beta based on sqrt progress
         if total_epochs <= 0: return beta_low
         # Progress based on sqrt(t+1)/sqrt(T)
         progress = np.sqrt(epoch_idx + 1) / np.sqrt(total_steps)
         # Apply optional scaling factor
         progress = progress * c_factor
         # Ensure progress doesn't exceed 1 after scaling
         progress = min(progress, 1.0)
         current_beta = beta_low + (beta_high - beta_low) * progress
         return np.clip(current_beta, beta_low, beta_high)

    else:
        raise ValueError(f"Unknown schedule type: {schedule_type}")

### Linear Probe Evaluation

In [None]:
@torch.no_grad()
def get_embeddings(encoder_backbone, loader, device):
    """Extract backbone features for all samples in loader."""
    encoder_backbone.eval()
    all_features = []
    all_labels = []
    for images, labels in loader:
        images = images.to(device)
        features = encoder_backbone(images)
        all_features.append(features.cpu())
        all_labels.append(labels.cpu())
    return torch.cat(all_features), torch.cat(all_labels)

def train_linear_probe(encoder_backbone, device, config):
    """Train and evaluate a linear classifier on frozen features."""
    print("\n--- Training and Evaluating Linear Probe ---")

    # Get features from the training set using the frozen encoder
    train_transform = get_cifar10_linear_probe_transform()
    probe_train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=train_transform)
    # Use double batch size for faster feature extraction
    probe_train_loader = DataLoader(probe_train_dataset, batch_size=config['batch_size'] * 2, shuffle=False, num_workers=config.get('num_workers', 2))
    print("Extracting training features...")
    X_train, y_train = get_embeddings(encoder_backbone, probe_train_loader, device)
    print(f"Training features shape: {X_train.shape}")

    # 2. Get features from the test set
    test_transform = get_cifar10_linear_probe_transform()
    probe_test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=test_transform)
    probe_test_loader = DataLoader(probe_test_dataset, batch_size=config['batch_size'] * 2, shuffle=False, num_workers=config.get('num_workers', 2))
    print("Extracting testing features...")
    X_test, y_test = get_embeddings(encoder_backbone, probe_test_loader, device)
    print(f"Testing features shape: {X_test.shape}")

    # Train a linear classifier (Logistic Regression)
    print("Training logistic regression classifier...")
    classifier = LogisticRegression(random_state=config['seed'], max_iter=1000, C=1.0, solver='liblinear') # Liblinear often good for this
    classifier.fit(X_train.numpy(), y_train.numpy())

    # 4. Evaluate the classifier
    y_pred = classifier.predict(X_test.numpy())
    accuracy = accuracy_score(y_test.numpy(), y_pred) * 100
    print(f"Linear Probe Test Accuracy: {accuracy:.2f}%")
    return accuracy

### Main Training Function

In [None]:
def main_cifar10(config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    start_run_time = time.time()

    # Extract config vars
    seed = config['seed']
    schedule = config['schedule']
    output_dir = config['output_dir']
    no_save = config.get('no_save', False)
    epochs = config['epochs']
    num_workers = config.get('num_workers', 4) # Default workers

    # Setup output dir
    run_dir = None
    if not no_save:
        run_dir = os.path.join(output_dir, f"schedule_{schedule}_c{config['c_factor']}_seed_{seed}")
        os.makedirs(run_dir, exist_ok=True)
        with open(os.path.join(run_dir, 'config.json'), 'w') as f:
            json.dump(config, f, indent=4)

    # Set seed
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

    # Data
    contrastive_transform = get_cifar10_contrastive_transforms()
    contrastive_dataset = ContrastiveCIFAR10Dataset(train=True, transform=contrastive_transform)
    contrastive_loader = DataLoader(contrastive_dataset, batch_size=config['batch_size'], shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)

    # Model
    model = ContrastiveResNet(projection_dim=config['projection_dim']).to(device)
    optimizer = optim.Adam(model.parameters(), lr=config['lr'], weight_decay=1e-6)
    #TODO Learning rate scheduler might be better?
    # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=0) # Example scheduler

    # Training Loop
    losses = []
    print(f"\n--- Training CIFAR-10 with schedule: {schedule} ---")
    epoch_times = []
    for epoch in range(epochs):
        epoch_start_time = time.time()
        model.train()
        epoch_loss = 0.0
        num_batches = 0
        current_beta = get_beta(epoch, epochs, schedule, config['beta_low'], config['beta_high'], config['exp_gamma'], config['c_factor'])

        for view1, view2 in contrastive_loader:
            view1, view2 = view1.to(device, non_blocking=True), view2.to(device, non_blocking=True)

            optimizer.zero_grad()
            emb1 = model(view1)
            emb2 = model(view2)
            loss = infoNCE_views(emb1, emb2, current_beta, device)

            # Handle potential NaN loss
            if torch.isnan(loss):
                print(f"Warning: NaN loss detected at epoch {epoch+1}, batch {num_batches+1}. Beta={current_beta}. Skipping update.")
                # For now, just skip optimizer step
                continue
            else:
                 loss.backward()
                 optimizer.step()

            epoch_loss += loss.item()
            num_batches += 1

        avg_epoch_loss = epoch_loss / max(num_batches, 1)
        losses.append(avg_epoch_loss)

        # Step the scheduler if using one
        # if scheduler: scheduler.step()

        epoch_end_time = time.time()
        epoch_duration = epoch_end_time - epoch_start_time
        epoch_times.append(epoch_duration)

        # Print progress every 10 epochs or at the end
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1 or epoch == 0:
             print(f"Epoch [{epoch+1}/{epochs}], Beta: {current_beta:.2f}, Loss: {avg_epoch_loss:.4f}, Time: {epoch_duration:.2f}s")


    # Final Evaluation: Linear Probe
    # Extract the backbone from the trained model
    final_encoder_backbone = model.backbone
    final_probe_accuracy = train_linear_probe(final_encoder_backbone, device, config)

    total_run_time = time.time() - start_run_time
    print(f"Finished Contrastive Training. Total Time: {total_run_time:.2f}s")

    # Save results
    results = {
        'losses': losses,
        'final_probe_accuracy': final_probe_accuracy,
        'epoch_times': epoch_times,
        'total_training_time': total_run_time
    }
    if not no_save and run_dir:
        # Save only the backbone state_dict if needed for probing later
        torch.save(model.backbone.state_dict(), os.path.join(run_dir, 'final_backbone.pth'))
        # Or save the full model
        torch.save(model.state_dict(), os.path.join(run_dir, 'final_contrastive_model.pth'))
        with open(os.path.join(run_dir, 'results.json'), 'w') as f:
            json.dump(results, f, indent=4)
        print(f"Results saved to {run_dir}")

        # Plotting
        fig, ax1 = plt.subplots(figsize=(8, 5))

        color = 'tab:red'
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Average Contrastive Loss', color=color)
        ax1.plot(range(1, epochs + 1), losses, color=color)
        ax1.tick_params(axis='y', labelcolor=color)
        ax1.grid(True)

        # # For plotting beta schedule on secondary axis
        # ax2 = ax1.twinx()
        # betas = [get_beta(e, epochs, schedule, config['beta_low'], config['beta_high'], config['exp_gamma'], config['log_c_factor']) for e in range(epochs)]
        # color = 'tab:blue'
        # ax2.set_ylabel('Beta (Inverse Temp)', color=color)
        # ax2.plot(range(1, epochs + 1), betas, color=color, linestyle=':')
        # ax2.tick_params(axis='y', labelcolor=color)

        fig.tight_layout()
        plt.title(f"Loss Curve (Schedule: {schedule}) - Final Probe Acc: {final_probe_accuracy:.2f}%")
        plt.savefig(os.path.join(run_dir, "contrastive_loss_curve.png"))
        plt.show()

    return results

### Config + Experiments

In [None]:
cifar_config = {
    'seed': 1000,
    'd_embed': 128,              # Standard embedding dimension for ResNet features
    'projection_dim': 128,       # Dimension after projection head
    'epochs': 100,               # e.g., 200-400 is common
    'batch_size': 256,           # Adjust based on GPU memory (256-512 is common for ResNets)
    'lr': 3e-4,                  # Common starting point for Adam with ResNets
    'beta_low': 1.0,             # Corresponds to tau=1.0
    'beta_high': 10000000.0,     # Corresponds to tau=0.0000001 (Common SimCLR value)
    'exp_gamma': 0.95,           # Gamma for exponential TAU decay (adjust if needed)
    'c_factor': 1,             # Scaling for schedules
    'output_dir': '/content/drive/MyDrive/colab_outputs/cifar10_results', # Output directory for results
    'no_save': False,            # Set to True to disable saving models/logs/plots
    'num_workers': 2             # Number of workers for DataLoader (adjust based on system)
}

all_results_cifar = {}
# schedules_to_run_cifar = ['fixed_low', 'fixed_high', 'log', 'linear_beta', 'linear_tau', 'sqrt_beta']
schedules_to_run_cifar = ['fixed_low', 'fixed_high', 'log', 'sqrt_beta']

for sched in schedules_to_run_cifar:
    print(f"\n{'='*20} RUNNING SCHEDULE: {sched} {'='*20}")
    current_config = cifar_config.copy()
    current_config['schedule'] = sched

    results = main_cifar10(current_config)
    all_results_cifar[sched] = results
    print(f"{'='*20} COMPLETED SCHEDULE: {sched} {'='*20}")

# Aggregate and Compare Results
print("\n--- Comparison Across Schedules (CIFAR-10) ---")
print("| Schedule     | Final Probe Acc (%) | Final Loss   | Avg Epoch Time (s) |")
print("|--------------|---------------------|--------------|--------------------|")
for sched, res in all_results_cifar.items():
     if res:
         avg_epoch_time = np.mean(res['epoch_times']) if res['epoch_times'] else -1
         print(f"| {sched:<12} | {res['final_probe_accuracy']:<19.2f} | {res['losses'][-1]:<12.4f} | {avg_epoch_time:<18.2f} |")
     else:
         print(f"| {sched:<12} | N/A                 | N/A          | N/A                |")

# Plot comparison of final probe accuracies
plt.figure(figsize=(8, 5))
schedule_names_cifar = list(all_results_cifar.keys())
final_probe_scores = [all_results_cifar[s]['final_probe_accuracy'] if all_results_cifar[s] else 0 for s in schedule_names_cifar]
bars = plt.bar(schedule_names_cifar, final_probe_scores)
plt.xlabel("Annealing Schedu    le")
plt.ylabel("Final Linear Probe Accuracy (%)")
plt.title("Comparison of Final Probe Accuracy Across Schedules (CIFAR-10)")
plt.grid(axis='y')
# Add accuracy values on top of bars
for bar in bars:
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval, f'{yval:.2f}%', va='bottom', ha='center') # va: vertical alignment

if not cifar_config['no_save']:
    plt.savefig(os.path.join(cifar_config['output_dir'], "comparison_probe_accuracy.png"))
plt.show()