In [1]:
!nvidia-smi

Mon Aug  5 15:12:44 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.171.04             Driver Version: 535.171.04   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A30                     Off | 00000000:01:00.0 Off |                   On |
| N/A   31C    P0              49W / 165W |     50MiB / 24576MiB |     N/A      Default |
|                                         |                      |              Enabled |
+-----------------------------------------+----------------------+----------------------+

+------------------------------------------------------------------

In [2]:
from dataclasses import dataclass
from typing import Any, Union, Tuple, Callable, Optional
from functools import partial
import time

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

from qdax import environments, environments_v1
from jax import random
import wandb

import pickle
from optax import exponential_decay
from IPython.display import HTML
from brax.io import html
import os
import jax.debug
import matplotlib.pyplot as plt
%matplotlib inline

In [3]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import numpy as np
import optax 
from flax.linen.initializers import constant, orthogonal, lecun_uniform
from typing import Sequence, NamedTuple, Any
from flax.training.train_state import TrainState
import distrax
from wrappers import (
    LogWrapper,
    BraxGymnaxWrapper,
    VecEnv,
    NormalizeVecObservation,
    NormalizeVecReward,
    ClipAction,
)

In [4]:
class ActorCritic(nn.Module):
    action_dim: Sequence[int]
    activation: str = "tanh"
    
    @nn.compact
    def __call__(self, x):
        if self.activation == "relu":
            activation = nn.relu
        else:
            activation = nn.tanh
            
        actor_mean = nn. Dense(
            64, kernel_init = orthogonal(np.sqrt(2)), bias_init = constant(0.0)
        )(x)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(actor_mean)
        actor_mean = activation(actor_mean)
        actor_mean = nn.Dense(
            self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0)
        )(actor_mean)
        actor_logstd = self.param("log_std", lambda _, shape: jnp.log(0.5)*jnp.ones(shape), (self.action_dim,))
        pi = distrax.MultivariateNormalDiag(loc=actor_mean, scale_diag=jnp.exp(actor_logstd))
        
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(x)
        critic = activation(critic)
        critic = nn.Dense(
            64, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0)
        )(critic)
        critic = activation(critic)
        critic = nn.Dense(1, kernel_init=orthogonal(1.0), bias_init=constant(0.0))(critic)
        
        return pi, jnp.squeeze(critic, axis=-1)
        
        

In [5]:
class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: jnp.ndarray
    
