## Installation and Imports

In [None]:
!pip install ray[rllib] torch
!pip install tensorflow_probability
!pip install wandb

In [1]:
import ray
from ray.rllib.algorithms.sac import SACConfig
from ray.tune.registry import register_env
from ray.tune.logger import pretty_print

from ray import air, tune
from ray.air import session
from ray.air.integrations.wandb import setup_wandb
from ray.air.integrations.wandb import WandbLoggerCallback

import gym

## Configure and Run

In [2]:
config = (
    SACConfig()
    .environment(
        env="HalfCheetah-v3",
        normalize_actions=True
    )
    .training(
        q_model_config={
            "fcnet_activation": "relu",
            "fcnet_hiddens": [256, 256]
        },
        policy_model_config={
            "fcnet_activation": "relu",
            "fcnet_hiddens": [256, 256]
        },
        tau=0.005,
        target_entropy="auto",
        n_step=1,  # num of SGD steps per batch of data in training step
        train_batch_size=256,
        target_network_update_freq=1,
        replay_buffer_config={"type":"MultiAgentPrioritizedReplayBuffer"},
        num_steps_sampled_before_learning_starts=10_000,
        optimization_config={
          "actor_learning_rate": 0.0003,
          "critic_learning_rate": 0.0003,
          "entropy_learning_rate": 0.0003,
        },
        clip_actions=False
    )
    .rollouts(
        num_rollout_workers=3,
        rollout_fragment_length=1,
    )
    .resources(num_gpus=0)
    .evaluation(evaluation_interval=100) # For 1000 timesteps iter; 100 evals
    .reporting(
        min_sample_timesteps_per_iteration=1000,
        metrics_num_episodes_for_smoothing=5
    )
    .framework(framework="torch")
)

In [3]:
wandb_init = dict(
    save_code=True,
    config={
        "env": "HalfCheetah-v3",
        
        "actor_learning_rate": 0.0003,
        "critic_learning_rate": 0.0003,
        "entropy_learning_rate": 0.0003,
        "framework": "torch",
        
        "num_rollout_workers": 3,
        "num_gpu": 0,
        "metrics_num_episodes_for_smoothing": 5
    },
    tags=["local"],
    notes="Test to inspect scaling on Vast.ai",
    name="HalfCheetah_local"
    # job_type=D
    # monitor_gym=
)

In [4]:
tuner = tune.Tuner(
    "SAC",
    run_config=air.RunConfig(
        name="HalfCheetah_local",
        stop={"training_iteration": 3_000, "episode_reward_mean": 150},
        checkpoint_config=air.CheckpointConfig(checkpoint_at_end=True),
        callbacks=[
                WandbLoggerCallback(project="HalfCheetah", 
                                    api_key="392c8a47eb0658eb5c71190757a69110e2140f4a"
                                    save_checkpoints=True, 
                                    **wandb_init)
            ],
        local_dir="./results"
        ),
    param_space=config
)

results = tuner.fit()

2023-02-10 11:04:20,575	INFO worker.py:1538 -- Started a local Ray instance.
2023-02-10 11:04:22,218	INFO wandb.py:250 -- Already logged into W&B.


0,1
Current time:,2023-02-10 11:13:51
Running for:,00:09:28.95
Memory:,6.1/7.5 GiB

Trial name,status,loc,iter,total time (s),ts,reward,episode_reward_max,episode_reward_min,episode_len_mean
SAC_HalfCheetah-v3_4a7ee_00000,RUNNING,192.168.152.36:8770,19,519.932,19038,-259.557,-190.513,-334.009,1000


