In [1]:
import numpy as np
import pandas as pd
import jax
from jax import random
import jax.numpy as jnp
from network import *
from train import *
from dataset import *
from loss import *
from utils import *
from flax.training.train_state import TrainState
import torch.utils.data as data
from sklearn.model_selection import train_test_split

In [2]:
hp = Hyperparam()
hp.layers = [2, 10, 10, 1]
hp.lr = 0.001
hp.batch_size = 128

In [3]:
# data
df = pd.read_csv("training_data/circle.csv")
dataset = NumpyDataset(df[["x", "y"]].to_numpy(), df["d"].to_numpy())
train_dataset, val_dataset = train_test_split(dataset, train_size=0.9, shuffle=True)
train_loader = data.DataLoader(
    train_dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = data.DataLoader(
    val_dataset, batch_size=hp.batch_size, collate_fn=numpy_collate)

In [4]:
model = MLP(hp.layers)

key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (2,)) # Dummy input data
params = model.init(key2, x) # Initialization call
tx = optax.adam(learning_rate=hp.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [5]:
trained_state = trainer(
    state, train_loader, val_loader, l2_loss_fn, 
    num_epochs=1000, exp_str=hp.to_str())

TRAIN: EPOCH 1/1000 | BATCH 0/71 | LOSS: 0.4082540273666382
TRAIN: EPOCH 1/1000 | BATCH 1/71 | LOSS: 0.3994673639535904
TRAIN: EPOCH 1/1000 | BATCH 2/71 | LOSS: 0.36648546655972797
TRAIN: EPOCH 1/1000 | BATCH 3/71 | LOSS: 0.37646012753248215
TRAIN: EPOCH 1/1000 | BATCH 4/71 | LOSS: 0.384040492773056
TRAIN: EPOCH 1/1000 | BATCH 5/71 | LOSS: 0.3826052596171697
TRAIN: EPOCH 1/1000 | BATCH 6/71 | LOSS: 0.3794936750616346
TRAIN: EPOCH 1/1000 | BATCH 7/71 | LOSS: 0.3876274563372135
TRAIN: EPOCH 1/1000 | BATCH 8/71 | LOSS: 0.3839373257425096
TRAIN: EPOCH 1/1000 | BATCH 9/71 | LOSS: 0.3728368878364563
TRAIN: EPOCH 1/1000 | BATCH 10/71 | LOSS: 0.36880364743146027
TRAIN: EPOCH 1/1000 | BATCH 11/71 | LOSS: 0.3634077807267507
TRAIN: EPOCH 1/1000 | BATCH 12/71 | LOSS: 0.35895123848548305
TRAIN: EPOCH 1/1000 | BATCH 13/71 | LOSS: 0.35261740003313335
TRAIN: EPOCH 1/1000 | BATCH 14/71 | LOSS: 0.3463959634304047
TRAIN: EPOCH 1/1000 | BATCH 15/71 | LOSS: 0.3444713409990072
TRAIN: EPOCH 1/1000 | BATCH 16

In [16]:
jax.tree_map(lambda x:x.shape, params)

{'params': {'layers_0': {'bias': (12,), 'kernel': (2, 12)},
  'layers_1': {'bias': (10,), 'kernel': (12, 10)},
  'layers_2': {'bias': (1,), 'kernel': (10, 1)}}}

In [8]:
# load the last checkpoint
params = load("checkpoint/layers:2_10_10_1,lr:0.001,batch_size:128/990/default")
bind_model = model.bind(params)
bind_model(x)

Array([-0.08845351], dtype=float32)