# GSS

In [1]:
# Import the libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms

from torchvision.models import resnet34 as torchvision_resnet34

from torch.utils.data import DataLoader, Subset, random_split, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
import datetime
from collections import defaultdict

# Checking status of GPU and time
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"Notebook last modified at: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

Using device: cpu
Notebook last modified at: 2025-07-28 23:25:44


## Implement GSS greedy

In [None]:
def compute_gradient(model, x, y, loss_fn, device):
    """Compute normalized gradient for a single sample"""
    model.eval()
    x, y = x.unsqueeze(0).to(device), y.unsqueeze(0).to(device)
    model.zero_grad()
    features, logits = model(x)
    loss = loss_fn(logits, y)
    loss.backward()
    
    # Collect gradients
    grad = []
    for param in model.parameters():
        if param.grad is not None:
            grad.append(param.grad.flatten())
    grad = torch.cat(grad)
    grad_norm = torch.norm(grad, p=2)
    return grad / (grad_norm + 1e-8)  # Normalize gradient

def gss_greedy_selection(model, data_loader, num_exemplars_per_class, loss_fn, device):
    """GSS-Greedy exemplar selection"""
    model.eval()
    exemplars = defaultdict(list)
    for class_id in range(data_loader.dataset.num_classes):
        # Get class-specific data
        class_samples = [(x, y) for x, y in data_loader.dataset if y == class_id]
        if not class_samples:
            continue
        
        # Compute gradients for all samples
        gradients = []
        samples = []
        for x, y in class_samples:
            grad = compute_gradient(model, x, y, loss_fn, device)
            gradients.append(grad)
            samples.append((x, y))
        
        # Greedy selection
        selected_indices = []
        for _ in range(min(num_exemplars_per_class, len(samples))):
            if not selected_indices:
                # Pick first sample randomly
                idx = np.random.randint(0, len(samples))
            else:
                # Compute max-min distance in gradient space
                distances = []
                for i, grad in enumerate(gradients):
                    if i in selected_indices:
                        continue
                    min_dist = min([torch.norm(grad - gradients[j]) for j in selected_indices])
                    distances.append((i, min_dist))
                idx = max(distances, key=lambda x: x[1])[0]
            selected_indices.append(idx)
        
        # Store selected exemplars
        for idx in selected_indices:
            exemplars[class_id].append(samples[idx])
    
    return exemplars