In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_soom
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

In [2]:
device = 'cuda'

In [3]:
class Net(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.f1 = nn.Linear(1, 10, device=device)
        self.f2 = nn.Linear(10, 20, device=device)
        self.f3 = nn.Linear(20, 20, device=device)
        self.f4 = nn.Linear(20, 10, device=device)
        self.f5 = nn.Linear(10, 1, device=device)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [4]:
X = np.random.uniform(0, 1, size=(1000, 1), )
# y = X[:, 0] - X[:, 1]**2 + 2 * X[:, 2] * X[:, 3] + (1 / ((1 + X[:, 4]) ** 6))
y = np.sinc(X).sum(axis=1, keepdims=True)

torch_data = TensorDataset(torch.Tensor(X).to(device), torch.Tensor(y).to(device))
data_loader = DataLoader(torch_data, batch_size=1000)

In [7]:
model = Net(device=device)
loss_fn = nn.MSELoss()
opt = pytorch_soom.LM(model.parameters(), lr=1, ld=1, model=model, use_diagonal=False)

all_loss = []
patience = 0
max_patience = 10
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss.append(0)
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x)
        # opt.update(loss)

        all_loss[epoch] += loss
    
    all_loss[epoch] /= len(data_loader)

    if epoch > 0 and all_loss[epoch-1] < all_loss[epoch]:
        patience -= 1
    else:
        patience = max_patience
        

    print(', loss: {}'.format(all_loss[epoch].detach().cpu().numpy().item()))

    if patience <= 0:
        break

epoch:  0, loss: 0.5460023283958435
epoch:  1, loss: 0.5388656854629517
epoch:  2, loss: 0.5317680239677429
epoch:  3, loss: 0.5247076749801636
epoch:  4, loss: 0.5176854133605957
epoch:  5, loss: 0.5102195143699646
epoch:  6, loss: 0.5034319162368774
epoch:  7, loss: 0.4967721700668335
epoch:  8, loss: 0.49019044637680054
epoch:  9, loss: 0.48401814699172974
epoch:  10, loss: 0.47814521193504333
epoch:  11, loss: 0.4717056453227997
epoch:  12, loss: 0.4656820297241211
epoch:  13, loss: 0.45930030941963196
epoch:  14, loss: 0.4529972970485687
epoch:  15, loss: 0.44826504588127136
epoch:  16, loss: 0.44304463267326355
epoch:  17, loss: 0.4381445348262787
epoch:  18, loss: 0.4325650930404663
epoch:  19, loss: 0.4270698130130768
epoch:  20, loss: 0.4212556779384613
epoch:  21, loss: 0.4158148467540741
epoch:  22, loss: 0.41040951013565063
epoch:  23, loss: 0.40490931272506714
epoch:  24, loss: 0.3995688855648041
epoch:  25, loss: 0.39403975009918213
epoch:  26, loss: 0.38878437876701355
e

In [6]:
model = Net(device=device)
loss_fn = nn.MSELoss()
opt = pytorch_soom.LM(model.parameters(), lr=1, ld=1, model=model, use_diagonal=True, debug_stability=True)

all_loss = {}
patience = 0
max_patience = 10
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step(b_x)
        opt.update(loss)

        all_loss[epoch+1] += loss
    
    all_loss[epoch+1] /= len(data_loader)

    if epoch > 0 and all_loss[epoch] < all_loss[epoch+1]:
        patience -= 1
    else:
        patience = max_patience
        

    print(', loss: {}'.format(all_loss[epoch+1].detach().cpu().numpy().item()))

    if patience <= 0:
        break

epoch:  0

  opt = pytorch_soom.LM(model.parameters(), lr=1, ld=1, model=model, use_diagonal=True, debug_stability=True)


, loss: 0.3595387935638428
epoch:  1, loss: 0.35622265934944153
epoch:  2, loss: 0.3526589572429657
epoch:  3, loss: 0.35100114345550537
epoch:  4, loss: 0.3460403382778168
epoch:  5, loss: 0.34098583459854126
epoch:  6, loss: 0.33836400508880615
epoch:  7, loss: 0.33150574564933777
epoch:  8, loss: 0.3263603150844574
epoch:  9, loss: 0.3205708861351013
epoch:  10, loss: 0.3160203695297241
epoch:  11, loss: 0.31031689047813416
epoch:  12, loss: 0.3050231337547302
epoch:  13, loss: 0.300214946269989
epoch:  14, loss: 0.2954219877719879
epoch:  15, loss: 0.29049327969551086
epoch:  16, loss: 0.28598761558532715
epoch:  17, loss: 0.2814864218235016
epoch:  18, loss: 0.27658283710479736
epoch:  19, loss: 0.2722622752189636
epoch:  20, loss: 0.26798656582832336
epoch:  21, loss: 0.26362207531929016
epoch:  22, loss: 0.2598413825035095
epoch:  23, loss: 0.2556151747703552
epoch:  24, loss: 0.25531333684921265
epoch:  25, loss: 0.25091052055358887
epoch:  26, loss: 0.24656879901885986
epoch: 