In [1]:
from dataclasses import dataclass
from typing import Any, Union, Tuple, Callable, Optional
from functools import partial
import time

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

from qdax import environments, environments_v1
from jax import random

import pickle
from optax import exponential_decay
from IPython.display import HTML
from brax.io import html
import os
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [3]:
EPS = 1e-8
class MLP(nn.Module):
    """MCPG MLP module"""
    hidden_layers_size: Tuple[int, ...]
    action_size: int
    activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
    bias_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.zeros
    hidden_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    mean_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    
    def setup(self):
        self.hidden_layers = [nn.Dense(features, kernel_init=self.hidden_init, bias_init=self.bias_init) for features in self.hidden_layers_size]
        self.mean = nn.Dense(self.action_size, kernel_init=self.mean_init, bias_init=self.bias_init)
        self.log_std = self.param("log_std", lambda _, shape: jnp.log(0.5)*jnp.ones(shape), (self.action_size,))
        
    def distribution_params(self, obs: jnp.ndarray):
        hidden = obs
        for hidden_layer in self.hidden_layers:
            hidden = self.activation(hidden_layer(hidden))
            
        mean = self.mean(hidden)
        log_std = self.log_std
        std = jnp.exp(log_std)
        
        return mean, log_std, std
    
    def logp(self, obs: jnp.ndarray, action: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        logp = jax.scipy.stats.norm.logpdf(action, mean, std)
        return logp.sum(axis=-1)
    
    def __call__(self, random_key: Any, obs: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        
        # Sample action
        rnd = jax.random.normal(random_key, shape = (self.action_size,))
        action = jax.lax.stop_gradient(mean + rnd * std)
        
        logp = jnp.sum(jax.scipy.stats.norm.logpdf(action, mean, std), axis=-1) 
                
        return action, logp
    

In [6]:
@dataclass
class Config:
    no_agents: int = 2
    batch_size: int = 2048
    mini_batch_size: int = 64
    no_epochs: int = 10
    learning_rate: float = 3e-4
    discount_rate: float = 0.99
    clip_param: float = 0.2
    grad_steps: int = 10
    
class MCPG:
    
    def __init__(self, config, policy, env):
        self._config = config
        self._policy = policy
        self._env = env
        
    def init(self, random_key):
        random_key_1, random_key_2 = jax.random.split(random_key)
        fake_obs = jnp.zeros(shape=(self._env.observation_size,))
        params = self._policy.init(random_key_1, random_key_2, fake_obs)
        tx = optax.adam(learning_rate=self._config.learning_rate)
        
        return TrainState.create(apply_fn=self._policy.apply, params=params, tx=tx)
    
    @partial(jax.jit, static_argnames=("self",))
    def logp_fn(self, params, obs, action):
        return self._policy.apply(params, obs, action, method=self._policy.logp)
    
    @partial(jax.jit, static_argnames=("self",))
    def sample_step(self, random_key, train_state, env_state):
        """Samples one step in the environment and returns the next state, action and 
        log-prob of the action.
        """
        
        action, action_logp = train_state.apply_fn(train_state.params, random_key, env_state.obs)
        
        next_env_state = self._env.step(env_state, action)
        
        return next_env_state, action, action_logp
    
    @partial(jax.jit, static_argnames=("self",))
    def sample_trajectory(self, random_key, train_state):
        """Samples a full trajectory using the environment and policy."""
        random_keys = jax.random.split(random_key, self._env.episode_length+1)
        env_state_init = self._env.reset(random_keys[-1])
        
        def _scan_sample_step(carry, x):
            (train_state, env_state,) = carry
            (random_key, ) = x
            
            next_env_state, action, action_logp = self.sample_step(random_key, train_state, env_state)
            return (train_state, next_env_state), (env_state.obs, action, action_logp, next_env_state.reward, env_state.done, env_state.info["state_descriptor"])
        
        _, (obs, action, action_logp, reward, done, state_desc) = jax.lax.scan(
            _scan_sample_step, 
            (train_state, env_state_init), 
            (random_keys[:self._env.episode_length],),
            length=self._env.episode_length,
            )
        
        mask = 1. - jnp.clip(jnp.cumsum(done), a_min=0., a_max=1.)
        
        return obs, action, action_logp, reward, state_desc, mask
    
    
    @partial(jax.jit, static_argnames=("self",))
    def get_return(self, reward):
        """ Computes the discounted return for each step in the trajectory.
        """
        
        def _body(carry, x):
            (next_return,) = carry
            (reward,) = x
            
            current_return = reward + self._config.discount_rate * next_return
            return (current_return,), (current_return,)
        
        _, (return_,) = jax.lax.scan(
            _body,
            (jnp.array(0.),),
            (reward,),
            length=self._env.episode_length,
            reverse=True,
            )
            
        return return_
    
    @partial(jax.jit, static_argnames=("self",))
    def standardize(self, return_):
        return jax.nn.standardize(return_, axis=0, variance=1, epsilon=EPS)
    
    @partial(jax.jit, static_argnames=("self",))
    def get_return_standardize(self, reward, mask):
        """Standardizes the return values for stability in training
        """
        return_ = jax.vmap(self.get_return)(reward * mask)
        return self.standardize(return_)
    
    @partial(jax.jit, static_argnames=("self",))
    def loss_rein(self, params, obs, action, mask, return_standardized):
        """ REINFORCE loss function.
        """
        logp_ = self.logp_fn(params, jax.lax.stop_gradient(obs), jax.lax.stop_gradient(action))
        return -jnp.mean(jnp.multiply(logp_ * mask, jax.lax.stop_gradient(return_standardized)))
    
    @partial(jax.jit, static_argnames=("self",))
    def loss_ppo(self, params, obs, action, logp, mask, return_standardized):
        """ PPO loss function.
        """
        
        logp_ = self.logp_fn(params, jax.lax.stop_gradient(obs), jax.lax.stop_gradient(action))
        ratio = jnp.exp(logp_ - jax.lax.stop_gradient(logp))
        
        pg_loss_1 = jnp.multiply(ratio * mask, jax.lax.stop_gradient(return_standardized))
        pg_loss_2 = jax.lax.stop_gradient(return_standardized) * jax.lax.clamp(1. - self._config.clip_param, ratio, 1. + self._config.clip_param) * mask
        return -jnp.mean(jnp.minimum(pg_loss_1, pg_loss_2))
    

    @partial(jax.jit, static_argnames=("self",))
    def flatten_trajectory(self, obs, action, logp, mask, return_standardized):
        # Calculate the total number of elements in the combined first two dimensions
        total_elements = obs.shape[0] * obs.shape[1]
        
        # Flatten the first two dimensions
        obs = jnp.reshape(obs, (total_elements, obs.shape[2:]))
        action = jnp.reshape(action, (total_elements, action.shape[2:]))
        logp = jnp.reshape(logp, (total_elements,))
        mask = jnp.reshape(mask, (total_elements,))
        return_standardized = jnp.reshape(return_standardized, (total_elements,))
        
        return obs, action, logp, mask, return_standardized
    
    @partial(jax.jit, static_argnames=("self",))
    def train_step(self, random_key, train_state):
        # Sample trajectories
        random_keys = jax.random.split(random_key, self._config.no_agents+1)
        start_time = time.time()
        obs, action, logp, reward, _, mask = jax.vmap(self.sample_trajectory, in_axes=(0, None))(random_keys[:self._config.no_agents], train_state)
        time_elapsed = time.time() - start_time
        
        # Compute standaerdized return
        return_standardized = self.get_return_standardize(reward, mask)
        
        obs_, action_, logp_, mask_, return_standardized_ = self.flatten_trajectory(obs, action, logp, mask, return_standardized)
        #b_inds = random.permutation(random_keys[-1], self._config.batch_size)
        
        def _scan_epoch_train(carry, _):
            (train_state,) = carry
            
            (train_state,), (losses,) = self.epoch_train(random_keys[-1], train_state, obs_, action_, logp_, mask_, return_standardized_)
            
            return (train_state,), (losses,)
        
        final_train_state, losses = jax.lax.scan(
            _scan_epoch_train,
            (train_state,),
            None,
            length=self._config.no_epochs,
            )
        
        metrics = {
            "loss" : losses,
            "reward" : reward * mask,
            "mask" : mask
        }
        
        return (final_train_state,), (metrics,)
    
    
    @partial(jax.jit, static_argnames=("self",))
    def epoch_train(self, random_key, train_state, obs, action, logp, mask, return_standardized):
        b_inds = random.permutation(random_key, self._config.batch_size)
        
        
        def _scan_mini_train(carry, _):
            (train_state, counter) = carry
        
            idx = b_inds[counter * self._config.mini_batch_size : (counter+1) * self._config.mini_batch_size]
            loss, grad = jax.value_and_grad(self.loss_ppo)(train_state.params, obs[idx], action[idx], logp[idx], mask[idx], return_standardized[idx])
            new_train_state = train_state.apply_gradients(grads=grad)  
            return (new_train_state, counter+1), loss
        
        final_train_state, losses = jax.lax.scan(
            _scan_mini_train,
            (train_state, 0),
            None,
            length=self._config.batch_size // self._config.mini_batch_size,
            )
        
        return (final_train_state,), (losses,)
    
    @partial(jax.jit, static_argnames=("self", "no_steps"))
    def train(self, random_key, train_state, no_steps):
        """Trains the policy for a number of steps."""
        
        random_keys = jax.random.split(random_key, no_steps)
        
        def _scan_train_step(carry, x):
            (train_state,) = carry
            (random_key,) = x
            
            (train_state,), (metrics,) = self.train_step(random_key, train_state)
            
            return (train_state,), (metrics,)
        
        (train_state,), (metrics,) = jax.lax.scan(
            _scan_train_step,
            (train_state,),
            (random_keys,),
            length=no_steps,
            )
        
        return (train_state,), (metrics,)
    
