In [1]:
!nvidia-smi

Sun Jul 28 23:10:06 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   30C    P0              25W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
from dataclasses import dataclass
from typing import Any, Union, Tuple, Callable, Optional
from functools import partial
import time

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

from qdax import environments, environments_v1
from jax import random
import wandb

import pickle
from optax import exponential_decay
from IPython.display import HTML
from brax.io import html
import os
import jax.debug
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax 
from flax.linen.initializers import constant, orthogonal 
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
from wrappers import (
    LogWrapper,
    BraxGymnaxWrapper,
    VecEnv,
    NormalizeVecObservation,
    NormalizeVecReward,
    ClipAction,
)

In [3]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [6]:
config_dict = {
    "no_agents": 32,
    "batch_size": 32 * 1000,
    "mini_batch_size": 32000,
    "no_epochs": 10,
    "learning_rate": 3e-4,
    "discount_rate": 0.99,
    "clip_param": 0.2,
    "vf_coef": 0.5,
    "gae_lambda": 0.95,
    "env_name": "halfcheetah_uni",
}

# Initialize wandb with the configuration dictionary
wandb.init(project="mcpg", name='PPOish', config=config_dict)

env = get_env(config_dict["env_name"])


policy_hidden_layers = [64, 64]
value_hidden_layers = [64, 64]

policy = MLP(
    hidden_layers_size=policy_hidden_layers,
    action_size=env.action_size,
    activation=nn.tanh,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    mean_init=jax.nn.initializers.orthogonal(scale=0.01),
)

value_net = ValueNet(
    hidden_layers_size=value_hidden_layers,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    value_init=jax.nn.initializers.orthogonal(scale=1.),
    activation=nn.tanh,
)

agent = MCPG(Config(**wandb.config), policy, value_net, env)

random_key = jax.random.PRNGKey(0)
train_state_policy, train_state_value = agent.init(random_key)

num_steps = 1000
log_period = 10

metrics_wandb = dict.fromkeys(["mean loss", "mean reward", "mask", "evaluation", 'time'], jnp.array([]))
eval_num = config_dict["no_agents"]
print(f"Number of evaluations per training step: {eval_num}")
start_time = time.time()
for i in range(num_steps // log_period):
    random_key, subkey = jax.random.split(random_key)
    train_state_policy, train_state_value, current_metrics = agent.train(subkey, train_state_policy, train_state_value, log_period, eval=False)
    timelapse = time.time() - start_time
    print(f"Step {(i+1) * log_period}, Time: {timelapse}")
    
    current_metrics["evaluation"] = jnp.arange(log_period*eval_num*(i+1), log_period*eval_num*(i+2), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    current_metrics["mean loss"] = jnp.repeat(jnp.mean(current_metrics["loss"]), log_period)
    current_metrics["mean reward"] = jnp.repeat(jnp.mean(jnp.sum(current_metrics["reward"], axis=-1)), log_period)
    current_metrics["mask"] = jnp.repeat(jnp.mean(current_metrics["mask"]), log_period)
    '''
    metrics_wandb = jax.tree_util.tree_map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics_wandb, current_metrics)
    
    log_metrics = jax.tree_util.tree_map(lambda metric: metric[-1], metrics_wandb)
    
    wandb.log(log_metrics)
    '''
    
    def update_metrics(old_metrics, new_metrics):
        updated_metrics = {}
        for key in old_metrics:
            if key in new_metrics:
                # Check if old metrics for key is empty, and initialize properly if so
                if old_metrics[key].size == 0:
                    updated_metrics[key] = new_metrics[key]
                else:
                    updated_metrics[key] = jnp.concatenate([old_metrics[key], new_metrics[key]], axis=0)
            else:
                raise KeyError(f"Key {key} not found in new metrics.")
        return updated_metrics

    # In your training loop:
    try:
        metrics_wandb = update_metrics(metrics_wandb, current_metrics)
        log_metrics = {k: v[-1] for k, v in metrics_wandb.items()}  # Assuming you want the latest entry
        wandb.log(log_metrics)
    except Exception as e:
        print(f"Error updating metrics: {e}")

    
    
    
    
    start_time = time.time()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mk_mitsides[0m ([33mmitsides[0m). Use [1m`wandb login --relogin`[0m to force relogin


Number of evaluations per training step: 32
Rewards shape: (1000,)
Values shape: (1000,)
Masks shape: (1000,)
Next values shape: (1000,)
Next masks shape: (1000,)
Traced<ShapedArray(float32[1])>with<DynamicJaxprTrace(level=5/0)>
<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>
Shape of obs: (32000, 19)
Shape of action: (32000, 6)
Shape of logp: (32000,)
Shape of mask: (32000,)
Shape of returns: (32000,)
Shape of advantages: (32000,)
Mean Loss: -0.3751281201839447
Mean Reward: 227.6048583984375
Mean Mask: 1.0
--------------------------------------------------
Mean Loss: -0.34584346413612366
Mean Reward: 220.39080810546875
Mean Mask: 1.0
--------------------------------------------------
Mean Loss: -0.2665552794933319
Mean Reward: 207.3115692138672
Mean Mask: 1.0
--------------------------------------------------
Mean Loss: -0.09995461255311966
Mean Reward: 208.86026000976562
Mean Mask: 1.0
-------------------