Skip to content

kh11kim/mlp_jax_starting

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Simple Learning Pipeline using Jax

This repository contains code for a basic learning pipeline using Jax, a numerical computing library for machine learning research.

Usage

To use the code in this repository, follow these steps:

  1. Set hyperparameters:
hp = Hyperparam()
hp.dims = [2, 10, 10, 1]
hp.lr = 0.001
hp.batch_size = 128
 
  1. Load data:
df = pd.read_csv("training_data/circle.csv")
dataset = NumpyDataset(df[["x", "y"]].to_numpy(), df["d"].to_numpy())
train_dataset, val_dataset = train_test_split(dataset, train_size=0.9, shuffle=True)

train_loader = data.DataLoader(
    train_dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=numpy_collate)
val_loader = data.DataLoader(
    val_dataset, batch_size=hp.batch_size, collate_fn=numpy_collate)
  1. Create model and initialize parameters:
model = get_mlp(hp)
key1, key2 = random.split(random.PRNGKey(0))
x = random.normal(key1, (2,))
params = model.init(key2, x)
  1. Train model and save checkpoints:
tx = optax.adam(learning_rate=hp.lr)
state = TrainState.create(apply_fn=model.apply, params=params, tx=tx)
trained_state = trainer(
    state, train_loader, val_loader, l2_loss_fn,
    num_epochs=100, exp_str=hp.as_str())
    
save("model", trained_state, hp, force=True)
  1. Load last checkpoint and use model:
sdf_fn = get_mlp_by_path("./model")
sdf_fn(jnp.zeros(2))

License

This project is licensed under the terms of the MIT license.

About

Simple Learning Pipeline using Jax

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published