In [1]:
!pip install muax==0.0.2.7.1b0

[0mCollecting muax==0.0.2.7.1b0
  Downloading muax-0.0.2.7.1b0-py3-none-any.whl (20 kB)
[0mInstalling collected packages: muax
  Attempting uninstall: muax
[0m    Found existing installation: muax 0.0.2.7.1a0
    Uninstalling muax-0.0.2.7.1a0:
      Successfully uninstalled muax-0.0.2.7.1a0
[0mSuccessfully installed muax-0.0.2.7.1b0
[0m

In [2]:
import jax 
from jax import numpy as jnp
jax.config.update('jax_platform_name', 'cpu')

import muax
from muax import nn 

In [3]:
import haiku as hk

class Representation(hk.Module):
  def __init__(self, embedding_dim, name='representation'):
    super().__init__(name=name)

    self.repr_func = hk.Sequential([
        hk.Linear(embedding_dim)
    ])

  def __call__(self, obs):
    s = self.repr_func(obs)
    return s 


class Prediction(hk.Module):
  def __init__(self, num_actions, full_support_size, name='prediction'):
    super().__init__(name=name)        
    
    self.v_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.pi_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(num_actions)
    ])
  
  def __call__(self, s):
    v = self.v_func(s)
    logits = self.pi_func(s)
    logits = jax.nn.softmax(logits, axis=-1)
    return v, logits


class Dynamic(hk.Module):
  def __init__(self, embedding_dim, num_actions, full_support_size, name='dynamic'):
    super().__init__(name=name)
    
    self.ns_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(embedding_dim)
    ])
    self.r_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.cat_func = jax.jit(lambda s, a: 
                            jnp.concatenate([s, jax.nn.one_hot(a, num_actions)],
                                            axis=1)
                            )
  
  def __call__(self, s, a):
    sa = self.cat_func(s, a)
    r = self.r_func(sa)
    ns = self.ns_func(sa)
    return r, ns


def init_representation_func(representation_module, embedding_dim):
    def representation_func(obs):
      repr_model = representation_module(embedding_dim)
      return repr_model(obs)
    return representation_func
  
def init_prediction_func(prediction_module, num_actions, full_support_size):
  def prediction_func(s):
    pred_model = prediction_module(num_actions, full_support_size)
    return pred_model(s)
  return prediction_func

def init_dynamic_func(dynamic_module, embedding_dim, num_actions, full_support_size):
  def dynamic_func(s, a):
    dy_model = dynamic_module(embedding_dim, num_actions, full_support_size)
    return dy_model(s, a)
  return dynamic_func 

In [4]:
support_size = 10 
embedding_size = 10
discount = 0.999
num_actions = 4
full_support_size = int(support_size * 2 + 1)

repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, num_actions, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, num_actions, full_support_size)

tracer = muax.PNStep(50, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

gradient_transform = muax.model.optimizer(init_value=0.005, peak_value=0.005, end_value=0.005, warmup_steps=20000, transition_steps=20000)

In [5]:
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'LunarLander-v2', 
                    max_episodes=1000,
                    max_training_steps=40000,
                    tracer=tracer,
                    buffer=buffer,
                    k_steps=10,
                    sample_per_trajectory=1,
                    num_trajectory=32,
                    tensorboard_dir='/home/fangbowen/tensorboard/LunarLander',
                    model_save_path='/home/fangbowen/models/LunarLander',
                    save_name='lunarlander_model_params',
                    random_seed=0,
                    log_all_metrics=True)


buffer warm up stage...
start training...


INFO:TrainMonitor:ep: 1,	T: 58,	G: -186,	avg_r: -3.26,	avg_G: -186,	t: 57,	dt: 157.692ms,	v: -2.31,	Rn: -131,	loss: 3.97,	training_step: 56,	test_G: -595
INFO:TrainMonitor:ep: 2,	T: 121,	G: -104,	avg_r: -1.68,	avg_G: -145,	t: 62,	dt: 23.427ms,	v: -65.1,	Rn: -91.8,	loss: 3.73,	training_step: 117
INFO:TrainMonitor:ep: 3,	T: 180,	G: -176,	avg_r: -3.04,	avg_G: -155,	t: 58,	dt: 23.121ms,	v: -65.5,	Rn: -143,	loss: 3.66,	training_step: 174
INFO:TrainMonitor:ep: 4,	T: 241,	G: -200,	avg_r: -3.34,	avg_G: -167,	t: 60,	dt: 23.328ms,	v: -79.7,	Rn: -160,	loss: 3.62,	training_step: 233
INFO:TrainMonitor:ep: 5,	T: 297,	G: -136,	avg_r: -2.47,	avg_G: -161,	t: 55,	dt: 23.004ms,	v: -72.1,	Rn: -114,	loss: 3.52,	training_step: 287
INFO:TrainMonitor:ep: 6,	T: 362,	G: -38.9,	avg_r: -0.608,	avg_G: -140,	t: 64,	dt: 23.243ms,	v: -87.6,	Rn: -54.5,	loss: 3.54,	training_step: 350
INFO:TrainMonitor:ep: 7,	T: 437,	G: -88.6,	avg_r: -1.2,	avg_G: -133,	t: 74,	dt: 23.054ms,	v: -60.7,	Rn: -86.9,	loss: 3.48,	training_step:

In [6]:
model_path

'/home/fangbowen/models/LunarLander/epoch_0070_test_G_65.80075034/lunarlander_model_params'

In [15]:
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model.load(model_path)

In [16]:
import gymnasium as gym 
from muax.test import test
env_id = 'LunarLander-v2'
test_env = gym.make(env_id, render_mode='rgb_array')
test_key = jax.random.PRNGKey(0)
test(model, test_env, test_key, num_simulations=50, num_test_episodes=100, random_seed=None)

71.58757116828816