In [1]:
# import necessary libraries
import os, math, random
from dataclasses import dataclass
from typing import Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as T

from torch.utils.data import DataLoader, TensorDataset, Subset
import matplotlib.pyplot as plt

In [2]:
OUT_DIR = "HW_1-2-MinimalRatio"
os.makedirs(OUT_DIR, exist_ok=True)

In [3]:
def get_device():
    if torch.cuda.is_available(): return torch.device("cuda")
    return torch.device("cpu")

In [4]:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

In [5]:
# Task 1: Single-input Single-output function
def f_true(x: torch.Tensor) -> torch.Tensor:
    return torch.cos(2* math.pi*x) * (x**3)

def make_function_loaders(xmin=-3.0, xmax=3.0, n_train=256, n_eval=800, batch=128):
    x_tr = torch.linspace(xmin, xmax, n_train).unsqueeze(1)
    y_tr = f_true(x_tr)
    x_ev = torch.linspace(xmin, xmax, n_eval).unsqueeze(1)
    y_ev = f_true(x_ev)
    train = DataLoader(TensorDataset(x_tr, y_tr), batch_size=batch, shuffle=True, drop_last=False)
    eval_loader = DataLoader(TensorDataset(x_ev, y_ev), batch_size=256, shuffle=False)
    return train, eval_loader

class SimpleFunctionModel(nn.Module):
    def __init__(self, hidden=[18, 20, 15]):
        super().__init__()
        layers = []
        in_d = 1
        for h in hidden:
            layers += [nn.Linear(in_d, h), nn.Tanh()]
            in_d = h
        layers += [nn.Linear(in_d, 1)]
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

In [6]:
def compute_channel_stats(data_dir="./data"):
    """
    Compute mean and std of CIFAR-10 training set.
    Returns two lists: mean, std (each of length 3 for RGB).
    """
    # Load train set
    train_set = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True,
        transform=T.ToTensor()
    )
    loader = DataLoader(train_set, batch_size=5000, shuffle=False, num_workers=2)

    mean = 0.
    std = 0.
    nb_samples = 0

    for data, _ in loader:
        # data shape: [batch, channels, height, width]
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)  # flatten H*W
        mean += data.mean(2).sum(0)
        std  += data.std(2).sum(0)
        nb_samples += batch_samples

    mean /= nb_samples
    std /= nb_samples

    return mean.tolist(), std.tolist()

