In [1]:
!nvidia-smi

Thu Aug  1 02:43:17 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   29C    P0              28W / 165W |   9296MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
from dataclasses import dataclass
from typing import Any, Union, Tuple, Callable, Optional
from functools import partial
import time

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

from qdax import environments, environments_v1
from jax import random
import wandb

import pickle
from optax import exponential_decay
from IPython.display import HTML
from brax.io import html
import os
import jax.debug
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax 
from flax.linen.initializers import constant, orthogonal, lecun_uniform
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
from wrappers import (
    LogWrapper,
    BraxGymnaxWrapper,
    VecEnv,
    NormalizeVecObservation,
    NormalizeVecReward,
    ClipAction,
)

In [4]:
class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"
    
    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
            
        actor_mean = nn. Dense(
            64, kernel_init = orthogonal(np.sqrt(2)), bias_init = constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        actor_logstd = self.param("log_std", lambda _, shape: jnp.log(0.5)*jnp.ones(shape), (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(loc=actor_mean, scale_diag=jnp.exp(actor_logstd))
        
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)
        
        return pi, jnp.squeeze(critic, axis=-1)
        
        

In [5]:
class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
    
def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
    env = LogWrapper(env)
    env = ClipAction(env)
    env = VecEnv(env)
    if config["NORMALIZE_ENV"]:
        env = NormalizeVecObservation(env)
        env = NormalizeVecReward(env, config["GAMMA"])
    
    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac
    
    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).shape[0], activation=config["ACTIVATION"]
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(
                    learning_rate=linear_schedule,
                    eps=1e-5,
                ),
            )
            
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(
                    learning_rate=config["LR"],
                    eps=1e-5,
                ),
            )
            
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )
        
        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)
        
        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state
                
                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)
                
                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(
                    rng_step, env_state, action, env_params,
                )    
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition
            
            runner_state, traj_batch = jax.lax.scan(
                _env_step, 
                runner_state, 
                None, 
                length=config["NUM_STEPS"]
            )
            
            # COMPUTE ADVANTAGES
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)
            
            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae
                
                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                
                return advantages, advantages + traj_batch.value
            
            advantages, targets = _calculate_gae(traj_batch, last_val)
            
            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minibatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info
                    
                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)
                        
                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = 0.5 * jnp.mean(jnp.maximum(value_losses, value_losses_clipped))
                        
                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                            ratio, 
                            1.0 - config["CLIP_EPS"], 
                            1.0 + config["CLIP_EPS"]
                            ) 
                            * gae
                        )
                        loss_actor = -jnp.mean(jnp.minimum(loss_actor1, loss_actor2))
                        entropy = jnp.mean(pi.entropy())
                        
                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)
                    
                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grad = grad_fn(
                        train_state.params, 
                        traj_batch, 
                        advantages, 
                        targets)
                    train_state = train_state.apply_gradients(grads=grad)
                    return train_state, total_loss
                
                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_ENVS"] * config["NUM_STEPS"]
                ), "batch size must be equal to number of steps * number of envs"
                
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch   
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                
                train_state, total_loss = jax.lax.scan(
                    _update_minibatch,
                    train_state,
                    minibatches,
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss
            
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch,
                update_state,
                None,
                length=config["UPDATE_EPOCHS"], 
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]
            if config.get("DEBUG"):
                
                def callback(info):
                    return_values = info["returned_episode_returns"][
                        info["returned_episode"]
                    ]
                    timesteps = (
                        info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    )
                    for t in range(len(timesteps)):
                        print(
                            f"global step={timesteps[t]}, episodic return={return_values[t]}"
                        )
                        
                jax.debug.callback(callback, metric)
                
            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric
        
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step,
            runner_state,
            None,
            length=config["NUM_UPDATES"],
        )
        return {"runner_state": runner_state, "metrics": metric, "net": network, "env": env}
    
    return train        

In [6]:
config = {
    "LR": 1e-3,
    "NUM_ENVS": 1,
    "NUM_STEPS": 20480,
    "TOTAL_TIMESTEPS": 2e6,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 32,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "walker2d",
    "ANNEAL_LR": False,
    "NORMALIZE_ENV": True,
    "DEBUG": True,
}
rng = jax.random.PRNGKey(30)
#train_jit = jax.jit(make_train(config))
#out = train_jit(rng)
train = make_train(config)
out = train(rng)

