In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_soom
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset

In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.f1 = nn.Linear(1, 10)
        self.f2 = nn.Linear(10, 20)
        self.f3 = nn.Linear(20, 20)
        self.f4 = nn.Linear(20, 10)
        self.f5 = nn.Linear(10, 1)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [6]:
X = np.random.uniform(0, 1, size=(300, 1))
# y = X[:, 0] - X[:, 1]**2 + 2 * X[:, 2] * X[:, 3] + (1 / ((1 + X[:, 4]) ** 6))
y = np.sinc(X).sum(axis=1, keepdims=True)

torch_data = TensorDataset(torch.Tensor(X), torch.Tensor(y))
data_loader = DataLoader(torch_data, batch_size=100)

In [7]:
model = Net()
loss_fn = nn.MSELoss()
opt = optim.SGD(model.parameters(), lr = 1e-2)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step()

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().numpy().item()))

epoch:  0, loss: 0.7849858403205872
epoch:  1, loss: 0.6961668133735657
epoch:  2, loss: 0.6188426613807678
epoch:  3, loss: 0.5514369010925293
epoch:  4, loss: 0.49262356758117676
epoch:  5, loss: 0.44128191471099854
epoch:  6, loss: 0.3964559733867645
epoch:  7, loss: 0.3573244512081146
epoch:  8, loss: 0.32317742705345154
epoch:  9, loss: 0.29339954257011414
epoch:  10, loss: 0.26753804087638855
epoch:  11, loss: 0.245152547955513
epoch:  12, loss: 0.2257201075553894
epoch:  13, loss: 0.2088526040315628
epoch:  14, loss: 0.19421260058879852
epoch:  15, loss: 0.1815105676651001
epoch:  16, loss: 0.17049908638000488
epoch:  17, loss: 0.1609625369310379
epoch:  18, loss: 0.15271227061748505
epoch:  19, loss: 0.14558278024196625
epoch:  20, loss: 0.13942863047122955
epoch:  21, loss: 0.13412193953990936
epoch:  22, loss: 0.12955059111118317
epoch:  23, loss: 0.12561637163162231
epoch:  24, loss: 0.1222333088517189
epoch:  25, loss: 0.11932641267776489
epoch:  26, loss: 0.116830267012119

In [8]:
model = Net()
loss_fn = nn.MSELoss()
opt = optim.Adam(model.parameters(), lr = 1e-2)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step()

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().numpy().item()))

epoch:  0, loss: 0.5276046395301819
epoch:  1, loss: 0.3672352135181427
epoch:  2, loss: 0.20673155784606934
epoch:  3, loss: 0.16143134236335754
epoch:  4, loss: 0.15939025580883026
epoch:  5, loss: 0.10734068602323532
epoch:  6, loss: 0.09610405564308167
epoch:  7, loss: 0.08673232793807983
epoch:  8, loss: 0.06293223053216934
epoch:  9, loss: 0.04226066544651985
epoch:  10, loss: 0.025672635063529015
epoch:  11, loss: 0.007858662866055965
epoch:  12, loss: 0.004108670633286238
epoch:  13, loss: 0.004339640494436026
epoch:  14, loss: 0.007878604345023632
epoch:  15, loss: 0.0061804247088730335
epoch:  16, loss: 0.0032874392345547676
epoch:  17, loss: 0.0011900911340489984
epoch:  18, loss: 0.0012900785077363253
epoch:  19, loss: 0.0019112005829811096
epoch:  20, loss: 0.0017971588531509042
epoch:  21, loss: 0.0012107500806450844
epoch:  22, loss: 0.0005990276695229113
epoch:  23, loss: 0.0005925684818066657
epoch:  24, loss: 0.0006833143415860832
epoch:  25, loss: 0.00064061739249154