In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pytorch_soom
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset
from sklearn.datasets import load_diabetes
from sklearn.preprocessing import MinMaxScaler

In [2]:
device = 'cpu'

In [3]:
class Net(nn.Module):
    def __init__(self, input_size, device='cpu'):
        super().__init__()
        self.f1 = nn.Linear(input_size, 10, device=device)
        self.f2 = nn.Linear(10, 20, device=device)
        self.f3 = nn.Linear(20, 20, device=device)
        self.f4 = nn.Linear(20, 10, device=device)
        self.f5 = nn.Linear(10, 1, device=device)

        self.activation = nn.ReLU()
        # self.activation = nn.Sigmoid()

    def forward(self, x):
        x = self.activation(self.f1(x))
        x = self.activation(self.f2(x))
        x = self.activation(self.f3(x))
        x = self.activation(self.f4(x))
        x = self.f5(x)
        
        return x


In [4]:
X, y = load_diabetes(return_X_y = True, scaled=False)

X_scaler = MinMaxScaler()
X = X_scaler.fit_transform(X)

y_scaler = MinMaxScaler()
y = y_scaler.fit_transform(y.reshape((-1, 1)))

torch_data = TensorDataset(torch.Tensor(X).to(device), torch.Tensor(y).to(device))
data_loader = DataLoader(torch_data, batch_size=1000)

In [5]:
model = Net(input_size = X.shape[1], device=device)
loss_fn = nn.MSELoss()
opt = optim.SGD(model.parameters(), lr = 1e-2)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step()

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().cpu().numpy().item()))

epoch:  0, loss: 0.2699007987976074
epoch:  1, loss: 0.25835636258125305
epoch:  2, loss: 0.24751296639442444
epoch:  3, loss: 0.2373473048210144
epoch:  4, loss: 0.22782915830612183
epoch:  5, loss: 0.2188768982887268
epoch:  6, loss: 0.2104426771402359
epoch:  7, loss: 0.20253190398216248
epoch:  8, loss: 0.1951756775379181
epoch:  9, loss: 0.18834707140922546
epoch:  10, loss: 0.18198102712631226
epoch:  11, loss: 0.1759854108095169
epoch:  12, loss: 0.17031608521938324
epoch:  13, loss: 0.16496112942695618
epoch:  14, loss: 0.15989555418491364
epoch:  15, loss: 0.15509456396102905
epoch:  16, loss: 0.1505279690027237
epoch:  17, loss: 0.14618365466594696
epoch:  18, loss: 0.14205381274223328
epoch:  19, loss: 0.13812384009361267
epoch:  20, loss: 0.13438251614570618
epoch:  21, loss: 0.13081848621368408
epoch:  22, loss: 0.12742391228675842
epoch:  23, loss: 0.1241883635520935
epoch:  24, loss: 0.12110395729541779
epoch:  25, loss: 0.11816363036632538
epoch:  26, loss: 0.1153622865

In [6]:
model = Net(input_size = X.shape[1], device=device)
loss_fn = nn.MSELoss()
opt = optim.Adam(model.parameters(), lr = 1e-2)

all_loss = {}
for epoch in range(100):
    print('epoch: ', epoch, end='')
    all_loss[epoch+1] = 0
    for batch_idx, (b_x, b_y) in enumerate(data_loader):
        pre = model(b_x)
        loss = loss_fn(pre, b_y)
        opt.zero_grad()
        loss.backward()

        # parameter update step based on optimizer
        opt.step()

        all_loss[epoch+1] += loss
    all_loss[epoch+1] /= len(data_loader)
    print(', loss: {}'.format(all_loss[epoch+1].detach().cpu().numpy().item()))

epoch:  0, loss: 0.13254565000534058
epoch:  1, loss: 0.11314540356397629
epoch:  2, loss: 0.09315424412488937
epoch:  3, loss: 0.07415223121643066
epoch:  4, loss: 0.05929703637957573
epoch:  5, loss: 0.055472757667303085
epoch:  6, loss: 0.06621158123016357
epoch:  7, loss: 0.07040335983037949
epoch:  8, loss: 0.06514700502157211
epoch:  9, loss: 0.058726005256175995
epoch:  10, loss: 0.05507079139351845
epoch:  11, loss: 0.05429603531956673
epoch:  12, loss: 0.05515581741929054
epoch:  13, loss: 0.05646618828177452
epoch:  14, loss: 0.05750930309295654
epoch:  15, loss: 0.05796803906559944
epoch:  16, loss: 0.057780053466558456
epoch:  17, loss: 0.05703849717974663
epoch:  18, loss: 0.05593316629528999
epoch:  19, loss: 0.05470661073923111
epoch:  20, loss: 0.05362262204289436
epoch:  21, loss: 0.052906572818756104
epoch:  22, loss: 0.052679792046546936
epoch:  23, loss: 0.05288742482662201
epoch:  24, loss: 0.05328083783388138
epoch:  25, loss: 0.053513988852500916
epoch:  26, loss