In [1]:

import time
from typing import Any

import gymnasium as gym
import numpy as np
import torch
from torch import optim, nn

from src.experiment_logging.experiment_log import ExperimentLogItem
from src.experiment_logging.experiment_logger import ExperimentLogger, log_experiment
from src.module_analysis import count_parameters
from src.reinforcement_learning.algorithms.sac.sac import SAC, SACInfoStashConfig
from src.reinforcement_learning.algorithms.sac.sac_policy import SACPolicy
from src.reinforcement_learning.core.action_selectors.predicted_std_action_selector import PredictedStdActionSelector
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.loss_config import LossInfoStashConfig
from src.reinforcement_learning.core.policies.components.actor import Actor
from src.reinforcement_learning.core.policies.components.q_critic import QCritic
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
from src.stopwatch import Stopwatch
from src.summary_statistics import maybe_compute_summary_statistics

%load_ext autoreload
%autoreload 2

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
def get_setup() -> dict[str, str]:
    return {
        'notebook': _ih[1] + '\n\n' + _ih[-1], # first and last cell input (imports and this cell)
    }

step_stopwatch = Stopwatch()
total_stopwatch = Stopwatch()
best_iteration_score = -1e6

def save_experiment_state():
    experiment_id = logger.experiment_log['experiment_id']
    algo.save(
        folder_location=f'models/{env_name}/{experiment_id}', 
        name=experiment_id, 
        latest_log_item=logger.get_latest_log_item()
    )
    

def on_rollout_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    if step % 1000 != 0:
        return
    
    episode_scores = rl.buffer.compute_most_recent_episode_scores(rl.num_envs, consider_truncated_as_done=True)
    
    if len(episode_scores) > 0:
    
        global best_iteration_score
        iteration_score = episode_scores.mean()
        if iteration_score >= best_iteration_score:
            pass
    
    info['episode_scores'] = episode_scores
        
def on_optimization_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):    
    if step % 1000 != 0:
        return
    
    num_env_steps = step * rl.num_envs
    
    step_time = step_stopwatch.reset()
    total_time = total_stopwatch.time_passed()
    
    tail_indices = rl.buffer.tail_indices(1000)
    
    episode_scores = info.get('episode_scores')
    
    log_item: ExperimentLogItem = {
        'step': step,
        'num_env_steps': num_env_steps,
        'scores': maybe_compute_summary_statistics(episode_scores),
        'actor_loss': maybe_compute_summary_statistics(info['final_actor_loss']),
        'entropy_coef_loss': maybe_compute_summary_statistics(info.get('final_entropy_coef_loss')),
        'critic_loss': maybe_compute_summary_statistics(info['final_critic_loss']),
        'entropy_coef': maybe_compute_summary_statistics(info['entropy_coef']),
        'action_stds': maybe_compute_summary_statistics(info['rollout'].get('action_stds')),
        'action_magnitude': maybe_compute_summary_statistics(np.abs(rl.buffer.actions[tail_indices])),
        'num_gradient_steps': rl.gradient_steps_performed,
        'step_time': step_time,
        'total_time': total_time
    }
    print(logger.format_log_item(log_item, mean_format='5.3f', std_format='5.3f', step_time='.2f', total_time='.2f'), end='\n\n')
    logger.add_item(log_item)
    if step % 10000 == 0:
        logger.save_experiment_log()
        
        print()
    print()

device = torch.device("cuda:0") if True else torch.device('cpu')
print(f'using device {device}')

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
env_kwargs = {}
num_envs = 1

def create_env(render_mode: str | None):
    make_single_env = lambda: gym.make(env_name, render_mode=render_mode, **env_kwargs)
    
    if num_envs == 1:
        return make_single_env()
        
    return parallelize_env_async(make_single_env, num_envs)


