In [1]:
import torch       
from torchvision import datasets, transforms
from torch.utils.data import Subset
import torch.nn as nn

import matplotlib.pyplot as plt

# 1. define your dataloader

In [2]:
transform_list = transforms.Compose([transforms.ToTensor(),
                                     transforms.Normalize(mean=[0.131], std=[0.289])])
dataset = datasets.MNIST(root='../.data/', train=True, download=True, transform=transform_list)
dataset = Subset(dataset, range(2000))

loader = torch.utils.data.DataLoader(
      dataset=dataset,
      batch_size=500,
      shuffle=False)

# 2. define your model

In [3]:
class Flatten(nn.Module):
    def forward(self, input):
        return input.flatten(1)

n_hidden = 2
hidden_size = 10
device = 'cpu'
layers = [Flatten(), nn.Linear(28 * 28, hidden_size), nn.ReLU()] + \
         [nn.Linear(hidden_size, hidden_size), nn.ReLU()] * (n_hidden - 1) + \
         [nn.Linear(hidden_size, 10), nn.LogSoftmax(dim=1)]
model = nn.Sequential(*layers).to(device)

# EWC penalty

We now compute the matrices coefficients using the generator above. Using different representations we get different performances, and also we use more or less memory. But low memory footprint comes at the price of a less accurate approximation of the FIM, and so the gradient that we get from the penalty can be drastically different.

In [4]:
from nngeometry.nngeometry.representations import KFACMatrix, DiagMatrix, DenseMatrix, BlockDiagMatrix, EKFACMatrix
from nngeometry.nngeometry.vector import PVector
from nngeometry.nngeometry.metrics import FIM_MonteCarlo1

Suppose we now train our model on task 1. We want to store the current state of the network using:
 1. the current parameter values `v1`
 2. the current Fisher Information Matrix `F_XXX`

In [7]:
v1 = PVector.from_model(model).clone().detach()
F_kfac = FIM_MonteCarlo1(representation=KFACMatrix,
                         loader=loader,
                         model=model)

F_blockdiag = FIM_MonteCarlo1(representation=BlockDiagMatrix,
                              loader=loader,
                              model=model)

F_dense = FIM_MonteCarlo1(representation=DenseMatrix,
                          loader=loader,
                          model=model)

F_diag = FIM_MonteCarlo1(representation=DiagMatrix,
                         loader=loader,
                         model=model)

F_ekfac = FIM_MonteCarlo1(representation=EKFACMatrix,
                          loader=loader,
                          model=model)
F_ekfac.update_diag()

In [None]:
F_kfac

In [6]:
n_parameters = F_kfac.generator.get_n_parameters()

We can now continue training, using an additional regularizer term, that uses the FIM and the difference betweem the current parameter value and `v1`

In [8]:
v_current = PVector.from_model(model)

you can now compute a regularizer scalar object and backward through it

In [9]:
regularizer_dense = F_dense.vTMv(v_current - v1)
regularizer_dense.backward()
[(p.size(), p.grad) for p in model.parameters()]

[(torch.Size([10, 784]), tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]])),
 (torch.Size([10]), tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])),
 (torch.Size([10, 10]), tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
          [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])),
 (torch.Size([10]), tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])),
 (torch.Size([10, 10

note that in this case all gradients stay at `0`, because `v_current = v1`, so for the purpose of this example, we will modify the value of `v_current`. You can think of this modification as resulting from SGD updates on the first iterations of training on task 2.

In [10]:
for p in model.parameters():
    p.data.add_(torch.randn_like(p))

We now recompute the gradients from the regularizer:

In [11]:
regularizer_dense = F_dense.vTMv(v_current - v1)
regularizer_dense.backward()
[(p.size(), p.grad) for p in model.parameters()]

[(torch.Size([10, 784]),
  tensor([[ 0.0063,  0.0063,  0.0063,  ...,  0.0063,  0.0063,  0.0063],
          [-0.0208, -0.0208, -0.0208,  ..., -0.0208, -0.0208, -0.0208],
          [ 0.0445,  0.0445,  0.0445,  ...,  0.0445,  0.0445,  0.0445],
          ...,
          [-0.0347, -0.0347, -0.0347,  ..., -0.0347, -0.0347, -0.0347],
          [ 0.0030,  0.0030,  0.0030,  ...,  0.0030,  0.0030,  0.0030],
          [-0.0199, -0.0199, -0.0199,  ..., -0.0199, -0.0199, -0.0199]])),
 (torch.Size([10]),
  tensor([-0.0139,  0.0458, -0.0983, -0.0546,  0.1555,  0.0057,  0.0844,  0.0766,
          -0.0066,  0.0439])),
 (torch.Size([10, 10]),
  tensor([[ 1.0168e-02,  5.7981e-02,  3.9742e-04,  7.5618e-02,  5.0058e-02,
            1.9567e-02,  6.8715e-02,  2.5859e-02, -1.7546e-04, -1.4914e-03],
          [-1.3681e-02, -1.3471e-02, -1.2331e-02, -1.5503e-02,  7.0392e-03,
            3.2765e-03,  1.7308e-03,  9.4715e-03, -2.0885e-03,  1.8633e-03],
          [-3.6515e-03, -2.4311e-02, -6.2564e-03, -3.2900e-02,

# Comparison of regularization obtained using different representations

In [12]:
model.zero_grad()
regularizer_dense = F_dense.vTMv(v_current - v1)
regularizer_dense.backward()
g_dense = torch.cat([p.grad.view(-1) for p in model.parameters()])

In [13]:
model.zero_grad()
regularizer_bd = F_blockdiag.vTMv(v_current - v1)
regularizer_bd.backward()
g_bd = torch.cat([p.grad.view(-1) for p in model.parameters()])

In [14]:
model.zero_grad()
regularizer_kfac = F_kfac.vTMv(v_current - v1)
regularizer_kfac.backward()
g_kfac = torch.cat([p.grad.view(-1) for p in model.parameters()])

In [15]:
model.zero_grad()
regularizer_ekfac = F_ekfac.vTMv(v_current - v1)
regularizer_ekfac.backward()
g_ekfac = torch.cat([p.grad.view(-1) for p in model.parameters()])

In [16]:
model.zero_grad()
regularizer_diag = F_diag.vTMv(v_current - v1)
regularizer_diag.backward()
g_diag = torch.cat([p.grad.view(-1) for p in model.parameters()])

In [None]:
plt.plot(g_kfac[:100].cpu().numpy(), label='kfac')
plt.plot(g_ekfac[:100].cpu().numpy(), label='ekfac')
plt.plot(g_diag[:100].cpu().numpy(), label='diag')
plt.plot(g_dense[:100].cpu().numpy(), label='dense')
plt.plot(g_bd[:100].cpu().numpy(), label='block diagonal')
plt.legend()
plt.title('Compare 100 elements of the gradient (first layer)')

In [1]:
plt.plot(g_kfac[-100:].cpu().numpy(), label='kfac')
plt.plot(g_ekfac[-100:].cpu().numpy(), label='ekfac')
plt.plot(g_diag[-100:].cpu().numpy(), label='diag')
plt.plot(g_dense[-100:].cpu().numpy(), label='dense')
plt.plot(g_bd[-100:].cpu().numpy(), label='block diagonal')
plt.ylim(-.01, .01)
plt.legend()
plt.title('Compare 100 elements of the gradient (last layer)')

NameError: name 'plt' is not defined