In [6]:
import inspect
import os
import time
from pathlib import Path

import gymnasium
from gymnasium import Env
from gymnasium.vector import VectorEnv

from sac import init_action_selector, init_policy, init_optimizer, wrap_env
from src.datetime import get_current_timestamp
from src.model_db.model_db import ModelDB
from src.model_db.dummy_model_db import DummyModelDB
from src.reinforcement_learning.algorithms.policy_mitosis.mitosis_policy_info import MitosisPolicyInfo
from src.model_db.tiny_model_db import TinyModelDB
from src.module_analysis import count_parameters, get_gradients_per_parameter
from src.moving_averages import ExponentialMovingAverage
from src.reinforcement_learning.core.action_selectors.action_selector import ActionSelector
from src.reinforcement_learning.core.action_selectors.predicted_std_action_selector import PredictedStdActionSelector
from src.reinforcement_learning.core.action_selectors.state_dependent_noise_action_selector import \
    StateDependentNoiseActionSelector
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_episode_returns, compute_returns
from src.reinforcement_learning.core.policies.base_policy import BasePolicy
from src.reinforcement_learning.core.policy_construction import InitActionSelectorFunction, PolicyConstruction
from src.reinforcement_learning.gym.envs.test_env import TestEnv
from src.schedulers import FixedValueScheduler, OneStepRecursiveScheduler
from src.stopwatch import Stopwatch
from src.summary_statistics import format_summary_statics
from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy
from typing import Any, SupportsFloat, Optional
from gymnasium.wrappers import RecordVideo, AutoResetWrapper, NormalizeReward, TransformReward, TransformObservation, ClipAction
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.algorithms.sac.sac import SAC, SAC_DEFAULT_OPTIMIZER_PROVIDER
from src.reinforcement_learning.algorithms.ppo.ppo import PPO, PPOLoggingConfig
from src.reinforcement_learning.core.normalization import NormalizationType
from src.torch_device import set_default_torch_device, optimizer_to_device
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
from torch.distributions import Normal, Categorical

import torch
from torch import optim, nn
import torch.distributions as dist
import gymnasium as gym
import numpy as np

from src.torch_functions import antisymmetric_power

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [7]:
from src.reinforcement_learning.core.loss_config import LossLoggingConfig
from src.reinforcement_learning.algorithms.sac.sac import SAC, SACLoggingConfig

policy_id: str
policy: BasePolicy
optimizer: optim.Optimizer
wrapped_env: Env
steps_trained: int
def get_policy(create_new_if_exists: bool):
    
    global policy_id, policy, optimizer, wrapped_env, steps_trained
    
    policy_in_ram = 'policy' in globals()
    if not policy_in_ram or create_new_if_exists:
        if not policy_in_ram:
            print('No policy in RAM, creating a new one')
        
        policy_id = get_current_timestamp()
        policy, optimizer, wrapped_env = PolicyConstruction.init_from_info(
            env=env,
            info=PolicyConstruction.create_policy_initialization_info(
                init_action_selector=init_action_selector,
                init_policy=init_policy,
                init_optimizer=init_optimizer,
                wrap_env=wrap_env,
            ),
        )
        steps_trained = 0
        print(f'New policy {policy_id} created')
    
    if parent_policy_id is not None:
        model_entry = policy_db.load_model_state_dict(policy, parent_policy_id)
        steps_trained = model_entry['model_info']['steps_trained']
        print(f'Loading state dict from policy {parent_policy_id}')
    
    print(f'Using policy {policy_id} with parent policy {parent_policy_id}')
    return policy_id, policy, optimizer, wrapped_env, steps_trained

score_mean_ema = ExponentialMovingAverage(alpha=0.25)
stopwatch = Stopwatch()
best_iteration_score = -1e6

