In [1]:
import numpy as np
import pandas as pd
import jax
from jax import random
import jax.numpy as jnp
from network import *
from train import *
from dataset import *
from loss import *
from utils import *
from flax.training.train_state import TrainState
import torch.utils.data as data
from sklearn.model_selection import train_test_split

In [2]:
hyperparams = Hyperparam()
hyperparams.layers = [2, 10, 10, 1]
hyperparams.lr = 0.001
hyperparams.batch_size = 128

In [3]:
# data
df = pd.read_csv("circle.csv")
dataset = NumpyDataset(df[["x", "y"]].to_numpy(), df["d"].to_numpy())
train_dataset, val_dataset = train_test_split(dataset, train_size=0.9, shuffle=True)
train_loader = data.DataLoader(
    train_dataset, batch_size=hyperparams.batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = data.DataLoader(
    val_dataset, batch_size=hyperparams.batch_size, collate_fn=numpy_collate)

In [4]:
model = MLP(hyperparams.layers)

key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (2,)) # Dummy input data
params = model.init(key2, x) # Initialization call
tx = optax.adam(learning_rate=hyperparams.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)

In [6]:
trained_state = trainer(
    state, train_loader, val_loader, l2_loss_fn, 
    num_epochs=1000, exp_str=hyperparams.to_str())

TRAIN: EPOCH 1/1000 | BATCH 0/71 | LOSS: 0.3886238932609558
TRAIN: EPOCH 1/1000 | BATCH 1/71 | LOSS: 0.39779481291770935
TRAIN: EPOCH 1/1000 | BATCH 2/71 | LOSS: 0.4026128053665161
TRAIN: EPOCH 1/1000 | BATCH 3/71 | LOSS: 0.39013029634952545
TRAIN: EPOCH 1/1000 | BATCH 4/71 | LOSS: 0.3844456195831299
TRAIN: EPOCH 1/1000 | BATCH 5/71 | LOSS: 0.3762319087982178
TRAIN: EPOCH 1/1000 | BATCH 6/71 | LOSS: 0.37614625692367554
TRAIN: EPOCH 1/1000 | BATCH 7/71 | LOSS: 0.36857252195477486
TRAIN: EPOCH 1/1000 | BATCH 8/71 | LOSS: 0.3647873236073388
TRAIN: EPOCH 1/1000 | BATCH 9/71 | LOSS: 0.3602416843175888
TRAIN: EPOCH 1/1000 | BATCH 10/71 | LOSS: 0.36205787279389123
TRAIN: EPOCH 1/1000 | BATCH 11/71 | LOSS: 0.35807065417369205
TRAIN: EPOCH 1/1000 | BATCH 12/71 | LOSS: 0.3524192273616791
TRAIN: EPOCH 1/1000 | BATCH 13/71 | LOSS: 0.3483713056359972
TRAIN: EPOCH 1/1000 | BATCH 14/71 | LOSS: 0.34549347559611004
TRAIN: EPOCH 1/1000 | BATCH 15/71 | LOSS: 0.3419052269309759
TRAIN: EPOCH 1/1000 | BATCH