In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_soom
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

In [2]:
device = 'cpu'

In [3]:
class Net(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.f1 = nn.Linear(1, 10, device=device)
        self.f2 = nn.Linear(10, 20, device=device)
        self.f3 = nn.Linear(20, 20, device=device)
        self.f4 = nn.Linear(20, 10, device=device)
        self.f5 = nn.Linear(10, 1, device=device)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [4]:
X = np.random.uniform(0, 1, size=(1000, 1))
# y = X[:, 0] - X[:, 1]**2 + 2 * X[:, 2] * X[:, 3] + (1 / ((1 + X[:, 4]) ** 6))
y = np.sinc(X).sum(axis=1, keepdims=True)

torch_data = TensorDataset(torch.Tensor(X).to(device), torch.Tensor(y).to(device))
data_loader = DataLoader(torch_data, batch_size=1000)

In [5]:
model = Net(device=device)
loss_fn = nn.MSELoss()
opt = pytorch_soom.Newton(model.parameters(), lr=1, model=model)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x)

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().numpy().item()))

epoch:  0, loss: 0.6696417331695557
epoch:  1, loss: 0.6703011393547058
epoch:  2, loss: 0.6636910438537598
epoch:  3, loss: 0.6507073640823364
epoch:  4, loss: 0.6428490281105042
epoch:  5, loss: 0.6296328902244568
epoch:  6, loss: 0.6126193404197693
epoch:  7, loss: 0.6042999625205994
epoch:  8, loss: 0.5903075933456421
epoch:  9, loss: 0.57951420545578
epoch:  10, loss: 0.569944441318512
epoch:  11, loss: 0.5611693263053894
epoch:  12, loss: 0.5530411601066589
epoch:  13, loss: 0.5438578128814697
epoch:  14, loss: 0.5330502390861511
epoch:  15, loss: 0.5199289917945862
epoch:  16, loss: 0.22097478806972504
epoch:  17, loss: 0.21790513396263123
epoch:  18, loss: 0.21519875526428223
epoch:  19, loss: 0.21125604212284088
epoch:  20, loss: 0.2070716768503189
epoch:  21, loss: 0.19994871318340302
epoch:  22, loss: 0.1974535435438156
epoch:  23, loss: 0.1926138550043106
epoch:  24, loss: 0.18849338591098785
epoch:  25, loss: 0.18567220866680145
epoch:  26, loss: 0.1822865605354309
epoch: 