In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
from sklearn import datasets


class MNIST_small(Dataset):
    
    def __init__(self, train=True):
        digits = datasets.load_digits()
        data_num = digits['target'].shape[0]
        if train == True:
            self.data, self.target = digits['data'][:int(0.8*data_num)], digits['target'][:int(0.8*data_num)]
        else:
            self.data, self.target = digits['data'][int(0.8*data_num):], digits['target'][int(0.8*data_num):]
        
    def __len__(self):
        return self.target.shape[0]
    
    def __getitem__(self, idx):
        data = torch.from_numpy(self.data[idx]).float()
        target = torch.from_numpy(np.asarray(self.target[idx]))
        sample = {'data': data, 'target': target}
        
        return sample

trainset = MNIST_small(train=True)
testset = MNIST_small(train=False)

trainloader = DataLoader(trainset, batch_size=256, shuffle=True, num_workers=0)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=0)

device = torch.device('cuda:2')


### SGD

In [59]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(64, 40)
        self.fc2 = nn.Linear(40, 10)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        return out

model = MLP().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def AccuarcyCompute(pred,label):
    pred = pred.cpu().data.numpy()
    label = label.cpu().data.numpy()
    test_np = (np.argmax(pred,1) == label)
    test_np = np.float32(test_np)
    return np.mean(test_np)

for epoch in range(3):
    for idx, batch in enumerate (trainloader):
        optimizer.zero_grad()
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        outputs = model(inputs)
        loss = criterion(outputs,labels)
        print ('loss:{}'.format(loss.item()))
        loss.backward()
        optimizer.step()

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)

loss:2.7432808876037598
loss:2.5503458976745605
loss:2.3574624061584473
loss:2.2161707878112793
loss:2.180281639099121
loss:2.1399340629577637
loss:2.136575937271118
loss:1.9382085800170898
loss:1.9730643033981323
loss:1.8597010374069214
loss:1.8164050579071045
loss:1.827252984046936
loss:1.7980520725250244
loss:1.6091443300247192
loss:1.6104958057403564
loss:1.5744025707244873
loss:1.6058464050292969
loss:1.5472010374069214
testing...
0.49722222222222223


### LM

In [60]:
from torch.optim.optimizer import Optimizer
import time

class LM(Optimizer):
    '''
    Arguments:
        lr: learning rate (step size) default:1
        alpha: the hyperparameter in the regularization default:0.2
    '''
    def __init__(self, params, lr=1, alpha=0.2):
        defaults = dict(
            lr = lr,
            alpha = alpha
        )
        super(LM, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError ("LM doesn't support per-parameter options") 

        self._params = self.param_groups[0]['params']

    def step(self, closure=None):
        '''
        performs a single step
        in the closure: we approximate the Hessian for cross entropy loss

        '''
        assert len(self.param_groups) == 1
        group = self.param_groups[0]
        lr = group['lr']
        alpha = group['alpha']
        params = group['params']


        time_start=time.time()
        prev_loss, g, H = closure(sample=True)
        time_end=time.time()
        print('sampling time',time_end-time_start,'s')
        
        H += torch.eye(H.shape[0]).to(device)*alpha

        time_start = time.time()
        delta_w = -1 * torch.matmul(torch.inverse(H), g).detach()
        time_end = time.time()
        print('inverting time',time_end-time_start,'s')

        offset = 0
        for p in self._params:
            numel = p.numel()
            with torch.no_grad():
                p.add_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
            offset += numel

        loss = closure(sample=False)

        if loss < prev_loss:
            print(loss.item())
            print ('successful iteration')
            if alpha > 1e-5:
                group['alpha'] /= 10
        else:
            print ('failed iteration')
            if alpha < 1e5:
                group['alpha'] *= 10
            # undo the step
            offset = 0
            for p in self._params:    
                numel = p.numel()
                with torch.no_grad():
                    p.sub_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
                offset += numel


def gather_flat_grad(params):
    views=[]
    for p in params:
        if p.grad is None:
            view = p.new(p.numel()).zero_()
        else:
            view = p.grad.view(-1)
        views.append(view)
    
    return torch.cat(views, 0)


In [61]:
model = MLP().to(device)
optimizer = LM(model.parameters(),lr=1, alpha=1)
criterion = nn.CrossEntropyLoss(reduce=False)

for epoch in range(3):
    for idx,data in enumerate(trainloader):    
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        # evaluation H and g in a mini-batch
        def closure(sample=True):
            N = inputs.shape[0]
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            if sample:
                for i in range(N):
                    optimizer.zero_grad()
                    grad = torch.autograd.grad(loss[i], model.parameters(), retain_graph=True, allow_unused=True)
                    g_i = torch.cat([x.contiguous().view(-1, 1) for x in grad])
                    H_i = torch.matmul(g_i, g_i.T)
                    if i==0:
                        H = H_i / N
                        g = g_i / N
                    else:
                        H += H_i / N
                        g += g_i / N
                return torch.sum(loss)/N, g, H
            else:
                return torch.sum(loss)/N
        
        optimizer.step(closure)

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)


sampling time 0.09043478965759277 s
inverting time 0.8497962951660156 s
2.6508350372314453
successful iteration
sampling time 0.1419360637664795 s
inverting time 0.7393507957458496 s
1.93453049659729
successful iteration
sampling time 0.1215667724609375 s
inverting time 0.6907782554626465 s
1.4909210205078125
successful iteration
sampling time 0.13475823402404785 s
inverting time 1.1203365325927734 s
failed iteration
sampling time 0.10550212860107422 s
inverting time 0.5628073215484619 s
failed iteration
sampling time 0.12768220901489258 s
inverting time 0.6146109104156494 s
1.0897706747055054
successful iteration
sampling time 0.14484047889709473 s
inverting time 0.5620076656341553 s
failed iteration
sampling time 0.11862564086914062 s
inverting time 0.503312349319458 s
0.9412541389465332
successful iteration
sampling time 0.11890959739685059 s
inverting time 0.7089035511016846 s
failed iteration
sampling time 0.08351492881774902 s
inverting time 1.1188771724700928 s
0.653226494789123