In [1]:
!nvidia-smi

Mon Jul 29 13:58:27 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   28C    P0              25W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
from dataclasses import dataclass
from typing import Any, Union, Tuple, Callable, Optional
from functools import partial
import time

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

from qdax import environments, environments_v1
from jax import random
import wandb

import pickle
from optax import exponential_decay
from IPython.display import HTML
from brax.io import html
import os
import jax.debug
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [4]:
EPS = 1e-8


class ValueNet(nn.Module):
    """MCPG MLP module"""
    hidden_layers_size: Tuple[int, ...]
    activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
    bias_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.zeros
    hidden_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    value_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    
    def setup(self):
        self.hidden_layers = [nn.Dense(features, kernel_init=self.hidden_init, bias_init=self.bias_init) for features in self.hidden_layers_size]
        self.value = nn.Dense(1, kernel_init=self.value_init, bias_init=self.bias_init)
        
    def __call__(self, obs: jnp.ndarray):
        hidden = obs
        for hidden_layer in self.hidden_layers:
            hidden = self.activation(hidden_layer(hidden))
            
        value = self.value(hidden)
        
        return value

class MLP(nn.Module):
    """MCPG MLP module"""
    hidden_layers_size: Tuple[int, ...]
    action_size: int
    activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
    bias_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.zeros
    hidden_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    mean_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    
    def setup(self):
        self.hidden_layers = [nn.Dense(features, kernel_init=self.hidden_init, bias_init=self.bias_init) for features in self.hidden_layers_size]
        self.mean = nn.Dense(self.action_size, kernel_init=self.mean_init, bias_init=self.bias_init)
        self.log_std = self.param("log_std", lambda _, shape: jnp.log(0.5)*jnp.ones(shape), (self.action_size,))
        
    def distribution_params(self, obs: jnp.ndarray):
        hidden = obs
        for hidden_layer in self.hidden_layers:
            hidden = self.activation(hidden_layer(hidden))
            
        mean = self.mean(hidden)
        log_std = self.log_std
        std = jnp.exp(log_std)
        
        return mean, log_std, std
    
    def logp(self, obs: jnp.ndarray, action: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        logp = jax.scipy.stats.norm.logpdf(action, mean, std)
        return logp.sum(axis=-1)
    
    def __call__(self, random_key: Any, obs: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        
        # Sample action
        rnd = jax.random.normal(random_key, shape = (self.action_size,))
        action = jax.lax.stop_gradient(mean + rnd * std)
        
        logp = jnp.sum(jax.scipy.stats.norm.logpdf(action, mean, std), axis=-1) 
        
        return action, logp
    

In [5]:
@dataclass
class Config:
    no_agents: int = 2048
    batch_size: int = 2048 * 10
    mini_batch_size: int = 640
    no_epochs: int = 4
    learning_rate: float = 3e-4
    discount_rate: float = 0.99
    clip_param: float = 0.2
    vf_coef: float = 0.5
    gae_lambda: float = 0.95
    env_name: str = "walker2d_uni"
    
class MCPG:
    
    def __init__(self, config, policy, val_net, env):
        self._config = config
        self._policy = policy
        self._value = val_net
        self._env = env
        
    def init(self, random_key):
        random_key_1, random_key_2, random_key_3 = jax.random.split(random_key, 3)
        fake_obs = jnp.zeros(shape=(self._env.observation_size,))
        policy_params = self._policy.init(random_key_1, random_key_2, fake_obs)
        tx = optax.adam(learning_rate=self._config.learning_rate, eps=1e-5)
        value_params = self._value.init(random_key_3, fake_obs)
        train_state_policy = TrainState.create(apply_fn=self._policy.apply, params=policy_params, tx=tx)
        train_state_value = TrainState.create(apply_fn=self._value.apply, params=value_params, tx=tx)
        
        return train_state_policy, train_state_value
    
    @partial(jax.jit, static_argnames=("self",))
    def logp_fn(self, params, obs, action):
        return self._policy.apply(params, obs, action, method=self._policy.logp)
    
    
    @partial(jax.jit, static_argnames=("self",))
    def sample_step(self, random_key, train_state_policy, train_state_value, env_state):
        """Samples one step in the environment and returns the next state, action and 
        log-prob of the action.
        """
            # Reset the environment if done is True
            
        #key, subkey = jax.random.split(random_key)
        
        #key, subkey = jax.random.split(random_key)

        # Use jax.lax.cond to conditionally print

        


        action, action_logp = train_state_policy.apply_fn(train_state_policy.params, random_key, env_state.obs)
        value = train_state_value.apply_fn(train_state_value.params, env_state.obs)
        
        next_env_state = self._env.step(env_state, action)
        
        return next_env_state, action, action_logp, value
    

    
    @partial(jax.jit, static_argnames=("self","evaluate"))
    def sample_trajectory(self, random_key, train_state_policy, train_state_value, evaluate=False):
        """Samples a full trajectory using the environment and policy."""
        if evaluate:
            length = self._env.episode_length
        else:
            length = int(self._config.batch_size / self._config.no_agents) + 1
            
        random_keys = jax.random.split(random_key, length+1)
        env_state_init = self._env.reset(random_keys[-1])
        
        mean = 0.0
        SSD = 0.0
        count = 1
        
        prev_R = 0
        def _scan_sample_step(carry, x):
            (train_state_policy, train_state_value, env_state, prev_R, mean, SSD, count) = carry
            (random_key,) = x
            
            next_env_state, action, action_logp, value = self.sample_step(random_key, train_state_policy, train_state_value, env_state)
            R = next_env_state.reward + self._config.discount_rate * prev_R
            prev_R = R
        
            new_mean = mean + (R - mean) / (count + 1)
            new_SSD = SSD + (R - mean) * (R - new_mean)
            
            std_dev = jnp.sqrt(new_SSD / (count + 1))
            
            normalized_reward = next_env_state.reward / (std_dev + EPS)
            actual_reward = next_env_state.reward

            # Calculate the standard deviation of the dynamically sliced rewards
            
            return (train_state_policy, train_state_value, next_env_state, prev_R, new_mean, new_SSD, count+1), (env_state.obs, action, action_logp, value, normalized_reward, actual_reward, env_state.done, env_state.info["state_descriptor"])
        
        _, (obs, action, action_logp, values, normalized_reward, actual_reward, done, state_desc) = jax.lax.scan(
            _scan_sample_step, 
            (train_state_policy, train_state_value, env_state_init, prev_R, mean, SSD, count), 
            (random_keys[:length],),
            length=length,
            )
        
        #jax.debug.print("done: {}", jnp.sum(done))
        if evaluate:
            mask = 1. - jnp.clip(jnp.cumsum(done), a_min=0., a_max=1.)
        else:
            mask = 1. - jnp.clip(jnp.cumsum(done), a_min=0., a_max=1.)
            #mask = 1. - done
        #jax.debug.print("Mask: {}", jnp.sum(mask))
        '''
        next_action = train_state_policy.apply_fn(train_state_policy.params, random_keys[-2], obs[-1])
        next_state = self._env.step(final_env_state, next_action)
        next_obs = next_state.obs
        next_done = next_state.done
        next_value = train_state_value.apply_fn(train_state_value.params, next_obs)
        '''
        
        #jax.debug.print("Reward shape: {}", reward[:-1].shape)
        
        
        return obs[:-1], action[:-1], action_logp[:-1], values[:-1], normalized_reward[:-1], actual_reward[:-1], state_desc[:-1], mask[:-1], values[-1], 1. - mask[-1]
        
        
    
    @partial(jax.jit, static_argnames=("self",))
    def get_return(self, reward):
        """ Computes the discounted return for each step in the trajectory.
        """
        
        def _body(carry, x):
            (next_return,) = carry
            (reward,) = x
            current_return = reward + self._config.discount_rate * next_return
            return (current_return,), (current_return,)
        
        _, (return_,) = jax.lax.scan(
            _body,
            (jnp.array(0.),),
            (reward,),
            length=int(self._config.batch_size / self._config.no_agents),
            reverse=True,
            )
            
        return return_
    
    @partial(jax.jit, static_argnames=("self",))
    def standardize(self, return_):
        return jax.nn.standardize(return_, axis=0, variance=1, epsilon=EPS)
    
    @partial(jax.jit, static_argnames=("self",))
    def get_return_standardize(self, reward, mask):
        """Standardizes the return values for stability in training
        """
        return_ = jax.vmap(self.get_return)(reward * mask)
        return self.standardize(return_)
    
    @partial(jax.jit, static_argnames=("self",))
    def loss_rein(self, params, obs, action, mask, return_standardized):
        """ REINFORCE loss function.
        """
        logp_ = self.logp_fn(params, jax.lax.stop_gradient(obs), jax.lax.stop_gradient(action))
        return -jnp.mean(jnp.multiply(logp_ * mask, jax.lax.stop_gradient(return_standardized)))
    
    @partial(jax.jit, static_argnames=("self",))
    def pg_loss(self, params, obs, action, logp, mask, advantages):
        """ PPO loss function.
        """
        
        logp_ = self.logp_fn(params, jax.lax.stop_gradient(obs), jax.lax.stop_gradient(action))
        ratio = jnp.exp(logp_ - jax.lax.stop_gradient(logp))
        
        pg_loss_1 = jnp.multiply(ratio, jax.lax.stop_gradient(advantages))
        pg_loss_2 = jax.lax.stop_gradient(advantages) * jax.lax.clamp(1. - self._config.clip_param, ratio, 1. + self._config.clip_param)
        
        # change the normalizer constant later
        return jnp.mean(jnp.maximum(-pg_loss_1, -pg_loss_2))
    
    
    # let's do the unclipped version first
    @partial(jax.jit, static_argnames=("self",))
    def value_loss(self, params, train_state_value, obs, returns):
        value = train_state_value.apply_fn(params, obs)
        return 0.5 * jnp.mean((value - returns) ** 2)
    
    @partial(jax.jit, static_argnames=("self",))
    def total_loss(self, policy_params, value_params, obs, action, logp, mask, returns, advantages, train_state_value):
        """ Total loss function.
        """
        return self.pg_loss(policy_params, obs, action, logp, mask, advantages) + self.value_loss(value_params, train_state_value, obs, returns) * self._config.vf_coef

    

    @partial(jax.jit, static_argnames=("self",))
    def flatten_trajectory(self, obs, action, logp, mask, returns, advantages):
        # Calculate the total number of elements in the combined first two dimensions
        total_elements = obs.shape[0] * obs.shape[1]
        
        new_obs_shape = (total_elements,) + obs.shape[2:]  # obs.shape[2:] should be unpacked if it's a tuple
        new_action_shape = (total_elements,) + action.shape[2:]  # Same handling as for obs
            
        # Flatten the first two dimensions
        obs = jnp.reshape(obs, new_obs_shape)
        action = jnp.reshape(action, new_action_shape)
        logp = jnp.reshape(logp, (total_elements,))
        mask = jnp.reshape(mask, (total_elements,))
        #return_standardized = jnp.reshape(return_standardized, (total_elements,))
        returns = jnp.reshape(returns, (total_elements,))
        advantages = jnp.reshape(advantages, (total_elements,))
        
        print(f"Shape of obs: {obs.shape}")
        print(f"Shape of action: {action.shape}")
        print(f"Shape of logp: {logp.shape}")
        print(f"Shape of mask: {mask.shape}")
        print(f"Shape of returns: {returns.shape}")
        print(f"Shape of advantages: {advantages.shape}")
        
                
        return obs, action, logp, mask, returns, advantages
    
    '''
    @partial(jax.jit, static_argnames=("self",))
    def compute_gae_and_returns(self, rewards, values, masks, next_value):
        #print(values.shape)
        #print()
        #print(masks.shape)
        #print()
        #print(next_value.shape)
        #print()
        #print(jnp.append(values[1:], next_value).shape)
        #print()
        #print(jnp.append(masks[1:], 1.).shape)
        
        next_value = next_value.reshape((next_value.shape[0], 1))
        values_added = jnp.concatenate((values, next_value), axis=1)
        mask_added = jnp.concatenate((masks, jnp.ones((masks.shape[0], 1))), axis=1)
        print(f"Values added: {values_added.shape}")
        print(f"Mask added: {mask_added.shape}")
        values_next = values_added * mask_added
        deltas = rewards + self._config.discount_rate * values_next[:, 1:] - values
        
        def gae_scan_fn(carry, delta_mask):
            gae, _ = carry
            delta, mask = delta_mask
            gae = delta + self._config.discount_rate * self._config.gae_lambda * mask * gae
            
            return (gae, mask), gae
        
        
        
        last_gae = deltas[-1]
        all_but_last = jnp.stack([deltas, masks], axis=-1)
        print(f"all_but_last: {all_but_last.shape}")
        
        final_advantages, _ = jax.lax.scan(
            gae_scan_fn,
            (last_gae, masks[-1]),
            all_but_last,
            reverse=True
        )
        
        advantages = jnp.append(final_advantages, last_gae)
        
        returns = advantages + values
        
        return advantages, returns
        
    '''
        
    @partial(jax.jit, static_argnames=("self",))
    def compute_gae_and_returns(self, rewards, values, masks, next_value, next_done):
        # Ensure next_value is properly shaped to be concatenated
        next_value = next_value.reshape((1,))  # Assuming next_value is a single scalar
        next_done = next_done.reshape((1,))

        # Extend values with next_value at the end for correct future value alignment
        next_values = jnp.concatenate((values[1:], next_value), axis=0)
        next_masks = jnp.concatenate((masks[1:], 1. - next_done), axis=0)
        
        # Append 1 to masks at the end to handle terminal states correctly
        #masks_extended = jnp.concatenate((masks, jnp.ones((masks.shape[0], 1))), axis=1)
        
        # Calculate deltas using the extended values and masks
        #deltas = rewards + self._config.discount_rate * values_extended[:, 1:] * masks_extended[:, 1:] - values

        # GAE calculation setup
        #last_advantage = rewards[-1] + self._config.discount_rate * next_value * masks[-1] - values[-1]
        
        print("Rewards shape:", rewards.shape)
        print("Values shape:", values.shape)
        print("Masks shape:", masks.shape)
        print("Next values shape:", next_values.shape)
        print("Next masks shape:", next_masks.shape)
        print(next_value)
        def gae_scan_fn(carry, x):
            (next_advantage,) = carry
            (reward, value, next_value, mask) = x
            
            current_delta = reward + self._config.discount_rate * next_value * mask - value
            advantage = current_delta + self._config.discount_rate * self._config.gae_lambda * mask * next_advantage
            return (advantage,), (advantage,)
            


        # Transpose deltas and masks to iterate over the second dimension

        # Perform the scan
        _, (advantages,) = jax.lax.scan(
            gae_scan_fn,
            (jnp.array(0.),),
            (rewards, values, next_values, next_masks),
            length=rewards.shape[0],
            reverse=True
        )
        
        
        print(type(advantages))
        print(type(values))
        # Calculate returns by adding values to advantages
        returns = advantages + values

        return advantages, returns
        
        
    
    @partial(jax.jit, static_argnames=("self",))
    def train_step(self, random_key, train_state_policy, train_state_value):
        # Sample trajectories
        random_keys = jax.random.split(random_key, self._config.no_agents+1)
        start_time = time.time()
        obs, action, logp, values, reward, actual_reward, _, mask, next_value, next_done = jax.vmap(self.sample_trajectory, in_axes=(0, None, None))(random_keys[:self._config.no_agents], train_state_policy, train_state_value)
        values= jnp.squeeze(values, axis=-1)
       
        
        #next_value = train_state_value.apply_fn(train_state_value.params, obs[:, -1])
        #next_value = jnp.squeeze(next_value, axis=-1)
        
        advantages, returns = jax.vmap(self.compute_gae_and_returns, in_axes=(0, 0, 0, 0, 0))(reward, values, mask, next_value, next_done)
         
        
        
        time_elapsed = time.time() - start_time
        
        # Compute standaerdized return
        #print(f"Reward before passing through the get_return_standardize{obs.shape}")
        #return_standardized = self.get_return_standardize(reward, mask)
        
       # print(f"Before flattening{obs.shape}")
        
        obs_, action_, logp_, mask_, returns_, advantages_ = self.flatten_trajectory(obs, action, logp, mask, returns, advantages)
        
        #print(f"After flattening{obs_.shape}")
        #b_inds = random.permutation(random_keys[-1], self._config.batch_size)
        
        random_keys_ = jax.random.split(random_keys[-1], self._config.no_epochs)
        
        def _scan_epoch_train(carry, x):
            (train_state_policy, train_state_value) = carry
            (random_key,) = x
            
            (train_state_policy, train_state_value), losses = self.epoch_train(random_key, train_state_policy, train_state_value, obs_, action_, logp_, mask_, returns_, advantages_)
            
            return (train_state_policy, train_state_value), losses
        
        #print("Before _scan_epoch_train")
        
        #print(type(train_state))
        (final_train_state_policy, final_train_state_value), losses = jax.lax.scan(
            _scan_epoch_train,
            (train_state_policy, train_state_value),
            (random_keys_,),
            length=self._config.no_epochs,
            )
        
        #print("After _scan_epoch_train")
        #print(type(final_train_state[0]))
        metrics = {
            "loss" : losses,
            "reward" : actual_reward * mask,
            "mask" : mask
        }
        jax.debug.print("Mean Loss: {}", jnp.mean(metrics["loss"]))
        #jax.debug.print("Reward shape: {}", metrics['reward'].shape)
        jax.debug.print("Mean Reward: {}", jnp.mean(jnp.sum(metrics["reward"], axis=-1)))
        jax.debug.print("Mean Mask: {}", jnp.mean(metrics["mask"]))
        jax.debug.print("-" * 50)
        
        return (final_train_state_policy, final_train_state_value), (metrics,)
    
    '''
    @partial(jax.jit, static_argnames=("self",))
    def epoch_train(self, random_key, train_state, obs, action, logp, mask, return_standardized):
        b_inds = random.permutation(random_key, self._config.batch_size)
        
        
        def _scan_mini_train(carry, _):
            (train_state, counter) = carry
        
            idx = b_inds[counter * self._config.mini_batch_size : (counter+1) * self._config.mini_batch_size]
            loss, grad = jax.value_and_grad(self.loss_ppo)(train_state.params, obs[idx], action[idx], logp[idx], mask[idx], return_standardized[idx])
            new_train_state = train_state.apply_gradients(grads=grad)  
            return (new_train_state, counter+1), loss
        
        final_train_state, losses = jax.lax.scan(
            _scan_mini_train,
            (train_state, 0),
            None,
            length=self._config.batch_size // self._config.mini_batch_size,
            )
        
        return (final_train_state,), (losses,)
    '''
    
    
    @partial(jax.jit, static_argnames=("self",))
    def epoch_train(self, random_key, train_state_policy, train_state_value, obs, action, logp, mask, returns, advantages):
        total_size = self._config.batch_size
        
        shuffled_indices = jax.random.permutation(random_key, total_size)
        
        num_batches = self._config.batch_size // self._config.mini_batch_size
        batch_indices = jnp.array([shuffled_indices[i * self._config.mini_batch_size:(i + 1) * self._config.mini_batch_size] for i in range(num_batches)])
        #jax.debug.print("Returns: {}", return_standardized)
        #jax.debug.print("Returns shape: {}", return_standardized.shape)
        def _scan_mini_train(carry, x):
            (train_state_policy, train_state_value, counter) = carry
            (idx,) = x
            #jax.debug.print("Returns_: {}", return_standardized[idx])
            #jax.debug.print("Returns_ shape: {}", return_standardized[idx].shape)
            #jax.debug.print("Obs: {}", obs[idx].shape)
            
            #value = train_state_value.apply_fn(train_state_value.params, obs[idx])
            '''
            loss, grad = jax.value_and_grad(self.total_loss)(train_state_policy.params, obs[idx], action[idx], logp[idx], mask[idx], value, returns[idx], advantages[idx])
            new_train_state_policy = train_state_policy.apply_gradients(grads=grad)  
            new_train_state_value = train_state_value.apply_gradients(grads=grad)
            
            return (new_train_state_policy, new_train_state_value, counter+1), loss
            '''
            #jax.debug.print("idx: {}", idx)
            loss, (policy_grads, value_grads) = jax.value_and_grad(self.total_loss, argnums=(0, 1))(train_state_policy.params, train_state_value.params, obs[idx], action[idx], logp[idx], mask[idx], returns[idx], advantages[idx], train_state_value)
            
            #loss, policy_grads = jax.value_and_grad(self.pg_loss)(train_state_policy.params, obs[idx], action[idx], logp[idx], mask[idx], advantages[idx])
            #loss_, value_grads = jax.value_and_grad(self.value_loss)(train_state_value.params, train_state_value, obs[idx], returns[idx])
            
            #policy_loss, policy_grad = jax.value_and_grad(self.pg_loss)(train_state_policy.params, obs[idx], action[idx], logp[idx], mask[idx], advantages[idx])
            #value_loss, value_grad = jax.value_and_grad(self.value_loss)(train_state_value.params, value, returns[idx])

            # Apply gradients to each network separately
            new_train_state_policy = train_state_policy.apply_gradients(grads=policy_grads)
            new_train_state_value = train_state_value.apply_gradients(grads=value_grads)
            
            return (new_train_state_policy, new_train_state_value, counter + 1), loss
        
        #print('Before _scan_mini_train')
        (final_train_state_policy, final_train_state_value, _), losses = jax.lax.scan(
            _scan_mini_train,
            (train_state_policy, train_state_value, 0),
            (batch_indices,),
            length=num_batches,
            )
        
        #print('After _scan_mini_train')
        
        #print(final_train_state)
        
        return (final_train_state_policy, final_train_state_value), losses
    
    @partial(jax.jit, static_argnames=("self", "no_steps", "eval"))
    def train(self, random_key, train_state_policy, train_state_value, no_steps, eval=False):
        """Trains the policy for a number of steps."""
        
        random_keys = jax.random.split(random_key, no_steps+1)
    

        def _scan_train_step(carry, x):
            (train_state_policy, train_state_value) = carry
            (random_key,) = x
            
            (train_state_policy, train_state_value), (metrics,) = self.train_step(random_key, train_state_policy, train_state_value)
            
            return (train_state_policy, train_state_value), (metrics,)
        
        #print('Before  _scan_train_step')
        
        (train_state_policy, train_state_value), (metrics,) = jax.lax.scan(
            _scan_train_step,
            (train_state_policy, train_state_value),
            (random_keys[:no_steps],),
            length=no_steps,
            )
        
        #print('After  _scan_train_step')
        
        if eval:
            mean_reward = self.evaluate(random_keys[-1], train_state_policy, train_state_value)
            jax.debug.print("Mean Reward over 20 episodes: {}", mean_reward)
            #return mean_reward
        
        
        return train_state_policy, train_state_value, metrics
    
    @partial(jax.jit, static_argnames=("self",))
    def evaluate(self, random_key, train_state_policy, train_state_value):
        """Evaluates the policy in the environment."""
        random_keys = jax.random.split(random_key, 20)
        
        def _scan_evaluate(carry, _):
            (train_state_policy, train_state_value) = carry
            
            _, _, _, _, _, reward, _, mask, _, _ = jax.vmap(self.sample_trajectory, in_axes=(0, None, None, None))(random_keys, train_state_policy, train_state_value, True)
            return (train_state_policy, train_state_value), (reward * mask,)
        
        (train_state_policy, train_state_value), (reward,) = jax.lax.scan(
            _scan_evaluate,
            (train_state_policy, train_state_value),
            None,
            length=20,
            )
        
        return jnp.mean(jnp.sum(reward, axis=-1))


In [6]:
config_dict = {
    "no_agents": 2048,
    "batch_size": 2048 * 10,
    "mini_batch_size": 640,
    "no_epochs": 4,
    "learning_rate": 3e-4,
    "discount_rate": 0.99,
    "clip_param": 0.2,
    "vf_coef": 0.5,
    "gae_lambda": 0.95,
    "env_name": "walker2d_uni",
}

# Initialize wandb with the configuration dictionary
wandb.init(project="mcpg", name='PPOish', config=config_dict)

env = get_env(config_dict["env_name"])


policy_hidden_layers = [64, 64]
value_hidden_layers = [64, 64]

policy = MLP(
    hidden_layers_size=policy_hidden_layers,
    action_size=env.action_size,
    activation=nn.tanh,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    mean_init=jax.nn.initializers.orthogonal(scale=0.01),
)

value_net = ValueNet(
    hidden_layers_size=value_hidden_layers,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    value_init=jax.nn.initializers.orthogonal(scale=1.),
    activation=nn.tanh,
)

agent = MCPG(Config(**wandb.config), policy, value_net, env)

random_key = jax.random.PRNGKey(0)
train_state_policy, train_state_value = agent.init(random_key)

num_steps = 200
log_period = 20

metrics_wandb = dict.fromkeys(["mean loss", "mean reward", "mask", "evaluation", 'time'], jnp.array([]))
eval_num = config_dict["no_agents"]
print(f"Number of evaluations per training step: {eval_num}")
start_time = time.time()
for i in range(num_steps // log_period):
    random_key, subkey = jax.random.split(random_key)
    train_state_policy, train_state_value, current_metrics = agent.train(subkey, train_state_policy, train_state_value, log_period, eval=True)
    timelapse = time.time() - start_time
    print(f"Step {(i+1) * log_period}, Time: {timelapse}")
    
    current_metrics["evaluation"] = jnp.arange(log_period*eval_num*(i+1), log_period*eval_num*(i+2), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    current_metrics["mean loss"] = jnp.repeat(jnp.mean(current_metrics["loss"]), log_period)
    current_metrics["mean reward"] = jnp.repeat(jnp.mean(jnp.sum(current_metrics["reward"], axis=-1)), log_period)
    current_metrics["mask"] = jnp.repeat(jnp.mean(current_metrics["mask"]), log_period)
    '''
    metrics_wandb = jax.tree_util.tree_map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics_wandb, current_metrics)
    
    log_metrics = jax.tree_util.tree_map(lambda metric: metric[-1], metrics_wandb)
    
    wandb.log(log_metrics)
    '''
    
    def update_metrics(old_metrics, new_metrics):
        updated_metrics = {}
        for key in old_metrics:
            if key in new_metrics:
                # Check if old metrics for key is empty, and initialize properly if so
                if old_metrics[key].size == 0:
                    updated_metrics[key] = new_metrics[key]
                else:
                    updated_metrics[key] = jnp.concatenate([old_metrics[key], new_metrics[key]], axis=0)
            else:
                raise KeyError(f"Key {key} not found in new metrics.")
        return updated_metrics

    # In your training loop:
    try:
        metrics_wandb = update_metrics(metrics_wandb, current_metrics)
        log_metrics = {k: v[-1] for k, v in metrics_wandb.items()}  # Assuming you want the latest entry
        wandb.log(log_metrics)
    except Exception as e:
        print(f"Error updating metrics: {e}")

    
    
    
    
    start_time = time.time()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mk_mitsides[0m ([33mmitsides[0m). Use [1m`wandb login --relogin`[0m to force relogin


Number of evaluations per training step: 2048
Rewards shape: (10,)
Values shape: (10,)
Masks shape: (10,)
Next values shape: (10,)
Next masks shape: (10,)
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=5/0)>
<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
Shape of obs: (20480, 18)
Shape of action: (20480, 6)
Shape of logp: (20480,)
Shape of mask: (20480,)
Shape of returns: (20480,)
Shape of advantages: (20480,)
Mean Loss: -1.316440224647522
Mean Reward: 9.55335807800293
Mean Mask: 1.0
--------------------------------------------------
Mean Loss: -1.3088446855545044
Mean Reward: 9.597723007202148
Mean Mask: 1.0
--------------------------------------------------
Mean Loss: -1.2663633823394775
Mean Reward: 9.662751197814941
Mean Mask: 1.0
--------------------------------------------------
Mean Loss: -1.2370920181274414
Mean Reward: 9.697147369384766
Mean Mask: 1.0
---------------------------------