In [1]:
import gymnasium as gym
import numpy as np
from stable_baselines3 import SAC, TD3, A2C
import matplotlib.pyplot as plt
import pickle
import os
import argparse
import asyncio

In [12]:
models_dir = 'models'
logs_dir = 'logs'
if not os.path.exists(models_dir):
    os.makedirs(models_dir)

if not os.path.exists(logs_dir):
    os.makedirs(logs_dir)

In [19]:
def train(env, sb3_algo, max_iters=4, model=None, curr_name=None, gamma=0.99):
    gamma_str = str(gamma).replace('.', '_')
    if model is None:
        match sb3_algo:
            case 'SAC':
                model = SAC('MlpPolicy', env, verbose=1, tensorboard_log=f'{logs_dir}/SAC_gamma{gamma_str}')
            case 'TD3':
                model = TD3('MlpPolicy', env, verbose=1, tensorboard_log=f'{logs_dir}/TD3_gamma{gamma_str}')
            case 'A2C':
                model = A2C('MlpPolicy', env, verbose=1, tensorboard_log=f'{logs_dir}/A2C_gamma{gamma_str}')
            case _:
                print('Invalid algorithm')
                return
        name = f'{models_dir}/{sb3_algo}_gamma{gamma_str}'
    else:
        if curr_name is None:
            print('Please provide a name for the model')
            return
        name = curr_name
        model.set_env(env)

    os.makedirs(name, exist_ok=True)
    TIMESTEPS = 25000
    iters = 0
    while iters < max_iters:
        iters += 1
        model.learn(total_timesteps=TIMESTEPS, reset_num_timesteps=False)
        model.save(f'{name}/{TIMESTEPS*iters}')

def test(env, sb3_algo, path_to_model):
    match sb3_algo:
        case 'SAC':
            model = SAC.load(path_to_model)
        case 'TD3':
            model = TD3.load(path_to_model)
        case 'A2C':
            model = A2C.load(path_to_model)
        case _:
            print('Invalid algorithm')
            return

    obs = env.reset()[0]
    done = False
    extra_steps = 500
    while True:
        action, _states = model.predict(obs)
        obs, _, done, _, _ = env.step(action)
        
        if done:
            extra_steps -= 1

        if extra_steps < 0:
            break

In [20]:
gymenv = gym.make('Humanoid-v4', render_mode=None)

In [23]:
train(gymenv, 'SAC', gamma=0.99)
train(gymenv, 'SAC', gamma=0.9)
train(gymenv, 'SAC', gamma=0.8)
train(gymenv, 'SAC', gamma=0.5)
train(gymenv, 'SAC', gamma=0.1)

Using cpu device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.
Logging to logs/SAC_gamma0_99/SAC_0
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 24       |
|    ep_rew_mean     | 119      |
| time/              |          |
|    episodes        | 4        |
|    fps             | 2700     |
|    time_elapsed    | 0        |
|    total_timesteps | 96       |
---------------------------------




---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.6     |
|    ep_rew_mean     | 113      |
| time/              |          |
|    episodes        | 8        |
|    fps             | 229      |
|    time_elapsed    | 0        |
|    total_timesteps | 181      |
| train/             |          |
|    actor_loss      | -27.4    |
|    critic_loss     | 18.6     |
|    ent_coef        | 0.974    |
|    ent_coef_loss   | -0.721   |
|    learning_rate   | 0.0003   |
|    n_updates       | 80       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.8     |
|    ep_rew_mean     | 114      |
| time/              |          |
|    episodes        | 12       |
|    fps             | 165      |
|    time_elapsed    | 1        |
|    total_timesteps | 273      |
| train/             |          |
|    actor_loss      | -27.4    |
|    critic_loss     | 47       |
|    ent_coef 



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.5     |
|    ep_rew_mean     | 113      |
| time/              |          |
|    episodes        | 8        |
|    fps             | 238      |
|    time_elapsed    | 0        |
|    total_timesteps | 180      |
| train/             |          |
|    actor_loss      | -23.9    |
|    critic_loss     | 20.5     |
|    ent_coef        | 0.976    |
|    ent_coef_loss   | -0.673   |
|    learning_rate   | 0.0003   |
|    n_updates       | 79       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 23.8     |
|    ep_rew_mean     | 120      |
| time/              |          |
|    episodes        | 12       |
|    fps             | 169      |
|    time_elapsed    | 1        |
|    total_timesteps | 286      |
| train/             |          |
|    actor_loss      | -29.1    |
|    critic_loss     | 15       |
|    ent_coef 



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 27.2     |
|    ep_rew_mean     | 136      |
| time/              |          |
|    episodes        | 8        |
|    fps             | 211      |
|    time_elapsed    | 1        |
|    total_timesteps | 218      |
| train/             |          |
|    actor_loss      | -26.8    |
|    critic_loss     | 27.7     |
|    ent_coef        | 0.967    |
|    ent_coef_loss   | -0.943   |
|    learning_rate   | 0.0003   |
|    n_updates       | 117      |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 25.4     |
|    ep_rew_mean     | 127      |
| time/              |          |
|    episodes        | 12       |
|    fps             | 172      |
|    time_elapsed    | 1        |
|    total_timesteps | 305      |
| train/             |          |
|    actor_loss      | -29.1    |
|    critic_loss     | 41.7     |
|    ent_coef 



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 20.2     |
|    ep_rew_mean     | 101      |
| time/              |          |
|    episodes        | 8        |
|    fps             | 301      |
|    time_elapsed    | 0        |
|    total_timesteps | 162      |
| train/             |          |
|    actor_loss      | -25.4    |
|    critic_loss     | 97.1     |
|    ent_coef        | 0.982    |
|    ent_coef_loss   | -0.507   |
|    learning_rate   | 0.0003   |
|    n_updates       | 61       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 22.2     |
|    ep_rew_mean     | 111      |
| time/              |          |
|    episodes        | 12       |
|    fps             | 192      |
|    time_elapsed    | 1        |
|    total_timesteps | 267      |
| train/             |          |
|    actor_loss      | -31      |
|    critic_loss     | 9.73     |
|    ent_coef 



---------------------------------
| rollout/           |          |
|    ep_len_mean     | 24.8     |
|    ep_rew_mean     | 122      |
| time/              |          |
|    episodes        | 8        |
|    fps             | 234      |
|    time_elapsed    | 0        |
|    total_timesteps | 198      |
| train/             |          |
|    actor_loss      | -24.9    |
|    critic_loss     | 21.5     |
|    ent_coef        | 0.97     |
|    ent_coef_loss   | -0.85    |
|    learning_rate   | 0.0003   |
|    n_updates       | 97       |
---------------------------------
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 24.3     |
|    ep_rew_mean     | 121      |
| time/              |          |
|    episodes        | 12       |
|    fps             | 178      |
|    time_elapsed    | 1        |
|    total_timesteps | 292      |
| train/             |          |
|    actor_loss      | -29.1    |
|    critic_loss     | 16.2     |
|    ent_coef 

In [24]:
gymenv_test = gym.make('Humanoid-v4', render_mode='human')

In [25]:
test(gymenv_test, 'SAC', 'models/SAC_gamma0_99/100000')



In [26]:
test(gymenv_test, 'SAC', 'models/SAC_gamma0_9/100000')



In [27]:
test(gymenv_test, 'SAC', 'models/SAC_gamma0_8/100000')



In [28]:
test(gymenv_test, 'SAC', 'models/SAC_gamma0_5/100000')



In [29]:
test(gymenv_test, 'SAC', 'models/SAC_gamma0_1/100000')

: 