In [3]:
%reload_ext autoreload
%autoreload 2

In [4]:
from backpack import backpack
from backpack.extensions import KFAC
from backpack.utils import load_data

X, y = load_data()

loss = lossfunc(model(X), y)

with backpack(KFAC()):
        loss.backward()

        for param in model.parameters():
                print(param.grad)
                print(param.kfac)


ImportError: cannot import name 'load_data' from 'backpack.utils' (/home/huh/Projects/backpack/libraries/backpack/backpack/utils/__init__.py)

In [3]:
import torch

model = torch.nn.Sequential(
        torch.nn.Linear(764, 64),
        torch.nn.ReLU(),
        torch.nn.Linear(64, 10)
)
lossfunc = torch.nn.CrossEntropyLoss()

from backpack import extend

model = extend(model)
lossfunc = extend(lossfunc)


loss = lossfunc(model(X), y)

with backpack(KFAC()):
        loss.backward()


NameError: name 'X' is not defined

In [11]:
"""
Quick example: A small second-order optimizer with BackPACK
on the classic MNIST example from PyTorch,
https://github.com/pytorch/examples/blob/master/mnist/main.py

The optimizer we implement uses a constant damping parameter
and uses the diagonal of the GGN/Fisher matrix as a preconditioner;

```
x_{t+1} = x_t - (G_t + bI)^{-1} g_t
```

- `x_t` are the parameters of the model
- `G_t` is the diagonal of the Gauss-Newton/Fisher matrix at `x_t`
- `b` is a damping parameter
- `g_t` is the gradient

"""

import torch
import torchvision
# The main BackPACK functionalities
from backpack import backpack, extend
# The diagonal GGN extension
from backpack.extensions import DiagGGNMC, DiagGGNExact, KFLR2, KFLR, KFAC
# This layer did not exist in Pytorch 1.0
from backpack.core.layers import Flatten
from bpoptim import KFRA2ConstantDampingOptimizer, KFRAConstantDampingOptimizer, KFACConstantDampingOptimizer

# Hyperparameters
BATCH_SIZE = 64
STEP_SIZE = 0.01
DAMPING = 1.0
MAX_ITER = 100
torch.manual_seed(0)


"""
Step 1: Load data and create the model.

We're going to load the MNIST dataset,
and fit a 3-layer MLP with ReLU activations.
"""


mnist_loader = torch.utils.data.dataloader.DataLoader(
    torchvision.datasets.MNIST(
        './data',
        train=True,
        download=True,
        transform=torchvision.transforms.Compose([
            torchvision.transforms.ToTensor(),
            torchvision.transforms.Normalize(
                (0.1307,), (0.3081,)
            )
        ])),
    batch_size=BATCH_SIZE,
    shuffle=True
)

model = torch.nn.Sequential(
    torch.nn.Conv2d(1, 20, 5, 1),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    torch.nn.Conv2d(20, 50, 5, 1),
    torch.nn.ReLU(),
    torch.nn.MaxPool2d(2, 2),
    Flatten(), 
    # Pytorch <1.2 doesn't have a Flatten layer
    torch.nn.Linear(4*4*50, 500),
    torch.nn.ReLU(),
    torch.nn.Linear(500, 10),
)

loss_function = torch.nn.CrossEntropyLoss()

def get_accuracy(output, targets):
    """Helper function to print the accuracy"""
    predictions = output.argmax(dim=1, keepdim=True).view_as(targets)
    return predictions.eq(targets).float().mean().item()


"""
Step 2: Create the optimizer.

After we call the backward pass with backpack,
every parameter will have a `diag_ggn_mc` field
in addition to a `grad` field.

We can use it to compute the search direction for that parameter,
```
step_direction = p.grad / (p.diag_ggn_mc + group["damping"])
```
and update the weights
"""

# from pprint import pprint


class DiagGGNOptimizer(torch.optim.Optimizer):
    def __init__(self, parameters, step_size, damping):
        super().__init__(
            parameters, 
            dict(step_size=step_size, damping=damping)
        )

    def step(self):
        for group in self.param_groups:
            
            for p in group["params"]:
#                 print(p.diag_ggn_exact)
                step_direction = p.grad / (p.diag_ggn_exact + group["damping"])
                # step_direction = p.grad / (p.diag_ggn_mc + group["damping"])
                p.data.add_(-group["step_size"], step_direction)
        return loss



"""
Step 3: Tell BackPACK about the model and loss function, 
create the optimizer, and we will be ready to go
"""

extend(model)
extend(loss_function)

optimizer = DiagGGNOptimizer(
    model.parameters(), 
    step_size=STEP_SIZE, 
    damping=DAMPING
)
# optimizer = KFACConstantDampingOptimizer(
# # optimizer = KFRA2ConstantDampingOptimizer(
# # optimizer = KFRAConstantDampingOptimizer(
#     model.parameters(), 
#     lr=1, 
#     damping=DAMPING
# )



"""
Final step: The training loop!

The only difference with a traditional training loop:
Before calling the backward pass, we will call
```
    with backpack(DiagGGNMC()):
```
BackPACK will then add the diagonal of the GGN in the
`diag_ggn_mc` field during the backward pass.
"""


for batch_idx, (x, y) in enumerate(mnist_loader):
    output = model(x)

    accuracy = get_accuracy(output, y)

    with backpack(DiagGGNExact()): 
#     with backpack(DiagGGNMC()):
#     with backpack(KFAC()):
#     with backpack(KFLR2()):
        loss = loss_function(output, y)
        loss.backward()
        optimizer.step()
        # def closure(): return loss, output
        # optimizer.step(closure)

    print(  "Iteration %3.d/%d   " % (batch_idx, MAX_ITER) +    "Minibatch Loss %.3f  " % (loss) +    "Accuracy %.0f" % (accuracy * 100) + "%"  )

    if batch_idx >= MAX_ITER:
        break


tensor([[[[8.0489e-04, 6.8021e-04, 2.9903e-04, 1.0630e-04, 8.0913e-05],
          [8.4600e-04, 5.5541e-04, 2.8320e-04, 9.6866e-05, 9.3449e-05],
          [7.6921e-04, 5.5314e-04, 2.1918e-04, 1.0065e-04, 1.2008e-04],
          [9.1813e-04, 7.9457e-04, 4.5872e-04, 2.4658e-04, 2.3808e-04],
          [1.0558e-03, 9.7957e-04, 6.6604e-04, 5.0099e-04, 3.9226e-04]]],


        [[[7.3236e-04, 6.2506e-04, 2.4445e-04, 1.3674e-04, 1.7447e-04],
          [5.0444e-04, 4.6962e-04, 2.6814e-04, 1.7015e-04, 2.1084e-04],
          [3.0363e-04, 3.7275e-04, 3.9544e-04, 3.2087e-04, 2.9833e-04],
          [3.4984e-04, 4.2940e-04, 4.3486e-04, 2.9132e-04, 2.2621e-04],
          [4.9221e-04, 6.4288e-04, 5.4128e-04, 2.9335e-04, 3.2105e-04]]],


        [[[1.2261e-03, 1.3490e-03, 1.0877e-03, 9.5664e-04, 1.0566e-03],
          [1.0996e-03, 1.0640e-03, 8.0848e-04, 8.9491e-04, 1.0750e-03],
          [1.0222e-03, 9.5561e-04, 8.8689e-04, 1.1565e-03, 1.2972e-03],
          [9.0788e-04, 9.3064e-04, 1.1369e-03, 1.5705e-0

KeyboardInterrupt: 

In [26]:
a=torch.randn([64,10,1])
print(a.squeeze().shape)

torch.Size([64, 10])