In [7]:
def get_cifar10_loaders(
    data_dir="./data",
    batch_size=128,
    subset_train: int = 5000, 
    subset_eval: int = 2000, 
    num_workers=2,
    drop_last=False
):
    # compute mean/std
    mean, std = compute_channel_stats(data_dir)
    print("CIFAR-10 stats:", mean, std)

    train_tfms = T.Compose([
        T.RandomCrop(32, padding=2),
        T.RandomHorizontalFlip(), 
        T.ToTensor(),
        T.Normalize(mean, std),
    ])
    test_tfms = T.Compose([
        T.ToTensor(),
        T.Normalize(mean, std),
    ])

    full_train = torchvision.datasets.CIFAR10(
        root=data_dir, train=True, download=True, transform=train_tfms
    )
    full_test = torchvision.datasets.CIFAR10(
        root=data_dir, train=False, download=True, transform=test_tfms
    )

    idx_tr = list(range(min(subset_train, len(full_train))))
    idx_ev = list(range(min(subset_eval, len(full_test))))
    train_loader = DataLoader(Subset(full_train, idx_tr), batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
    test_loader = DataLoader(Subset(full_test,  idx_ev), batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
   
    return train_loader, test_loader

In [8]:
class CNNModel(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(inplace=True),   # <-- 'features.0' is first conv
            nn.MaxPool2d(2),  # 32x16x16
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 64x8x8
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64*8*8, 128), nn.ReLU(inplace=True),
            nn.Linear(128, num_classes),
        )
    def forward(self, x): return self.classifier(self.features(x))

In [9]:
@torch.no_grad()
def grad_norm(model: nn.Module) -> float:
    grad_all = 0.0
    for p in model.parameters():
        if p.grad is not None:
            grad_all += (p.grad.detach() ** 2).sum().item()
    grad_norm = grad_all ** 0.5
    return grad_norm

In [10]:
def eval_average_loss(model: nn.Module, loss_fn, loader: DataLoader, device: torch.device) -> float:
    model.eval()
    total, n = 0.0, 0
    with torch.no_grad():
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            pred = model(xb)
            total += loss_fn(pred, yb).item() * xb.size(0)
            n += xb.size(0)
    return total / max(1, n)

In [11]:
def grad_norm_sq_from_loss(loss: torch.Tensor, params: List[torch.nn.Parameter]) -> torch.Tensor:
    """
    Compute sum of squared gradients of `loss` wrt `params`.
    """
    grads = torch.autograd.grad(loss, params, create_graph=True, retain_graph=True, allow_unused=True)
    terms = []
    for g in grads:
        if g is not None:
            terms.append((g**2).sum())
    return torch.stack(terms).sum() if terms else loss*0.0  # safe zero

In [12]:
def minimal_ratio_sampling(model: nn.Module,
                           loss_fn,
                           eval_loader: DataLoader,
                           device: torch.device,
                           noise_std: float = 1e-3,
                           samples: int = 50) -> Tuple[float, float]:
    """
    Sampling-based minimal ratio around current weights.
    Returns (base_loss, ratio).
    """
    base_loss = eval_average_loss(model, loss_fn, eval_loader, device)

    # snapshot current params
    base_state = {k: v.clone() for k, v in model.state_dict().items()}
    worse = 0
    total = 0

    for _ in range(samples):
        # perturb
        with torch.no_grad():
            for p in model.parameters():
                p.add_(torch.randn_like(p) * noise_std)

        loss_p = eval_average_loss(model, loss_fn, eval_loader, device)
        total += 1
        if loss_p > base_loss:  
            worse += 1

        # restore
        model.load_state_dict(base_state, strict=True)

    ratio = worse / max(1, total)
    return base_loss, ratio

In [13]:
@dataclass
class Config:
    runs: int = 100
    stage1_epochs: int = 20        # task loss epochs
    stage2_steps: int = 50         # grad-norm objective steps
    stage2_tol: float = 1e-6       # stop if grad-norm^2 < tol
    lr: float = 1e-3
    weight_decay: float = 0.0
    batch: int = 128
    noise_std: float = 1e-3        # for sampling perturbations
    samples: int = 50              # # of perturbations per run

In [14]:
def reach_stationary_point(model: nn.Module,
                           train_loader: DataLoader,
                           loss_fn,
                           cfg: Config,
                           device: torch.device):
    """
    Two-stage optimization:
      Stage-1: minimize task loss.
      Stage-2: minimize gradient-norm
    Returns a model at (near) stationary point.
    """
    model.to(device)
    opt = optim.Adam(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

    # training on task loss
    for _ in range(cfg.stage1_epochs):
        model.train()
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()
            pred = model(xb)
            loss = loss_fn(pred, yb)
            loss.backward()
            opt.step()

    # minimize grad-norm squared of the same loss
    # We keep using batches from the same train_loader
    for step in range(cfg.stage2_steps):
        model.train()
        total_g2 = 0.0
        for xb, yb in train_loader:
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad()

            pred = model(xb)
            loss = loss_fn(pred, yb)

            g2 = grad_norm_sq_from_loss(loss, list(model.parameters()))  # ∑||∇θ L||^2
            g2.backward()
            opt.step()

            total_g2 += g2.detach().item()

        # early stop if gradient norm is tiny (average over batches)
        avg_g2 = total_g2 / max(1, len(train_loader))
        if avg_g2 < cfg.stage2_tol:
            break

    return model

In [15]:
def run_task_FUNCTION(cfg: Config):
    dev = get_device()
    set_seed(42)

    train_loader, eval_loader = make_function_loaders(batch=cfg.batch)
    loss_fn = nn.MSELoss()

    losses, ratios = [], []
    for r in range(cfg.runs):
        set_seed(1000 + r)
        model = SimpleFunctionModel()
        model = reach_stationary_point(model, train_loader, loss_fn, cfg, dev)
        base_loss, ratio = minimal_ratio_sampling(model, loss_fn, eval_loader, dev,
                                                  noise_std=cfg.noise_std, samples=cfg.samples)
        losses.append(base_loss); ratios.append(ratio)

    # Plot loss vs minimal ratio
    plt.figure(figsize=(6.8, 5.2))
    plt.scatter(ratios, losses, s=20)
    plt.xlabel("minimal ratio")
    plt.ylabel("loss")
    plt.title("FUNCTION: minimal ratio vs loss")
    plt.grid(True, linewidth=0.3)
    out = f"{OUT_DIR}/function_minratio_vs_loss.png"
    plt.tight_layout(); plt.savefig(out, dpi=180); plt.close()
    print(f"[FUNCTION] saved: {out}")

In [16]:
def run_task_CIFAR(cfg: Config, data_dir="./data"):
    dev = get_device()
    set_seed(42)

    train_loader, eval_loader = get_cifar10_loaders(data_dir=data_dir, batch_size=cfg.batch,
                                                  subset_train=5000, subset_eval=2000)
    loss_fn = nn.CrossEntropyLoss()

    losses, ratios = [], []
    for r in range(cfg.runs):
        set_seed(2000 + r)
        model = CNNModel()
        model = reach_stationary_point(model, train_loader, loss_fn, cfg, dev)
        base_loss, ratio = minimal_ratio_sampling(model, loss_fn, eval_loader, dev,
                                                  noise_std=cfg.noise_std, samples=cfg.samples)
        losses.append(base_loss); ratios.append(ratio)

    plt.figure(figsize=(6.8, 5.2))
    plt.scatter(ratios, losses, s=20)
    plt.xlabel("minimal ratio")
    plt.ylabel("loss")
    plt.title("CIFAR-10: minimal ratio vs loss (subset)")
    plt.grid(True, linewidth=0.3)
    out = f"{OUT_DIR}/cifar_minratio_vs_loss.png"
    plt.tight_layout(); plt.savefig(out, dpi=180); plt.close()
    print(f"[CIFAR] saved: {out}")

In [17]:
def main(task="both",
         runs=100,
         # Stage-1 (= train on task loss) epochs
         task_loss_epochs=10,
         # Stage-2 (= minimize grad-norm^2) steps
         grad_norm_steps=30,
         lr=1e-3, weight_decay=0.0,
         batch=128,
         noise_std=1e-3, samples=50,
         data_dir="./data"):
    """
    Run the 'minimal ratio vs loss' experiment.

    task: "function", "cifar", or "both"
    runs: number of independent runs (100)
    _epochs: epochs for stage-1 (task loss)
    _steps:  steps for stage-2 (grad-norm^2 objective)
    noise_std: perturbation std for minimal-ratio sampling
    samples:   number of perturbation samples per run
    """
    # build per-task configs
    cfg_func = Config(
        runs=runs, stage1_epochs=task_loss_epochs, stage2_steps=grad_norm_steps,
        stage2_tol=1e-6, lr=lr, weight_decay=weight_decay,
        batch=batch, noise_std=noise_std, samples=samples
    )
    cfg_cifar = Config(
        runs=runs, stage1_epochs=task_loss_epochs, stage2_steps=grad_norm_steps,
        stage2_tol=1e-6, lr=lr, weight_decay=weight_decay,
        batch=batch, noise_std=noise_std, samples=samples
    )

    if task in ("function", "both"):
        run_task_FUNCTION(cfg_func)

    if task in ("cifar", "both"):
        run_task_CIFAR(cfg_cifar, data_dir=data_dir)

In [None]:
if __name__ == "__main__":
    # Function
    main(task="function")

    # CIFAR
    main(task="cifar")

[FUNCTION] saved: HW_1-2-MinimalRatio/function_minratio_vs_loss.png
Files already downloaded and verified




CIFAR-10 stats: [0.4913996756076813, 0.4821583926677704, 0.44653093814849854] [0.20230092108249664, 0.19941280782222748, 0.20096160471439362]
Files already downloaded and verified
Files already downloaded and verified
