In [1]:
# import libraries
import os
import math
import time
import random
import argparse
from typing import Dict, List, Tuple
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt

In [2]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [3]:
def flatten_layer(model: nn.Module, layer_name: str) -> torch.Tensor:
    """Flatten a specific layer by name."""
    mod = model
    for tok in layer_name.split('.'):
        mod = getattr(mod, tok)
    parts = [mod.weight.detach().flatten()]
    if getattr(mod, 'bias', None) is not None:
        parts.append(mod.bias.detach().flatten())
    return torch.cat(parts).cpu()

In [4]:
def flatten_whole(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.detach().flatten() for p in model.parameters()]).cpu()

In [5]:
def pca2d(X: torch.Tensor) -> torch.Tensor:
    """
    PCA to 2D via SVD. X: [N, D]. Returns Y: [N, 2].
    """
    X = X - X.mean(dim=0, keepdim=True) # center the data

    U, S, Vh = torch.linalg.svd(X, full_matrices=False)
    W = Vh[:2].T   # picking top-2 principal directions
    Y = X @ W      # projecting data onto these 2 principal directions
    return Y

In [6]:
# Task 1: Single-input Single-output function
def f_true(x: torch.Tensor) -> torch.Tensor:
    return torch.cos(2* math.pi*x) * (x**3)

def make_function_loaders(xmin=-3.0, xmax=3.0, n_train=256, batch=64, device=None):
    x = torch.linspace(xmin, xmax, n_train).unsqueeze(1)
    y = f_true(x)
    ds = TensorDataset(x, y)
    loader = DataLoader(ds, batch_size=batch, shuffle=True, drop_last=False)
    if device:
        pass
    return loader

class SimpleFunctionModel(nn.Module):
    def __init__(self, hidden=[18, 20, 15]):
        super().__init__()
        layers = []
        in_d = 1
        for h in hidden:
            layers += [nn.Linear(in_d, h), nn.Tanh()]
            in_d = h
        layers += [nn.Linear(in_d, 1)]
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

In [7]:
def compute_channel_stats(data_dir="./data"):
    """
    Compute mean and std of CIFAR-10 training set.
    Returns two lists: mean, std (each of length 3 for RGB).
    """
    # Load train set
    train_set = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True,
        transform=T.ToTensor()
    )
    loader = DataLoader(train_set, batch_size=5000, shuffle=False, num_workers=2)

    mean = 0.
    std = 0.
    nb_samples = 0

    for data, _ in loader:
        # data shape: [batch, channels, height, width]
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)  # flatten H*W
        mean += data.mean(2).sum(0)
        std  += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples

    return mean.tolist(), std.tolist()

In [8]:
def get_cifar10_loaders(
    data_dir="./data",
    batch_size=128,
    num_workers=2,
    drop_last=False
):
    # compute mean/std
    mean, std = compute_channel_stats(data_dir)
    print("CIFAR-10 stats:", mean, std)

    train_tfms = T.Compose([
        T.RandomCrop(32, padding=2),
        T.RandomHorizontalFlip(), 
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    test_tfms = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    train_set = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=train_tfms
    )
    test_set = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=test_tfms
    )

    train_loader = DataLoader(
        train_set, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True, drop_last=drop_last
    )
    test_loader = DataLoader(
        test_set, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True, drop_last=False
    )
    return train_loader, test_loader, train_set.classes

In [9]:
class CNNModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(inplace=True),   # <-- 'features.0' is first conv
            nn.MaxPool2d(2),  # 32x16x16
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 64x8x8
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*8*8, 128), nn.ReLU(inplace=True),
            nn.Linear(128, num_classes),
        )
    def forward(self, x): return self.classifier(self.features(x))


In [10]:
def get_data(task="function"):
    if task == "function":
        # single-input single-output generated data
        return train_loader, test_loader, input_dim, output_dim
    elif task == "cifar":
        # CIFAR-10 dataset
        return train_loader, test_loader, input_dim, output_dim

