In [13]:
from dataclasses import dataclass
from typing import Any, Union, Tuple, Callable, Optional
from functools import partial
import time

import jax
import jax.numpy as jnp
import flax
import flax.linen as nn
from flax.training.train_state import TrainState
import optax

from qdax import environments, environments_v1
from jax import random
import wandb

import pickle
from optax import exponential_decay
from IPython.display import HTML
from brax.io import html
import os
import jax.debug
import matplotlib.pyplot as plt
%matplotlib inline

In [14]:
def get_env(env_name):
    if env_name == "hopper_uni":
        episode_length = 1000
        
        env = environments_v1.create(env_name, episode_length=episode_length)
    elif env_name == "halfcheetah_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)
        
    elif env_name == "walker2d_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length)	
    elif env_name == "ant_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=True)
    elif env_name == "humanoid_uni":
        episode_length = 1000

        env = environments_v1.create(env_name, episode_length=episode_length, exclude_current_positions_from_observation=True)	
    '''
    elif env_name == "ant_omni":
        episode_length = 250
        max_bd = 30.

        env = environments.create(env_name, episode_length=episode_length, use_contact_forces=False, exclude_current_positions_from_observation=False)	
    elif env_name == "humanoid_uni":
        episode_length = 1000
        max_bd = 1.

        env = environments.create(env_name, episode_length=episode_length)	
    else:
        ValueError(f"Environment {env_name} not supported.")
    '''
    return env

