# ZSharp: Sharpness-Aware Minimization with Z-Score Gradient Filtering

This notebook demonstrates the ZSharp optimizer and compares it with standard SGD on CIFAR-10.

## What is ZSharp?

ZSharp extends SAM (Sharpness-Aware Minimization) with intelligent gradient filtering:

1. **Z-score Normalization**: Normalizes gradients within each layer
2. **Percentile-based Filtering**: Keeps only the most important gradients (configurable threshold)
3. **SAM Perturbation**: Applies filtered gradients to SAM's two-step optimization

**Key Benefits:**
- 5.26% improvement over SGD on CIFAR-10
- Better generalization with lower test loss
- Reduced gradient noise and improved training stability
- Memory efficient gradient processing

## Setup and Installation

In [None]:
! [ ! -d "zsharp" ] && git clone https://github.com/bangyen/zsharp.git
! cd zsharp && pip install -e .

print("Setup complete!")

## Imports and Configuration

In [None]:
import os

os.chdir('./zsharp')
print(f"Current working directory: {os.getcwd()}")

In [None]:

import numpy as np
import torch
import torchvision
from torch import nn, optim
from torchvision import transforms
from tqdm import tqdm

from src.models import get_model
from src.optimizer import ZSharp

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## ZSharp vs SGD


In [None]:
# Constants for demo limits
MAX_TRAINING_BATCHES = 50
MAX_EVAL_SAMPLES = 1000

# Quick demo - simplified version for Colab
def quick_demo():
    print("ZSharp Demo - Comparing with SGD")
    print("=" * 50)

    # Load CIFAR-10
    transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
            ),
        ]
    )

    trainset = torchvision.datasets.CIFAR10(
        root="./data", train=True, download=True, transform=transform
    )
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=128, shuffle=True, num_workers=2
    )

    testset = torchvision.datasets.CIFAR10(
        root="./data", train=False, download=True, transform=transform
    )
    testloader = torch.utils.data.DataLoader(
        testset, batch_size=128, shuffle=False, num_workers=2
    )

    # Create models
    model_sgd = get_model("resnet18", num_classes=10).to(device)
    model_zsharp = get_model("resnet18", num_classes=10).to(device)

    # Optimizers
    optimizer_sgd = optim.SGD(
        model_sgd.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4
    )
    optimizer_zsharp = ZSharp(
        model_zsharp.parameters(),
        base_optimizer=optim.SGD,
        rho=0.05,
        percentile=70,
        lr=0.01,
        momentum=0.9,
        weight_decay=1e-4,
    )

    criterion = nn.CrossEntropyLoss()

    print("Training for 3 epochs...")

    # Train both models
    for epoch in range(3):
        print(f"\nEpoch {epoch + 1}/3")

        # SGD Training
        model_sgd.train()
        sgd_loss = 0
        sgd_pbar = tqdm(
            enumerate(trainloader), total=MAX_TRAINING_BATCHES, desc="SGD Training", leave=False
        )
        for batch_idx, (data, target) in sgd_pbar:
            if batch_idx >= MAX_TRAINING_BATCHES:  # Limit for demo
                break
            data, target = data.to(device), target.to(device)
            optimizer_sgd.zero_grad()
            output = model_sgd(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer_sgd.step()
            sgd_loss += loss.item()
            sgd_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        # ZSharp Training
        model_zsharp.train()
        zsharp_loss = 0
        zsharp_pbar = tqdm(
            enumerate(trainloader),
            total=MAX_TRAINING_BATCHES,
            desc="ZSharp Training",
            leave=False,
        )
        for batch_idx, (data, target) in zsharp_pbar:
            if batch_idx >= MAX_TRAINING_BATCHES:  # Limit for demo
                break
            data, target = data.to(device), target.to(device)
            optimizer_zsharp.zero_grad()
            loss = criterion(model_zsharp(data), target)
            loss.backward()
            optimizer_zsharp.first_step()
            criterion(model_zsharp(data), target).backward()
            optimizer_zsharp.second_step()
            zsharp_loss += loss.item()
            zsharp_pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        print(f"SGD Loss: {sgd_loss / MAX_TRAINING_BATCHES:.4f}")
        print(f"ZSharp Loss: {zsharp_loss / MAX_TRAINING_BATCHES:.4f}")

    # Evaluate
    def evaluate_model(model, testloader, model_name="Model"):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            eval_pbar = tqdm(
                testloader, desc=f"Evaluating {model_name}", leave=False
            )
            for data, target in eval_pbar:
                if total >= MAX_EVAL_SAMPLES:  # Limit for demo
                    break
                data, target = data.to(device), target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
                eval_pbar.set_postfix({"acc": f"{100 * correct / total:.2f}%"})
        return 100 * correct / total

    sgd_acc = evaluate_model(model_sgd, testloader, "SGD")
    zsharp_acc = evaluate_model(model_zsharp, testloader, "ZSharp")

    print("\nResults:")
    print(f"SGD Test Accuracy:     {sgd_acc:.2f}%")
    print(f"ZSharp Test Accuracy:  {zsharp_acc:.2f}%")
    print(f"Improvement:           {zsharp_acc - sgd_acc:.2f}%")

    if zsharp_acc > sgd_acc:
        print("ZSharp outperforms SGD!")
    else:
        print("Both methods perform similarly in this quick demo")


# Run the demo
quick_demo()

## Key Takeaways

This demo shows how ZSharp works:

1. **Gradient Filtering**: ZSharp filters out noisy gradients using Z-score normalization
2. **SAM Integration**: Combines filtered gradients with SAM's sharpness-aware optimization
3. **Better Performance**: Typically achieves higher test accuracy than standard SGD

### Next Steps

- Try the full training script: `python -m scripts.train --config configs/zsharp_baseline.yaml`
- Experiment with different hyperparameters (percentile, rho)
- Check out the [full paper](https://arxiv.org/html/2505.02369v3) for detailed results

For more information, visit the [GitHub repository](https://github.com/bangyen/zsharp).