In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import pandas as pd
import copy
from dbclass import TrainDB

n_epochs = 3
batch_size_train = 32
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 60000/(10*batch_size_train)

random_seed = 1
torch.backends.cudnn.enabled = False
torch.manual_seed(random_seed)

train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('drive/My Drive/mnist/MNIST_data/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=False)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('drive/My Drive/mnist/MNIST_data/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.log_softmax(x)


network = Net()
optimizer = optim.SGD(network.parameters(), lr=learning_rate,
                      momentum=momentum)

#net1 = Net()
#net2 = Net()

net1 = copy.deepcopy(network)
net2 = copy.deepcopy(network)
net3 = copy.deepcopy(network) 


train_losses = []
train_counter = []
test_losses = []
test_counter = [i*len(train_loader.dataset) for i in range(n_epochs + 1)]



In [3]:
db = TrainDB(network,train_loader, torch.nn.functional.nll_loss,batchfreq=100)

In [4]:
def train(epoch):
  network.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    #if batch_idx%50==0:
    #    print(batch_idx)
    grad_vec = None
    prev_state = copy.deepcopy(network.state_dict())
    optimizer.zero_grad()
    output = network(data)
    loss = F.nll_loss(output, target)
    loss.backward(create_graph=True,retain_graph=True)
    #grads = []
    #print(target)
    #for param in network.parameters():
    #    grads.append(param.grad.view(-1))
    #grads = torch.autograd.grad(
    #            loss, network.parameters(), create_graph=True
    #        )
    #grad_vec = torch.cat([g.contiguous().view(-1) for g in grads])
    #print('Norm of grad')
    #if batch_idx <20:
    #    print(grad_vec)
    #print(torch.norm(grad_vec/32.0))
    #grad_vec = 0
    optimizer.step()
    db.step(epoch,batch_idx,prev_state,network,loss.item())
    #print(db.tdiffnorm)
    if batch_idx % log_interval == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), loss.item()))
      train_losses.append(loss.item())
      train_counter.append(
        (batch_idx*batch_size_train) + ((epoch-1)*len(train_loader.dataset)))
  return

def test():
  network.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
      output = network(data)
      test_loss += F.nll_loss(output, target, size_average=False).item()
      pred = output.data.max(1, keepdim=True)[1]
      correct += pred.eq(target.data.view_as(pred)).sum()
  test_loss /= len(test_loader.dataset)
  test_losses.append(test_loss)
  print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
    test_loss, correct, len(test_loader.dataset),
    100. * correct / len(test_loader.dataset)))


test()
n_epochs = 1
for epoch in range(1, n_epochs + 1):
  train(epoch)
  test()
#db.finnetwork = network




Test set: Avg. loss: 2.3004, Accuracy: 751/10000 (7%)

0
50
100
150
200
250
300
350
400
450
500
550
600
650
700
750
800
850
900
950
1000
1050
1100
1150
1200
1250
1300
1350
1400
1450
1500
1550
1600
1650
1700
1750
1800
1850

Test set: Avg. loss: 0.1419, Accuracy: 9570/10000 (95%)



In [8]:
table1 = db.tweight
table3 = db.tdiffnorm
table2 = db.tnorm

#table3[['conv1.weight', 'conv2.weight', 'fc1.weight', 'fc2.weight']][1:].plot()
#table2[['conv1.weight', 'conv2.weight', 'fc1.weight', 'fc2.weight']].plot()
table2

