In [None]:
from dataclasses import dataclass
from typing import Any, Union, Tuple, Callable, Optional
from functools import partial
import time

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

from qdax import environments, environments_v1

import pickle
from optax import exponential_decay
from IPython.display import HTML
from brax.io import html
import os
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import functools
from typing import Any, Callable

from absl import logging
import flax
from flax import linen as nn
import agent
import models
import test_episodes
from flax.metrics import tensorboard
from flax.training import checkpoints
from flax.training import train_state
import jax
import jax.numpy as jnp
import ml_collections
import numpy as np
import optax

In [None]:
EPS = 1e-8
class MLP(nn.Module):
    """MCPG MLP module"""
    hidden_layers_size: Tuple[int, ...]
    action_size: int
    activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
    bias_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.zeros
    hidden_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    mean_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    
    def setup(self):
        self.hidden_layers = [nn.Dense(features, kernel_init=self.hidden_init, bias_init=self.bias_init) for features in self.hidden_layers_size]
        self.mean = nn.Dense(self.action_size, kernel_init=self.mean_init, bias_init=self.bias_init)
        self.log_std = self.param("log_std", lambda _, shape: jnp.log(0.5)*jnp.ones(shape), (self.action_size,))
        
    def distribution_params(self, obs: jnp.ndarray):
        hidden = obs
        for hidden_layer in self.hidden_layers:
            hidden = self.activation(hidden_layer(hidden))
            
        mean = self.mean(hidden)
        log_std = self.log_std
        std = jnp.exp(log_std)
        
        return mean, log_std, std
    
    def logp(self, obs: jnp.ndarray, action: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        logp = jax.scipy.stats.norm.logpdf(action, mean, std)
        return logp.sum(axis=-1)
    
    def __call__(self, random_key: Any, obs: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        
        # Sample action
        rnd = jax.random.normal(random_key, shape = (self.action_size,))
        action = jax.lax.stop_gradient(mean + rnd * std)
        
        logp = jnp.sum(jax.scipy.stats.norm.logpdf(action, mean, std), axis=-1) 
                
        return action, logp
    

In [None]:
@dataclass
class Config:
    batch_size: int = 64 
    learning_rate: float = 3e-4
    discount_rate: float = 0.99
    clip_param: float = 0.2
    grad_steps: int = 10
    
class MCPG:
    
    def __init__(self, config, policy, env):
        self._config = config
        self._policy = policy
        self._env = env
        
    def init(self, random_key):
        random_key_1, random_key_2 = jax.random.split(random_key)
        fake_obs = jnp.zeros(shape=(self._env.observation_size,))
        params = self._policy.init(random_key_1, random_key_2, fake_obs)
        tx = optax.adam(learning_rate=self._config.learning_rate)
        
        return TrainState.create(apply_fn=self._policy.apply, params=params, tx=tx)
    
    @partial(jax.jit, static_argnums=('self',))
    def logp_fn(self, params, obs, action):
        return self._policy.apply(params, obs, action, method=self._policy.logp)
    
    @partial(jax.jit, static_argnums=('self',))
    def sample_step(self, random_key, train_state, env_state):
        """Samples one step in the environment and returns the next state, action and 
        log-prob of the action.
        """
        
        action, action_logp = train_state.apply_fn(train_state.params, random_key, env_state.obs)
        
        next_env_state = self._env.step(env_state, action)
        
        return next_env_state, action, action_logp
    
    @partial(jax.jit, static_argnums=('self',))
    def sample_trajectory(self, random_key, train_state):
        """Samples a full trajectory using the environment and policy."""
        random_keys = jax.random.split(random_key, self._env.episode_length+1)
        env_state_init = self._env.reset(random_keys[-1])
        
        def _scan_sample_step(carry, x):
            (train_state, env_state,) = carry
            (random_key, ) = x
            
            next_env_state, action, action_logp = self.sample_step(random_key, train_state, env_state)
            return (train_state, next_env_state), (env_state.obs, action, action_logp, next_env_state.reward, env_state.done, env_state.info["state_descriptor"])
        
        _, (obs, action, action_logp, reward, done, state_desc) = jax.lax.scan(
            _scan_sample_step, 
            (train_state, env_state_init), 
            (random_keys[:self._env.episode_length],),
            length=self._env.episode_length,
            )
        
        mask = 1. - jnp.clip(jnp.cumsum(done), a_min=0., a_max=1.)
        
        return obs, action, action_logp, reward, state_desc, mask
    
    
    @partial(jax.jit, static_argnums=('self',))
    def get_return(self, reward):
        """ Computes the discounted return for each step in the trajectory.
        """
        
        def _body(carry, x):
            (next_return,) = carry
            (reward,) = x
            
            current_return = reward + self._config.discount_rate * next_return
            return (current_return,), (current_return,)
        
        _, (return_,) = jax.lax.scan(
            _body,
            (jnp.array(0.),),
            (reward,),
            length=self._env.episode_length,
            reverse=True,
            )
            
        return return_
    
    