In [10]:
# Data handling
import pandas as pd
import numpy as np
import torch
import torch.nn as nn

# Bokeh libraries
from bokeh.io import output_file, output_notebook 
from bokeh.plotting import figure, show, from_networkx
from bokeh.models import ColumnDataSource, Circle, MultiLine
from bokeh.layouts import row, column, gridplot
from bokeh.models import ColumnDataSource, CustomJS, Slider
from bokeh.colors import Color

import networkx as nx

from datetime import datetime
from pathlib import Path


import nx_utils
from nx_utils import STATE_DICT


output_notebook()  # Render inline in a Jupyter Notebook

In [2]:
# create dataset

def f(data):
    x1 = data[:,[0]]
    x2 = data[:,[1]]
    x3 = data[:,[2]]
    x4 = data[:,[3]]
    out = np.transpose(np.array([(x1+x3)**3, x2**2+np.sin(np.pi*x4)]))
    return out

d_in = 4
d_out = 2

seed = 0
np.random.seed(seed)
torch.manual_seed(seed)

inputs = np.random.rand(100,d_in)*2-1
labels = f(inputs)
inputs = torch.tensor(inputs, dtype=torch.float, requires_grad=True)
labels = torch.tensor(labels, dtype=torch.float, requires_grad=True)

inputs_test = np.random.rand(100,d_in)*2-1
labels_test = f(inputs_test)
inputs_test = torch.tensor(inputs_test, dtype=torch.float, requires_grad=True)
labels_test = torch.tensor(labels_test, dtype=torch.float, requires_grad=True)

inputs.shape, labels.shape

(torch.Size([100, 4]), torch.Size([1, 100, 2]))

In [11]:
epochs = 1000
name = "Number-1"
base = Path("models") / name / datetime.now().strftime("%Y-%m-%d_%H%M%S")
base.mkdir(parents=True, exist_ok=True)

In [12]:

N=30
model = nn.Sequential(
    nn.Linear(d_in, N),
    nn.ReLU(),
    nn.Linear(N,N),
    nn.ReLU(),
    nn.Linear(N,N),
    nn.ReLU(),
    nn.Linear(N,d_out)
)

opt = torch.optim.SGD(model.parameters(), lr=0.01)

for epoch in range(epochs):
    
    opt.zero_grad()
    pred  = model(inputs)
    loss = torch.mean((pred-labels)**2)
    pred_test  = model(inputs_test)
    loss_test = torch.mean((pred_test-labels_test)**2)

    if epoch % 5 == 0:
        torch.save({
                'epoch': epoch,
                STATE_DICT: model.state_dict(),
                'optimizer_state_dict': opt.state_dict(),
                'train_loss': loss.item(),
                'test_loss': loss_test.item(),
                }, base / f"{epoch}.pt")
            

    print(loss.item())
    loss.backward()
    opt.step()

1.129734992980957
1.1285922527313232
1.1274690628051758
1.126362681388855
1.1252715587615967
1.1241933107376099
1.1231313943862915
1.1220858097076416
1.1210554838180542
1.1200425624847412
1.119044303894043
1.1180615425109863
1.1170920133590698
1.1161353588104248
1.115187406539917
1.1142423152923584
1.113308072090149
1.1123836040496826
1.1114736795425415
1.110573410987854
1.1096906661987305
1.1088154315948486
1.1079494953155518
1.1070905923843384
1.1062376499176025
1.1053942441940308
1.1045589447021484
1.1037304401397705
1.1029078960418701
1.1020907163619995
1.1012835502624512
1.1004884243011475
1.0996986627578735
1.0989161729812622
1.098140001296997
1.0973641872406006
1.096591830253601
1.0958212614059448
1.0950560569763184
1.0942943096160889
1.0935355424880981
1.0927799940109253
1.0920275449752808
1.0912779569625854
1.0905311107635498
1.0897868871688843
1.0890436172485352
1.0883022546768188
1.0875624418258667
1.086824893951416
1.0860908031463623
1.0853571891784668
1.0846227407455444
1.