In [1]:
import torch as th
import torch.nn as nn
from torch.optim import Adam

import gymnasium as gym

from jarl.modules.mlp import MLP
from jarl.envs.gym import TorchGymEnv
from jarl.modules.operator import Critic
from jarl.modules.encoder import FlattenEncoder
from jarl.modules.policy import DiagonalGaussianPolicy

from jarl.train.optimizer import Optimizer
from jarl.train.graph.graph import TrainGraph
from jarl.train.sample.base import BatchSampler
from jarl.train.update.policy import ClippedPolicyUpdate
from jarl.train.update.critic import MSECriticUpdate

from jarl.train.modify.compute import (
    ComputeValues,
    ComputeAdvantages,
    ComputeReturns
)

In [2]:
env = gym.make('BipedalWalker-v3')
env = TorchGymEnv(env)

In [3]:
policy = DiagonalGaussianPolicy(
    head=FlattenEncoder(),
    body=MLP(func=nn.ReLU, dims=[64, 32, 32, 16])
).build(env)

critic = Critic(
    head=FlattenEncoder(), 
    body=MLP(func=nn.ReLU, dims=[64, 32, 32, 16]),
).build(env)

In [4]:
ppo_block = (
    TrainGraph(
        BatchSampler(64),
        ClippedPolicyUpdate(
            1024, policy, optimizer=Optimizer(Adam)
        ), 
        MSECriticUpdate(
            2048, critic, optimizer=Optimizer(Adam)
        )
    )
    .add_modifier(ComputeAdvantages())
    .add_modifier(ComputeReturns())
    .add_modifier(ComputeValues(critic))
    .compile()
)

In [7]:
ppo_block.ready(1024)
ppo_block.active_dep

[<jarl.train.modify.compute.ComputeValues at 0x158d8ced0>,
 <jarl.train.modify.compute.ComputeAdvantages at 0x16632b390>]