def create_policy():
    in_size = 17
    action_size = 6
    
    actor_net = nn.Sequential(
        nn.Linear(in_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
    )

    critic = QCritic(
        n_critics=2,
        create_q_network=lambda: nn.Sequential(
            nn.Linear(in_size + action_size, 256),
            nn.ReLU(),
            # BatchRenorm(256),
            nn.Linear(256, 256),
            nn.ReLU(),
            # BatchRenorm(256),
            nn.Linear(256, 1)
        )
    )

    return SACPolicy(
        actor=Actor(actor_net, PredictedStdActionSelector(
            latent_dim=256,
            action_dim=action_size,
            base_std=1.0,
            squash_output=True,
        )),
        critic=critic
    )


env = create_env(render_mode=None)
policy = create_policy()
logger = ExperimentLogger(f'experiment_logs/{env_name}/sac/')

try:
    print(f'{count_parameters(policy) = }')
    print(f'{env = }, {num_envs = }')
        
    with ((torch.autograd.set_detect_anomaly(False))):
        algo = SAC(
            env=env,
            policy=policy,
            actor_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            critic_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            buffer_size=1_000_000,
            reward_scale=1,
            gamma=0.99,
            tau=0.005,
            entropy_coef_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),
            entropy_coef=1.0,
            rollout_steps=1,
            gradient_steps=1,
            warmup_steps=10_000,
            optimization_batch_size=256,
            target_update_interval=1,
            callback=Callback(
                on_rollout_done=on_rollout_done,
                rollout_schedulers={},
                on_optimization_done=on_optimization_done,
                optimization_schedulers={},
            ),
            stash_config=SACInfoStashConfig(stash_rollout_infos=True, stash_rollout_action_stds=True,
                                            stash_entropy_coef=True,
                                            entropy_coef_loss=LossInfoStashConfig(stash_final=True),
                                            actor_loss=LossInfoStashConfig(stash_final=True),
                                            critic_loss=LossInfoStashConfig(stash_final=True)),
            torch_device=device,
        )
        total_stopwatch.reset()
        with log_experiment(
            logger,
            experiment_tags=algo.collect_tags(),
            hyper_parameters=algo.collect_hyper_parameters(),
            setup=get_setup(),
        ) as x:
            logger.save_experiment_log()
            print('\nStarting Training\n\n')
            # import cProfile
            # pr = cProfile.Profile()
            # pr.enable()
            algo.learn(3_500_000)
            # pr.disable()  
            # pr.dump_stats('profile_stats.pstat')
except KeyboardInterrupt as ki:
    print('keyboard interrupt')
    raise ki
finally:
    print('closing envs')
    time.sleep(0.5)
    env.close()
    print('envs closed')
    

print('done')

using device cuda:0
count_parameters(policy) = 217870
env = <TimeLimit<OrderEnforcing<PassiveEnvChecker<HalfCheetahEnv<HalfCheetah-v4>>>>>, num_envs = 1
Grabbing system information... done!
saved experiment log 2024-10-13_21-42-05_037223~5N9PiQ at experiment_logs/HalfCheetah-v4/sac/2024-10-13_21-42-05_037223~5N9PiQ.json

Starting Training

step = 11000, num_env_steps = 11000, scores = -261.020, actor_loss = -17.568, entropy_coef_loss = -3.027, critic_loss = 0.894, entropy_coef = 0.741, action_stds = 0.908 ± 0.078, action_magnitude = 0.518 ± 0.285, num_gradient_steps = 1000, step_time = 17.40, total_time = 17.38

step = 12000, num_env_steps = 12000, scores = -252.412, actor_loss = -25.266, entropy_coef_loss = -5.946, critic_loss = 1.171, entropy_coef = 0.549, action_stds = 0.901 ± 0.066, action_magnitude = 0.540 ± 0.290, num_gradient_steps = 2000, step_time = 12.74, total_time = 30.12

step = 13000, num_env_steps = 13000, scores = -235.385, actor_loss = -29.939, entropy_coef_loss = -8.9