In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_soom
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import *
from sklearn.preprocessing import MinMaxScaler

In [2]:
device = 'cpu'

In [3]:
class Net(nn.Module):
    def __init__(self, input_size, device='cpu'):
        super().__init__()
        self.f1 = nn.Linear(input_size, 10, device=device)
        self.f2 = nn.Linear(10, 20, device=device)
        self.f3 = nn.Linear(20, 20, device=device)
        self.f4 = nn.Linear(20, 10, device=device)
        self.f5 = nn.Linear(10, 1, device=device)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [4]:
X, y = load_diabetes(return_X_y = True, scaled=False)
# X, y = make_regression(n_samples=1000, n_features=100)
# print(X.shape)

X_scaler = MinMaxScaler()
X = X_scaler.fit_transform(X)

y_scaler = MinMaxScaler()
y = y_scaler.fit_transform(y.reshape((-1, 1)))

torch_data = TensorDataset(torch.Tensor(X).to(device), torch.Tensor(y).to(device))
data_loader = DataLoader(torch_data, batch_size=1000)

In [11]:
model = Net(input_size = X.shape[1], device=device)
loss_fn = nn.MSELoss()
# opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=False, c1=1e-4, tau=0.5, line_search_method='backtrack', line_search_cond='armijo')
opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=False, c1=1e-4, tau=0.5, line_search_method='backtrack', line_search_cond='wolfe')
# opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=False, c1=1e-4, tau=0.5, line_search_method='backtrack', line_search_cond='strong-wolfe')
# opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=False, c1=1e-4, tau=0.5, line_search_method='backtrack', line_search_cond='goldstein')

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x, b_y, loss_fn)

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().numpy().item()))

epoch:  0, loss: 0.5097523331642151
epoch:  1, loss: 0.5009604692459106
epoch:  2, loss: 0.4380524158477783
epoch:  3, loss: 0.4298827648162842
epoch:  4, loss: 0.4264882504940033
epoch:  5, loss: 0.3951069116592407
epoch:  6, loss: 0.27165091037750244
epoch:  7, loss: 0.18259297311306
epoch:  8, loss: 0.17886075377464294
epoch:  9, loss: 0.1497703492641449
epoch:  10, loss: 0.1490883082151413
epoch:  11, loss: 0.11225183308124542
epoch:  12, loss: 0.09590497612953186
epoch:  13, loss: 0.0929986760020256
epoch:  14, loss: 0.09023471176624298
epoch:  15, loss: 0.07982666045427322
epoch:  16, loss: 0.07568445056676865
epoch:  17, loss: 0.07520782947540283
epoch:  18, loss: 0.07335309684276581
epoch:  19, loss: 0.07245758175849915
epoch:  20, loss: 0.0706978440284729
epoch:  21, loss: 0.06732102483510971
epoch:  22, loss: 0.06428799033164978
epoch:  23, loss: 0.058572329580783844
epoch:  24, loss: 0.05370477959513664
epoch:  25, loss: 0.049855269491672516
epoch:  26, loss: 0.0481640137732

In [10]:
model = Net(input_size = X.shape[1], device=device)
loss_fn = nn.MSELoss()
opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=True, c1=1e-4, tau=0.5, line_search_method='backtrack', line_search_cond="armijo")
# opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=True, c1=1e-4, tau=0.5, line_search_method='backtrack', line_search_cond="wolfe")
# opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=True, c1=1e-4, tau=0.5, line_search_method='backtrack', line_search_cond="strong-wolfe")
# opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model, hessian_approx=True, c1=0.1, tau=0.99, line_search_method='backtrack', line_search_cond="goldstein")

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x, b_y, loss_fn)

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().numpy().item()))

epoch:  0, loss: 0.08656758069992065
epoch:  1, loss: 0.08137121051549911
epoch:  2, loss: 0.06291179358959198
epoch:  3, loss: 0.06012265011668205
epoch:  4, loss: 0.05838647484779358
epoch:  5, loss: 0.058228667825460434
epoch:  6, loss: 0.05772688612341881
epoch:  7, loss: 0.057529449462890625
epoch:  8, loss: 0.05750676989555359
epoch:  9, loss: 0.05740099772810936
epoch:  10, loss: 0.0573103241622448
epoch:  11, loss: 0.057204682379961014
epoch:  12, loss: 0.056979354470968246
epoch:  13, loss: 0.056944865733385086
epoch:  14, loss: 0.05679263547062874
epoch:  15, loss: 0.05633760988712311
epoch:  16, loss: 0.05442000553011894
epoch:  17, loss: 0.05227173864841461
epoch:  18, loss: 0.050003618001937866
epoch:  19, loss: 0.04958706349134445
epoch:  20, loss: 0.04880253225564957
epoch:  21, loss: 0.03947758302092552
epoch:  22, loss: 0.03696170821785927
epoch:  23, loss: 0.03608682379126549
epoch:  24, loss: 0.03570878878235817
epoch:  25, loss: 0.033621177077293396
epoch:  26, loss