def make_train(config):
    config["NUM_UPDATES"] = (
        config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"]
    )
    config["MINIBATCH_SIZE"] = (
        config["NUM_ENVS"] * config["NUM_STEPS"] // config["NUM_MINIBATCHES"]
    )
    env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
    env = LogWrapper(env)
    env = ClipAction(env)
    env = VecEnv(env)
    if config["NORMALIZE_ENV"]:
        env = NormalizeVecObservation(env)
        env = NormalizeVecReward(env, config["GAMMA"])
    
    def linear_schedule(count):
        frac = (
            1.0
            - (count // (config["NUM_MINIBATCHES"] * config["UPDATE_EPOCHS"]))
            / config["NUM_UPDATES"]
        )
        return config["LR"] * frac
    
    def train(rng):
        # INIT NETWORK
        network = ActorCritic(
            env.action_space(env_params).shape[0], activation=config["ACTIVATION"]
        )
        rng, _rng = jax.random.split(rng)
        init_x = jnp.zeros(env.observation_space(env_params).shape)
        network_params = network.init(_rng, init_x)
        if config["ANNEAL_LR"]:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(
                    learning_rate=linear_schedule,
                    eps=1e-5,
                ),
            )
            
        else:
            tx = optax.chain(
                optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
                optax.adam(
                    learning_rate=config["LR"],
                    eps=1e-5,
                ),
            )
            
        train_state = TrainState.create(
            apply_fn=network.apply,
            params=network_params,
            tx=tx,
        )
        
        # INIT ENV
        rng, _rng = jax.random.split(rng)
        reset_rng = jax.random.split(_rng, config["NUM_ENVS"])
        obsv, env_state = env.reset(reset_rng, env_params)
        
        # TRAIN LOOP
        def _update_step(runner_state, unused):
            # COLLECT TRAJECTORIES
            def _env_step(runner_state, unused):
                train_state, env_state, last_obs, rng = runner_state
                
                # SELECT ACTION
                rng, _rng = jax.random.split(rng)
                pi, value = network.apply(train_state.params, last_obs)
                #print(f"pi: {pi.shape}")
                #print(f"last_obs: {last_obs.shape}")
                action = pi.sample(seed=_rng)
                log_prob = pi.log_prob(action)
                
                # STEP ENV
                rng, _rng = jax.random.split(rng)
                rng_step = jax.random.split(_rng, config["NUM_ENVS"])
                obsv, env_state, reward, done, info = env.step(
                    rng_step, env_state, action, env_params,
                )    
                transition = Transition(
                    done, action, value, reward, log_prob, last_obs, info
                )
                runner_state = (train_state, env_state, obsv, rng)
                return runner_state, transition
            
            runner_state, traj_batch = jax.lax.scan(
                _env_step, 
                runner_state, 
                None, 
                length=config["NUM_STEPS"]
            )
            
            # COMPUTE ADVANTAGES
            train_state, env_state, last_obs, rng = runner_state
            _, last_val = network.apply(train_state.params, last_obs)
            
            def _calculate_gae(traj_batch, last_val):
                def _get_advantages(gae_and_next_value, transition):
                    gae, next_value = gae_and_next_value
                    done, value, reward = (
                        transition.done,
                        transition.value,
                        transition.reward,
                    )
                    delta = reward + config["GAMMA"] * next_value * (1 - done) - value
                    gae = (
                        delta
                        + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - done) * gae
                    )
                    return (gae, value), gae
                
                _, advantages = jax.lax.scan(
                    _get_advantages,
                    (jnp.zeros_like(last_val), last_val),
                    traj_batch,
                    reverse=True,
                    unroll=16,
                )
                
                return advantages, advantages + traj_batch.value
            
            advantages, targets = _calculate_gae(traj_batch, last_val)
            
            # UPDATE NETWORK
            def _update_epoch(update_state, unused):
                def _update_minibatch(train_state, batch_info):
                    traj_batch, advantages, targets = batch_info
                    
                    def _loss_fn(params, traj_batch, gae, targets):
                        # RERUN NETWORK
                        pi, value = network.apply(params, traj_batch.obs)
                        log_prob = pi.log_prob(traj_batch.action)
                        
                        # CALCULATE VALUE LOSS
                        value_pred_clipped = traj_batch.value + (
                            value - traj_batch.value
                        ).clip(-config["CLIP_EPS"], config["CLIP_EPS"])
                        value_losses = jnp.square(value - targets)
                        value_losses_clipped = jnp.square(value_pred_clipped - targets)
                        value_loss = 0.5 * jnp.mean(jnp.maximum(value_losses, value_losses_clipped))
                        
                        # CALCULATE ACTOR LOSS
                        ratio = jnp.exp(log_prob - traj_batch.log_prob)
                        gae = (gae - gae.mean()) / (gae.std() + 1e-8)
                        loss_actor1 = ratio * gae
                        loss_actor2 = (
                            jnp.clip(
                            ratio, 
                            1.0 - config["CLIP_EPS"], 
                            1.0 + config["CLIP_EPS"]
                            ) 
                            * gae
                        )
                        loss_actor = -jnp.mean(jnp.minimum(loss_actor1, loss_actor2))
                        entropy = jnp.mean(pi.entropy())
                        
                        total_loss = (
                            loss_actor
                            + config["VF_COEF"] * value_loss
                            - config["ENT_COEF"] * entropy
                        )
                        return total_loss, (value_loss, loss_actor, entropy)
                    
                    grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)
                    total_loss, grad = grad_fn(
                        train_state.params, 
                        traj_batch, 
                        advantages, 
                        targets)
                    train_state = train_state.apply_gradients(grads=grad)
                    return train_state, total_loss
                
                train_state, traj_batch, advantages, targets, rng = update_state
                rng, _rng = jax.random.split(rng)
                batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
                assert (
                    batch_size == config["NUM_ENVS"] * config["NUM_STEPS"]
                ), "batch size must be equal to number of steps * number of envs"
                
                permutation = jax.random.permutation(_rng, batch_size)
                batch = (traj_batch, advantages, targets)
                batch = jax.tree_util.tree_map(
                    lambda x: x.reshape((batch_size,) + x.shape[2:]), batch
                )
                shuffled_batch = jax.tree_util.tree_map(
                    lambda x: jnp.take(x, permutation, axis=0), batch   
                )
                minibatches = jax.tree_util.tree_map(
                    lambda x: jnp.reshape(
                        x, [config["NUM_MINIBATCHES"], -1] + list(x.shape[1:])
                    ),
                    shuffled_batch,
                )
                
                train_state, total_loss = jax.lax.scan(
                    _update_minibatch,
                    train_state,
                    minibatches,
                )
                update_state = (train_state, traj_batch, advantages, targets, rng)
                return update_state, total_loss
            
            update_state = (train_state, traj_batch, advantages, targets, rng)
            update_state, loss_info = jax.lax.scan(
                _update_epoch,
                update_state,
                None,
                length=config["UPDATE_EPOCHS"], 
            )
            train_state = update_state[0]
            metric = traj_batch.info
            rng = update_state[-1]
            if config.get("DEBUG"):
                
                def callback(info):
                    return_values = info["returned_episode_returns"][
                        info["returned_episode"]
                    ]
                    timesteps = (
                        info["timestep"][info["returned_episode"]] * config["NUM_ENVS"]
                    )
                    for t in range(len(timesteps)):
                        print(
                            f"global step={timesteps[t]}, episodic return={return_values[t]}"
                        )
                        
                jax.debug.callback(callback, metric)
                
            runner_state = (train_state, env_state, last_obs, rng)
            return runner_state, metric
        
        rng, _rng = jax.random.split(rng)
        runner_state = (train_state, env_state, obsv, _rng)
        runner_state, metric = jax.lax.scan(
            _update_step,
            runner_state,
            None,
            length=config["NUM_UPDATES"],
        )
        return {"runner_state": runner_state, "metrics": metric, "net": network, "env": env}
    
    return train        

