In [2]:
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import importlib
import copy
import argparse
from torchvision import transforms, datasets
import backpack

from matplotlib import pyplot as plt
import torch.nn.functional as F
from scipy.sparse.linalg import LinearOperator
from scipy.sparse.linalg import eigsh
from torch.autograd import Variable, grad
from numpy.linalg import eig as eig
from torch.distributions.multivariate_normal import MultivariateNormal
from utils import *
from models.fc import *
from models.wide_resnet import wide_resnet_t
from models.wide_resnet_1 import WideResNet
import scipy
from scipy.linalg import eigh_tridiagonal
from functions import *
from scipy.optimize import minimize
from dataset import *
from torch.optim.lr_scheduler import CosineAnnealingLR
from backpack import extend, backpack
from backpack.extensions import (
    GGNMP,
    HMP,
    KFAC,
    KFLR,
    KFRA,
    PCHMP,
    BatchDiagGGNExact,
    BatchDiagGGNMC,
    BatchDiagHessian,
    BatchGrad,
    BatchL2Grad,
    DiagGGNExact,
    DiagGGNMC,
    DiagHessian,
    SumGradSquared,
    Variance,
)


%load_ext autoreload
%autoreload 2
# %aimport functions
# %aimport analyze


In [3]:
## preperation

num_train = 50000
num_val = 10000
num_d = 200
b = 0.001
c = 50*b
num_nt = 50
num_ns = 1000
args = (10, 1, num_ns)
model_name = "random"
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


path = "/exp/random_" + str(b) +"/"
print(path)
mkdir(path)

/exp/random_0.001/


In [None]:
## data generation

data_train = torch.randn(num_train, num_d)*c*torch.exp(-b*torch.arange(num_d)) / np.sqrt(num_d)
data_val = torch.randn(num_val, num_d)*c*torch.exp(-b*torch.arange(num_d)) / np.sqrt(num_d)

print(data_train.shape, data_val.shape)
e, v = torch.eig(data_train.T@data_train / num_train)
e = e[:, 0]
idx = list(np.flip(e.numpy().argsort()))
e = e[idx]
plt.plot(np.real(e.numpy()))
plt.yscale("log")

In [86]:
## targets generation

w1 = torch.randn(num_d, num_nt) / torch.sqrt(torch.tensor(num_d))
w2 = torch.randn(num_nt, 10) / torch.sqrt(torch.tensor(num_nt))

def f(data, w1, w2):
    out = torch.tanh(data@w1)@w2
    targets = torch.argmax(out, dim=1)
    return targets

targets_train = f(data_train, w1, w2)
targets_val = f(data_val, w1, w2)

torch.save((data_train, targets_train), path + "train_data.pt")
torch.save((data_val, targets_val), path + "val_data.pt")

In [109]:
## data loader