[34m[1mwandb[0m: Currently logged in as: [33mdanieladejumo[0m. Use [1m`wandb login --relogin`[0m to force relogin
[2m[36m(SAC pid=8770)[0m 2023-02-10 11:04:26,610	INFO algorithm.py:501 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
[2m[33m(raylet)[0m [2023-02-10 11:04:30,462 E 8512 8557] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-02-10_11-04-18_300548_8410 is over 95% full, available space: 1219694592; capacity: 31845081088. Object creation will fail if spilling is required.




Trial name,agent_timesteps_total,counters,custom_metrics,date,done,episode_len_mean,episode_media,episode_reward_max,episode_reward_mean,episode_reward_min,episodes_this_iter,episodes_total,experiment_id,hostname,info,iterations_since_restore,node_ip,num_agent_steps_sampled,num_agent_steps_trained,num_env_steps_sampled,num_env_steps_sampled_this_iter,num_env_steps_trained,num_env_steps_trained_this_iter,num_faulty_episodes,num_healthy_workers,num_in_flight_async_reqs,num_remote_worker_restarts,num_steps_trained_this_iter,perf,pid,policy_reward_max,policy_reward_mean,policy_reward_min,sampler_perf,sampler_results,time_since_restore,time_this_iter_s,time_total_s,timers,timestamp,timesteps_since_restore,timesteps_total,training_iteration,trial_id,warmup_time
SAC_HalfCheetah-v3_4a7ee_00000,19038,"{'num_env_steps_sampled': 19038, 'num_env_steps_trained': 771328, 'num_agent_steps_sampled': 19038, 'num_agent_steps_trained': 771328, 'last_target_update_ts': 19038, 'num_target_updates': 3013}",{},2023-02-10_11-13-16,False,1000,{},-190.513,-259.557,-334.009,0,18,eb2ab4e9039e47baa7a0744ff713e4f8,Daniel,"{'learner': {'default_policy': {'learner_stats': {'allreduce_latency': 0.0, 'grad_gnorm': 9.758378982543945, 'actor_loss': -25.98797607421875, 'critic_loss': 0.1140708327293396, 'alpha_loss': -8.783862113952637, 'alpha_value': 0.4065146, 'log_alpha_value': -0.9001354, 'target_entropy': -6.0, 'policy_t': 0.009235848672688007, 'mean_q': 24.703245162963867, 'max_q': 29.371721267700195, 'min_q': 17.71662712097168}, 'td_error': array([ 0.33663464, 1.2601948 , 0.28399277, 1.0609436 , 1.1876373 ,  0.77235985, 0.39661217, 1.8075237 , 0.21414566, 3.072195 ,  2.3819637 , 0.4548006 , 0.24678612, 0.79231644, 0.8752146 ,  0.8064976 , 0.59739685, 0.88677406, 1.6359854 , 0.49676704,  2.2438574 , 2.4265738 , 1.096798 , 1.5275669 , 0.5167084 ,  0.7446928 , 0.63783646, 0.78740597, 0.8644352 , 0.70473003,  0.25894928, 3.8774958 , 1.8265591 , 0.52253056, 0.3203869 ,  0.6314659 , 0.11497688, 1.2570848 , 0.3603592 , 0.62269974,  0.55908775, 0.5118036 , 0.7691231 , 2.31979 , 1.1254578 ,  0.67871094, 1.1452103 , 0.22850704, 3.9334507 , 0.9412079 ,  0.4457178 , 1.8052874 , 2.7814283 , 0.32353497, 0.42216587,  0.53152657, 1.2548428 , 2.9495945 , 0.31592178, 2.0177383 ,  1.4619961 , 0.5534277 , 2.5569992 , 0.65735817, 0.6479845 ,  1.0793705 , 0.87029076, 0.66249657, 1.0622044 , 2.323742 ,  0.4010811 , 0.891881 , 0.4186468 , 0.6992321 , 1.0623341 ,  1.0056705 , 1.174758 , 1.3097258 , 2.323824 , 0.64501 ,  1.0750351 , 0.3852768 , 0.9641895 , 2.3115416 , 0.67684937,  0.8154249 , 0.1271534 , 0.63460064, 1.9324121 , 0.63217354,  2.850933 , 1.6855679 , 2.989849 , 0.43947506, 2.6809158 ,  1.0275698 , 1.7981501 , 2.0140724 , 0.63266945, 0.5763073 ,  0.5687628 , 2.600439 , 0.35207653, 0.5303459 , 0.91572 ,  0.8894968 , 0.5999069 , 0.88246155, 0.46098137, 1.1507416 ,  0.9566822 , 1.1904469 , 1.1747885 , 3.0292225 , 0.3907013 ,  0.90608215, 0.45789242, 0.36025524, 0.7595682 , 0.6853914 ,  0.47565842, 2.0772629 , 0.77293587, 3.083043 , 0.5579281 ,  0.22869682, 0.82181644, 1.4444485 , 1.2327309 , 2.2273254 ,  0.330225 , 3.7707224 , 0.43575096, 0.6655512 , 1.0478611 ,  0.30140686, 1.896472 , 0.71027946, 2.5655537 , 0.4602995 ,  1.4443302 , 20.699846 , 2.152135 , 0.6435814 , 2.1164722 ,  0.18476963, 0.4216776 , 0.5605507 , 0.6306133 , 1.7774782 ,  0.18500137, 1.8345537 , 1.1081963 , 0.69869995, 0.7706938 ,  0.36132622, 0.7507343 , 1.6786642 , 1.4784584 , 1.8067188 ,  0.97959614, 1.0228138 , 0.3762703 , 1.2080765 , 0.99732494,  0.6340637 , 1.9076681 , 0.21814823, 1.3577166 , 1.0699825 ,  1.2514515 , 0.7959509 , 0.20892906, 1.7141218 , 0.0699482 ,  0.5657463 , 1.9342804 , 0.7694483 , 1.7399502 , 1.4493809 ,  0.8604851 , 0.49267197, 3.3351908 , 0.33294678, 2.1961966 ,  1.2414217 , 1.0598145 , 1.4460678 , 0.8308735 , 2.048954 ,  0.44484806, 0.45335674, 2.6631975 , 0.6593342 , 0.3684368 ,  3.1114597 , 2.0016766 , 1.4549799 , 0.3212738 , 1.074542 ,  0.2679186 , 0.5940571 , 1.037981 , 3.1126242 , 0.29491997,  0.8105984 , 0.9493542 , 0.5602484 , 1.0763607 , 1.4213467 ,  2.437419 , 0.4743271 , 1.0421429 , 0.53688717, 1.2567396 ,  1.3113213 , 0.76556015, 0.5134773 , 0.5274782 , 1.7223206 ,  1.0456581 , 1.6769428 , 1.372714 , 0.15181923, 2.7394123 ,  3.2577715 , 1.4613667 , 0.25989056, 0.29662514, 2.4640274 ,  0.29676914, 1.0779533 , 1.9858942 , 1.5365696 , 1.7043457 ,  0.30888557, 0.9649706 , 0.25320816, 5.2329187 , 0.6947098 ,  1.3241634 , 2.8769398 , 0.8897095 , 2.0535202 , 0.90903664,  2.4368172 , 0.19243336, 0.8223791 , 0.30472946, 0.4822092 ,  0.45879364, 0.9614725 , 0.9965677 , 1.0878162 , 1.1616325 ,  0.13834572], dtype=float32), 'mean_td_error': 1.220942497253418, 'model': {}, 'custom_metrics': {}, 'num_agent_steps_trained': 256.0, 'num_grad_updates_lifetime': 3013.0, 'diff_num_grad_updates_vs_sampler_policy': 3012.0}}, 'num_env_steps_sampled': 19038, 'num_env_steps_trained': 771328, 'num_agent_steps_sampled': 19038, 'num_agent_steps_trained': 771328, 'last_target_update_ts': 19038, 'num_target_updates': 3013}",19,192.168.152.36,19038,771328,19038,1002,771328,85504,0,3,0,0,85504,"{'cpu_util_percent': 46.8, 'ram_util_percent': 80.81625}",8770,{},{},{},"{'mean_raw_obs_processing_ms': 1.3236382406449168, 'mean_inference_ms': 2.3153769610507147, 'mean_action_processing_ms': 0.2232992425872684, 'mean_env_wait_ms': 0.27413021622578043, 'mean_env_render_ms': 0.0}","{'episode_reward_max': -190.51257700856925, 'episode_reward_min': -334.0085722934719, 'episode_reward_mean': -259.55740757502326, 'episode_len_mean': 1000.0, 'episode_media': {}, 'episodes_this_iter': 0, 'policy_reward_min': {}, 'policy_reward_max': {}, 'policy_reward_mean': {}, 'custom_metrics': {}, 'hist_stats': {'episode_reward': [-259.6053751951275, -190.51257700856925, -334.0085722934719, -307.29519873564004, -206.36531464230757], 'episode_lengths': [1000, 1000, 1000, 1000, 1000]}, 'sampler_perf': {'mean_raw_obs_processing_ms': 1.3236382406449168, 'mean_inference_ms': 2.3153769610507147, 'mean_action_processing_ms': 0.2232992425872684, 'mean_env_wait_ms': 0.27413021622578043, 'mean_env_render_ms': 0.0}, 'num_faulty_episodes': 0}",519.932,57.6614,519.932,"{'training_iteration_time_ms': 158.002, 'load_time_ms': 0.297, 'load_throughput': 862166.231, 'learn_time_ms': 25.808, 'learn_throughput': 9919.35, 'synch_weights_time_ms': 5.692}",1676023996,0,19038,19,4a7ee_00000,8.90337


[2m[33m(raylet)[0m [2023-02-10 11:04:40,467 E 8512 8557] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-02-10_11-04-18_300548_8410 is over 95% full, available space: 1219559424; capacity: 31845081088. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-02-10 11:04:50,473 E 8512 8557] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-02-10_11-04-18_300548_8410 is over 95% full, available space: 1219481600; capacity: 31845081088. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-02-10 11:05:00,478 E 8512 8557] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-02-10_11-04-18_300548_8410 is over 95% full, available space: 1219358720; capacity: 31845081088. Object creation will fail if spilling is required.
[2m[33m(raylet)[0m [2023-02-10 11:05:10,487 E 8512 8557] (raylet) file_system_monitor.cc:105: /tmp/ray/session_2023-02-10_11-04-18_300548_8410 is over 95% full, available space: 1219256320; capaci