In [32]:
p = (jnp.array([[3,4,5], [4,5,2]]), jnp.array([1,2,3]))

In [6]:
config = {
    "LR": 1e-3,
    "NUM_ENVS": 2048,
    "NUM_STEPS": 10,
    "TOTAL_TIMESTEPS": 2e6,
    "UPDATE_EPOCHS": 4,
    "NUM_MINIBATCHES": 32,
    "GAMMA": 0.99,
    "GAE_LAMBDA": 0.95,
    "CLIP_EPS": 0.2,
    "ENT_COEF": 0.0,
    "VF_COEF": 0.5,
    "MAX_GRAD_NORM": 0.5,
    "ACTIVATION": "tanh",
    "ENV_NAME": "ant",
    "ANNEAL_LR": False,
    "NORMALIZE_ENV": True,
    "DEBUG": True,
}
rng = jax.random.PRNGKey(30)
#train_jit = jax.jit(make_train(config))
#out = train_jit(rng)
train = make_train(config)
out = train(rng)

global step=4096, episodic return=1.1765235662460327
global step=4096, episodic return=2.3092637062072754
global step=4096, episodic return=-2.4587533473968506
global step=4096, episodic return=-0.10262155532836914
global step=4096, episodic return=-1.799912691116333
global step=4096, episodic return=-0.09559839963912964
global step=4096, episodic return=-0.22978943586349487
global step=4096, episodic return=1.35135817527771
global step=4096, episodic return=-1.0551214218139648
global step=4096, episodic return=1.3405739068984985
global step=4096, episodic return=1.0146634578704834
global step=4096, episodic return=0.9090223908424377
global step=4096, episodic return=-0.855269730091095
global step=4096, episodic return=-2.0903780460357666
global step=4096, episodic return=1.2595715522766113
global step=4096, episodic return=2.4156792163848877
global step=4096, episodic return=-1.8300812244415283
global step=4096, episodic return=0.19900977611541748
global step=4096, episodic return=0.8