def on_rollout_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    
    if step % 1000 != 0:
        return
    
    tail_indices = rl.buffer.tail_indices(1000)
    
    rewards = rl.buffer.rewards[tail_indices]
    # if 'raw_rewards' in info['rollout']:
    #     rewards = info['rollout']['raw_rewards']
    
    episode_scores = compute_episode_returns(
        rewards=rewards,
        episode_starts=np.repeat(np.arange(len(tail_indices)).reshape(-1, 1), num_envs, axis=1) % 1000 == 0,
        last_episode_starts=info['last_episode_starts'],
        gamma=1.0,
        gae_lambda=1.0,
        normalize_rewards=None,
        remove_unfinished_episodes=True,
    )
    
    global best_iteration_score
    iteration_score = episode_scores.mean()
    score_moving_average = score_mean_ema.update(iteration_score)
    if iteration_score >= best_iteration_score:
        best_iteration_score = iteration_score
        policy_db.save_model_state_dict(
            model_id=policy_id,
            parent_model_id=parent_policy_id,
            model_info={
                'score': iteration_score.item(),
                'steps_trained': steps_trained,
                'wrap_env_source_code': wrap_env_source_code_source,
                'init_policy_source_code': init_policy_source
            },
            model=policy,
            optimizer=optimizer,
        )
    
    info['episode_scores'] = episode_scores
    info['score_moving_average'] = score_moving_average
        
