In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch_numopt
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_diabetes
from sklearn.preprocessing import MinMaxScaler
import time

In [2]:
device = 'cpu'

In [3]:
class Net(nn.Module):
    def __init__(self, input_size, device='cpu'):
        super().__init__()
        self.f1 = nn.Linear(input_size, 10, device=device)
        self.f2 = nn.Linear(10, 20, device=device)
        self.f3 = nn.Linear(20, 20, device=device)
        self.f4 = nn.Linear(20, 10, device=device)
        self.f5 = nn.Linear(10, 1, device=device)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [4]:
X, y = load_diabetes(return_X_y = True, scaled=False)

X_scaler = MinMaxScaler()
X = X_scaler.fit_transform(X)

y_scaler = MinMaxScaler()
y = y_scaler.fit_transform(y.reshape((-1, 1)))

torch_data = TensorDataset(torch.Tensor(X).to(device), torch.Tensor(y).to(device))
data_loader = DataLoader(torch_data, batch_size=1000)

In [5]:
model = Net(input_size = X.shape[1], device=device)
loss_fn = nn.MSELoss()
opt = torch_numopt.LM(model.parameters(), lr=1, mu=0.001, mu_dec=0.1, model=model, use_diagonal=False, c1=1e-4, tau=0.1, line_search_method='backtrack')

times = []

all_loss = []
patience = 0
max_patience = 10
for epoch in range(100):
    start = time.perf_counter()
    print('epoch: ', epoch, end='')
    all_loss.append(0)
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x, b_y, loss_fn)
        opt.update(loss)

        all_loss[epoch] += loss
    end = time.perf_counter()
    
    all_loss[epoch] /= len(data_loader)

    if epoch > 0 and all_loss[epoch-1] <= all_loss[epoch]:
        patience -= 1
    else:
        patience = max_patience
        

    print(', loss: {}'.format(all_loss[epoch].detach().cpu().numpy().item()))
    # print(', time spent {}'.format(end-start))
    times.append(end-start)

    if patience <= 0:
        break

print(sum(times)/len(times))

epoch:  0, loss: 0.14730283617973328
epoch:  1, loss: 0.1328197419643402
epoch:  2, loss: 0.11953799426555634
epoch:  3, loss: 0.10804057866334915
epoch:  4, loss: 0.09924285113811493
epoch:  5, loss: 0.09785141050815582
epoch:  6, loss: 0.09741376340389252
epoch:  7, loss: 0.09637270122766495
epoch:  8, loss: 0.0953952893614769
epoch:  9, loss: 0.09487812221050262
epoch:  10, loss: 0.08663635700941086
epoch:  11, loss: 0.0865328460931778
epoch:  12, loss: 0.08645948022603989
epoch:  13, loss: 0.07620896399021149
epoch:  14, loss: 0.07614918798208237
epoch:  15, loss: 0.0753786489367485
epoch:  16, loss: 0.07462982088327408
epoch:  17, loss: 0.07427847385406494
epoch:  18, loss: 0.06533383578062057
epoch:  19, loss: 0.06434480845928192
epoch:  20, loss: 0.055998388677835464
epoch:  21, loss: 0.0513995923101902
epoch:  22, loss: 0.04815226420760155
epoch:  23, loss: 0.04434782266616821
epoch:  24, loss: 0.040134381502866745
epoch:  25, loss: 0.03988191857933998
epoch:  26, loss: 0.03730

In [6]:
model = Net(input_size = X.shape[1], device=device)
loss_fn = nn.MSELoss()
opt = torch_numopt.LM(model.parameters(), lr=1, mu=0.001, mu_dec=0.1, model=model, use_diagonal=True, c1=1e-4, tau=0.1, line_search_method='backtrack')

all_loss = []
patience = 0
max_patience = 10
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss.append(0)
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x, b_y, loss_fn)
        opt.update(loss)

        all_loss[epoch] += loss
    
    all_loss[epoch] /= len(data_loader)

    if epoch > 0 and all_loss[epoch-1] <= all_loss[epoch]:
        patience -= 1
    else:
        patience = max_patience
        

    print(', loss: {}'.format(all_loss[epoch].detach().cpu().numpy().item()))

    if patience <= 0:
        break

epoch:  0, loss: 0.5393333435058594
epoch:  1, loss: 0.5313586592674255
epoch:  2, loss: 0.5293454527854919
epoch:  3, loss: 0.5273312926292419
epoch:  4, loss: 0.5253179669380188
epoch:  5, loss: 0.5233054757118225
epoch:  6, loss: 0.521294116973877
epoch:  7, loss: 0.5192835330963135
epoch:  8, loss: 0.517274022102356
epoch:  9, loss: 0.5152654647827148
epoch:  10, loss: 0.5132580399513245
epoch:  11, loss: 0.5112513899803162
epoch:  12, loss: 0.5092458128929138
epoch:  13, loss: 0.5072413086891174
epoch:  14, loss: 0.5052377581596375
epoch:  15, loss: 0.5032352209091187
epoch:  16, loss: 0.5012338161468506
epoch:  17, loss: 0.4992333948612213
epoch:  18, loss: 0.4972340166568756
epoch:  19, loss: 0.49523571133613586
epoch:  20, loss: 0.4932384788990021
epoch:  21, loss: 0.49124228954315186
epoch:  22, loss: 0.48924720287323
epoch:  23, loss: 0.48725321888923645
epoch:  24, loss: 0.4852602481842041
epoch:  25, loss: 0.48326846957206726
epoch:  26, loss: 0.48127782344818115
epoch:  27