In [7]:
!pip install muax==0.0.2.7

[0mCollecting muax==0.0.2.7
  Downloading muax-0.0.2.7-py3-none-any.whl (20 kB)
[0mInstalling collected packages: muax
  Attempting uninstall: muax
[0m    Found existing installation: muax 0.0.2.6
    Uninstalling muax-0.0.2.6:
      Successfully uninstalled muax-0.0.2.6
[0mSuccessfully installed muax-0.0.2.7
[0m

In [1]:
import jax 
jax.config.update('jax_platform_name', 'cpu')

import muax
from muax import nn 

# 1. Use `muax.fit` to fit CartPole-v1

`muax` provides example `representation`, `prediction` and `dynamic` modules

In [2]:
support_size = 10 
embedding_size = 8
discount = 0.99
full_support_size = int(support_size * 2 + 1)
repr_fn = nn._init_representation_func(nn.Representation, embedding_size)
pred_fn = nn._init_prediction_func(nn.Prediction, 2, full_support_size)
dy_fn = nn._init_dynamic_func(nn.Dynamic, embedding_size, 2, full_support_size)

`muax` has `Episode tracer` and `replay buffuer` to track and store trajectories from interacting with environments

In [13]:
tracer = muax.PNStep(10, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

`muax` leverages `optax` to update weights

In [14]:
gradient_transform = muax.model.optimizer(init_value=0.02, peak_value=0.02, end_value=0.002, warmup_steps=5000, transition_steps=5000)
# gradient_transform = optax.adam(0.02)

In [15]:
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model = muax.fit(model, 'CartPole-v1', 
                max_episodes=1000,
                max_training_steps=10000,
                tracer=tracer,
                buffer=buffer,
                k_steps=10,
                sample_per_trajectory=1,
                num_trajectory=32,
                tensorboard_dir='data/tensorboard/',
                save_path='cartpole_model_params',
                random_seed=0,
                log_all_metrics=True)

buffer warm up stage...
start training...


INFO:TrainMonitor:ep: 1,	T: 9,	G: 8,	avg_r: 1,	avg_G: 8,	t: 8,	dt: 2204.659ms,	v: 0.0416,	Rn: 4.4,	loss: 4.46,	training_step: 7,	test_G: 9
INFO:TrainMonitor:ep: 2,	T: 18,	G: 8,	avg_r: 1,	avg_G: 8,	t: 8,	dt: 820.472ms,	v: 5.68,	Rn: 4.4,	loss: 2.2,	training_step: 14
INFO:TrainMonitor:ep: 3,	T: 54,	G: 35,	avg_r: 1,	avg_G: 17,	t: 35,	dt: 555.296ms,	v: 6.94,	Rn: 12.9,	loss: 1.87,	training_step: 48
INFO:TrainMonitor:ep: 4,	T: 65,	G: 10,	avg_r: 1,	avg_G: 15.2,	t: 10,	dt: 14.027ms,	v: 8.31,	Rn: 5.34,	loss: 1.69,	training_step: 57
INFO:TrainMonitor:ep: 5,	T: 80,	G: 14,	avg_r: 1,	avg_G: 15,	t: 14,	dt: 13.860ms,	v: 8.36,	Rn: 8.46,	loss: 1.64,	training_step: 70
INFO:TrainMonitor:ep: 6,	T: 89,	G: 8,	avg_r: 1,	avg_G: 13.8,	t: 8,	dt: 13.392ms,	v: 8.99,	Rn: 4.4,	loss: 1.65,	training_step: 77
INFO:TrainMonitor:ep: 7,	T: 100,	G: 10,	avg_r: 1,	avg_G: 13.3,	t: 10,	dt: 13.598ms,	v: 8.66,	Rn: 5.34,	loss: 1.65,	training_step: 86
INFO:TrainMonitor:ep: 8,	T: 111,	G: 10,	avg_r: 1,	avg_G: 12.9,	t: 10,	dt: 14.248

# 2. Customize the training loop