# Calculating the FIM Matrix

This notebook calculates the FIM matrix for a small MLP after learning the toy dataset.

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import random
import pickle

from utils import mlp,train_test_model
from utils.ewc_utils.ToyExampleEWC import FullEWC, LowRankEWC, MinorDiagonalEWC, BlockDiagonalEWC
from utils.ewc_utils.ToyExampleEWC import SketchEWC

from data.sequential_lines import Lines

from sklearn.decomposition import TruncatedSVD



In [2]:
## Get Configurations

n_samples=1000
epochs = 200
lr = 1e-3
batch_size = 100
input_size = 2
hidden_sizes = [128,64]
output_size = 2

num_task = 4

activation='ReLU'
slope=.1
device='cuda:0'

ewc_alpha=0.5

gain={
    'Sigmoid':1.,
    'TanH':1.,
    'ReLU':np.sqrt(2.),
    'leakyReLU':np.sqrt(2./(1.+slope**2))
}

In [3]:
def weights_init(m):
    if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):        
        torch.nn.init.xavier_uniform_(m.weight,gain=gain[activation])
        if m.bias: 
            torch.nn.init.xavier_uniform_(m.bias,gain=gain[activation])

In [4]:
X,Y=torch.meshgrid(torch.linspace(-0.5,2.5,150),torch.linspace(-0.75,1.25,100))
grid=torch.stack([X.reshape(-1),Y.reshape(-1)]).T

In [6]:
full_ewc_importance=1e+5

stable_rank=[]
diagonal_error=[]
block_diagonal_error=[]
sketched_error=[]
low_rank_error=[]

for seed in range(5):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    datagen = Lines(max_iter=num_task, num_samples=n_samples)
    train_loader,test_loader=datagen.get_full_lines(batch_size=batch_size)

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    ## define a MLP model
    model=mlp.MLP(input_size=input_size,output_size=output_size,
                  hidden_size=hidden_sizes,activation=activation,
                  require_bias=True,device=device).to(device)
    # model.apply(weights_init)
    full_ewc= FullEWC(model,device=device,alpha=ewc_alpha)

    ## performing training
    for _ in tqdm(range(epochs)):
        model.train()
        optimizer = torch.optim.Adam(params=model.parameters(),lr=lr)
        train_test_model.train_classifier(model=model,
                                          optimizer=optimizer,
                                          data_loader=train_loader,
                                          device=device) 

    fim=full_ewc.calculate_FIM(train_loader).to('cpu')
    
    diagonal_error.append(((fim-np.diag(np.diag(fim)))**2).sum()/(fim**2).sum())
    
    block_diagonal_ewc = BlockDiagonalEWC(model,device=device,alpha=ewc_alpha, n_bucket=50)
    block_diagonal_fim = block_diagonal_ewc.calculate_approximation(train_loader).to('cpu')
    block_diagonal_error.append(((fim-block_diagonal_fim)**2).sum()/(fim**2).sum())

    for sketch_seed in range(5):
        random.seed(sketch_seed)
        np.random.seed(sketch_seed)
        torch.manual_seed(sketch_seed)

        sketch_ewc = SketchEWC(model,device=device,alpha=ewc_alpha,n_sketch=50)
        sketched_fim = sketch_ewc.calculate_approximation(train_loader).to('cpu')
        sketched_error.append(((fim-sketched_fim)**2).sum()/(fim**2).sum())
        
    u, s, v =torch.svd_lowrank(fim, q=50)
    low_rank_approx = torch.mm(torch.mm(u, torch.diag(s)), v.t())
    low_rank_error.append(((fim-low_rank_approx)**2).sum()/(fim**2).sum())

#     svd = TruncatedSVD(n_components=500)
#     svd.fit(fim)
#     stable_rank.append(sum(svd.singular_values_ ** 2) / (svd.singular_values_[0] ** 2))

diagonal_error = np.array(diagonal_error)
block_diagonal_error = np.array(block_diagonal_error)
sketched_error = np.array(sketched_error)
low_rank_error = np.array(low_rank_error)
# stable_rank = np.array(stable_rank)


HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=200.0), HTML(value='')))




In [7]:
print(diagonal_error.mean(), diagonal_error.std())
print(block_diagonal_error.mean(), block_diagonal_error.std())
print(sketched_error.mean(), sketched_error.std())
print(low_rank_error.mean(), low_rank_error.std())
# print(np.array(stable_rank).mean(), np.array(stable_rank).std())

0.94711936 0.006053829
0.8287951 0.013002222
0.08082563 0.10188777
3.0142753e-11 4.8421198e-11
