In [1]:
import torch as th
from torch.optim import Adam

import gymnasium as gym

from gumbo.env.gym import TorchEnv
from gumbo.data.buffer import EpisodicBuffer
from gumbo.data.sampler import BatchSampler

from gumbo.modules.core import MLP
from gumbo.modules.encoder import IdentityEncoder, FlattenEncoder
from gumbo.modules.policy import DiagonalGaussianPolicy, CategoricalPolicy
from gumbo.modules.operator import ValueOperator

from gumbo.data.collector import Collector

from gumbo.optimizer import Optimizer

from gumbo.learn import PPO
from gumbo.learn import Trainer

In [2]:
env = gym.make("LunarLander-v2")
env = TorchEnv(env)

In [3]:
data = EpisodicBuffer.from_spec(env.env_spec, size=2048)

In [4]:
policy = CategoricalPolicy(
    encoder=FlattenEncoder(),
    body=MLP()
).build(env.obs_spec, env.act_spec)

In [5]:
collector = Collector(env, policy, data)

In [6]:
critic = ValueOperator(
    encoder=FlattenEncoder(),
    body=MLP()
).build(env.obs_spec)

In [7]:
sampler = BatchSampler(64, 10)

In [8]:
optimizer = Optimizer(Adam)

In [9]:
ppo = PPO(policy, critic, sampler, optimizer)

In [10]:
trainer = Trainer(collector, ppo)

In [11]:
trainer.train(1000000)

mean_ep_rew: -205.14696533339364, mean_ep_len: 97.52380952380952
{'policy_loss': -0.012120097875595093, 'critic_loss': 841.9310913085938, 'entropy_loss': -0.0, 'approx_kl': -0.028607148677110672}
mean_ep_rew: -135.62566028941762, mean_ep_len: 93.0909090909091
{'policy_loss': -0.023400068283081055, 'critic_loss': 334.25970458984375, 'entropy_loss': -0.0, 'approx_kl': -0.02868829108774662}
mean_ep_rew: -168.48262977600098, mean_ep_len: 93.0909090909091
{'policy_loss': 0.04469699785113335, 'critic_loss': 189.08457946777344, 'entropy_loss': -0.0, 'approx_kl': -0.008667952381074429}
mean_ep_rew: -149.51039106195623, mean_ep_len: 93.0909090909091
{'policy_loss': -0.02336733788251877, 'critic_loss': 254.96791076660156, 'entropy_loss': -0.0, 'approx_kl': -0.016471894457936287}
mean_ep_rew: -138.15854091644286, mean_ep_len: 102.4
{'policy_loss': -0.027029041200876236, 'critic_loss': 242.38174438476562, 'entropy_loss': -0.0, 'approx_kl': -0.016380630433559418}
mean_ep_rew: -85.69680507805036, me

KeyboardInterrupt: 