In [1]:
!pip install muax==0.0.2.7.1a0

[0m

In [1]:
import jax 
from jax import numpy as jnp
jax.config.update('jax_platform_name', 'cpu')

import muax
from muax import nn 

In [65]:
import haiku as hk

class Representation(hk.Module):
  def __init__(self, embedding_dim, name='representation'):
    super().__init__(name=name)

    self.repr_func = hk.Sequential([
        hk.Linear(embedding_dim)
    ])

  def __call__(self, obs):
    s = self.repr_func(obs)
    return s 


class Prediction(hk.Module):
  def __init__(self, num_actions, full_support_size, name='prediction'):
    super().__init__(name=name)        
    
    self.v_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.pi_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(num_actions)
    ])
  
  def __call__(self, s):
    v = self.v_func(s)
    logits = self.pi_func(s)
    logits = jax.nn.softmax(logits, axis=-1)
    return v, logits


class Dynamic(hk.Module):
  def __init__(self, embedding_dim, num_actions, full_support_size, name='dynamic'):
    super().__init__(name=name)
    
    self.ns_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(embedding_dim)
    ])
    self.r_func = hk.Sequential([
        hk.Linear(64), jax.nn.elu,
        hk.Linear(full_support_size)
    ])
    self.cat_func = jax.jit(lambda s, a: 
                            jnp.concatenate([s, jax.nn.one_hot(a, num_actions)],
                                            axis=1)
                            )
  
  def __call__(self, s, a):
    sa = self.cat_func(s, a)
    r = self.r_func(sa)
    ns = self.ns_func(sa)
    return r, ns


def init_representation_func(representation_module, embedding_dim):
    def representation_func(obs):
      repr_model = representation_module(embedding_dim)
      return repr_model(obs)
    return representation_func
  
def init_prediction_func(prediction_module, num_actions, full_support_size):
  def prediction_func(s):
    pred_model = prediction_module(num_actions, full_support_size)
    return pred_model(s)
  return prediction_func

def init_dynamic_func(dynamic_module, embedding_dim, num_actions, full_support_size):
  def dynamic_func(s, a):
    dy_model = dynamic_module(embedding_dim, num_actions, full_support_size)
    return dy_model(s, a)
  return dynamic_func 

In [66]:
support_size = 10 
embedding_size = 10
discount = 0.999
num_actions = 4
full_support_size = int(support_size * 2 + 1)

repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, num_actions, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, num_actions, full_support_size)

tracer = muax.PNStep(50, discount, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

gradient_transform = muax.model.optimizer(init_value=0.005, peak_value=0.005, end_value=0.005, warmup_steps=20000, transition_steps=20000)

In [67]:
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'LunarLander-v2', 
                    max_episodes=1000,
                    max_training_steps=40000,
                    tracer=tracer,
                    buffer=buffer,
                    k_steps=10,
                    sample_per_trajectory=1,
                    num_trajectory=32,
                    tensorboard_dir='/home/fangbowen/tensorboard/LunarLander',
                    model_save_path='/home/fangbowen/models/LunarLander',
                    save_name='lunarlander_model_params',
                    random_seed=0,
                    log_all_metrics=True)


buffer warm up stage...
start training...


INFO:TrainMonitor:ep: 1,	T: 59,	G: -20.6,	avg_r: -0.354,	avg_G: -20.6,	t: 58,	dt: 177.084ms,	v: -2.25,	Rn: 2.98,	loss: 3.95,	training_step: 57,	test_G: -861
INFO:TrainMonitor:ep: 2,	T: 123,	G: -140,	avg_r: -2.22,	avg_G: -80.1,	t: 63,	dt: 24.169ms,	v: -86.7,	Rn: -107,	loss: 3.67,	training_step: 119
INFO:TrainMonitor:ep: 3,	T: 197,	G: -187,	avg_r: -2.57,	avg_G: -116,	t: 73,	dt: 23.888ms,	v: -97.3,	Rn: -146,	loss: 3.61,	training_step: 191
INFO:TrainMonitor:ep: 4,	T: 254,	G: -197,	avg_r: -3.52,	avg_G: -136,	t: 56,	dt: 25.107ms,	v: -56.1,	Rn: -152,	loss: 3.62,	training_step: 246
INFO:TrainMonitor:ep: 5,	T: 316,	G: -114,	avg_r: -1.87,	avg_G: -132,	t: 61,	dt: 24.207ms,	v: -67,	Rn: -102,	loss: 3.48,	training_step: 306
INFO:TrainMonitor:ep: 6,	T: 377,	G: -112,	avg_r: -1.87,	avg_G: -128,	t: 60,	dt: 24.053ms,	v: -100,	Rn: -86.6,	loss: 3.45,	training_step: 365
INFO:TrainMonitor:ep: 7,	T: 434,	G: -78.2,	avg_r: -1.4,	avg_G: -121,	t: 56,	dt: 24.292ms,	v: -90.1,	Rn: -69.5,	loss: 3.4,	training_step: 42

In [68]:
model_path

'/home/fangbowen/models/LunarLander/epoch_0090_test_G_-34.79156123/lunarlander_model_params'

In [69]:
model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=discount,
                    optimizer=gradient_transform, support_size=support_size)

model.load(model_path)

In [70]:
import gymnasium as gym 
from muax.test import test
env_id = 'LunarLander-v2'
test_env = gym.make(env_id, render_mode='rgb_array')
test_key = jax.random.PRNGKey(0)
test(model, test_env, test_key, num_simulations=50, num_test_episodes=100, random_seed=None)

48.28451863343125