In [17]:
import jax
import flax.linen as nn
import flax
import gymnasium as gym
import optax
import jax.numpy as jnp


class GenericPolicy(nn.Module):
    state_dim: int
    n_actions: int = 4
    hidden_dim: int = 64

    @nn.compact
    def __call__(self, state):
        x = nn.Dense(self.hidden_dim)(state)
        x = nn.relu(x)
        x = nn.Dense(self.n_actions)(x)
        return x


class ValueFunction(nn.Module):
    state_dim: int
    hidden_dim: int = 64

    @nn.compact
    def __call__(self, state):
        x = nn.Dense(self.hidden_dim)(state)
        x = nn.relu(x)
        x = nn.Dense(1)(x)
        return x


def policy_loss(params, policy, logits, states, actions, advantages, eps=0.3):
    current_logits = nn.log_softmax(policy.apply(params, states))
    indexed_logits = jnp.take_along_axis(current_logits, jnp.expand_dims(actions, 1), axis=1)
    ratios = jnp.exp(indexed_logits - logits)
    unclipped_objective = ratios * advantages
    clipped_objective = jnp.where(advantages >= 0, 
                                  (1 + eps) * advantages, 
                                  (1 - eps) * advantages)
    loss = -jnp.mean(jnp.minimum(unclipped_objective, clipped_objective))
    return loss


def value_function_loss( params, value_fn, state, reward):
    advantages = jnp.expand_dims(reward,1) - value_fn.apply(params, state)
    return jnp.mean(advantages**2), advantages


In [18]:
env = "CartPole-v1" 
train_steps=100
lr=1e-3
max_ep_len=100
n_rollouts=10
df=0.99

env = gym.make(env)
# initialize the model and optimizer
policy = GenericPolicy(env.observation_space.shape[0], env.action_space.n)
value_fn = ValueFunction(env.observation_space.shape[0])
policy_optimizer = optax.adam(learning_rate=lr)
value_optimizer = optax.adam(learning_rate=lr)

policy_params = policy.init(
    jax.random.PRNGKey(0), jax.numpy.zeros((1, env.observation_space.shape[0]))
)

value_params = value_fn.init(
    jax.random.PRNGKey(0), jax.numpy.zeros((1, env.observation_space.shape[0]))
)
policy_optimizer_state = policy_optimizer.init(policy_params)
value_optimizer_state = value_optimizer.init(value_params)

value_grad_fn = jax.value_and_grad(value_function_loss, has_aux=True)
policy_grad_fn = jax.value_and_grad(policy_loss)
key = jax.random.PRNGKey(0)
# create gym environment
for step in range(train_steps):
    # collect data
    data = []
    total_reward = 0
    for _ in range(n_rollouts):
        ep = []
        state,_ = env.reset()
        state = jnp.array(state)
        for _ in range(max_ep_len):
            logits = nn.log_softmax(policy.apply(policy_params, state))
            key, subkey = jax.random.split(key)
            action = jax.random.categorical(subkey, logits)
            action_prob = logits[action]
            next_state, reward, done, _,_ = env.step(action.item())
            next_state = jnp.array(next_state)
            # convert above line to dict
            total_reward += reward
            ep.append(
                {
                    "state": state,
                    "action": action,
                    "action_prob": action_prob,
                    "reward": reward,
                    "next_state": next_state,
                    "done": done,
                }
            )
            if done:
                break
            state = next_state
        # use discounting to attribute rewards for episodes
        for tm1, t in zip(ep[:-1][:-1:-1], ep[1:][::-1]):
            tm1["reward"] += df * t["reward"]
        data.extend(ep)
    print(f"Step: {step}, Average Episode Reward: {total_reward/n_rollouts}")
    # batch, compute advantage, and update value function
    states = jax.numpy.array([d["state"] for d in data])
    rewards = jax.numpy.array([d["reward"] for d in data])
    actions = jax.numpy.array([d["action"] for d in data])
    logits = jax.numpy.array([d["action_prob"] for d in data])
    # call jax value and grad to get advantadges and gradients+loss
    (value_loss, advantadges), value_grads = value_grad_fn(
        value_params, value_fn, states, rewards
    )

    # apply gradients to value function
    value_updates, value_optimizer_state = value_optimizer.update(value_grads, value_optimizer_state)
    value_params = optax.apply_updates(value_params, value_updates)

    # update policy
    policy_loss, policy_grads = policy_grad_fn(
        policy_params, policy, logits, states, actions, advantadges
    )
    policy_updates, policy_optimizer_state = policy_optimizer.update(policy_grads, policy_optimizer_state)
    policy_params = optax.apply_updates(policy_params, policy_updates)
    print(f"Step: {step}, Policy Loss: {policy_loss}, Value Loss: {value_loss}")

Step: 0, Average Episode Reward: 26.0
Step: 0, Policy Loss: -0.8916085958480835, Value Loss: 0.8428187966346741
Step: 1, Average Episode Reward: 17.9
Step: 1, Policy Loss: -0.8510328531265259, Value Loss: 0.7779422998428345
Step: 2, Average Episode Reward: 22.0
Step: 2, Policy Loss: -0.8632423877716064, Value Loss: 0.7820778489112854
Step: 3, Average Episode Reward: 25.5
Step: 3, Policy Loss: -0.8581407070159912, Value Loss: 0.7659480571746826
Step: 4, Average Episode Reward: 30.5
Step: 4, Policy Loss: -0.834945023059845, Value Loss: 0.7442376613616943
Step: 5, Average Episode Reward: 23.3
Step: 5, Policy Loss: -0.83938068151474, Value Loss: 0.7518985867500305
Step: 6, Average Episode Reward: 32.5
Step: 6, Policy Loss: -0.791509747505188, Value Loss: 0.6944068074226379
Step: 7, Average Episode Reward: 27.0
Step: 7, Policy Loss: -0.7432982921600342, Value Loss: 0.6264093518257141
Step: 8, Average Episode Reward: 20.1
Step: 8, Policy Loss: -0.7705106139183044, Value Loss: 0.6543177366256

In [9]:
advantadges.shape

(179, 179)

In [60]:
action

Array(2, dtype=int32)

In [51]:
subkey

Array([2718843009, 1272950319], dtype=uint32)

In [61]:
logits

Array([-1.473531 , -1.5424082, -1.3628829, -1.2003208], dtype=float32)