## Imports

import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
import numpy as np
import functools

from flax.training import train_state

## Actor-Critic Architecture

class ActorCritic(nn.Module):
    action_dim: int

    @nn.compact
    def __call__(self, state):
    x = nn.Dense(256)(state)
    x = nn.relu(x)
    x = nn.Dense(256)(x)
    x = nn.relu(x)

    logits = nn.Dense(self.action_dim)(x)
    value = nn.Dense(1)(x)

    return logits, value

## Sample Actions

In [None]:
def sample_actions(rng, logits):
    log_probs = jax.nn.log_softmax(logits)
    return jax.random.categorical(rng, logits), log_probs

## Compute Advantages and Reward-To-Go

In [None]:
def compute_advantages(rewards, values, dones, last_value, gamma):
    returns = []
    discounted_reward = 0
    for i in reversed(len(rewards)):
        discounted_reward = rewards[i] + gamma * discounted_reward * dones[i]
        returns.insert(0, discounted_reward)
    returns = np.array(returns)
    advantages = returns - values
    return advantages, returns


## Create Train State

In [None]:
def create_train_state(rng, model, input_dim, action_dim, learning_rate):
    dummy_input = jnp.ones((1, input_dim))
    params = model.init(rng, dummy_input)
    tx = optax.adam(learning_rate)
    return train_state.TrainState(
        apply_fn = model.apply,
        params = params,
        tx = tx
    )

## Compute Loss

In [None]:
@functools.partial(jax.jit, static_argnums = (5, 6, 7))
def train_step(
    state,
    states,
    actions,
    old_log_probs,
    advantages,
    returns,
    clip_eps,
    max_grad_norm
):
    def loss_fn(params):
        mean, values = state.apply_fn(params, states)
        log_probs = jax.nn.log_softmax(mean)
        
        ratio = jnp.exp(log_probs - old_log_probs)
        clipped_ratio = jnp.clip(ratio, 1 - clip_eps, 1 + clip + eps)
        loss1 = ratio * advantages
        loss2 = clipped_ratio * advantages
        policy_loss = -jnp.mean(jnp.minimum(loss1, loss2))

        value_loss = jnp.mean((values.squeeze() - returns) ** 2)

        return policy_loss + value_loss
    
    grad_fn = jax.value_and_grad(loss_fn)
    _, grads = grad_fn(state.params)
    grads, _ = optax.clip_by_global_norm(grads, max_grad_norm)
    new_state = state.apply_gradients(grads)
    return new_state

## Update Function

In [None]:
def update_ppo(state, obs, batch_size, num_minibatches, clip_eps, max_grad_norm):
    indices = jnp.arange(len(obs["states"]))
    indices = jax.random.permutation(jax.random.PRNGKey(42), indices)

    for _ in range(num_minibatches):
        for i in range(0, len(indices), batch_size):
            mb_indices = indices[i: i + batch_size]
            mb_states = obs["states"][mb_indices]
            mb_actions = obs["actions"][mb_indices]
            mb_old_logprobs = obs["log_probs"][mb_indices]
            mb_advantages = obs["advantages"][mb_indices]
            mb_returns = obs["returns"][mb_indices]

            mb_advantages = (mb_advantages - jnp.mean(advantages)) / jnp.std(advantages) + 1e-8
            state = train_step(state, mb_states, mb_actions, mb_old_logprobs, mb_advantages, mb_returns, clip_eps, max_grad_norm)
    return state

## Collect Trajectories

In [None]:
def collect_trajectories(env, state, rng, steps_per_epoch, gamma):
    buffer = {
        "states": [],
        "actions": [],
        "rewards": [],
        "dones": [],
        "values": [],
        "log_probs": []
    }
    obs = env.reset()
    done = False
    epsiode_return = 0
    episode_length = 0
    for _ in range(steps_per_epoch):
        rng, actions_rng = jax.random.split(rng)
        mean, value = state.apply_fn(state.params, jnp.array([obs]))
        action, log_prob = sample_actions(action_rng, mean)

        next_obs, reward, done, _ = env.step(np.array(action))

        buffer["states"].append(obs)
        buffer["actions"].append(action)
        buffer["rewards"].append(reward)
        buffer["dones"].append(done)
        buffer["values"].append(value[0, 0])
        buffer["log_probs"].append(log_prob)

        episode_return += reward
        episode_length += 1

        if done:
            obs = env.reset()
            done = False
            episode_return = 0
            episode_length = 0
        else:
            obs = next_obs
        
        observations = {}
        for key, val in buffer.items():
            observations[key] = jnp.array(val)
        
        last_value = state.apply_fn(state.params, jnp.array([obs]))[2][0, 0]
        advantages, returns = compute_advantages(
            observations["rewards"],
            observations["values"],
            observations["dones"],
            last_value,
            gamma
        )
        observations["advantages"] = advantages
        observations["returns"] = returns
        return observations, rng


## Train

In [None]:
def train(
    env,
    seed,
    num_epochs,
    steps_per_epoch,
    batch_size,
    num_minibatches,
    gamma,
    clip_eps,
    learning_rate,
    max_grad_norm
):
    dummy_state = env.reset()
    input_dim = dummy_state.shape
    action_dim = env.action_dim()

    rng = jax.random.PRNGKey(seed)
    rng, actor_rng = jax.random.split(rng)
    actor_critic = ActorCritic(action_dim)
    state = create_train_state(rng, actor_critic, input_dim, action_dim, learning_rate)

    for epoch in range(num_epochs):
        obs, rng = collect_trajectories(
            env, state, rng, steps_per_epoch, gamma
        )
        state = update_ppo(state, obs, batch_size, num_minibatches, clip_eps, max_grad_norm)

        if epoch % 20 == 0:
            eval_returns, sample = evaluate_policy(env, state, rng, 1)
            print("Eval return {eval_returns}, Sample {sample}")
    return state

## Evaluate

In [None]:
def evaluate_policy(env, state, rng, evals):
    returns = []
    samples = []
    for _ in range(evals):
        obs = env.reset()
        done = False
        episode_return = 0
        while not done:
            mean, _ = state.apply_fn(state.params, jnp.array([obs]))
            action = np.array(sample_actions(rng, mean))
            obs, reward, done, _ = env.step(action)
            episode_return += reward
        sample = env.get_sample()
        returns.append(episode_return)
        samples.append(sample)
    return np.mean(np.array(returns)), samples[-1]


## Environment

In [None]:
class Environment:

    def __init__(self):
        pass

## Main Runner

In [None]:
if __name__ == "__main__":
    env = Environment()
    state = train(
        env,
        42,
        200,
        2048,
        64,
        64,
        0.99,
        0.2,
        3e-4,
        0.5
    )

### Things to complete
- [x] discrete actions
- [x] advantage computation
- [ ] environment