In [1]:
import torch, torchvision
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
import os
#import oil.augLayers as augLayers
from oil.model_trainers.classifier import Classifier
from oil.datasetup.datasets import CIFAR10, C10augLayers
from oil.datasetup.dataloaders import getLabLoader
from oil.architectures.img_classifiers.smallconv import smallCNN
from oil.utils.utils import cosLr, loader_to
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [2]:
from oil.lazy.lazy_matrix import LazyMatrix, Lazy
from oil.lazy.lazy_types import LazyAvg
from oil.utils.utils import reusable
from oil.lazy.linalg.VRmethods import GradLoader, oja_grad2,SGHA_grad2,SGD,SVRG, SGHA_grad,oja_subspace_grad,SGHA_subspace_grad2,SGHA_subspace_grad
from oil.logging.lazyLogger import LazyLogger
from oil.lazy.linalg.lanczos import power_method
from oil.lazy.hessian import Hessian, Fisher, autoHvpBatch
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import scipy as sp
import scipy.linalg
import matplotlib.pyplot as plt
import time
import pandas as pd
import os

In [3]:
train_epochs = 100
net_config =        {'numClasses':10,'k':16}
loader_config =     {'amnt_dev':0,'lab_BS':256,'dataseed':0,'num_workers':1}
opt_config =        {'lr':.1, 'momentum':.9, 'weight_decay':1e-4, 'nesterov':True}
sched_config =      {'cycle_length':train_epochs,'cycle_mult':1}
trainer_config =    {'log_args':{'minPeriod':0.1}}
all_hypers = {**net_config,**loader_config,**opt_config,**sched_config,**trainer_config}
trainer_config['log_dir'] = os.path.expanduser('~/tb-experiments/smallcnn/')

def makeTrainer():
    device = torch.device('cuda')
    CNN = smallCNN(**net_config).to(device)
    fullCNN = nn.Sequential(C10augLayers(),CNN)
    trainset, testset = CIFAR10(False, '~/datasets/cifar10/')

    dataloaders = {}
    dataloaders['test'] = DataLoader(testset, batch_size=256,shuffle=False, num_workers=1)
    dataloaders['train'], dataloaders['dev'] = getLabLoader(trainset,**loader_config)
    dataloaders = {k: loader_to(device)(v) for k,v in dataloaders.items()}

    opt_constr = lambda params: optim.SGD(params, **opt_config)
    lr_sched = cosLr(**sched_config)
    return Classifier(fullCNN,dataloaders,opt_constr,lr_sched,**trainer_config,tracked_hypers=all_hypers)

trainer = makeTrainer()
#trainer.train(train_epochs)

Files already downloaded and verified
Files already downloaded and verified
Creating Train, Dev split         with 50000 Train and 0 Dev


In [4]:
#trainer.save_checkpoint()
trainer.load_checkpoint()
print(trainer.getAccuracy(trainer.dataloaders['test']))
print(sum(p.numel() for p in trainer.model.parameters()))

=> loading checkpoint '/home/marc/tb-experiments/smallcnn/checkpoints/c.100.ckpt'
0.8731
45114


In [5]:
trainer.model.eval();
H = Hessian(trainer.model,trainer.dataloaders['train'])
F = Fisher(trainer.model,trainer.dataloaders['train'])

In [6]:
w0 = H.xp.new_randn(H,[H.shape[-1]])
trainer.model.device = next(trainer.model.parameters()).device

In [7]:
#import torch.nn.functional as F
%load_ext line_profiler
for mb in trainer.dataloaders['train']:
    break
#%timeit F.cross_entropy(trainer.model(mb[0]),mb[1]).backward()
#%timeit H@w0
#%prun H@w0

In [8]:
grads = GradLoader(SGHA_grad,[H,F])

In [9]:
# Create a Lazy Average from iterable of minibatches
#grads = GradLoader(oja_grad2,[F])

In [10]:
# Setup code for logging
logger = LazyLogger(**{'no_print':False, 'minPeriod':0.1, 'timeFrac':1})
logger.i = 0 # annoying but we will add some temporary state to keep track of step
def log(w,lr,grad):
    logger.i+=1
    with logger as do_log:
        if do_log:
            wallclocktime = time.time()
            metrics = {}
            metrics[r"$||\nabla L(w)||$"] = np.linalg.norm(grad.cpu().numpy())
            #metrics[r"$\frac{w^TFw}{w^Tw}$"] = (w@(F@w)/(w@w)).cpu().numpy()
            metrics[r"$\frac{w^THw}{w^TFw}$"] = (w@(H@w)/(w@(F@w))).cpu().numpy()
            logger.add_scalars('metrics',metrics,step=logger.i)
            logger.report()

In [None]:
w0 = H.xp.new_randn(H,[H.shape[-1]])
trainer.model.device = next(trainer.model.parameters()).device
w0 /= torch.norm(w0)
lr = lambda e: .01#*cosLr(num_epochs)(e)
num_epochs =20
wf = SVRG(grads,w0,lr,num_epochs,log)

HBox(children=(IntProgress(value=0, max=20), HTML(value='')))

   $||\nabla L(w)||$  $\frac{w^THw}{w^TFw}$
1           0.535721               1.005126
    $||\nabla L(w)||$  $\frac{w^THw}{w^TFw}$
99           2.051666               1.005309
     $||\nabla L(w)||$  $\frac{w^THw}{w^TFw}$
197           3.950673                1.00401
     $||\nabla L(w)||$  $\frac{w^THw}{w^TFw}$
295           7.371708               1.010033
     $||\nabla L(w)||$  $\frac{w^THw}{w^TFw}$
393           0.493606               1.025473
     $||\nabla L(w)||$  $\frac{w^THw}{w^TFw}$
491           1.159195               1.049635


In [None]:
plt.style.use('ggplot')
plt.rcParams.update({'font.size': 14})
f = plt.figure()
logger.scalar_frame.plot(logy=True)
plt.ylabel('Relative Error')
plt.xlabel("Iterations")
plt.legend()
#f.savefig("VR_PCA.pdf", bbox_inches='tight')

In [None]:
logger.scalar_frame