In [11]:
def get_model(task="function"):
    if task == "function":
        return SimpleFunctionModel(hidden_sizes=[64,64])
    elif task == "cifar":
        return CNNModel()

In [12]:
def plot_traj(xy: torch.Tensor, runs: int, pts_per_run: int, title: str, out_path: str, epoch_marks: List[int]):
    """Plot PCA trajectories for all runs."""
    plt.figure(figsize=(7.2, 5.6))
    cmap = plt.cm.viridis
    colors = cmap(np.linspace(0, 1, pts_per_run))
    for r in range(runs):
        s, e = r*pts_per_run, (r+1)*pts_per_run
        coords = xy[s:e].numpy()
        plt.plot(coords[:,0], coords[:,1], '-', alpha=0.7, linewidth=2)
        for i in range(pts_per_run):
            plt.scatter(coords[i,0], coords[i,1], s=50, c=colors[i].reshape(1, -1))
    sm = plt.cm.ScalarMappable(cmap=cmap, norm=plt.Normalize(vmin=epoch_marks[0], vmax=epoch_marks[-1]))
    sm.set_array([])
    cbar = plt.colorbar(sm, pad=0.01)
    cbar.set_label("Epoch")
    plt.title(title)
    plt.xlabel("PCA-1"); plt.ylabel("PCA-2")
    plt.grid(True, linewidth=0.3)
    plt.tight_layout()
    plt.savefig(out_path, dpi=180)
    plt.close()

In [13]:
def train_epoch_function(model, loader, opt, device):
    model.train()
    loss_fn = nn.MSELoss() # MSE loss function for single-input single-output function
    total_loss = 0.0
    total = 0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        pred = model(xb)
        loss = loss_fn(pred, yb)
        loss.backward()
        opt.step()
        total_loss += loss.item() * xb.size(0)
        total += xb.size(0)
    return total_loss / total  

@torch.no_grad()
def eval_cifar(model, loader, device):
    model.eval()
    total, correct, tot_loss = 0, 0, 0.0
    loss_fn = nn.CrossEntropyLoss(reduction='sum') # Cross entropy loss function for CIFAR dataset
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        tot_loss += loss_fn(logits, yb).item()
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
    return tot_loss/total, correct/total

def train_epoch_cifar(model, loader, opt, device):
    model.train()
    loss_fn = nn.CrossEntropyLoss()
    total, correct, tot_loss = 0, 0, 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad()
        logits = model(xb)
        loss = loss_fn(logits, yb)
        loss.backward()
        opt.step()
        tot_loss += loss.item() * xb.size(0)
        correct += (logits.argmax(1) == yb).sum().item()
        total += xb.size(0)
    return tot_loss/total, correct/total

