In [1]:
import numpy as np
import pandas as pd
import jax
from jax import random
import jax.numpy as jnp
from network import *
from train import *
from dataset import *
from loss import *
from flax.training.train_state import TrainState
import torch.utils.data as data
from sklearn.model_selection import train_test_split

In [2]:
hp = Hyperparam()
hp.dims = [2, 10, 10, 1]
hp.lr = 0.001
hp.batch_size = 128

In [3]:
# data
df = pd.read_csv("training_data/circle.csv")
dataset = NumpyDataset(df[["x", "y"]].to_numpy(), df["d"].to_numpy())
train_dataset, val_dataset = train_test_split(dataset, train_size=0.9, shuffle=True)
train_loader = data.DataLoader(
    train_dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = data.DataLoader(
    val_dataset, batch_size=hp.batch_size, collate_fn=numpy_collate)

In [4]:
model = get_mlp(hp)

key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (2,)) # Dummy input data
params = model.init(key2, x) # Initialization call
tx = optax.adam(learning_rate=hp.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [6]:
trained_state = trainer(
    state, train_loader, val_loader, l2_loss_fn, 
    num_epochs=100, exp_str=hp.as_str())

TRAIN: EPOCH 1/100 | BATCH 0/71 | LOSS: 0.38530755043029785
TRAIN: EPOCH 1/100 | BATCH 1/71 | LOSS: 0.4003516882658005
TRAIN: EPOCH 1/100 | BATCH 2/71 | LOSS: 0.39103670914967853
TRAIN: EPOCH 1/100 | BATCH 3/71 | LOSS: 0.3876916915178299
TRAIN: EPOCH 1/100 | BATCH 4/71 | LOSS: 0.3773049354553223
TRAIN: EPOCH 1/100 | BATCH 5/71 | LOSS: 0.3731925090154012
TRAIN: EPOCH 1/100 | BATCH 6/71 | LOSS: 0.37016725540161133
TRAIN: EPOCH 1/100 | BATCH 7/71 | LOSS: 0.3682968057692051
TRAIN: EPOCH 1/100 | BATCH 8/71 | LOSS: 0.36428865128093296
TRAIN: EPOCH 1/100 | BATCH 9/71 | LOSS: 0.356124347448349
TRAIN: EPOCH 1/100 | BATCH 10/71 | LOSS: 0.3581731075590307
TRAIN: EPOCH 1/100 | BATCH 11/71 | LOSS: 0.3545665070414543
TRAIN: EPOCH 1/100 | BATCH 12/71 | LOSS: 0.3523436693044809
TRAIN: EPOCH 1/100 | BATCH 13/71 | LOSS: 0.3465723821095058
TRAIN: EPOCH 1/100 | BATCH 14/71 | LOSS: 0.34332483212153114
TRAIN: EPOCH 1/100 | BATCH 15/71 | LOSS: 0.33864808082580566
TRAIN: EPOCH 1/100 | BATCH 16/71 | LOSS: 0.33

In [8]:
save("model", trained_state, hp, force=True)

In [9]:
# load the last checkpoint
sdf_fn = get_mlp_by_path("./model")
sdf_fn(jnp.zeros(2))