In [1]:
import torch

In [2]:
# Dummy dataset
X = torch.randn(100, 5)
y = torch.randn(100, 1)

In [3]:
# Simple linear model
w = torch.randn(5, 1, requires_grad=True)

In [4]:
# Hyperparameters
lr = 0.1
epochs = 5
batch_size = 10

In [5]:
# Loss function
def loss_fn(X, y, w):
    return ((X @ w - y) ** 2).mean()

In [6]:
# SVRG implementation
for epoch in range(epochs):
    # Compute full gradient (reference point)
    w_snapshot = w.clone().detach().requires_grad_(True)
    full_loss = loss_fn(X, y, w_snapshot)
    full_loss.backward()
    full_grad = w_snapshot.grad.clone().detach()
    
    # Shuffle data
    perm = torch.randperm(X.size(0))
    X_shuffled = X[perm]
    y_shuffled = y[perm]
    
    for i in range(0, X.size(0), batch_size):
        X_batch = X_shuffled[i:i+batch_size]
        y_batch = y_shuffled[i:i+batch_size]
        
        # Compute gradient at current w
        loss = loss_fn(X_batch, y_batch, w)
        loss.backward()
        grad_w = w.grad.clone().detach()
        w.grad.zero_()
        
        # Compute gradient at snapshot w
        w_snapshot_batch = w_snapshot.detach().clone().requires_grad_(True)
        loss_snapshot = loss_fn(X_batch, y_batch, w_snapshot_batch)
        loss_snapshot.backward()
        grad_snapshot = w_snapshot_batch.grad.clone().detach()
        
        # SVRG update
        v = grad_w - grad_snapshot + full_grad
        w.data -= lr * v
        w.grad.zero_()

    print(f"Epoch {epoch+1}, Loss: {loss_fn(X, y, w).item()}")

Epoch 1, Loss: 0.9968023896217346
Epoch 2, Loss: 0.937600314617157
Epoch 3, Loss: 0.9346346855163574
Epoch 4, Loss: 0.9345545172691345
Epoch 5, Loss: 0.9345430135726929