In [14]:
# Run one experiment (task flag)
def run_experiment(task: str,
                   runs: int = 8,
                   epochs: int = 12,
                   log_every: int = 3,
                   lr: float = 1e-3,
                   weight_decay: float = 5e-4,
                   batch_size: int = 128,
                   data_dir: str = "./data"):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\n=== Task: {task} | device: {device} ===")

    epoch_marks = list(range(0, epochs+1, log_every))  
    pts_per_run = len(epoch_marks)

    # storage across runs
    layer_trajs: List[List[torch.Tensor]] = []
    whole_trajs: List[List[torch.Tensor]] = []
    metric_logs: List[List[float]] = []  # loss

    if task == "function":
        train_loader = make_function_loaders(batch=batch_size)
        layer_name = "net.0"  # first Linear
        for r in range(runs):
            set_seed(1000 + r)
            model = SimpleFunctionModel(hidden=[64, 64]).to(device)
            opt = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

            layer_ckpts = [flatten_layer(model, layer_name)]
            whole_ckpts = [flatten_whole(model)]
            metric_ckpt = []  

            for ep in range(1, epochs+1):
                tr_loss = train_epoch_function(model, train_loader, opt, device)
                if ep % log_every == 0:
                    metric_ckpt.append(tr_loss)
                    layer_ckpts.append(flatten_layer(model, layer_name))
                    whole_ckpts.append(flatten_whole(model))
                # print(f"[Run {r+1}/{runs}] Epoch {ep:02d}/{epochs} | TrainLoss {tr_loss:.4f}")

            layer_trajs.append(layer_ckpts)
            whole_trajs.append(whole_ckpts)
            metric_logs.append(metric_ckpt)

    elif task == "cifar":
        train_loader, test_loader, _ = get_cifar10_loaders(data_dir=data_dir, batch_size=batch_size)
        layer_name = "features.0"  # first Conv2d
        for r in range(runs):
            set_seed(2000 + r)
            model = CNNModel().to(device)
            opt = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

            layer_ckpts = [flatten_layer(model, layer_name)]
            whole_ckpts = [flatten_whole(model)]
            metric_ckpt = []  

            for ep in range(1, epochs+1):
                tr_loss, tr_acc = train_epoch_cifar(model, train_loader, opt, device)
                if ep % log_every == 0:
                    metric_ckpt.append(tr_loss)  
                    layer_ckpts.append(flatten_layer(model, layer_name))
                    whole_ckpts.append(flatten_whole(model))
                # print(f"[Run {r+1}/{runs}] Epoch {ep:02d}/{epochs} | TrainLoss {tr_loss:.4f} | TrainAcc {tr_acc*100:5.2f}%")

            layer_trajs.append(layer_ckpts)
            whole_trajs.append(whole_ckpts)
            metric_logs.append(metric_ckpt)

    else:
        raise ValueError("task must be 'function' or 'cifar'")
        
    # PCA fit on ALL checkpoints across ALL runs
    layer_mat = torch.stack([v for run in layer_trajs for v in run], dim=0)  
    whole_mat = torch.stack([v for run in whole_trajs for v in run], dim=0)
    layer_xy = pca2d(layer_mat)
    whole_xy = pca2d(whole_mat)

    # Plots
    plot_traj(layer_xy, runs, pts_per_run,
              f"Optimization Trajectories (First Layer) [{task}]",
              f"HW_1-2-VisualizeOptimization/opt_traj_layer_{task}.png",
              epoch_marks)
    plot_traj(whole_xy, runs, pts_per_run,
              f"Optimization Trajectories (Whole model) [{task}]",
              f"HW_1-2-VisualizeOptimization/opt_traj_whole_{task}.png",
              epoch_marks)

    print("\nSaved figures:")
    print(f" - HW_1-2-VisualizeOptimization/opt_traj_layer_{task}.png")
    print(f" - HW_1-2-VisualizeOptimization/opt_traj_whole_{task}.png")

In [15]:
def main(task="both",
                 runs=8,
                 epochs=12,
                 log_every=3,
                 lr=1e-3,
                 weight_decay=5e-4,
                 batch_size=128,
                 data_dir="./data"):

    if task in ("function", "both"):
        run_experiment("function", runs=runs, epochs=epochs, log_every=log_every,
                       lr=lr, weight_decay=weight_decay, batch_size=batch_size,
                       data_dir=data_dir)

    if task in ("cifar", "both"):
        run_experiment("cifar", runs=runs, epochs=epochs, log_every=log_every,
                       lr=lr, weight_decay=weight_decay, batch_size=batch_size,
                       data_dir=data_dir)

In [None]:
if __name__ == "__main__":
    main(task="function", epochs=1000)
    main(task="cifar", epochs=20)


=== Task: function | device: cuda ===


  cbar = plt.colorbar(sm, pad=0.01)



Saved figures:
 - HW_1-2-VisualizeOptimization/opt_traj_layer_function.png
 - HW_1-2-VisualizeOptimization/opt_traj_whole_function.png

=== Task: cifar | device: cuda ===
Files already downloaded and verified




CIFAR-10 stats: [0.4913996756076813, 0.4821583926677704, 0.44653093814849854] [0.20230092108249664, 0.19941280782222748, 0.20096160471439362]
Files already downloaded and verified
Files already downloaded and verified
