## CIFAR-10

This notebook contains our experiment with forward gradient vs. backproagation for the CIFAR-10 dataset. 

#### Setup

In [None]:
# Run once
# CPU only: !pip install torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
!pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade
!pip install functorch
print("--> Restarting colab instance") 
get_ipython().kernel.do_shutdown(True)

In [None]:
!git clone https://github.com/benjaminrike1/forward_gradient

In [None]:
cd forward_gradient

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import functorch as ft

import numpy as np
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

from optim_functions import beale, rosenbrock
from helpers import optimize
from plot_helpers import plot_loss, plot_countour, plot_contour2
from loss import functional_xent, softmax, clamp_probs, _xent
from optimizers import ForwardSGD
from models import Net, ConvNet, LogisticRegression

torch.manual_seed(0)


## CIFAR-10

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])


mnist_train = torchvision.datasets.CIFAR10(
    '/tmp/data',
    train=True, 
    download=True, 
    transform=transform
)
train_data_loader = torch.utils.data.DataLoader(mnist_train, 
                                          batch_size=64, 
                                          shuffle=True)

mnist_test = torchvision.datasets.CIFAR10(
    '/tmp/data',
    train=False, 
    download=True, 
    transform=transform
)
test_data_loader = torch.utils.data.DataLoader(mnist_test, 
                                              batch_size=64,
                                              shuffle=True)

## Neural Network

### SGD

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Forward gradient:

In [None]:
net = Net().to(device) # defining net

# making the net functional to run the code in functorch
# for evaluating the Jacobian-vector product
func, params = ft.make_functional(net)

# removing requires gradient as it will not be used
# for the forward AD
for param in params:
    param.requires_grad_(False)

# defining our optimizer
opt = ForwardSGD(func, functional_xent, params, lr=2e-4, momentum = False, decay=1e-5)

# running the code for e epochs
losses_fwd = []
epochs = 50
test_losses_fwd = []
for e in range(epochs):
  # training
  for i, (image, label) in enumerate(train_data_loader):
    image, label = image.to(device), label.to(device)
    params, loss = opt.step(image, label)
    losses_fwd.append(loss.item())
  # evaluating on the test set
  for i, (image, label) in enumerate(test_data_loader):
    batch_loss = []
    with torch.no_grad():
      image, label = image.to(device), label.to(device)
      test_loss = functional_xent(func, params, image, label)
      batch_loss.append(test_loss.item())
    test_losses.append(np.mean(batch_loss))
    print(f"Test loss in epoch {i+1}: {np.mean(batch_loss)}")

Backpropagation:

In [None]:
criterion = nn.CrossEntropyLoss() # loss function
net = Net().to(device) # defining net
backprop = torch.optim.SGD(net.parameters(), lr=2e-4, weight_decay=1e-4) # normal SGD in torch

epochs=50

# storing losses
losses = []
test_losses = []
for epoch in range(epochs):
  # going over training set in batches
  for i, (image, label) in enumerate(train_data_loader):
    image, label = image.to(device), label.to(device)
    backprop.zero_grad()
    outputs = net(image)
    loss = criterion(outputs, label)
    loss.backward()
    backprop.step()
    losses.append(loss.item())
  for i, (image, label) in enumerate(test_data_loader):
    batch_loss = []
    with torch.no_grad():
      image, label = image.to(device), label.to(device)
      test_loss = criterion(net(image), label)
      batch_loss.append(test_loss.item())
    test_losses.append(np.mean(batch_loss))
    print(f"Test loss in epoch {i+1}: {np.mean(batch_loss)}")

### Comparing results

In [None]:
fig, ax = plt.subplots(1,2,figsize=(16,8))

ax[0].plot(losses, color='r', label="Backprop", alpha=.7)
ax[0].set_xlabel("Iterations")
ax[0].set_ylabel("Loss")
ax[0].plot(losses_fwd, color='b', label='Forward gradient', alpha=.7)
ax[0].legend()

ax[1].plot(test_losses, color='r', label="Backprop", alpha=.7)
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("Loss")
ax[1].plot(test_losses_fwd, color='b', label='Forward gradient', alpha=.7)
ax[1].legend()

### Learning rate optimization

