In [None]:
# Run once
# CPU only: !pip install torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -f https://download.pytorch.org/whl/cpu/torch_stable.html
!pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --upgrade
!pip install functorch
print("--> Restarting colab instance") 
get_ipython().kernel.do_shutdown(True)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torchvision
import functorch as ft

import numpy as np
from functools import partial

import matplotlib.pyplot as plt
import seaborn as sns

from optim_functions import beale, rosenbrock
from helpers import optimize
from plot_helpers import plot_loss, plot_countour, plot_contour2
from loss import functional_xent, softmax, clamp_probs, _xent
from optimizers import forwardSGD
from models import Net, ConvNet, LogisticRegression

torch.manual_seed(0)


## Test functions

In [None]:
steps = 100
torch.manual_seed(0)
primal0 = torch.randn(1) # x input
primal1 = torch.randn(1) # y input

x = primal0.clone()
x.requires_grad_()
y = primal1.clone()
y.requires_grad_()
params = (x, y)

loss_rev, grad_rev, params_rev = optimize(beale, params, steps, optimizer="SGD", lr=0.05)
plot_loss(loss_rev, steps)

loss_fwd, grad_fwd, params_rev = optimize(beale, (primal0, primal1), steps, lr=0.03)
plot_loss(loss_fwd, steps)

In [None]:
steps = 100
torch.manual_seed(42)
primal0 = torch.randn(1) # x input
primal1 = torch.randn(1) # y input

x = primal0.clone()
x.requires_grad_()
y = primal1.clone()
y.requires_grad_()
params = (x, y)

loss_fwd, grad_fwd, params_fwd = optimize(beale, (primal0, primal1), 1000, lr=0.01)
plot_contour2(loss_fwd, params_fwd, beale, (-1,4), (-1,2))

In [None]:
loss_rev, grad_rev, params_rev = optimize(beale, params, 1000, optimizer="SGD", lr=0.01)
plot_contour2(loss_rev, params_rev, beale, (-1,4), (-1,2))

In [None]:
steps = 100
torch.manual_seed(9)
primal0 = torch.randn(1) # x input
primal1 = torch.randn(1) # y input

x = primal0.clone()
x.requires_grad_()
y = primal1.clone()
y.requires_grad_()
params = (x, y)

loss_fwd, grad_fwd, params_fwd = optimize(rosenbrock, (primal0, primal1), 25000, lr=5e-4)
plot_contour2(loss_fwd, params_fwd, rosenbrock, (-1.1,1.1), (0,2))

In [None]:
torch.manual_seed(9)
primal0 = torch.randn(1) # x input
primal1 = torch.randn(1) # y input

x = primal0.clone()
x.requires_grad_()
y = primal1.clone()
y.requires_grad_()
params = (x, y)

loss_rev, grad_rev, params_rev = optimize(rosenbrock, params, 25000, optimizer="SGD", lr=5e-4)
plot_contour2(loss_rev, params_rev, rosenbrock, (-1.1,1.1), (0,2))

Reproducing data from figure 1 in the paper:

Real gradient: [-3.434, -0.808] 


In [None]:
primals = (torch.tensor([1.5]), torch.tensor([-0.1]))
tangents = []
fwd_grads = []
steps=10
for i in range(steps):
  tangent = (torch.randn(1), torch.randn(1))
  f, jvp = ft.jvp(beale, primals, tangent)
  fwd_grads.append([jvp.mul(t).item() for t in tangent])
  tangents.append([t.item() for t in tangent])
avg_grad = np.mean(fwd_grads, axis=0)
std_grad = np.std(fwd_grads, axis=0)
avg1 = avg_grad[0]
avg2 = avg_grad[1]


In [None]:
plt.figure(figsize=(8,8))
V = np.asarray(tangents)
W = np.asarray(fwd_grads)
origin = np.zeros((2,steps))# origin point
optimal1, optimal2 = -3.434, -0.808
plt.quiver(*origin, V[:,0], V[:,1], color='y', angles='xy', scale_units='xy', scale=1, label="Perturbation vectors")
plt.quiver(*origin, W[:,0], W[:,1], color='r', label="Forward gradients", angles='xy', scale_units='xy', scale=3)
plt.quiver(0,0, optimal1, optimal2, color='b', label="Real gradient", angles='xy', scale_units='xy', scale=1)
plt.quiver(0,0, avg1, avg2, color='g', label='Avg. forward gradient', angles='xy', scale_units='xy', scale=1)

plt.legend()
plt.xlim(-4, 2)
plt.ylim(-4, 2)
plt.show()

## Logistic Regression

#### Import MNIST

In [None]:
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])


mnist_train = torchvision.datasets.MNIST(
    '/tmp/data',
    train=True, 
    download=True, 
    transform=transform
)
train_data_loader = torch.utils.data.DataLoader(mnist_train, 
                                          batch_size=64, 
                                          shuffle=True)

mnist_test = torchvision.datasets.MNIST(
    '/tmp/data',
    train=False, 
    download=True, 
    transform=transform
)
test_data_loader = torch.utils.data.DataLoader(mnist_test, 
                                              batch_size=64,
                                              shuffle=True)

#### Helpers from fwdgrad

