In [98]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
from sklearn import datasets


class MNIST_small(Dataset):
    
    def __init__(self, train=True):
        digits = datasets.load_digits()
        data_num = digits['target'].shape[0]
        if train == True:
            self.data, self.target = digits['data'][:int(0.8*data_num)], digits['target'][:int(0.8*data_num)]
        else:
            self.data, self.target = digits['data'][int(0.8*data_num):], digits['target'][int(0.8*data_num):]
        
    def __len__(self):
        return self.target.shape[0]
    
    def __getitem__(self, idx):
        data = torch.from_numpy(self.data[idx]).float()
        target = torch.from_numpy(np.asarray(self.target[idx]))
        sample = {'data': data, 'target': target}
        
        return sample

trainset = MNIST_small(train=True)
testset = MNIST_small(train=False)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=0)

device = torch.device('cuda:2')


In [124]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(64, 40)
        self.fc2 = nn.Linear(40, 10)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        return out

model = MLP().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def AccuarcyCompute(pred,label):
    pred = pred.cpu().data.numpy()
    label = label.cpu().data.numpy()
    test_np = (np.argmax(pred,1) == label)
    test_np = np.float32(test_np)
    return np.mean(test_np)

for epoch in range(20):
    for idx, batch in enumerate (trainloader):
        optimizer.zero_grad()
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        outputs = model(inputs)
        loss = criterion(outputs,labels)
        print ('loss:{}'.format(loss.item()))
        loss.backward()
        optimizer.step()

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)

loss:2.68456768989563
loss:2.558706045150757
loss:2.525189161300659
loss:2.270024538040161
loss:2.2143473625183105
loss:2.2888429164886475
loss:2.18802809715271
loss:2.1403424739837646
loss:2.073456048965454
loss:2.0570931434631348
loss:1.883184790611267
loss:2.0486807823181152
loss:1.8677685260772705
loss:2.0030908584594727
loss:1.7067409753799438
loss:1.7043460607528687
loss:1.8648264408111572
loss:1.6157644987106323
loss:1.5672208070755005
loss:1.3406749963760376
loss:1.4936156272888184
loss:1.4134997129440308
loss:1.4630792140960693
loss:1.2997300624847412
loss:1.4504523277282715
loss:1.3156503438949585
loss:1.3700509071350098
loss:1.2275644540786743
loss:1.1649738550186157
loss:1.2308419942855835
loss:1.3489254713058472
loss:1.1250503063201904
loss:1.2172441482543945
loss:1.1545486450195312
loss:0.979830265045166
loss:1.2308664321899414
loss:1.0670419931411743
loss:1.1510714292526245
loss:1.1419411897659302
loss:1.2602543830871582
loss:1.0467212200164795
loss:1.0526807308197021
lo

In [110]:
from torch.optim.optimizer import Optimizer

class LM(Optimizer):
    '''
    Arguments:
        lr: learning rate (step size) default:1
        alpha: the hyperparameter in the regularization default:0.2
    '''
    def __init__(self, params, lr=1, alpha=0.2):
        defaults = dict(
            lr = lr,
            alpha = alpha
        )
        super(LM, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError ("LM doesn't support per-parameter options") 

        self._params = self.param_groups[0]['params']

    def step(self, closure=None):
        '''
        performs a single step
        in the closure: we approximate the Hessian for cross entropy loss

        '''
        assert len(self.param_groups) == 1
        group = self.param_groups[0]
        lr = group['lr']
        alpha = group['alpha']
        params = group['params']

        prev_loss, g, H = closure(sample=True)
        H += torch.eye(H.shape[0]).to(device)*alpha
        delta_w = -1 * torch.matmul(torch.inverse(H), g).detach()
        
        offset = 0
        for p in self._params:
            numel = p.numel()
            with torch.no_grad():
                p.add_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
            offset += numel

        loss = closure(sample=False)

        if loss < prev_loss:
            print(loss.item())
            print ('successful iteration')
            if alpha > 1e-5:
                group['alpha'] /= 10
        else:
            print ('failed iteration')
            if alpha < 1e5:
                group['alpha'] *= 10
            # undo the step
            offset = 0
            for p in self._params:    
                numel = p.numel()
                with torch.no_grad():
                    p.sub_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
                offset += numel


def gather_flat_grad(params):
    views=[]
    for p in params:
        if p.grad is None:
            view = p.new(p.numel()).zero_()
        else:
            view = p.grad.view(-1)
        views.append(view)
    
    return torch.cat(views, 0)


In [123]:
model = MLP().to(device)
optimizer = LM(model.parameters(),lr=1, alpha=1)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    for idx,data in enumerate(trainloader):    
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        # evaluation H and g in a mini-batch
        def closure(sample=True):
            N = inputs.shape[0]
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            if sample:
                loss.backward()
                g = gather_flat_grad(model.parameters())
                for i in range(inputs.shape[0]):
                    inputs_i, labels_i = torch.unsqueeze(inputs[i], 0), torch.unsqueeze(labels[i], 0)
                    optimizer.zero_grad()
                    outputs_i = model(inputs_i)
                    loss_i = criterion(outputs_i, labels_i)
                    loss_i.backward()
                    flat_grad = gather_flat_grad(model.parameters()).view(-1, 1)
                    H_i = torch.matmul(flat_grad, flat_grad.T)
                    if i==0:
                        H = H_i / N
                    else:
                        H += H_i / N
                return loss.data, g, H
            else:
                return loss.data
        
        optimizer.step(closure)

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)


tensor(2.3458, device='cuda:2')
successful iteration
tensor(1.7412, device='cuda:2')
successful iteration
tensor(1.5835, device='cuda:2')
successful iteration
failed iteration
failed iteration
tensor(1.3722, device='cuda:2')
successful iteration
failed iteration
failed iteration
tensor(1.1818, device='cuda:2')
successful iteration
tensor(1.1088, device='cuda:2')
successful iteration
failed iteration
failed iteration
tensor(0.9991, device='cuda:2')
successful iteration
failed iteration
tensor(0.9544, device='cuda:2')
successful iteration
failed iteration
tensor(0.9357, device='cuda:2')
successful iteration
failed iteration
tensor(0.8606, device='cuda:2')
successful iteration
failed iteration
tensor(0.8081, device='cuda:2')
successful iteration
failed iteration
tensor(0.7646, device='cuda:2')
successful iteration
failed iteration
tensor(0.7246, device='cuda:2')
successful iteration
failed iteration
failed iteration
tensor(0.7169, device='cuda:2')
successful iteration
failed iteration
ten

263