global step=252, episodic return=370.1466979980469
global step=524, episodic return=380.30853271484375
global step=853, episodic return=442.8634338378906
global step=1100, episodic return=134.17640686035156
global step=1327, episodic return=347.08917236328125
global step=1569, episodic return=130.88368225097656
global step=1750, episodic return=67.63519287109375
global step=2032, episodic return=166.4425048828125
global step=2225, episodic return=305.92864990234375
global step=2402, episodic return=293.2664794921875
global step=2658, episodic return=369.1868591308594
global step=2909, episodic return=366.9624328613281
global step=3155, episodic return=130.24945068359375
global step=3426, episodic return=388.8323974609375
global step=3721, episodic return=182.0687255859375
global step=3967, episodic return=361.1302185058594
global step=4225, episodic return=144.3124237060547
global step=4501, episodic return=169.1455841064453
global step=4707, episodic return=92.85040283203125
global st

In [7]:
import brax
from brax.io import html

In [None]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=True, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [64]:
env = BraxGymnaxWrapper(config["ENV_NAME"])
#env = BraxGymnaxWrapper('ant')
env = LogWrapper(env)
env = ClipAction(env)
env = VecEnv(env)
if config["NORMALIZE_ENV"]:
    env = NormalizeVecObservation(env)
    #env = NormalizeVecReward(env, config["GAMMA"])

In [65]:
env = out['env']

In [94]:
env.action_space(None).shape[0]

6

In [95]:
env_.action_space(None).shape[0]

6

In [96]:
env.observation_space(None).shape

(17,)

In [97]:
env_.observation_space(None).shape

(17,)

In [54]:
#env = out["env"]
net = out["net"]
params = out["runner_state"][0].params

In [66]:
#env = get_env("hopper_uni")
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

In [67]:
reset_rng = jax.random.split(rng, config["NUM_ENVS"])
obs, state = reset_fn(reset_rng, None)

In [68]:
state