In [15]:
EPS = 1e-8
class MLP(nn.Module):
    """MCPG MLP module"""
    hidden_layers_size: Tuple[int, ...]
    action_size: int
    activation: Callable[[jnp.ndarray], jnp.ndarray] = nn.tanh
    bias_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.zeros
    hidden_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    mean_init: Callable[[jnp.ndarray, Any], jnp.ndarray] = jax.nn.initializers.lecun_uniform()
    
    def setup(self):
        self.hidden_layers = [nn.Dense(features, kernel_init=self.hidden_init, bias_init=self.bias_init) for features in self.hidden_layers_size]
        self.mean = nn.Dense(self.action_size, kernel_init=self.mean_init, bias_init=self.bias_init)
        self.log_std = self.param("log_std", lambda _, shape: jnp.log(0.5)*jnp.ones(shape), (self.action_size,))
        
    def distribution_params(self, obs: jnp.ndarray):
        hidden = obs
        for hidden_layer in self.hidden_layers:
            hidden = self.activation(hidden_layer(hidden))
            
        mean = self.mean(hidden)
        log_std = self.log_std
        std = jnp.exp(log_std)
        
        return mean, log_std, std
    
    def logp(self, obs: jnp.ndarray, action: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        logp = jax.scipy.stats.norm.logpdf(action, mean, std)
        return logp.sum(axis=-1)
    
    def __call__(self, random_key: Any, obs: jnp.ndarray):
        mean, _, std = self.distribution_params(obs)
        
        # Sample action
        rnd = jax.random.normal(random_key, shape = (self.action_size,))
        action = jax.lax.stop_gradient(mean + rnd * std)
        
        logp = jnp.sum(jax.scipy.stats.norm.logpdf(action, mean, std), axis=-1) 
                
        return action, logp
    

In [19]:
@dataclass
class Config:
    no_agents: int = 256
    batch_size: int = 256000
    mini_batch_size: int = 32000
    no_epochs: int = 10
    learning_rate: float = 3e-4
    discount_rate: float = 0.99
    clip_param: float = 0.2
    env_name: str = "ant_uni"
    
class MCPG:
    
    def __init__(self, config, policy, env):
        self._config = config
        self._policy = policy
        self._env = env
        
    def init(self, random_key):
        random_key_1, random_key_2 = jax.random.split(random_key)
        fake_obs = jnp.zeros(shape=(self._env.observation_size,))
        params = self._policy.init(random_key_1, random_key_2, fake_obs)
        tx = optax.adam(learning_rate=self._config.learning_rate)
        
        return TrainState.create(apply_fn=self._policy.apply, params=params, tx=tx)
    
    @partial(jax.jit, static_argnames=("self",))
    def logp_fn(self, params, obs, action):
        return self._policy.apply(params, obs, action, method=self._policy.logp)
    
    @partial(jax.jit, static_argnames=("self",))
    def sample_step(self, random_key, train_state, env_state):
        """Samples one step in the environment and returns the next state, action and 
        log-prob of the action.
        """
        
        action, action_logp = train_state.apply_fn(train_state.params, random_key, env_state.obs)
        
        next_env_state = self._env.step(env_state, action)
        
        return next_env_state, action, action_logp
    
    @partial(jax.jit, static_argnames=("self","evaluate"))
    def sample_trajectory(self, random_key, train_state, evaluate=False):
        """Samples a full trajectory using the environment and policy."""
        if evaluate:
            length = self._env.episode_length
        else:
            length = int(self._config.batch_size / self._config.no_agents)
            
        random_keys = jax.random.split(random_key, length+1)
        env_state_init = self._env.reset(random_keys[-1])
        
        def _scan_sample_step(carry, x):
            (train_state, env_state,) = carry
            (random_key, ) = x
            
            next_env_state, action, action_logp = self.sample_step(random_key, train_state, env_state)
            return (train_state, next_env_state), (env_state.obs, action, action_logp, next_env_state.reward, env_state.done, env_state.info["state_descriptor"])
        
        _, (obs, action, action_logp, reward, done, state_desc) = jax.lax.scan(
            _scan_sample_step, 
            (train_state, env_state_init), 
            (random_keys[:length],),
            length=length,
            )
        
        mask = 1. - jnp.clip(jnp.cumsum(done), a_min=0., a_max=1.)
        
        return obs, action, action_logp, reward, state_desc, mask
    
    
    @partial(jax.jit, static_argnames=("self",))
    def get_return(self, reward):
        """ Computes the discounted return for each step in the trajectory.
        """
        
        def _body(carry, x):
            (next_return,) = carry
            (reward,) = x
            
            current_return = reward + self._config.discount_rate * next_return
            return (current_return,), (current_return,)
        
        _, (return_,) = jax.lax.scan(
            _body,
            (jnp.array(0.),),
            (reward,),
            length=int(self._config.batch_size / self._config.no_agents),
            reverse=True,
            )
            
        return return_
    
    @partial(jax.jit, static_argnames=("self",))
    def standardize(self, return_):
        return jax.nn.standardize(return_, axis=0, variance=1, epsilon=EPS)
    
    @partial(jax.jit, static_argnames=("self",))
    def get_return_standardize(self, reward, mask):
        """Standardizes the return values for stability in training
        """
        return_ = jax.vmap(self.get_return)(reward * mask)
        return self.standardize(return_)
    
    @partial(jax.jit, static_argnames=("self",))
    def loss_rein(self, params, obs, action, mask, return_standardized):
        """ REINFORCE loss function.
        """
        logp_ = self.logp_fn(params, jax.lax.stop_gradient(obs), jax.lax.stop_gradient(action))
        return -jnp.mean(jnp.multiply(logp_ * mask, jax.lax.stop_gradient(return_standardized)))
    
    @partial(jax.jit, static_argnames=("self",))
    def loss_ppo(self, params, obs, action, logp, mask, return_standardized):
        """ PPO loss function.
        """
        
        logp_ = self.logp_fn(params, jax.lax.stop_gradient(obs), jax.lax.stop_gradient(action))
        ratio = jnp.exp(logp_ - jax.lax.stop_gradient(logp))
        
        pg_loss_1 = jnp.multiply(ratio * mask, jax.lax.stop_gradient(return_standardized))
        pg_loss_2 = jax.lax.stop_gradient(return_standardized) * jax.lax.clamp(1. - self._config.clip_param, ratio, 1. + self._config.clip_param) * mask
        return -jnp.mean(jnp.minimum(pg_loss_1, pg_loss_2))
    

    @partial(jax.jit, static_argnames=("self",))
    def flatten_trajectory(self, obs, action, logp, mask, return_standardized):
        # Calculate the total number of elements in the combined first two dimensions
        total_elements = obs.shape[0] * obs.shape[1]
        
        new_obs_shape = (total_elements,) + obs.shape[2:]  # obs.shape[2:] should be unpacked if it's a tuple
        new_action_shape = (total_elements,) + action.shape[2:]  # Same handling as for obs
            
        # Flatten the first two dimensions
        obs = jnp.reshape(obs, new_obs_shape)
        action = jnp.reshape(action, new_action_shape)
        logp = jnp.reshape(logp, (total_elements,))
        mask = jnp.reshape(mask, (total_elements,))
        return_standardized = jnp.reshape(return_standardized, (total_elements,))
                
        return obs, action, logp, mask, return_standardized
    
    @partial(jax.jit, static_argnames=("self",))
    def train_step(self, random_key, train_state):
        # Sample trajectories
        random_keys = jax.random.split(random_key, self._config.no_agents+1)
        start_time = time.time()
        obs, action, logp, reward, _, mask = jax.vmap(self.sample_trajectory, in_axes=(0, None))(random_keys[:self._config.no_agents], train_state)
        
        time_elapsed = time.time() - start_time
        
        # Compute standaerdized return
        print(f"Reward before passing through the get_return_standardize{obs.shape}")
        return_standardized = self.get_return_standardize(reward, mask)
        
       # print(f"Before flattening{obs.shape}")
        
        obs_, action_, logp_, mask_, return_standardized_ = self.flatten_trajectory(obs, action, logp, mask, return_standardized)
        
        #print(f"After flattening{obs_.shape}")
        #b_inds = random.permutation(random_keys[-1], self._config.batch_size)
        
        random_keys_ = jax.random.split(random_keys[-1], self._config.no_epochs)
        
        def _scan_epoch_train(carry, x):
            (train_state,) = carry
            (random_key,) = x
            
            (train_state,), losses = self.epoch_train(random_key, train_state, obs_, action_, logp_, mask_, return_standardized_)
            
            return (train_state,), losses
        
        #print("Before _scan_epoch_train")
        
        #print(type(train_state))
        final_train_state, losses = jax.lax.scan(
            _scan_epoch_train,
            (train_state,),
            (random_keys_,),
            length=self._config.no_epochs,
            )
        
        #print("After _scan_epoch_train")
        #print(type(final_train_state[0]))
        metrics = {
            "loss" : losses,
            "reward" : reward * mask,
            "mask" : mask
        }
        jax.debug.print("Mean Loss: {}", jnp.mean(metrics["loss"]))
        jax.debug.print("Mean Reward: {}", jnp.mean(jnp.sum(metrics["reward"], axis=-1)))
        jax.debug.print("-" * 50)
        
        return (final_train_state[0],), (metrics,)
    
    '''
    @partial(jax.jit, static_argnames=("self",))
    def epoch_train(self, random_key, train_state, obs, action, logp, mask, return_standardized):
        b_inds = random.permutation(random_key, self._config.batch_size)
        
        
        def _scan_mini_train(carry, _):
            (train_state, counter) = carry
        
            idx = b_inds[counter * self._config.mini_batch_size : (counter+1) * self._config.mini_batch_size]
            loss, grad = jax.value_and_grad(self.loss_ppo)(train_state.params, obs[idx], action[idx], logp[idx], mask[idx], return_standardized[idx])
            new_train_state = train_state.apply_gradients(grads=grad)  
            return (new_train_state, counter+1), loss
        
        final_train_state, losses = jax.lax.scan(
            _scan_mini_train,
            (train_state, 0),
            None,
            length=self._config.batch_size // self._config.mini_batch_size,
            )
        
        return (final_train_state,), (losses,)
    '''
    
    @partial(jax.jit, static_argnames=("self",))
    def epoch_train(self, random_key, train_state, obs, action, logp, mask, return_standardized):
        total_size = self._config.batch_size
        
        shuffled_indices = jax.random.permutation(random_key, total_size)
        
        num_batches = self._config.batch_size // self._config.mini_batch_size
        batch_indices = jnp.array([shuffled_indices[i * self._config.mini_batch_size:(i + 1) * self._config.mini_batch_size] for i in range(num_batches)])
        #jax.debug.print("Returns: {}", return_standardized)
        #jax.debug.print("Returns shape: {}", return_standardized.shape)
        def _scan_mini_train(carry, idx):
            (train_state, counter) = carry
            #jax.debug.print("Returns_: {}", return_standardized[idx])
            #jax.debug.print("Returns_ shape: {}", return_standardized[idx].shape)
            #jax.debug.print("Obs: {}", obs[idx].shape)
            loss, grad = jax.value_and_grad(self.loss_ppo)(train_state.params, obs[idx], action[idx], logp[idx], mask[idx], return_standardized[idx])
            new_train_state = train_state.apply_gradients(grads=grad)  
            return (new_train_state, counter+1), loss
        
        #print('Before _scan_mini_train')
        final_train_state, losses = jax.lax.scan(
            _scan_mini_train,
            (train_state, 0),
            batch_indices,
            length=num_batches,
            )
        
        #print('After _scan_mini_train')
        
        return (final_train_state[0],), losses
    
    @partial(jax.jit, static_argnames=("self", "no_steps", "eval"))
    def train(self, random_key, train_state, no_steps, eval=False):
        """Trains the policy for a number of steps."""
        
        random_keys = jax.random.split(random_key, no_steps+1)
    

        def _scan_train_step(carry, x):
            (train_state,) = carry
            (random_key,) = x
            
            (train_state,), (metrics,) = self.train_step(random_key, train_state)
            
            return (train_state,), (metrics,)
        
        #print('Before  _scan_train_step')
        
        (train_state,), (metrics,) = jax.lax.scan(
            _scan_train_step,
            (train_state,),
            (random_keys[:no_steps],),
            length=no_steps,
            )
        
        #print('After  _scan_train_step')
        if eval:
            mean_reward = self.evaluate(random_keys[-1], train_state)
            jax.debug.print("Mean Reward over 20 episodes: {}", mean_reward)
            return mean_reward
        
        return train_state, metrics
    
    @partial(jax.jit, static_argnames=("self",))
    def evaluate(self, random_key, train_state):
        """Evaluates the policy in the environment."""
        random_keys = jax.random.split(random_key, 20)
        
        def _scan_evaluate(carry, _):
            (train_state,) = carry
            
            obs, _, _, reward, _, mask = jax.vmap(self.sample_trajectory, in_axes=(0, None, None))(random_keys, train_state, True)
            return (train_state,), (reward * mask,)
        
        (train_state,), (reward,) = jax.lax.scan(
            _scan_evaluate,
            (train_state,),
            None,
            length=20,
            )
        
        return jnp.mean(jnp.sum(reward, axis=-1))


In [20]:
config_dict = {
    "no_agents": 256,
    "batch_size": 1024 * 256,
    "mini_batch_size": 1024 * 256,
    "no_epochs": 16,
    "learning_rate": 1e-3,
    "discount_rate": 0.99,
    "clip_param": 0.2,
    "env_name": "ant_uni",
}

# Initialize wandb with the configuration dictionary
wandb.init(project="mcpg", name='PPOish', config=config_dict)

env = get_env(config_dict["env_name"])


policy_hidden_layers = [128, 128]

policy = MLP(
    hidden_layers_size=policy_hidden_layers,
    action_size=env.action_size,
    activation=nn.tanh,
    hidden_init=jax.nn.initializers.orthogonal(scale=jnp.sqrt(2)),
    mean_init=jax.nn.initializers.orthogonal(scale=0.01),
)

agent = MCPG(Config(**wandb.config), policy, env)

random_key = jax.random.PRNGKey(0)
train_state = agent.init(random_key)

num_steps = 1000
log_period = 10

metrics_wandb = dict.fromkeys(["mean loss", "mean reward", "mask", "evaluation", 'time'], jnp.array([]))
eval_num = config_dict["no_agents"]
print(f"Number of evaluations per training step: {eval_num}")
start_time = time.time()
for i in range(num_steps // log_period):
    random_key, subkey = jax.random.split(random_key)
    train_state, current_metrics = agent.train(subkey, train_state, log_period)
    timelapse = time.time() - start_time
    print(f"Step {(i+1) * log_period}, Time: {timelapse}")
    
    current_metrics["evaluation"] = jnp.arange(log_period*eval_num*(i+1), log_period*eval_num*(i+2), dtype=jnp.int32)
    current_metrics["time"] = jnp.repeat(timelapse, log_period)
    current_metrics["mean loss"] = jnp.repeat(jnp.mean(current_metrics["loss"]), log_period)
    current_metrics["mean reward"] = jnp.repeat(jnp.mean(jnp.sum(current_metrics["reward"], axis=-1)), log_period)
    current_metrics["mask"] = jnp.repeat(jnp.mean(current_metrics["mask"]), log_period)
    '''
    metrics_wandb = jax.tree_util.tree_map(lambda metric, current_metric: jnp.concatenate([metric, current_metric], axis=0), metrics_wandb, current_metrics)
    
    log_metrics = jax.tree_util.tree_map(lambda metric: metric[-1], metrics_wandb)
    
    wandb.log(log_metrics)
    '''
    
    def update_metrics(old_metrics, new_metrics):
        updated_metrics = {}
        for key in old_metrics:
            if key in new_metrics:
                # Check if old metrics for key is empty, and initialize properly if so
                if old_metrics[key].size == 0:
                    updated_metrics[key] = new_metrics[key]
                else:
                    updated_metrics[key] = jnp.concatenate([old_metrics[key], new_metrics[key]], axis=0)
            else:
                raise KeyError(f"Key {key} not found in new metrics.")
        return updated_metrics

    # In your training loop:
    try:
        metrics_wandb = update_metrics(metrics_wandb, current_metrics)
        log_metrics = {k: v[-1] for k, v in metrics_wandb.items()}  # Assuming you want the latest entry
        wandb.log(log_metrics)
    except Exception as e:
        print(f"Error updating metrics: {e}")

    
    
    
    
    start_time = time.time()

VBox(children=(Label(value='0.003 MB of 0.003 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112276277627745, max=1.0…

Number of evaluations per training step: 256
Reward before passing through the get_return_standardize(256, 1024, 28)


2024-07-01 16:47:32.687663: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(2366): _wrapped_callback
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1151): __call__
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/profiler.py(336): wrapper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1185): _pjit_call_impl_python
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1229): call_impl_cache_miss
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1245): _pjit_call_imp

XlaRuntimeError: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.custom_call' failed: CpuCallback error: KeyboardInterrupt: <EMPTY MESSAGE>

At:
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py(2366): _wrapped_callback
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py(1151): __call__
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/profiler.py(336): wrapper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1185): _pjit_call_impl_python
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1229): call_impl_cache_miss
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(1245): _pjit_call_impl
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/core.py(935): process_primitive
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/core.py(447): bind_with_trace
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/core.py(2740): bind
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(168): _python_pjit_helper
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/pjit.py(257): cache_miss
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/jax/_src/traceback_util.py(179): reraise_with_filtered_traceback
  /tmp/ipykernel_1436758/4078604604.py(42): <module>
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3577): run_code
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3517): run_ast_nodes
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3334): run_cell_async
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/async_helpers.py(129): _pseudo_sync_runner
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3130): _run_cell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/IPython/core/interactiveshell.py(3075): run_cell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/zmqshell.py(549): run_cell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/ipkernel.py(449): do_execute
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(778): execute_request
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/ipkernel.py(362): execute_request
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(437): dispatch_shell
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(534): process_one
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelbase.py(545): dispatch_queue
  /usr/lib/python3.10/asyncio/events.py(80): _run
  /usr/lib/python3.10/asyncio/base_events.py(1909): _run_once
  /usr/lib/python3.10/asyncio/base_events.py(603): run_forever
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/tornado/platform/asyncio.py(205): start
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel/kernelapp.py(739): start
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/traitlets/config/application.py(1075): launch_instance
  /vol/bitbucket/km2120/QD-DRL/me-with-sample-based-drl/my_env/lib/python3.10/site-packages/ipykernel_launcher.py(18): <module>
  /usr/lib/python3.10/runpy.py(86): _run_code
  /usr/lib/python3.10/runpy.py(196): _run_module_as_main
; current tracing scope: custom-call.32; current profiling annotation: XlaModule:#prefix=jit(train)/jit(main),hlo_module=jit_train,program_id=857#.