In [1]:
import torch as th
import torch.nn as nn
from torch.optim import Adam

import gymnasium as gym

from jarl.modules.mlp import MLP
from jarl.envs.gym import TorchGymEnv
from jarl.data.dict import DotDict
from jarl.data.buffer import LazyBuffer
from jarl.modules.operator import Critic
from jarl.modules.encoder import FlattenEncoder
from jarl.modules.policy import CategoricalPolicy

from jarl.train.optim import Optimizer
from jarl.train.update.critic import MSECriticUpdate
from jarl.train.update.policy import ClippedPolicyUpdate
from jarl.train.update.ppo import PPOUpdate
from jarl.train.graph import TrainGraph
from jarl.train.sample.base import BatchSampler
from jarl.modules.discriminator import Discriminator

from jarl.train.modify.compute import (
    ComputeValues,
    ComputeLogProbs,
    ComputeAdvantages,
    ComputeReturns
)

In [2]:
env = gym.make('LunarLander-v2')
env = TorchGymEnv(env)

In [3]:
policy = CategoricalPolicy(
    head=FlattenEncoder(),
    body=MLP(func=nn.Tanh, dims=[64, 64])
).build(env)

critic = Critic(
    head=FlattenEncoder(), 
    body=MLP(func=nn.Tanh, dims=[64, 64]),
).build(env)

discrim = Discriminator(
    head=FlattenEncoder(),
    body=MLP(func=nn.Tanh, dims=[64, 64])
).build(env)

In [4]:
ppo_block = (
    TrainGraph(
        BatchSampler(64, num_epoch=10),
        PPOUpdate(
            2048, policy, critic, optimizer=Optimizer(Adam, lr=3e-4), ent_coef=0.01
        )
    )
    .add_modifier(ComputeAdvantages())
    .add_modifier(ComputeLogProbs(policy))
    .add_modifier(ComputeReturns())
    .add_modifier(ComputeValues(critic))
    .compile()
)

In [5]:
import numpy as np

def get_episodic_return(data):
    don = data.don
    rew = data.rew
    ret = []
    ep_ret = 0
    for i, d in enumerate(don):
        ep_ret += rew[i].item()
        if d:
            ret.append(ep_ret)
            ep_ret = 0
    return ret

In [6]:
buffer = LazyBuffer(2048)

rews = []

obs = env.reset()
for t in range(int(1e6)):
    with th.no_grad():
        trs = DotDict(
            obs=obs, 
            act=policy(obs)
        )
    exp, obs = env.step(trs=trs)
    buffer.store(exp)

    if ppo_block.ready(t):
        data = buffer.serve()
        batch_info = ppo_block.update(data)
        rews.extend(get_episodic_return(data))
        batch_info.update()
        print(t, batch_info | dict(rew=np.mean(rews[-100:])))

2048 {'policy_loss': -0.000940173864364624, 'entropy_loss': -0.0136715704575181, 'approx_kl': -0.014608027413487434, 'critic_loss': 826.2957763671875, 'rew': -184.9752837061428}
4096 {'policy_loss': -0.034572310745716095, 'entropy_loss': -0.013655910268425941, 'approx_kl': 0.00018473342061042786, 'critic_loss': 584.3338012695312, 'rew': -164.99507585993038}
6144 {'policy_loss': -0.023270191624760628, 'entropy_loss': -0.013474492356181145, 'approx_kl': -0.005885083228349686, 'critic_loss': 414.9740905761719, 'rew': -161.4258410644741}
8192 {'policy_loss': -0.018377017229795456, 'entropy_loss': -0.013279171660542488, 'approx_kl': 0.01239684410393238, 'critic_loss': 364.78143310546875, 'rew': -150.7124396924847}


KeyboardInterrupt: 