NormalizeVecRewEnvState(mean=Array(0., dtype=float32, weak_type=True), var=Array(1., dtype=float32, weak_type=True), count=Array(1.e-04, dtype=float32, weak_type=True), return_val=Array([0.], dtype=float32), env_state=NormalizeVecObsEnvState(mean=Array([[ 1.2526040e+00,  3.2419357e-05, -1.6066957e-03,  1.9951731e-03,
        -1.9843120e-03,  9.0189616e-04,  3.9274758e-03, -7.3985713e-05,
         2.3522393e-03, -2.6327993e-03, -1.9293381e-03,  4.7406703e-03,
        -8.8355283e-04, -2.1393031e-03,  2.3615120e-03, -1.3998359e-03,
        -1.1415685e-03]], dtype=float32), var=Array([[2.5689165e-04, 9.9989993e-05, 9.9990248e-05, 9.9990393e-05,
        9.9990386e-05, 9.9990073e-05, 9.9991536e-05, 9.9989993e-05,
        9.9990546e-05, 9.9990684e-05, 9.9990364e-05, 9.9992241e-05,
        9.9990073e-05, 9.9990451e-05, 9.9990553e-05, 9.9990189e-05,
        9.9990124e-05]], dtype=float32), count=Array(1.0001, dtype=float32, weak_type=True), env_state=LogEnvState(env_state=State(pipeline_state=S

In [88]:
env_ = BraxGymnaxWrapper(config["ENV_NAME"])
#env = BraxGymnaxWrapper('ant')
env_ = LogWrapper(env_)
env_ = ClipAction(env_)
#env_ = VecEnv(env)
if config["NORMALIZE_ENV"]:
    env_ = NormalizeVecObservation(env_)

In [89]:
#env = get_env("hopper_uni")
reset_fn_ = jax.jit(env_.reset)
step_fn_ = jax.jit(env_.step)

In [87]:
rng = jax.random.PRNGKey(1)
obs, state_ = reset_fn_(rng, None)
state_.env_state.env_state.pipeline_state.x.pos

Array([[-1.3433027e-03,  0.0000000e+00,  1.2527804e+00],
       [-9.4429153e-04,  0.0000000e+00,  1.0527809e+00],
       [ 2.3356427e-03,  0.0000000e+00,  3.5279030e-01],
       [ 2.0324700e-01,  0.0000000e+00,  3.7901103e-03],
       [-9.4429153e-04,  0.0000000e+00,  1.0527809e+00],
       [ 5.3907203e-04,  0.0000000e+00,  3.5278285e-01],
       [ 2.0167159e-01,  0.0000000e+00,  3.0579269e-03]], dtype=float32)

In [93]:
rng

Array([ 133672768, 3027184172], dtype=uint32)

In [92]:
__rng = jax.random.split(rng, 1)
__rng

Array([[ 399689397, 3426049838]], dtype=uint32)

In [90]:
rng = jax.random.PRNGKey(1)
#reset_rng = jax.random.split(rng, config["NUM_ENVS"])
#obs, state = reset_fn(reset_rng, None)
obs, state = reset_fn_(rng, None)
rollout = [state]
total_reward = 0
steps = 0
s = 0
while True:
    s += 1
    steps += 1
    print(f"Step: {steps}")
    rng, _rng, __rng = jax.random.split(rng, 3)
    #pi, _ = net.apply(params, state.obs)
    pi, _ = net.apply(params, obs)
    action = pi.sample(seed=_rng)
    start_time = time.time()   
    #state = step_fn(state, action)
    #__rng = jax.random.split(rng, 1)
    obs, state, reward, done, _ = step_fn_(__rng, state, action) 
    print(f"Step time: {time.time() - start_time}")
    print(f"Reward: {reward}")
    total_reward += reward
    
    #if state.done:
    #    break
    #else:
    #    rollout.append(state)
    if done:
        break
    else:
        rollout.append(state)    
    #if s > 1000:
    #    break
    #else:
        
print(f"Total reward: {total_reward}")
    

Step: 1
Step time: 9.630108833312988
Reward: 1.0245981216430664
Step: 2
Step time: 8.776236057281494
Reward: 1.005138874053955
Step: 3
Step time: 0.002464771270751953
Reward: 0.9822401404380798
Step: 4
Step time: 0.002364635467529297
Reward: 1.013896107673645
Step: 5
Step time: 0.0024874210357666016
Reward: 0.9564902186393738
Step: 6
Step time: 0.0023839473724365234
Reward: 0.9251849055290222
Step: 7
Step time: 0.0024330615997314453
Reward: 0.9190219044685364
Step: 8
Step time: 0.0024518966674804688
Reward: 0.8810594081878662
Step: 9
Step time: 0.0026488304138183594
Reward: 0.9072877168655396
Step: 10
Step time: 0.0023767948150634766
Reward: 0.9529709219932556
Step: 11
Step time: 0.0024323463439941406
Reward: 0.8659573793411255
Step: 12
Step time: 0.0024785995483398438
Reward: 0.7979046702384949
Step: 13
Step time: 0.0023584365844726562
Reward: 0.7299497127532959
Step: 14
Step time: 0.0028641223907470703
Reward: 0.5735960602760315
Step: 15
Step time: 0.001993417739868164
Reward: 0.5168

In [43]:
state.env_state

NormalizeVecObsEnvState(mean=Array([[ 1.1573642 , -0.47945768, -0.23382315, -0.04235097,  0.44697273,
        -0.12198464, -0.2836983 ,  0.23967423, -0.25485012, -0.0911651 ,
        -0.40483886, -0.19864   , -0.03250589,  0.30760854, -0.14199342,
        -0.20249607,  0.32600734]], dtype=float32), var=Array([[3.1277228e-03, 8.4184483e-02, 1.9963253e-02, 8.7328133e-04,
        8.0726102e-02, 1.6254807e-02, 2.8604344e-02, 1.0868104e-01,
        4.4699717e-02, 8.2097962e-02, 1.3906202e-01, 5.1733845e-01,
        4.8416817e-01, 4.2512898e+00, 1.1003299e+00, 8.6932242e-01,
        6.6067181e+00]], dtype=float32), count=Array(308.0001, dtype=float32, weak_type=True), env_state=LogEnvState(env_state=State(pipeline_state=State(q=Array([[ 0.00251595, -0.00413584, -0.00010115, -0.00083415, -0.00135881,
         0.00434954, -0.00383881, -0.00084007, -0.00362684]],      dtype=float32), qd=Array([[-3.5661827e-03,  1.9501895e-06, -2.1542888e-04,  4.1227909e-03,
         1.3782093e-03,  2.0741844e-0

FrozenInstanceError: cannot assign to field 'x'

In [46]:
la = np.squeeze(state.env_state.env_state.env_state.pipeline_state.x.pos)

In [48]:
np.squeeze(la).shape

(7, 3)

In [None]:
cleaned_rollout = []

In [31]:
a = html.render(env.sys, [state.env_state.env_state.env_state.pipeline_state for state in rollout])

RuntimeError: Expected state.x position and rotation to have 2 shape dimensions but received len(pos.shape)=3 and len(rot.shape)=3

In [14]:
display(HTML(a))

In [None]:
metrics = out["metrics"]

In [None]:
metrics["returned_episode_lengths"][150:200,:,1000]

Array([[135, 135, 135, 135, 135, 135, 135, 135, 135, 135],
       [135, 135, 135, 135, 135, 135, 135, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145],
       [145, 145, 145, 145, 145, 145, 145, 145, 145, 145

In [None]:
metrics["returned_episode_returns"][5:19,:,4]

Array([[  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.     ,   0.     ],
       [  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.     ,   0.     ],
       [  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.     ,   0.     ],
       [  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.     ,   0.     ],
       [  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.     ,   0.     ],
       [  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.     ,   0.     ],
       [  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.     ,   0.     ],
       [  0.     ,   0.     ,   0.     ,   0.     ,   0.     ,   0.     ,
          0.     ,   0.     ,   0.   

In [None]:
metrics["returned_episode"][5:19,:,4]

Array([[False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, False, False, False, False,
        False],
       [False, False, False, False, False, 

In [None]:
with open('visualization.html', 'w') as file:
    file.write(html_content)

In [None]:
display(HTML("<div style='color:red'>Hello World!</div>"))

In [None]:
env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None

rollout = []
rewards = []

rng, _rng = jax.random.split(rng)
env_state = jax.jit(env.reset)(_rng)
done = False
while not done:
    rollout.append(env_state)
    pi, _ = out["runner_state"][0].apply(out["runner_state"][0].params)
    action = pi.sample(seed=_rng)
    rng, _rng = jax.random.split(rng)
    env_state, _, reward, done, _ = env.step(_rng, env_state, action, None)
    rewards.append(reward)
    
a = html.render(env.sys, [s.qp for s in rollout])
HTML(a)

AttributeError: 'TrainState' object has no attribute 'apply'

In [None]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [None]:
config_dict = {
    "no_agents": 32,
    "batch_size": 32 * 1000,
    "mini_batch_size": 32000,
    "no_epochs": 10,
    "learning_rate": 3e-4,
    "discount_rate": 0.99,
    "clip_param": 0.2,
    "vf_coef": 0.5,
    "gae_lambda": 0.95,
    "env_name": "halfcheetah_uni",
}

# Initialize wandb with the configuration dictionary
wandb.init(project="mcpg", name='PPOish', config=config_dict)

env = get_env(config_dict["env_name"])


policy_hidden_layers = [64, 64]
value_hidden_layers = [64, 64]

policy = MLP(
    hidden_layers_size=policy_hidden_layers,
    action_size=env.action_size,
    activation=nn.tanh,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    mean_init=jax.nn.initializers.orthogonal(scale=0.01),
)

value_net = ValueNet(
    hidden_layers_size=value_hidden_layers,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    value_init=jax.nn.initializers.orthogonal(scale=1.),
    activation=nn.tanh,
)

agent = MCPG(Config(**wandb.config), policy, value_net, env)

random_key = jax.random.PRNGKey(0)
train_state_policy, train_state_value = agent.init(random_key)

num_steps = 1000
log_period = 10

metrics_wandb = dict.fromkeys(["mean loss", "mean reward", "mask", "evaluation", 'time'], jnp.array([]))
eval_num = config_dict["no_agents"]
print(f"Number of evaluations per training step: {eval_num}")
start_time = time.time()
for i in range(num_steps // log_period):
    random_key, subkey = jax.random.split(random_key)
    train_state_policy, train_state_value, current_metrics = agent.train(subkey, train_state_policy, train_state_value, log_period, eval=False)
    timelapse = time.time() - start_time
    print(f"Step {(i+1) * log_period}, Time: {timelapse}")
    
    current_metrics["evaluation"] = jnp.arange(log_period*eval_num*(i+1), log_period*eval_num*(i+2), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    current_metrics["mean loss"] = jnp.repeat(jnp.mean(current_metrics["loss"]), log_period)
    current_metrics["mean reward"] = jnp.repeat(jnp.mean(jnp.sum(current_metrics["reward"], axis=-1)), log_period)
    current_metrics["mask"] = jnp.repeat(jnp.mean(current_metrics["mask"]), log_period)
    '''
    metrics_wandb = jax.tree_util.tree_map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics_wandb, current_metrics)
    
    log_metrics = jax.tree_util.tree_map(lambda metric: metric[-1], metrics_wandb)
    
    wandb.log(log_metrics)
    '''
    
    def update_metrics(old_metrics, new_metrics):
        updated_metrics = {}
        for key in old_metrics:
            if key in new_metrics:
                # Check if old metrics for key is empty, and initialize properly if so
                if old_metrics[key].size == 0:
                    updated_metrics[key] = new_metrics[key]
                else:
                    updated_metrics[key] = jnp.concatenate([old_metrics[key], new_metrics[key]], axis=0)
            else:
                raise KeyError(f"Key {key} not found in new metrics.")
        return updated_metrics

    # In your training loop:
    try:
        metrics_wandb = update_metrics(metrics_wandb, current_metrics)
        log_metrics = {k: v[-1] for k, v in metrics_wandb.items()}  # Assuming you want the latest entry
        wandb.log(log_metrics)
    except Exception as e:
        print(f"Error updating metrics: {e}")

    
    
    
    
    start_time = time.time()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mk_mitsides[0m ([33mmitsides[0m). Use [1m`wandb login --relogin`[0m to force relogin


NameError: name 'MLP' is not defined