In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import numpy as np
from sample_grad import *
from sklearn import datasets
import time

class MNIST_small(Dataset):
    
    def __init__(self, train=True):
        digits = datasets.load_digits()
        data_num = digits['target'].shape[0]
        if train == True:
            self.data, self.target = digits['data'][:int(0.8*data_num)], digits['target'][:int(0.8*data_num)]
        else:
            self.data, self.target = digits['data'][int(0.8*data_num):], digits['target'][int(0.8*data_num):]
        
    def __len__(self):
        return self.target.shape[0]
    
    def __getitem__(self, idx):
        data = torch.from_numpy(self.data[idx]).float()
        target = torch.from_numpy(np.asarray(self.target[idx]))
        sample = {'data': data, 'target': target}
        
        return sample

trainset = MNIST_small(train=True)
testset = MNIST_small(train=False)

trainloader = DataLoader(trainset, batch_size=256, shuffle=True, num_workers=0)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=0)

device = torch.device('cuda:2')


### SGD

In [2]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(64, 40)
        self.fc2 = nn.Linear(40, 10)

    def forward(self, x):
        out = F.relu(self.fc1(x))
        out = F.relu(self.fc2(out))
        return out

model = MLP().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def AccuarcyCompute(pred,label):
    pred = pred.cpu().data.numpy()
    label = label.cpu().data.numpy()
    test_np = (np.argmax(pred,1) == label)
    test_np = np.float32(test_np)
    return np.mean(test_np)

for epoch in range(3):
    for idx, batch in enumerate (trainloader):
        optimizer.zero_grad()
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        time_start = time.time()
        outputs = model(inputs)
        loss = criterion(outputs,labels)
        print ('loss:{}'.format(loss.item()))
        loss.backward()
        optimizer.step()
        time_end = time.time()
        print('one step time',time_end-time_start,'s')

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)

loss:2.7769970893859863
one step time 0.48224949836730957 s
loss:2.6728012561798096
one step time 0.0016546249389648438 s
loss:2.4737870693206787
one step time 0.001741170883178711 s
loss:2.256592273712158
one step time 0.0017704963684082031 s
loss:2.1517393589019775
one step time 0.0018677711486816406 s
loss:2.0385594367980957
one step time 0.0017876625061035156 s
loss:2.00132155418396
one step time 0.0019638538360595703 s
loss:1.955230712890625
one step time 0.0018661022186279297 s
loss:1.857215404510498
one step time 0.0017366409301757812 s
loss:1.8027675151824951
one step time 0.0017311573028564453 s
loss:1.5695836544036865
one step time 0.0016179084777832031 s
loss:1.5444817543029785
one step time 0.0023233890533447266 s
loss:1.6557073593139648
one step time 0.0018928050994873047 s
loss:1.4571723937988281
one step time 0.0019516944885253906 s
loss:1.4737539291381836
one step time 0.0017774105072021484 s
loss:1.4272304773330688
one step time 0.00174713134765625 s
loss:1.50212097167

### LM

In [3]:
from torch.optim.optimizer import Optimizer
import time

class LM(Optimizer):
    '''
    Arguments:
        lr: learning rate (step size) default:1
        alpha: the hyperparameter in the regularization default:0.2
    '''
    def __init__(self, params, lr=1, alpha=0.2):
        defaults = dict(
            lr = lr,
            alpha = alpha
        )
        super(LM, self).__init__(params, defaults)

        if len(self.param_groups) != 1:
            raise ValueError ("LM doesn't support per-parameter options") 

        self._params = self.param_groups[0]['params']

    def step(self, closure=None):
        '''
        performs a single step
        in the closure: we approximate the Hessian for cross entropy loss

        '''
        assert len(self.param_groups) == 1
        group = self.param_groups[0]
        lr = group['lr']
        alpha = group['alpha']
        params = group['params']

        prev_loss, g, H = closure(sample=True)
        
        H += torch.eye(H.shape[0]).to(device)*alpha

        time_start = time.time()
        delta_w = -1 * torch.matmul(torch.inverse(H), g).detach()
        time_end = time.time()
        print('inverting time',time_end-time_start,'s')

        offset = 0
        for p in self._params:
            numel = p.numel()
            with torch.no_grad():
                p.add_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
            offset += numel

        loss = closure(sample=False)

        if loss < prev_loss:
            print(loss.item())
            print ('successful iteration')
            if alpha > 1e-5:
                group['alpha'] /= 10
        else:
            print ('failed iteration')
            if alpha < 1e5:
                group['alpha'] *= 10
            # undo the step
            offset = 0
            for p in self._params:    
                numel = p.numel()
                with torch.no_grad():
                    p.sub_(delta_w[offset:offset + numel].view_as(p),alpha=lr)
                offset += numel


def gather_flat_grads(params):
    '''
    return a matrix with size (batch_size, param_num)
    the flattened gradient of each sample in the batch 

    '''
    views=[]
    for p in params:
        if p.grads is None:
            view = torch.zeros_like(p.grads).view(p.grads.shape[0], -1)
        else:
            view = p.grads.view(p.grads.shape[0], -1)
        views.append(view)
    
    return torch.cat(views, 1)


In [6]:
model = MLP().to(device)
optimizer = LM(model.parameters(),lr=1, alpha=10)
criterion = nn.CrossEntropyLoss()

for epoch in range(3):
    for idx,data in enumerate(trainloader):    
        inputs, labels = batch['data'].to(device), batch['target'].to(device)
        # evaluation H and g in a mini-batch
        def closure(sample=True):
            N = inputs.shape[0]
            if sample:
                time_start = time.time()
                optimizer.zero_grad()
                with save_sample_grads(model):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    loss.backward()
                    g_all = gather_flat_grads(model.parameters())
                    g = g_all.sum(0)/N
                    H = torch.einsum('ijk, ikl -> ijl', [torch.unsqueeze(g_all, 2), torch.unsqueeze(g_all, 1)]).sum(0)/N
                    time_end = time.time()
                    print('sampling time',time_end-time_start,'s')
                    return torch.sum(loss), g, H
            else:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                return torch.sum(loss)
        
        optimizer.step(closure)

print ('testing...')
total, correct = 0, 0
for idx, batch in enumerate(testloader):    
    inputs, labels = batch['data'].to(device), batch['target'].to(device)
    outputs = model(inputs)
    correct += torch.sum(torch.argmax(outputs,1) == labels).item()
    total += inputs.shape[0]
print (correct/total)


sampling time 0.0021500587463378906 s
inverting time 0.9817430973052979 s
2.5856103897094727
successful iteration
sampling time 0.007706403732299805 s
inverting time 0.6140048503875732 s
2.4003236293792725
successful iteration
sampling time 0.00906515121459961 s
inverting time 0.7936170101165771 s
1.9124436378479004
successful iteration
sampling time 0.006339550018310547 s
inverting time 0.7975327968597412 s
failed iteration
sampling time 0.005716085433959961 s
inverting time 0.7659814357757568 s
failed iteration
sampling time 0.01167750358581543 s
inverting time 0.8576579093933105 s
1.7055720090866089
successful iteration
sampling time 0.014444828033447266 s
inverting time 0.6450448036193848 s
1.5008888244628906
successful iteration
sampling time 0.002460002899169922 s
inverting time 0.7141025066375732 s
failed iteration
sampling time 0.012775659561157227 s
inverting time 0.682819128036499 s
failed iteration
sampling time 0.007543802261352539 s
inverting time 0.5627365112304688 s
1.34