In [7]:
net = out["net"]

In [9]:
state

NameError: name 'state' is not defined

In [8]:
pi, _ = net.apply(out["runner_state"][0][0].params, jnp.zeros(24))

TypeError: 'TrainState' object is not subscriptable

In [7]:
state = out["runner_state"][1]

In [13]:
state.env_state.mean.shape

(17,)

In [8]:
predefined_mean = state.env_state.mean
predefined_var = state.env_state.var
predefined_count = state.env_state.count

In [9]:
import brax
from brax.io import html

In [10]:
predefined_stats = (predefined_mean, predefined_var, predefined_count)

In [11]:
predefined_mean.shape

(17,)

In [11]:
from wrappers import NormalizeVecObservationEval

In [12]:
env = BraxGymnaxWrapper(config["ENV_NAME"])
env = LogWrapper(env)
env = ClipAction(env)
if config["NORMALIZE_ENV"]:
    env = NormalizeVecObservationEval(env, predefined_stats)
    #env = NormalizeVecReward(env, config["GAMMA"])


In [13]:
reset_fn = jax.jit(env.reset)
step_fn = jax.jit(env.step)

In [7]:
net = out["net"]
params = out["runner_state"][0].params
obs  =out["runner_state"][2]

In [9]:
params

{'params': {'Dense_0': {'bias': Array([-2.42214352e-01,  4.36511561e-02, -6.99268878e-02, -5.68671562e-02,
           5.01498461e-01,  2.28097379e-01, -1.47791967e-01,  9.38449129e-02,
           2.94316471e-01, -2.09683150e-01,  1.20437881e-02, -6.76225945e-02,
          -8.09601396e-02, -4.19561446e-01,  5.42369671e-02, -4.72876169e-02,
          -1.32228479e-01,  1.10574037e-01, -1.19805247e-01, -8.78027603e-02,
          -1.37756452e-01,  1.21544302e-01,  2.21318379e-01,  1.31506696e-01,
          -2.40932014e-02, -1.20049134e-01, -1.72823016e-02, -1.28707066e-01,
           1.31420875e-02,  1.95838690e-01, -3.77267003e-02,  8.10854360e-02,
           2.08163321e-01,  2.23521605e-01,  3.21230255e-02,  3.49557161e-01,
          -9.94042829e-02,  3.68670598e-02, -3.86249423e-01, -8.68967101e-02,
           2.70827591e-01,  2.67663330e-01,  1.20781876e-01,  7.94627964e-02,
           2.28148699e-02,  1.88086599e-01, -3.62034470e-01,  6.36563823e-02,
          -3.31332088e-01, -1.07110

In [10]:
def flatten_parameters(params, flat_params=None):
    if flat_params is None:
        flat_params = []
    
    for key, value in params.items():
        if isinstance(value, dict):
            # Recursive call to handle nested dictionaries
            flatten_parameters(value, flat_params)
        elif key in ['bias', 'kernel', 'log_std']:
            # Flatten and append the parameters if they match expected keys
            print(f"Including {key} with shape {value.shape}")  # Debug: Confirm these parameters are included
            flat_params.append(np.ravel(value))
        else:
            print(f"Skipping {key}")  # Debug: Notice skipped params or incorrect structures

    return np.concatenate(flat_params) if flat_params else np.array([])

In [13]:
def flatten_policy_parameters(params, flat_params=None, metadata=None):
    if flat_params is None:
        flat_params = []
    if metadata is None:
        metadata = []
    
    for key, value in params.items():
        if isinstance(value, dict):
            # Recursive call to handle nested dictionaries
            flatten_policy_parameters(value, flat_params, metadata)
        elif key in ['bias', 'kernel', 'log_std']:
            # Flatten and append the parameters if they match expected keys
            flat_params.append(jnp.ravel(value))
            metadata.append((key, value.shape))  # Store metadata for reconstruction
        else:
            continue

    return jnp.concatenate(flat_params) if flat_params else jnp.array([]), metadata

In [14]:
flat_params, metadata = flatten_policy_parameters(params)

In [18]:

def reconstruct_policy_parameters(flat_params, metadata):
    current_index = 0
    params = {}
    for key, shape in metadata:
        size = jnp.prod(jnp.array(shape))  # Convert shape tuple to JAX array before calling jnp.prod()
        param_values = flat_params[current_index:current_index + size].reshape(shape)
        current_index += size
        keys = key.split('/')
        d = params
        for subkey in keys[:-1]:
            if subkey not in d:
                d[subkey] = {}
            d = d[subkey]
        d[keys[-1]] = param_values
    return params

In [19]:
orig = reconstruct_policy_parameters(flat_params, metadata)

In [20]:
orig

{'bias': Array([0.26958966], dtype=float32),
 'kernel': Array([[ 0.317764  ],
        [-0.09797648],
        [-0.05949139],
        [ 0.41761887],
        [ 0.109753  ],
        [ 0.04838854],
        [ 0.25207198],
        [ 0.16589098],
        [-0.08557546],
        [-0.21562658],
        [ 0.06127341],
        [-0.20117277],
        [ 0.01811008],
        [ 0.04639089],
        [ 0.18783645],
        [-0.17479779],
        [ 0.07862838],
        [-0.0796316 ],
        [-0.41177702],
        [-0.15078193],
        [ 0.0653279 ],
        [ 0.07293047],
        [-0.26354405],
        [-0.05977688],
        [-0.02185716],
        [ 0.03255723],
        [ 0.10047881],
        [-0.06634797],
        [ 0.05430691],
        [ 0.03652165],
        [ 0.15704912],
        [-0.09669256],
        [-0.1966002 ],
        [-0.06464773],
        [-0.18955734],
        [-0.10910282],
        [-0.29754972],
        [ 0.282111  ],
        [-0.06569406],
        [-0.05555981],
        [ 0.06132558],
  

In [11]:
flat_params = flatten_parameters(params)

Including bias with shape (64,)
Including kernel with shape (27, 64)
Including bias with shape (64,)
Including kernel with shape (64, 64)
Including bias with shape (8,)
Including kernel with shape (64, 8)
Including bias with shape (64,)
Including kernel with shape (27, 64)
Including bias with shape (64,)
Including kernel with shape (64, 64)
Including bias with shape (1,)
Including kernel with shape (64, 1)
Including log_std with shape (8,)


In [12]:
flat_params.shape

(12497,)

In [28]:
pi, val = net.apply(params, obs)

In [40]:
jnp.zeros(shape=(()))

Array(0., dtype=float32)

In [22]:
type(pi.mean())

jaxlib.xla_extension.ArrayImpl

In [25]:
pi.event_shape[0]

8

In [27]:
pi.stddev().shape

(2048, 8)

In [15]:
import sys

sys.getsizeof(pi)

48

In [20]:
num = 1
rng = jax.random.PRNGKey(num)
#reset_rng = jax.random.split(rng, config["NUM_ENVS"])
#obs, state = reset_fn(reset_rng, None)
obs, state = reset_fn(rng, None)
#obs = jnp.expand_dims(obs, axis=0)
rollout = [state]
total_reward = 0
steps = 0
s = 0
while True:
    s += 1
    steps += 1
    print(f"Step: {steps}")
    rng, _rng, __rng = jax.random.split(rng, 3)
    #pi, _ = net.apply(params, state.obs)
    pi, _ = net.apply(params, obs)
    action = pi.sample(seed=_rng)
    start_time = time.time()   
    #state = step_fn(state, action)
    #__rng_set = jax.random.split(__rng, config["NUM_ENVS"])
    obs, state, reward, done, _ = step_fn(__rng, state, action) 
    #obs = jnp.expand_dims(obs, axis=0)
    print(f"Step time: {time.time() - start_time}")
    print(done)
    print(f"Reward: {reward}")
    total_reward += reward
    
    #if state.done:
    #    break
    #else:
    #    rollout.append(state)
    if done.any():
        break
    else:
        rollout.append(state)    
    #if s > 1000:
    #    break
    #else:
        
print(f"Total reward: {total_reward}")
    

Step: 1
Step time: 0.0021295547485351562
False
Reward: -1.434090495109558
Step: 2
Step time: 0.0021266937255859375
False
Reward: -1.1268305778503418
Step: 3
Step time: 0.002136707305908203
False
Reward: -0.39751946926116943
Step: 4
Step time: 0.002091646194458008
False
Reward: -1.2285380363464355
Step: 5
Step time: 0.0021169185638427734
False
Reward: -1.4133044481277466
Step: 6
Step time: 0.002138853073120117
False
Reward: -0.9259083271026611
Step: 7
Step time: 0.0020647048950195312
False
Reward: -1.9490478038787842
Step: 8
Step time: 0.0020248889923095703
False
Reward: -0.8639781475067139
Step: 9
Step time: 0.0020842552185058594
False
Reward: -1.0794166326522827
Step: 10
Step time: 0.002033710479736328
False
Reward: -1.9556195735931396
Step: 11
Step time: 0.002043485641479492
False
Reward: -2.5373587608337402
Step: 12
Step time: 0.0020668506622314453
False
Reward: -1.479681372642517
Step: 13
Step time: 0.002059459686279297
False
Reward: -1.2270903587341309
Step: 14
Step time: 0.002089

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7960b07dc4f0>>
Traceback (most recent call last):
  File "/vol/bitbucket/km2120/me-with-sample-based-drl/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 790, in _clean_thread_parent_frames
    active_threads = {thread.ident for thread in threading.enumerate()}
  File "/vol/bitbucket/km2120/me-with-sample-based-drl/venv/lib/python3.10/site-packages/ipykernel/ipkernel.py", line 790, in <setcomp>
    active_threads = {thread.ident for thread in threading.enumerate()}
  File "/usr/lib/python3.10/threading.py", line 1145, in ident
    @property
KeyboardInterrupt: 


Step time: 0.0018038749694824219
False
Reward: 2.367222309112549
Step: 675
Step time: 0.0015940666198730469
False
Reward: 2.5170938968658447
Step: 676
Step time: 0.0018932819366455078
False
Reward: 2.3525867462158203
Step: 677
Step time: 0.0012950897216796875
False
Reward: 2.1230740547180176
Step: 678
Step time: 0.0013360977172851562
False
Reward: 1.0092742443084717
Step: 679
Step time: 0.0013973712921142578
False
Reward: 1.1878817081451416
Step: 680
Step time: 0.0013375282287597656
False
Reward: 0.9040946960449219
Step: 681
Step time: 0.001310110092163086
False
Reward: -0.28569984436035156
Step: 682
Step time: 0.0013506412506103516
False
Reward: 0.4978671073913574
Step: 683
Step time: 0.0013289451599121094
False
Reward: 0.686267614364624
Step: 684
Step time: 0.0013887882232666016
False
Reward: -0.0036864280700683594
Step: 685
Step time: 0.0013189315795898438
False
Reward: 0.8092391490936279
Step: 686
Step time: 0.0012969970703125
False
Reward: -0.1736454963684082
Step: 687
Step time: 

In [17]:
obs.shape

(17,)

In [27]:
a = html.render(env.sys, [state.env_state.env_state.pipeline_state for state in rollout])

In [28]:
with open(f'visualization_ant.html', 'w') as file:
    file.write(a)

In [None]:
display(HTML("<div style='color:red'>Hello World!</div>"))

In [None]:
env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None

rollout = []
rewards = []

rng, _rng = jax.random.split(rng)
env_state = jax.jit(env.reset)(_rng)
done = False
while not done:
    rollout.append(env_state)
    pi, _ = out["runner_state"][0].apply(out["runner_state"][0].params)
    action = pi.sample(seed=_rng)
    rng, _rng = jax.random.split(rng)
    env_state, _, reward, done, _ = env.step(_rng, env_state, action, None)
    rewards.append(reward)
    
a = html.render(env.sys, [s.qp for s in rollout])
HTML(a)

AttributeError: 'TrainState' object has no attribute 'apply'

In [None]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [None]:
config_dict = {
    "no_agents": 32,
    "batch_size": 32 * 1000,
    "mini_batch_size": 32000,
    "no_epochs": 10,
    "learning_rate": 3e-4,
    "discount_rate": 0.99,
    "clip_param": 0.2,
    "vf_coef": 0.5,
    "gae_lambda": 0.95,
    "env_name": "halfcheetah_uni",
}

# Initialize wandb with the configuration dictionary
wandb.init(project="mcpg", name='PPOish', config=config_dict)

env = get_env(config_dict["env_name"])


policy_hidden_layers = [64, 64]
value_hidden_layers = [64, 64]

policy = MLP(
    hidden_layers_size=policy_hidden_layers,
    action_size=env.action_size,
    activation=nn.tanh,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    mean_init=jax.nn.initializers.orthogonal(scale=0.01),
)

value_net = ValueNet(
    hidden_layers_size=value_hidden_layers,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    value_init=jax.nn.initializers.orthogonal(scale=1.),
    activation=nn.tanh,
)

agent = MCPG(Config(**wandb.config), policy, value_net, env)

random_key = jax.random.PRNGKey(0)
train_state_policy, train_state_value = agent.init(random_key)

num_steps = 1000
log_period = 10

metrics_wandb = dict.fromkeys(["mean loss", "mean reward", "mask", "evaluation", 'time'], jnp.array([]))
eval_num = config_dict["no_agents"]
print(f"Number of evaluations per training step: {eval_num}")
start_time = time.time()
for i in range(num_steps // log_period):
    random_key, subkey = jax.random.split(random_key)
    train_state_policy, train_state_value, current_metrics = agent.train(subkey, train_state_policy, train_state_value, log_period, eval=False)
    timelapse = time.time() - start_time
    print(f"Step {(i+1) * log_period}, Time: {timelapse}")
    
    current_metrics["evaluation"] = jnp.arange(log_period*eval_num*(i+1), log_period*eval_num*(i+2), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    current_metrics["mean loss"] = jnp.repeat(jnp.mean(current_metrics["loss"]), log_period)
    current_metrics["mean reward"] = jnp.repeat(jnp.mean(jnp.sum(current_metrics["reward"], axis=-1)), log_period)
    current_metrics["mask"] = jnp.repeat(jnp.mean(current_metrics["mask"]), log_period)
    '''
    metrics_wandb = jax.tree_util.tree_map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics_wandb, current_metrics)
    
    log_metrics = jax.tree_util.tree_map(lambda metric: metric[-1], metrics_wandb)
    
    wandb.log(log_metrics)
    '''
    
    def update_metrics(old_metrics, new_metrics):
        updated_metrics = {}
        for key in old_metrics:
            if key in new_metrics:
                # Check if old metrics for key is empty, and initialize properly if so
                if old_metrics[key].size == 0:
                    updated_metrics[key] = new_metrics[key]
                else:
                    updated_metrics[key] = jnp.concatenate([old_metrics[key], new_metrics[key]], axis=0)
            else:
                raise KeyError(f"Key {key} not found in new metrics.")
        return updated_metrics

    # In your training loop:
    try:
        metrics_wandb = update_metrics(metrics_wandb, current_metrics)
        log_metrics = {k: v[-1] for k, v in metrics_wandb.items()}  # Assuming you want the latest entry
        wandb.log(log_metrics)
    except Exception as e:
        print(f"Error updating metrics: {e}")

    
    
    
    
    start_time = time.time()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mk_mitsides[0m ([33mmitsides[0m). Use [1m`wandb login --relogin`[0m to force relogin


NameError: name 'MLP' is not defined