In [1]:
import optax
from jax_policy_grad import MLP, optimize_policy_gradient, plot_rewards

import jax.numpy as jnp
from jax import Array


def noncausal_loss(policy: MLP, s: Array, a: Array, r: Array):
    return optax.softmax_cross_entropy_with_integer_labels(policy(s), a).sum() * r.sum()


def causal_loss(policy: MLP, s: Array, a: Array, r: Array):
    q = jnp.cumsum(r[::-1], axis=0)[::-1]
    return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(policy(s), a) * q)


def causal_loss_with_baseline(policy: MLP, s: Array, a: Array, r: Array):
    q = jnp.cumsum(r[::-1], axis=0)[::-1]
    return jnp.sum(
        optax.softmax_cross_entropy_with_integer_labels(policy(s), a) * (q - q.mean())
    )


noncausal_rewards = optimize_policy_gradient(noncausal_loss, n_iters=750)
causal_rewards = optimize_policy_gradient(causal_loss, n_iters=750)
causal_baseline_rewards = optimize_policy_gradient(causal_loss, n_iters=750)

plot_rewards(
    {
        "Noncausal": noncausal_rewards,
        "Causal": causal_rewards,
        "Causal Baseline": causal_baseline_rewards,
    }
)

Optimizing Policy:   0%|          | 0/750 [00:00<?, ?it/s]


TypeError: unhashable type: 'jaxlib.xla_extension.ArrayImpl'