(https://github.com/orobix/fwdgrad)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = ConvNet().to(device)
    
func, params = ft.make_functional(net)

for param in params:
    param.requires_grad_(False)

opt = forwardSGD(func, functional_xent, params, lr=2e-4, momentum = False)


In [None]:
losses_fwd = []
epochs = 2
for epoch in range(epochs):
  for i, (image, label) in enumerate(train_data_loader):
    image, label = image.to(device), label.to(device)
    _, loss = opt.step(image, label)
    losses_fwd.append(loss)

In [None]:
criterion = nn.CrossEntropyLoss()
net = Net()
backprop = torch.optim.SGD(net.parameters(), lr=2e-4, weight_decay=1e-4)
epochs=2
losses = []
for epoch in range(epochs):
  for i, (image, label) in enumerate(train_data_loader):
    image, label = image.to(device), label.to(device)
    backprop.zero_grad()
    outputs = net(image)
    loss = criterion(outputs, label)
    loss.backward()
    backprop.step()
    losses.append(loss.item())

In [None]:
fig, ax = plt.subplots(1,2,figsize=(16,8))

ax[0].plot(losses, color='r', label="Backprop", alpha=.7)
ax[0].set_xlabel("Iterations")
ax[0].set_ylabel("Loss")
ax[0].plot(losses_fwd, color='b', label='Forward gradient', alpha=.7)
ax[0].legend()

#### Define model

#### Train

In [None]:
from timeit import default_timer

# Constanst
EPOCHS = 5
#lr = 0.005
learning_rates = 10**np.random.uniform(-5, 0, size=6)


model = LogisticRegression(784, 10).double()
model.to('cuda')
# model.train()

# Define baseline
baseline_model = LogisticRegression(784, 10).double()
baseline_model.to('cuda')
baseline_model.load_state_dict(model.state_dict())

baseline_criterion = torch.nn.CrossEntropyLoss()
baseline_optimizer = torch.optim.SGD(baseline_model.parameters(), lr=lr)

# Creates version of the model that can be invoked like a function
func, params, buffers = ft.make_functional_with_buffers(model)

def criterion(params, fmodel, input, target):
    y = fmodel(params, buffers, input)
    return _xent(y, target)

baseline_losses, baseline_accs = [[] for _ in learning_rates], [[] for _ in learning_rates]
losses, accs = [[] for _ in learning_rates], [[] for _ in learning_rates]


for n, lr in enumerate(learning_rates):
  print(f'Using learning rate: {lr}')
  for epoch in range(EPOCHS):
    t = default_timer()
      
    for i, (image, label) in enumerate(train_data_loader):
      #if i == 10:
      #  break
      image = image.to('cuda')
      label = label.to('cuda')

      #print('Model parameters:')
      #print(list(params))
      #print('Baseline parameters:')
      #print(list(baseline_model.parameters()))
      

      # Update parameters manually
      # Retrieve tangents for each parameter (for each forward pass)
      tangents = tuple([torch.rand_like(p, device='cuda') for p in params])
      
      # Partial takes a function, and arguments to the function.
      # The resulting function has only the parameters not specified in the partial call
      # In our case, the model parameters (params)
      f = partial(
          criterion,
          fmodel=func,
          input=image,
          target=label
      )

      # Calculate f and jvp
      loss, jvp = ft.jvp(f, (params, ), (tangents, ))

      gradients = [jvp.mul(tangent) for tangent in tangents]
      with torch.no_grad():
        for (g, param), target_param in zip(zip(gradients, params), model.parameters()):
          new_param = param.sub_(lr * g)
          target_param.copy_(new_param)
      
      # Update baseline
      
      baseline_optimizer.zero_grad()
      output = baseline_model(image)
      
      baseline_loss = baseline_criterion(output.double(), F.one_hot(label, num_classes=10).double())
      baseline_loss.backward()

      # print(f'Backward loss: {baseline_loss.item():.2f}; Forward loss: {loss.item():.2f}')
      baseline_optimizer.step()

      #for g, b in zip(gradients, baseline_model.parameters()):
      #  print('Forward gradient:')
      #  print(g)
      #  print('Backward gradient')
      #  print(b.grad)
      #  print('--------------------------------------------------')

      # Gradient diff (in mse loss)
      

    # Should use validation set, not test set! 
    baseline_losses[n].append(baseline_loss.item())
    baseline_accs[n].append(compute_accuracy(baseline_model))
    losses[n].append(loss.item())
    accs[n].append(compute_accuracy(model))

    t = default_timer() - t
    print(f'Epoch {epoch + 1} finished in {t:.2f} seconds with loss: {loss}')



In [None]:
import seaborn as sns
import matplotlib.pyplot as plt

fig, axs = plt.subplots(3, 2, sharex=True, sharey=True, figsize=(15, 12))
axs = axs.flatten()
X = np.asarray(range(5))
for i, ax in enumerate(axs):
  sns.lineplot(X, accs[i], ax=ax)
  sns.lineplot(X, baseline_accs[i], ax=ax)


In [None]:
# Model test
def compute_accuracy(model, batch_size=64):
  correct = 0
  for image, label in test_data_loader:
    image = image.to('cuda')
    label = label.to('cuda')
    pred = model(image).argmax(dim=1)
    correct += int((pred == label).sum())

  return correct / len(test_data_loader) / batch_size

compute_accuracy(model, 64)

In [None]:
# Baseline test
compute_accuracy(baseline_model, 64)