In [None]:
%load_ext autoreload
%autoreload 2

import gymnasium as gym


from vi_ppo.actor_critic import ActorCritic
from vi_ppo.nets.mlp import Mlp
from vi_ppo.rl_module import RlModule
import lightning as pl
from pytorch_lightning.loggers import TensorBoardLogger


In [None]:
# Initialise the environment
env = gym.make("LunarLander-v3")

# make the actor critic model
d = env.observation_space.shape[0]
n_a = env.action_space.n
hidden_dims = 16

# feature_config = Mlp.config_cls(
#     input_dims=d, 
#     output_dims=hidden_dims, 
#     hidden_dims=hidden_dims, 
#     n_layers=3, 
#     activation="silu",
# )

actor_config = Mlp.config_cls(
    input_dims=d, 
    output_dims=n_a, 
    hidden_dims=hidden_dims,
    n_layers=1, 
    activation="silu",
)
critic_config = Mlp.config_cls(
    input_dims=d, 
    output_dims=1, 
    hidden_dims=hidden_dims, 
    n_layers=1, 
    activation="silu",
)
ac_config = ActorCritic.config_cls(
    clip_epsilon=0.2, 
    value_coeff=0.5, 
    entropy_coeff=0.01
)

model = ActorCritic(
    ac_config, 
    actor_net=Mlp(actor_config), 
    critic=Mlp(critic_config), 
    # feature_extractor=Mlp(feature_config)
    )


config = RlModule.config_class(lr=3e-4)
module = RlModule(actor_critic=model, env=env, config=config)

logger = TensorBoardLogger("lightning_logs", name="lunar_lander")
trainer = pl.Trainer(max_epochs=100, logger=logger)

trainer.fit(module)
# module

In [None]:
# Initialise the environment
env = gym.make("LunarLander-v3", render_mode="human")

# Reset the environment to generate the first observation
observation, info = env.reset(seed=42)
for _ in range(5000):
    # this is where you would insert your policy
    action = module.predict(observation)

    # step (transition) through the environment with the action
    # receiving the next observation, reward and if the episode has terminated or truncated
    observation, reward, terminated, truncated, info = env.step(action)

    # If the episode has ended then we can reset to start a new episode
    if terminated or truncated:
        observation, info = env.reset()

env.close()

In [None]:
# !pip install -e ~/Projects/thread_the_needle/
import thread_the_needle as ttn
ttn.make('thread_the_needle')