The final search for learning rate is in a quite small interval as we earlier tried a wider search, but wanted to reduce the width to find a better optimum.

In [None]:
learning_rates = np.logspace(-5, -3, 3)
decays = np.logspace(-6, -4, 3)

for gamma in learning_rates:
  for lambda_ in decays:
    net = Net().to(device) # defining net

    # making the net functional to run the code in functorch
    # for evaluating the Jacobian-vector product
    func, params = ft.make_functional(net)

    # removing requires gradient as it will not be used
    # for the forward AD
    for param in params:
        param.requires_grad_(False)

    # defining our optimizer
    opt = ForwardSGD(func, functional_xent, params, lr=gamma, momentum = False, decay=lambda_)
    # running the code for e epochs
    epochs = 10
    test_losses_fwd = []
    for e in range(epochs):
      # training
      for i, (image, label) in enumerate(train_data_loader):
        image, label = image.to(device), label.to(device)
        params, loss = opt.step(image, label)
      # evaluating on the test set
      for i, (image, label) in enumerate(val_data_loader):
        batch_loss = []
        with torch.no_grad():
          image, label = image.to(device), label.to(device)
          val_loss = functional_xent(func, params, image, label)
          batch_loss.append(val_loss.item())
        test_losses.append(np.mean(batch_loss))
        print(f"Test loss in epoch {i+1}: {np.mean(batch_loss)}")

## Conv Net

Forward gradient:

In [None]:
net = ConvNet().to(device) # defining net

# making the net functional to run the code in functorch
# for evaluating the Jacobian-vector product
func, params = ft.make_functional(net)

# removing requires gradient as it will not be used
# for the forward AD
for param in params:
    param.requires_grad_(False)

# defining our optimizer
opt = ForwardSGD(func, functional_xent, params, lr=2e-4, momentum = False, decay=1e-5)

# running the code for e epochs
losses_fwd = []
epochs = 50
test_losses = []
for e in range(epochs):
  # training
  for i, (image, label) in enumerate(train_data_loader):
    image, label = image.to(device), label.to(device)
    params, loss = opt.step(image, label)
    losses_fwd.append(loss.item())
  # evaluating on the test set
  for i, (image, label) in enumerate(test_data_loader):
    batch_loss = []
    with torch.no_grad():
      image, label = image.to(device), label.to(device)
      test_loss = functional_xent(func, params, image, label)
      batch_loss.append(test_loss.item())
    test_losses.append(np.mean(batch_loss))
    print(f"Test loss in epoch {i+1}: {np.mean(batch_loss)}")

Backpropagation:

In [None]:
criterion = nn.CrossEntropyLoss() # loss function
net = ConvNet().to(device) # defining net
backprop = torch.optim.SGD(net.parameters(), lr=2e-4, weight_decay=1e-4) # normal SGD in torch

epochs=50

# storing losses
losses = []
test_losses = []
for epoch in range(epochs):
  # going over training set in batches
  for i, (image, label) in enumerate(train_data_loader):
    image, label = image.to(device), label.to(device)
    backprop.zero_grad()
    outputs = net(image)
    loss = criterion(outputs, label)
    loss.backward()
    backprop.step()
    losses.append(loss.item())
  for i, (image, label) in enumerate(test_data_loader):
    batch_loss = []
    with torch.no_grad():
      image, label = image.to(device), label.to(device)
      test_loss = criterion(net(image), label)
      batch_loss.append(test_loss.item())
    test_losses.append(np.mean(batch_loss))
    print(f"Test loss in epoch {i+1}: {np.mean(batch_loss)}")
    

### Comparing results

In [None]:
fig, ax = plt.subplots(1,2,figsize=(16,8))

ax[0].plot(losses, color='r', label="Backprop", alpha=.7)
ax[0].set_xlabel("Iterations")
ax[0].set_ylabel("Loss")
ax[0].plot(losses_fwd, color='b', label='Forward gradient', alpha=.7)
ax[0].legend()

ax[1].plot(test_losses, color='r', label="Backprop", alpha=.7)
ax[1].set_xlabel("Epochs")
ax[1].set_ylabel("Loss")
ax[1].plot(test_losses_fwd, color='b', label='Forward gradient', alpha=.7)
ax[1].legend()