Unnamed: 0,Unnamed: 1,conv1.weight,conv1.bias,conv2.weight,conv2.bias,fc1.weight,fc1.bias,fc2.weight,fc2.bias
0,0,1.835035,0.347588,2.590909,0.170132,0.809592,0.221882,0.762469,0.220865
1,0,1.834912,0.347498,2.590820,0.170012,0.809675,0.221834,0.762511,0.220336
1,1,1.834888,0.347586,2.590842,0.170018,0.809717,0.221847,0.762801,0.220484
1,2,1.834631,0.347548,2.590701,0.169730,0.809778,0.221797,0.762665,0.220352
1,3,1.834439,0.347614,2.590631,0.169358,0.809835,0.221560,0.762522,0.219614
1,...,...,...,...,...,...,...,...,...
1,1870,2.755737,0.403736,3.623700,0.158214,1.759621,0.237817,1.539549,0.213421
1,1871,2.756150,0.403682,3.624147,0.158209,1.759668,0.237836,1.539834,0.213266
1,1872,2.755474,0.403915,3.623782,0.158215,1.757414,0.238025,1.538726,0.213238
1,1873,2.755699,0.404117,3.624093,0.158215,1.756881,0.238196,1.539321,0.213293


In [5]:
import cProfile
cProfile.run('db.ithhess_eigenval(k=4,opt=True)')



[hessian_eigenthings] beginning deflated power iteration
[hessian_eigenthings] computing eigenvalue/vector 1 of 4
         4967601 function calls (4946640 primitive calls) in 250.561 seconds

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     1875    0.147    0.000   17.863    0.010 <ipython-input-1-59cc902f2346>:52(forward)
        1    0.002    0.002  250.556  250.556 <string>:1(<module>)
   120000    0.150    0.000    0.207    0.000 Image.py:2329(_check_size)
    60000    0.262    0.000    1.054    0.000 Image.py:2347(new)
    60000    0.289    0.000    1.982    0.000 Image.py:2421(frombuffer)
    60000    0.737    0.000    2.755    0.000 Image.py:2482(fromarray)
    60000    0.305    0.000    0.428    0.000 Image.py:461(_getencoder)
   180000    0.224    0.000    0.224    0.000 Image.py:539(__init__)
   180000    0.041    0.000    0.041    0.000 Image.py:559(size)
   120000    0.282    0.000    0.433    0.000 Image.py:563(_new

KeyboardInterrupt: 

In [6]:
import cProfile
cProfile.run('db.ithhess_eigenval(k=4,opt=False)')

[hessian_eigenthings] beginning deflated power iteration
[hessian_eigenthings] computing eigenvalue/vector 1 of 4




         8990621 function calls (8852256 primitive calls) in 135.632 seconds8m28s | Tot: 1ms | power iter error: 1.0000                                                              1/20 

   Ordered by: standard name

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     2876    0.160    0.000   14.507    0.005 <ipython-input-1-59cc902f2346>:52(forward)
        1    0.000    0.000  135.630  135.630 <string>:1(<module>)
   184064    0.207    0.000    0.282    0.000 Image.py:2329(_check_size)
    92032    0.332    0.000    1.416    0.000 Image.py:2347(new)
    92032    0.393    0.000    2.651    0.000 Image.py:2421(frombuffer)
    92032    0.898    0.000    3.599    0.000 Image.py:2482(fromarray)
    92032    0.406    0.000    0.574    0.000 Image.py:461(_getencoder)
   276096    0.294    0.000    0.294    0.000 Image.py:539(__init__)
   276096    0.057    0.000    0.057    0.000 Image.py:559(size)
   184064    0.380    0.000    0.584    0.000 Image.py:563(_new)
  

KeyboardInterrupt: 

In [None]:
from power_iter import Operator, deflated_power_iteration, smallest_eigenvalue
from scipy.sparse.linalg import LinearOperator, eigsh
import scipy
def eigenvalue_analysis2(operator, k=1, tol=1e-6, max_iter=100, quiet=False):
    """Return largest EV in magnitude and smallest algebraic eigenvalue."""
    eigmax, eigmin = smallest_eigenvalue(operator.hvp_op,
                                       power_iter_steps=max_iter,
                                       power_iter_err_threshold=tol,
                                       momentum=0.0,
                                       device=operator.device, quiet=False)
    return eigmax, eigmin

In [None]:
eigmax, eigmin = eigenvalue_analysis2(hess2,max_iter=20)