In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_soom
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

In [2]:
device = 'cuda'

In [3]:
class Net(nn.Module):
    def __init__(self, device='cpu'):
        super().__init__()
        self.f1 = nn.Linear(1, 10, device=device)
        self.f2 = nn.Linear(10, 20, device=device)
        self.f3 = nn.Linear(20, 20, device=device)
        self.f4 = nn.Linear(20, 10, device=device)
        self.f5 = nn.Linear(10, 1, device=device)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [4]:
X = np.random.uniform(0, 1, size=(1000, 1))
# y = X[:, 0] - X[:, 1]**2 + 2 * X[:, 2] * X[:, 3] + (1 / ((1 + X[:, 4]) ** 6))
y = np.sinc(X).sum(axis=1, keepdims=True)

torch_data = TensorDataset(torch.Tensor(X).to(device), torch.Tensor(y).to(device))
data_loader = DataLoader(torch_data, batch_size=1000)

In [5]:
model = Net(device=device)
loss_fn = nn.MSELoss()
opt = optim.SGD(model.parameters(), lr = 1e-2)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step()

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().cpu().numpy().item()))

epoch:  0, loss: 0.6124632954597473
epoch:  1, loss: 0.5835167765617371
epoch:  2, loss: 0.5563561320304871
epoch:  3, loss: 0.5308499932289124
epoch:  4, loss: 0.5068796277046204
epoch:  5, loss: 0.4843372702598572
epoch:  6, loss: 0.4631248414516449
epoch:  7, loss: 0.44315293431282043
epoch:  8, loss: 0.42433974146842957
epoch:  9, loss: 0.4066101312637329
epoch:  10, loss: 0.3898950219154358
epoch:  11, loss: 0.3741307556629181
epoch:  12, loss: 0.3592585623264313
epoch:  13, loss: 0.345223993062973
epoch:  14, loss: 0.33197659254074097
epoch:  15, loss: 0.31946951150894165
epoch:  16, loss: 0.3076590299606323
epoch:  17, loss: 0.2965046465396881
epoch:  18, loss: 0.2859683334827423
epoch:  19, loss: 0.27601465582847595
epoch:  20, loss: 0.2666104733943939
epoch:  21, loss: 0.25772473216056824
epoch:  22, loss: 0.24932830035686493
epoch:  23, loss: 0.2413937896490097
epoch:  24, loss: 0.23389555513858795
epoch:  25, loss: 0.22680939733982086
epoch:  26, loss: 0.22011259198188782
ep

In [6]:
model = Net(device=device)
loss_fn = nn.MSELoss()
opt = optim.Adam(model.parameters(), lr = 1e-2)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step()

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().cpu().numpy().item()))

epoch:  0, loss: 0.21514053642749786
epoch:  1, loss: 0.17482998967170715
epoch:  2, loss: 0.14475904405117035
epoch:  3, loss: 0.12503978610038757
epoch:  4, loss: 0.11315830051898956
epoch:  5, loss: 0.11294684559106827
epoch:  6, loss: 0.11638039350509644
epoch:  7, loss: 0.11348729580640793
epoch:  8, loss: 0.10554010421037674
epoch:  9, loss: 0.097958043217659
epoch:  10, loss: 0.09447843581438065
epoch:  11, loss: 0.09197928011417389
epoch:  12, loss: 0.09057636559009552
epoch:  13, loss: 0.08858001977205276
epoch:  14, loss: 0.08416341245174408
epoch:  15, loss: 0.07735150307416916
epoch:  16, loss: 0.07017824798822403
epoch:  17, loss: 0.06418316811323166
epoch:  18, loss: 0.05908946692943573
epoch:  19, loss: 0.053061045706272125
epoch:  20, loss: 0.04469316452741623
epoch:  21, loss: 0.035440560430288315
epoch:  22, loss: 0.02774355188012123
epoch:  23, loss: 0.022123951464891434
epoch:  24, loss: 0.016693610697984695
epoch:  25, loss: 0.011359634809195995
epoch:  26, loss: 0