In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch_numopt
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_diabetes
from sklearn.preprocessing import MinMaxScaler

In [2]:
device = "cpu"

In [3]:
class Net(nn.Module):
    def __init__(self, input_size, device="cpu"):
        super().__init__()
        self.f1 = nn.Linear(input_size, 10, device=device)
        self.f2 = nn.Linear(10, 20, device=device)
        self.f3 = nn.Linear(20, 20, device=device)
        self.f4 = nn.Linear(20, 10, device=device)
        self.f5 = nn.Linear(10, 1, device=device)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)

        return x

In [4]:
X, y = load_diabetes(return_X_y=True, scaled=False)

X_scaler = MinMaxScaler()
X = X_scaler.fit_transform(X)

y_scaler = MinMaxScaler()
y = y_scaler.fit_transform(y.reshape((-1, 1)))

torch_data = TensorDataset(torch.Tensor(X).to(device), torch.Tensor(y).to(device))
data_loader = DataLoader(torch_data, batch_size=1000)

In [5]:
model = Net(input_size=X.shape[1], device=device)
loss_fn = nn.MSELoss()
opt = torch_numopt.AGD(
    model.parameters(),
    lr=1,
    mu=0.001,
    mu_dec=0.1,
    model=model,
    use_diagonal=False,
    c1=1e-4,
    tau=0.5,
    line_search_method="backtrack",
    line_search_cond="armijo",
)

all_loss = []
patience = 0
max_patience = 10
for epoch in range(100):
    print("epoch: ", epoch, end="")
    all_loss.append(0)
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x, b_y, loss_fn)

        all_loss[epoch] += loss

    all_loss[epoch] /= len(data_loader)

    if epoch > 0 and all_loss[epoch - 1] <= all_loss[epoch]:
        patience -= 1
    else:
        patience = max_patience

    print(", loss: {}".format(all_loss[epoch].detach().cpu().numpy().item()))

    if patience <= 0:
        break

epoch:  0, loss: 0.4143599569797516
epoch:  1, loss: 0.32063186168670654
epoch:  2, loss: 0.28701144456863403
epoch:  3, loss: 0.03955715522170067
epoch:  4, loss: 0.03308854252099991
epoch:  5, loss: 0.030783135443925858
epoch:  6, loss: 0.029517097398638725
epoch:  7, loss: 0.029021598398685455
epoch:  8, loss: 0.028166640549898148
epoch:  9, loss: 0.027926409617066383
epoch:  10, loss: 0.027836577966809273
epoch:  11, loss: 0.027816487476229668
epoch:  12, loss: 0.027093276381492615
epoch:  13, loss: 0.027086913585662842
epoch:  14, loss: 0.026605214923620224
epoch:  15, loss: 0.02647796832025051
epoch:  16, loss: 0.02643202431499958
epoch:  17, loss: 0.02639693394303322
epoch:  18, loss: 0.0261886827647686
epoch:  19, loss: 0.02615484595298767
epoch:  20, loss: 0.026142122223973274
epoch:  21, loss: 0.026088343933224678
epoch:  22, loss: 0.026087898761034012
epoch:  23, loss: 0.02602299675345421
epoch:  24, loss: 0.026013804599642754
epoch:  25, loss: 0.02600683830678463
epoch:  26

In [6]:
model = Net(input_size=X.shape[1], device=device)
loss_fn = nn.MSELoss()
opt = torch_numopt.AGD(
    model.parameters(),
    lr=1,
    mu=0.001,
    mu_dec=0.1,
    model=model,
    use_diagonal=True,
    c1=1e-4,
    tau=0.1,
    line_search_method="backtrack",
    line_search_cond="armijo",
)

all_loss = []
patience = 0
max_patience = 10
for epoch in range(100):
    print("epoch: ", epoch, end="")
    all_loss.append(0)
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x, b_y, loss_fn)

        all_loss[epoch] += loss

    all_loss[epoch] /= len(data_loader)

    if epoch > 0 and all_loss[epoch - 1] <= all_loss[epoch]:
        patience -= 1
    else:
        patience = max_patience

    print(", loss: {}".format(all_loss[epoch].detach().cpu().numpy().item()))

    if patience <= 0:
        break

epoch:  0, loss: 0.33186641335487366
epoch:  1, loss: 0.13097988069057465
epoch:  2, loss: 0.05423133447766304
epoch:  3, loss: 0.043436627835035324
epoch:  4, loss: 0.031093105673789978
epoch:  5, loss: 0.028638187795877457
epoch:  6, loss: 0.027640851214528084
epoch:  7, loss: 0.026913736015558243
epoch:  8, loss: 0.026571044698357582
epoch:  9, loss: 0.02619623765349388
epoch:  10, loss: 0.02598540298640728
epoch:  11, loss: 0.025838101282715797
epoch:  12, loss: 0.025686677545309067
epoch:  13, loss: 0.02554275281727314
epoch:  14, loss: 0.025423740968108177
epoch:  15, loss: 0.025319743901491165
epoch:  16, loss: 0.025178784504532814
epoch:  17, loss: 0.02508353814482689
epoch:  18, loss: 0.025005122646689415
epoch:  19, loss: 0.024940500035881996
epoch:  20, loss: 0.02487681806087494
epoch:  21, loss: 0.024835428223013878
epoch:  22, loss: 0.024798540398478508
epoch:  23, loss: 0.024725237861275673
epoch:  24, loss: 0.024636441841721535
epoch:  25, loss: 0.024546850472688675
epoc