def create_loader(data, targets, bs):
    loader = [(data[i:i+bs, :], targets[i:i+bs]) for i in range(len(data) // bs)]
    return loader

def val(model, device, val_loader, criterion):
    sum_loss, sum_correct = 0, 0
    model.eval()
    ns = 0
    # with torch.no_grad():
    for data, target in val_loader:
        ns += len(data)
        data, target = data.to(device), target.to(device)
        output = model(data)

        pred = output.max(1)[1]
        sum_correct += pred.eq(target).sum().item()
        sum_loss += len(data) * criterion(output, target).item()

    # print(ns)

    return 1 - (sum_correct / ns), sum_loss / ns



data_train, targets_train = torch.load(path + "train_data.pt")
data_val, targets_val = torch.load(path + "val_data.pt")

train_loader = create_loader(data_train, targets_train, 1000)
test_loader = create_loader(data_val, targets_val, 1000)
train_loader_FIM = create_loader(data_train[:5000], targets_train[:5000], 1)
train_loader_prior = create_loader(data_train[:10000], targets_train[:10000], 10000)

criterion = nn.CrossEntropyLoss().to(device)
model = Network1(*args).to(device)

In [None]:
## train

epochs = 50

model = Network1(*args).to(device)
torch.save(model.state_dict(), path + "model_init.pt")

optimizer = optim.Adam(model.parameters(), lr = 0.001)
scheduler = CosineAnnealingLR(optimizer,T_max=epochs, eta_min = 1e-5)

val_err, val_loss = val(model, device, test_loader, criterion)
print(val_err, val_loss)

for epoch in range(epochs):

    model.train()
    for data, targets in train_loader:
        data, targets = data.to(device), targets.to(device)
        out = model(data)
        loss = criterion(out, targets)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    scheduler.step()

    if epoch == epochs//4:
        torch.save(model.state_dict(), path + "model_mid.pt")


    train_err, train_loss = val(model, device, train_loader, criterion)
    print(train_err, train_loss)


val_err, val_loss = val(model, device, test_loader, criterion)
print(val_err, val_loss)

torch.save(model.state_dict(), path + "model.pt")


In [110]:
## saving model statistics

model = Network1(*args)
model.load_state_dict(torch.load(path + "model.pt", map_location='cpu'))
model_init = Network1(*args)
model_init.load_state_dict(torch.load(path + "model_init.pt", map_location='cpu'))
model_mid = Network1(*args)
model_mid.load_state_dict(torch.load(path + "model_mid.pt", map_location='cpu'))
model = model.to(device)
model_init = model_init.to(device)
model_mid = model_mid.to(device)
norm_init = torch.norm(list_to_vec(list(model_init.parameters())), p=2)
norm_mid = torch.norm(list_to_vec(list(model_mid.parameters())), p=2)
norm_trained = torch.norm(list_to_vec(list(model.parameters())), p=2)
print(norm_init, norm_mid, norm_trained)
tr_err_mid, tr_loss_mid = val(model_mid, device, train_loader, criterion)
val_err_mid, val_loss_mid = val(model_mid, device, test_loader, criterion)
print("mid",tr_err_mid, tr_loss_mid, val_err_mid, val_loss_mid)
tr_err, tr_loss = val(model, device, train_loader, criterion)
val_err, val_loss = val(model, device, test_loader, criterion)
print("end", tr_err, tr_loss, val_err, val_loss)

stat_model = dict({'te': tr_err, 'tl':tr_loss, 've':val_err, 'vl':val_loss, 'tem':tr_err_mid, 'tlm':tr_loss_mid, 'vem':val_err_mid, 'vlm':val_loss_mid})
torch.save(stat_model, path + "stat_model.pt")

tensor(18.3885, device='cuda:0', grad_fn=<NormBackward1>) tensor(347.9543, device='cuda:0', grad_fn=<NormBackward1>) tensor(668.3050, device='cuda:0', grad_fn=<NormBackward1>)
mid 0.36758 1.8297062039375305 0.6275999999999999 2.029688262939453
end 0.07491999999999999 0.5744844317436218 0.5035000000000001 1.3796849489212035


In [91]:
#################################################
###### kfac of FIM at init and Hess at end ######
###### and their top eigen values and vecs ######
#################################################


kfac_list_init = FIM_kfac(model_init, train_loader_prior, mc=500, device="cpu", mode = "kfac", empirical = True)
eigspace_list_init, eigval_list_init = eigspace_FIM_kron(kfac_list_init)

kfac_list_end = FIM_kfac(model, train_loader_prior, mc=500, device="cpu", mode = "kfac", empirical = True)
eigspace_list_end, eigval_list_end = eigspace_FIM_kron(kfac_list_end)

torch.save((kfac_list_init, eigspace_list_init, eigval_list_init), path + "kfac_all_init.pt")
torch.save((kfac_list_end, eigspace_list_end, eigval_list_end), path + "kfac_all_end.pt")


In [92]:
# FIM at initialization and end of training, logit jacobian at end

L_i = FIM_truex(model_init, criterion, train_loader_FIM, "cpu")
print(L_i)
L_mid = FIM_truex(model_mid, criterion, train_loader_FIM, "cpu")
print(L_mid)
L_tt = FIM_truex(model, criterion, train_loader_FIM, "cpu")
print(L_tt)
L_te = FIM2x(model, criterion, train_loader_FIM, "cpu")
print(L_te)
L_logit = logit_jacobianx(model, 0, criterion, train_loader_FIM, 'cpu')
print(L_logit)
torch.save(L_i, path + "FIM_true_init.pt")
torch.save(L_mid, path + "FIM_true_mid.pt")
torch.save(L_tt, path + "FIM_true_end.pt")
torch.save(L_te, path + "FIM_em_end.pt")
torch.save(L_logit, path + "FIM_logit_end.pt")



[ 3.9212803e+01  3.3591747e+01  2.9878529e+01 ... -5.0256205e-07
 -5.1182661e-07 -5.2241938e-07]
[ 3.2783978e+01  2.9483574e+01  2.5158785e+01 ... -3.3688627e-07
 -3.3746048e-07 -3.4574586e-07]
[ 2.6105656e+01  2.3985191e+01  2.0163393e+01 ... -5.7532304e-07
 -6.4946170e-07 -7.3091178e-07]
[ 2.0888996e+01  1.7683691e+01  1.6524782e+01 ... -4.9012101e-07
 -5.1209855e-07 -6.6103672e-07]
[ 2.5772293e+02  4.3421935e-02  3.7069097e-02 ... -1.1970051e-06
 -1.2013277e-06 -1.2048131e-06]


In [None]:
# Hessian at the end of training

eig_hess, u_hess = hess_scipy(model, 500, train_loader, criterion, device)
torch.save((eig_hess, u_hess), path + "eig_hess_scipy.pt")

In [None]:
## Overlaps

model = Network1(*args)
model.load_state_dict(torch.load(path + "model.pt"))
model_init = Network1(*args)
model_init.load_state_dict(torch.load(path + "model_init.pt"))
model = model.to(device)
model_init = model_init.to(device)

FIM_i, L_i, u_i = torch.load(path + "FIM_true_init.pt")
FIM_tt, L_tt, u_tt = torch.load(path + "FIM_true_end.pt")
FIM_te, L_te, u_te = torch.load(path + "FIM_em_end.pt")

over = overlap(u_i, u_tt, 500, device)
diff = list_to_vec(diff_list(list(model.parameters()), list(model_init.parameters())))
frac_all = proj(diff, u_i, 300, device)
frac_all = frac_all.detach().cpu()

stat = dict({"data_eig":e, "L_i":L_i, "L_mid":L_mid, "L_tt":L_tt, "L_te":L_te, "L_hess":eig_hess, "L_logit":L_logit})
torch.save(stat, path + "stat.pt")