def on_optimization_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    global steps_trained
    steps_trained += rl.buffer.pos
    
    if step % 1000 != 0:
        return
    
    time_taken = stopwatch.reset()
    
    episode_scores = info['episode_scores']
    score_moving_average = info['score_moving_average']
    
    scores = format_summary_statics(
        episode_scores, 
        mean_format=' 6.3f',
        std_format='4.3f',
        min_value_format=' 6.3f',
        max_value_format='5.3f',
        n_format='>2'
    )
    # advantages = format_summary_statics(
    #     rl.buffer.advantages, 
    #     mean_format=' 6.3f',
    #     std_format='.1f',
    #     min_value_format=' 7.3f',
    #     max_value_format='6.3f',
    # )
    actor_loss = format_summary_statics(
        torch.abs(info['reduced_actor_loss']),  
        mean_format=' 5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    entropy_coef_loss = None if 'reduced_entropy_coef_loss' not in info else format_summary_statics(
        info['reduced_entropy_coef_loss'], 
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    critic_loss = format_summary_statics(
        info['reduced_critic_loss'], 
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    entropy_coef = format_summary_statics(
        info['entropy_coef'],
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    # resets = format_summary_statics(
    #     rl.buffer.dones.astype(int).sum(axis=0), 
    #     mean_format='.2f',
    #     std_format=None,
    #     min_value_format='1d',
    #     max_value_format=None,
    # )
    # kl_div = info['actor_kl_divergence'][-1]
    # grad_norm = format_summary_statics(
    #     info['grad_norm'], 
    #     mean_format=' 6.3f',
    #     std_format='.1f',
    #     min_value_format=' 7.3f',
    #     max_value_format='6.3f',
    # )
    action_stds = info['rollout'].get('action_stds')
    if action_stds is not None:
        rollout_action_stds = format_summary_statics(
            action_stds,
            mean_format='5.3f',
            std_format='5.3f',
            min_value_format=None,
            max_value_format=None,
        )
    else:
        rollout_action_stds = 'N/A'
    # ppo_epochs = info['nr_ppo_epochs']
    # ppo_updates = info['nr_ppo_updates']
    # expl_var = rl.buffer.compute_critic_explained_variance()
    print(f"{step = : >7}, "
          f"{scores = :s}, "
          f'score_ema = {score_moving_average: 6.3f}, '
          # f"{advantages = :s}, "
          f"{actor_loss = :s}, "
          +(f"{entropy_coef_loss = :s}, " if entropy_coef_loss is not None else '')+
          f"{critic_loss = :s}, "
          f"{entropy_coef = :s}, "
          f"rollout_stds = {rollout_action_stds:s}, "
          # f"{expl_var = :.3f}, "
          # f"{kl_div = :.4f}, "
          # f"{ppo_epochs = }, "
          # f"{ppo_updates = }, "
          # f"{grad_norm = :s}, "
          f"n_updates = {rl.gradient_steps_performed}, "
          # f"{resets = :s}, "
          f"time = {time_taken:4.1f} \n"
          )
    print()

device = torch.device("cuda:0") if True else torch.device('cpu')
print(f'using device {device}')

def create_env(render_mode: str | None):
    return gym.make(env_name, render_mode=render_mode, **env_kwargs)

wrap_env_source_code_source = inspect.getsource(wrap_env)
init_policy_source = inspect.getsource(init_policy)

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.001 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.05 }
num_envs = 16
    
# policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'saved_models/rl/{env_name}')
policy_db = DummyModelDB[MitosisPolicyInfo]()
print(f'{policy_db = }')

parent_policy_id=None  # '2024-04-28_20.57.23'

env = parallelize_env_async(lambda: create_env(render_mode=None), num_envs)


try:
    policy_id, policy, optimizer, wrapped_env, steps_trained = get_policy(create_new_if_exists=True)
    print(f'{count_parameters(policy) = }')
    print(f'{env = }, {num_envs = } \n\n')
        
    with (torch.autograd.set_detect_anomaly(False)):
        algo = SAC(
            env=wrapped_env,
            policy=policy,
            buffer_size=50_000,
            gamma=0.99,
            tau=0.005,
            entropy_coef_optimizer_provider=SAC_DEFAULT_OPTIMIZER_PROVIDER,
            entropy_coef=0.1,
            rollout_steps=1,
            warmup_steps=50,
            learning_starts=50,
            optimization_batch_size=256,
            gradient_steps=1,
            target_update_interval=1,
            # sde_noise_sample_freq=50,
            callback=Callback(
                on_rollout_done=on_rollout_done,
                rollout_schedulers={},
                on_optimization_done=on_optimization_done,
                optimization_schedulers={},
            ),
            logging_config=SACLoggingConfig(log_rollout_infos=True, log_rollout_action_stds=True,
                                            log_last_obs=True, log_entropy_coef=True,
                                            entropy_coef_loss=LossLoggingConfig(log_reduced=True),
                                            actor_loss=LossLoggingConfig(log_reduced=True),
                                            critic_loss=LossLoggingConfig(log_reduced=True)),
            torch_device=device,
        )
        # import cProfile
        # pr = cProfile.Profile()
        # pr.enable()
        algo.learn(1_000_000)
        # pr.disable()
        # pr.dump_stats('profile_stats.pstat')
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(0.5)
    env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

using device cuda:0
policy_db = DummyModelDB()
New policy 2024-09-20_19.16.12 created
Using policy 2024-09-20_19.16.12 with parent policy None
count_parameters(policy) = 217870
env = AsyncVectorEnv(16), num_envs = 16 

step =    1000, scores =  19.550 ± 31.158 [-65.371, 78.159] (n=16), score_ema =  19.550, actor_loss =  1.493 ±   nan, entropy_coef_loss = -257.626 ±   nan, critic_loss = 0.331 ±   nan, entropy_coef = 0.075 ±   nan, rollout_stds = 0.870 ± 0.039, n_updates = 951, time = 12.0 

step =    2000, scores =  0.268 ± 27.246 [-42.615, 48.489] (n=16), score_ema =  14.730, actor_loss =  2.559 ±   nan, entropy_coef_loss = -282.480 ±   nan, critic_loss = 0.605 ±   nan, entropy_coef = 0.056 ±   nan, rollout_stds = 0.804 ± 0.108, n_updates = 1951, time = 10.0 

step =    3000, scores =  29.252 ± 29.438 [-36.147, 67.347] (n=16), score_ema =  18.360, actor_loss =  3.090 ±   nan, entropy_coef_loss = -312.586 ±   nan, critic_loss = 0.388 ±   nan, entropy_coef = 0.041 ±   nan, rollout_stds =

2000

In [1]:
from src.reinforcement_learning.gym.singleton_vector_env import as_vec_env

record_env: gym.Env = create_env(render_mode='rgb_array')

# policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'saved_models/rl/{env_name}')
# policy_db.load_model_state_dict(policy, model_id='2024-05-24_16.15.39')

try:
    if 'render_fps' not in record_env.metadata:
        record_env.metadata['render_fps'] = 30
    record_env = AutoResetWrapper(
        RecordVideo(record_env, video_folder=rf'C:\Users\domin\Videos\rl\{get_current_timestamp()}', episode_trigger=lambda ep_nr: True)
    )
    record_env = wrap_env(as_vec_env(record_env)[0], {})
    
    policy.reset_sde_noise(1)
    
    def record(max_steps: int):
        with torch.no_grad():
            obs, info = record_env.reset()
            for step in range(max_steps):
                actions_dist, _ = policy.process_obs(torch.tensor(obs, device=device))
                actions = actions_dist.sample().detach().cpu().numpy()
                obs, reward, terminated, truncated, info = record_env.step(actions)
    
    record(5_000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing record_env')
    record_env.close()
    print('record_env closed')