In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_soom
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

In [7]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.f1 = nn.Linear(1, 10)
        self.f2 = nn.Linear(10, 20)
        self.f3 = nn.Linear(20, 20)
        self.f4 = nn.Linear(20, 10)
        self.f5 = nn.Linear(10, 1)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [8]:
X = np.random.uniform(0, 1, size=(300, 1))
# y = X[:, 0] - X[:, 1]**2 + 2 * X[:, 2] * X[:, 3] + (1 / ((1 + X[:, 4]) ** 6))
y = np.sinc(X).sum(axis=1, keepdims=True)

torch_data = TensorDataset(torch.Tensor(X), torch.Tensor(y))
data_loader = DataLoader(torch_data, batch_size=100)

In [9]:
model = Net()
loss_fn = nn.MSELoss()
opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x)

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().numpy().item()))

epoch:  0, loss: 0.6954212188720703
epoch:  1, loss: 0.49472880363464355
epoch:  2, loss: 0.3285824954509735
epoch:  3, loss: 0.23577441275119781
epoch:  4, loss: 0.18615181744098663
epoch:  5, loss: 0.15461742877960205
epoch:  6, loss: 0.13190722465515137
epoch:  7, loss: 0.12261179834604263
epoch:  8, loss: 0.11546588689088821
epoch:  9, loss: 0.11171431094408035
epoch:  10, loss: 0.10975382477045059
epoch:  11, loss: 0.10867925733327866
epoch:  12, loss: 0.1080649271607399
epoch:  13, loss: 0.10769862681627274
epoch:  14, loss: 0.10750968009233475
epoch:  15, loss: 0.10738686472177505
epoch:  16, loss: 0.10731904953718185
epoch:  17, loss: 0.10726212710142136
epoch:  18, loss: 0.10717464238405228
epoch:  19, loss: 0.16936153173446655
epoch:  20, loss: 0.2588174343109131
epoch:  21, loss: 0.13776926696300507
epoch:  22, loss: 0.12368468195199966
epoch:  23, loss: 0.11617035418748856
epoch:  24, loss: 0.11184003204107285
epoch:  25, loss: 0.10974719375371933
epoch:  26, loss: 0.108638