In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
from sklearn import datasets


class MNIST_small(Dataset):
    
    def __init__(self, train=True):
        digits = datasets.load_digits()
        data_num = digits['target'].shape[0]
        if train == True:
            self.data, self.target = digits['data'][:int(0.8*data_num)], digits['target'][:int(0.8*data_num)]
        else:
            self.data, self.target = digits['data'][int(0.8*data_num):], digits['target'][int(0.8*data_num):]
        
    def __len__(self):
        return self.target.shape[0]
    
    def __getitem__(self, idx):
        data = torch.from_numpy(self.data[idx]).float()
        target = torch.from_numpy(np.asarray(self.target[idx]))
        sample = {'data': data, 'target': target}
        
        return sample

trainset = MNIST_small(train=True)
testset = MNIST_small(train=False)

trainloader = DataLoader(trainset, batch_size=256, shuffle=True, num_workers=0)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=0)

device = torch.device('cuda:2')


### SGD

In [42]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(64, 40)
        self.fc2 = nn.Linear(40, 10)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        return out

model = MLP().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def AccuarcyCompute(pred,label):
    pred = pred.cpu().data.numpy()
    label = label.cpu().data.numpy()
    test_np = (np.argmax(pred,1) == label)
    test_np = np.float32(test_np)
    return np.mean(test_np)

for epoch in range(3):
    for idx, batch in enumerate (trainloader):
        optimizer.zero_grad()
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        time_start = time.time()
        outputs = model(inputs)
        loss = criterion(outputs,labels)
        print ('loss:{}'.format(loss.item()))
        loss.backward()
        optimizer.step()
        time_end = time.time()
        print('one step time',time_end-time_start,'s')

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)

loss:2.4028613567352295
one step time 0.002237558364868164 s
loss:2.4107494354248047
one step time 0.0015096664428710938 s
loss:2.303556442260742
one step time 0.0013704299926757812 s
loss:2.2836899757385254
one step time 0.0015871524810791016 s
loss:2.238734483718872
one step time 0.0017685890197753906 s
loss:2.233896017074585
one step time 0.0014755725860595703 s
loss:2.1757493019104004
one step time 0.0015711784362792969 s
loss:2.1301076412200928
one step time 0.0015959739685058594 s
loss:2.1664392948150635
one step time 0.001483917236328125 s
loss:2.124027729034424
one step time 0.0016427040100097656 s
loss:2.143213987350464
one step time 0.001619577407836914 s
loss:2.066114664077759
one step time 0.0018069744110107422 s
loss:2.042241096496582
one step time 0.0014319419860839844 s
loss:2.0453832149505615
one step time 0.0016672611236572266 s
loss:1.960923671722412
one step time 0.001527547836303711 s
loss:1.9975229501724243
one step time 0.0017650127410888672 s
loss:1.9724082946777

### LM

In [43]:
from torch.optim.optimizer import Optimizer
import time

class LM(Optimizer):
    '''
    Arguments:
        lr: learning rate (step size) default:1
        alpha: the hyperparameter in the regularization default:0.2
    '''
    def __init__(self, params, lr=1, alpha=0.2):
        defaults = dict(
            lr = lr,
            alpha = alpha
        )
        super(LM, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError ("LM doesn't support per-parameter options") 

        self._params = self.param_groups[0]['params']

    def step(self, closure=None):
        '''
        performs a single step
        in the closure: we approximate the Hessian for cross entropy loss

        '''
        assert len(self.param_groups) == 1
        group = self.param_groups[0]
        lr = group['lr']
        alpha = group['alpha']
        params = group['params']

        prev_loss, g, H = closure(sample=True)
        
        H += torch.eye(H.shape[0]).to(device)*alpha

        time_start = time.time()
        delta_w = -1 * torch.matmul(torch.inverse(H), g).detach()
        time_end = time.time()
        print('inverting time',time_end-time_start,'s')

        offset = 0
        for p in self._params:
            numel = p.numel()
            with torch.no_grad():
                p.add_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
            offset += numel

        loss = closure(sample=False)

        if loss < prev_loss:
            print(loss.item())
            print ('successful iteration')
            if alpha > 1e-5:
                group['alpha'] /= 10
        else:
            print ('failed iteration')
            if alpha < 1e5:
                group['alpha'] *= 10
            # undo the step
            offset = 0
            for p in self._params:    
                numel = p.numel()
                with torch.no_grad():
                    p.sub_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
                offset += numel


def gather_flat_grad(params):
    views=[]
    for p in params:
        if p.grad is None:
            view = p.new(p.numel()).zero_()
        else:
            view = p.grad.view(-1)
        views.append(view)
    
    return torch.cat(views, 0)


In [46]:
model = MLP().to(device)
optimizer = LM(model.parameters(),lr=1, alpha=10)
criterion = nn.CrossEntropyLoss(reduce=False)

for epoch in range(3):
    for idx,data in enumerate(trainloader):    
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        # evaluation H and g in a mini-batch
        def closure(sample=True):
            N = inputs.shape[0]
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            if sample:
                time_start = time.time()
                for i in range(N):
                    optimizer.zero_grad()
                    grad = torch.autograd.grad(loss[i], model.parameters(), retain_graph=True, allow_unused=True)
                    g_i = torch.cat([x.contiguous().view(-1, 1) for x in grad])
                    H_i = torch.matmul(g_i, g_i.T)
                    if i==0:
                        H = H_i / N
                        g = g_i / N
                    else:
                        H += H_i / N
                        g += g_i / N
                time_end = time.time()
                print('sampling time',time_end-time_start,'s')
                return torch.sum(loss)/N, g, H
            else:
                return torch.sum(loss)/N
        
        optimizer.step(closure)

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)


sampling time 0.08041167259216309 s
inverting time 0.9126813411712646 s
2.9437508583068848
successful iteration
sampling time 0.0813608169555664 s
inverting time 0.5797624588012695 s
2.2479209899902344
successful iteration
sampling time 0.11003422737121582 s
inverting time 0.7813029289245605 s
1.7959229946136475
successful iteration
sampling time 0.09125304222106934 s
inverting time 0.39702343940734863 s
1.5483099222183228
successful iteration
sampling time 0.0987553596496582 s
inverting time 0.4132351875305176 s
failed iteration
sampling time 0.11081194877624512 s
inverting time 0.795978307723999 s
failed iteration
sampling time 0.10111474990844727 s
inverting time 0.6588530540466309 s
1.2116883993148804
successful iteration
sampling time 0.12943673133850098 s
inverting time 0.6536641120910645 s
failed iteration
sampling time 0.10199952125549316 s
inverting time 0.7152912616729736 s
1.038315773010254
successful iteration
sampling time 0.11771893501281738 s
inverting time 1.00996184349