In [1]:
import inspect
import os
import time
from pathlib import Path

from gymnasium import Env

from src.datetime import get_current_timestamp
from src.model_db import ModelDB, DummyModelDB
from src.module_analysis import count_parameters, get_gradients_per_parameter
from src.moving_averages import ExponentialMovingAverage
from src.networks.core.tensor_shape import TensorShape
from src.networks.core.torch_wrappers.torch_net import TorchNet
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_gae_and_returns
from src.reinforcement_learning.core.objectives import ObjectiveLoggingConfig
from src.reinforcement_learning.core.policies.base_policy import BasePolicy
from src.reinforcement_learning.gym.envs.normalize_reward_wrapper import NormalizeRewardWrapper
from src.stopwatch import Stopwatch
from src.summary_statistics import format_summary_statics
from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy
from typing import Any, SupportsFloat, Optional
from gymnasium.wrappers import RecordVideo, AutoResetWrapper, NormalizeReward, TransformReward, TransformObservation, \
    RescaleAction, ClipAction
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.a2c.a2c import A2C
from src.reinforcement_learning.ppo.ppo import PPO, PPOLoggingConfig
from src.reinforcement_learning.core.normalization import NormalizationType
from src.reinforcement_learning.gym.envs.step_skip_wrapper import StepSkipWrapper
from src.reinforcement_learning.core.rl_base import RLBase
from src.torch_device import set_default_torch_device
from src.reinforcement_learning.gym.envs.parallelize_env import parallelize_env_async
from torch.distributions import Normal, Categorical

import torch
from torch import optim, nn
import torch.distributions as dist
from src.networks.core.seq_net import SeqNet
import gymnasium as gym
import numpy as np

%load_ext autoreload
%autoreload 2

In [4]:
policy_id: str
policy: Optional[BasePolicy]
def get_policy(create_new_if_exists: bool):
    
    global policy_id, policy
    
    policy_in_ram = 'policy' in globals()
    if not policy_in_ram or create_new_if_exists:
        if not policy_in_ram:
            print('No policy in RAM, creating a new one')
        
        policy_id = get_current_timestamp()
        policy = init_policy()
        print(f'New policy {policy_id} created')
    
    if parent_policy_id is not None:
        policy_db.load_model_state_dict(policy, parent_policy_id)
        print(f'Loading state dict from policy {parent_policy_id}')
    
    return policy_id, policy

def init_policy():
    class A2CNetwork(nn.Module):

        def __init__(self):
            super().__init__()
            
            in_size = 376
            shared_out_sizes = []
            actor_out_sizes = [512, 512, 256, 256, 256, 256, 256, 256, 17]
            critic_out_sizes = [512, 512, 256, 256, 256, 1]
            
            hidden_activation_function = nn.ELU()
            actor_out_activation_function = nn.Tanh()
            critic_out_activation_function = nn.Identity()
            
            self.has_shared = len(shared_out_sizes) > 0
            
            if self.has_shared:
                self.shared = SeqNet.from_layer_provider(
                    layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                        nn.Linear(in_features, out_features),
                        hidden_activation_function
                    ),
                    in_size=in_size,
                    out_sizes=shared_out_sizes
                )
            else:
                self.shared = TorchNet(nn.Identity(), in_shape=TensorShape(features=in_size), out_shape=TensorShape(features=in_size))

            self.actor = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    actor_out_activation_function if is_last_layer else hidden_activation_function
                ),
                in_size=self.shared.out_shape.get_definite_features(),
                out_sizes=actor_out_sizes
            )

            self.critic = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    critic_out_activation_function if is_last_layer else hidden_activation_function
                ),
                in_size=self.shared.out_shape.get_definite_features(),
                out_sizes=critic_out_sizes
            )

        def forward(self, x: torch.Tensor):
            if self.has_shared:
                shared_out = self.shared(x)
            else:
                shared_out = x

            return self.actor(shared_out), self.critic(shared_out)

    return ActorCriticPolicy(A2CNetwork(), lambda action_logits: dist.Normal(loc=action_logits, scale=policy_action_std))

score_mean_ema = ExponentialMovingAverage(alpha=0.1)
stopwatch = Stopwatch()
best_iteration_score = -1e6

def on_rollout_done(rl: PPO, step: int, info: dict[str, Any]):    
    if 'unnormalized_rewards' in info['rollout']:
        unnormalized_rewards = info['rollout']['unnormalized_rewards']
        _, gamma_1_returns = compute_gae_and_returns(
            value_estimates=np.zeros_like(rl.buffer.rewards[:len(unnormalized_rewards)]),
            rewards=unnormalized_rewards,
            episode_starts=rl.buffer.episode_starts[:len(unnormalized_rewards)],
            last_values=np.zeros_like(rl.buffer.rewards[0], dtype=float),
            last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
            gamma=1.0,
            gae_lambda=1.0,
            normalize_rewards=None,
            normalize_advantages=None,
        )
    else:
        _, gamma_1_returns = rl.buffer.compute_gae_and_returns(
            last_values=torch.zeros_like(rl.buffer.value_estimates[0]),
            last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
            gamma=1.0,
            gae_lambda=1.0,
            normalize_advantages=None,
            normalize_rewards=None,
        )
    
    episode_scores = gamma_1_returns[
        rl.buffer.episode_starts[:rl.buffer.pos]
    ]
    
    iteration_score = episode_scores.mean()
    score_moving_average = score_mean_ema.update(iteration_score)
    global best_iteration_score
    if iteration_score >= best_iteration_score:
        best_iteration_score = iteration_score
        policy_db.save_model_state_dict(
            model=policy,
            model_id=policy_id,
            parent_model_id=parent_policy_id,
            model_info={'score': iteration_score.item()},
            init_function=init_policy
        )
        
    info['episode_scores'] = episode_scores
    info['score_moving_average'] = score_moving_average

def on_optimization_done(rl: PPO, step: int, info: dict[str, Any]):
    time_taken = stopwatch.reset()
    
    episode_scores = info['episode_scores']
    score_moving_average = info['score_moving_average']
    
    scores = format_summary_statics(
        episode_scores, 
        mean_format=' 6.1f',
        std_format='4.1f',
        min_value_format=' 6.1f',
        max_value_format='5.1f',
    )
    advantages = format_summary_statics(
        info['advantages'], 
        mean_format=' 6.3f',
        std_format='.1f',
        min_value_format=' 7.3f',
        max_value_format='6.3f',
    )
    abs_actor_obj = format_summary_statics(
        rl.weigh_actor_objective(info['raw_actor_objective']),  
        mean_format=' 5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    critic_obj = format_summary_statics(
        info['weighted_critic_objective'], 
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    resets = format_summary_statics(
        rl.buffer.episode_starts.astype(int).sum(axis=0), 
        mean_format='.2f',
        std_format=None,
        min_value_format='1d',
        max_value_format=None,
    )
    kl_div = format_summary_statics(
        info['actor_kl_divergence'], 
        mean_format='6.4f',
        std_format='7.5f',
        min_value_format=None,
        max_value_format='6.4f',
    )
    ppo_epochs = info['nr_ppo_epochs']
    expl_var = rl.buffer.compute_critic_explained_variance(info['returns'])
    print(f"{step = : >7}, "
          f"{scores = :s}, "
          f'score_ema = {score_moving_average: 6.1f}, '
          f"{advantages = :s}, "
          f"{abs_actor_obj = :s}, "
          f"{critic_obj = :s}, "
          f"{expl_var = :.3f}, "
          f"{resets = :s}, "
          f"{kl_div = :s}, "
          f"{ppo_epochs = }, "
          f"time = {time_taken:4.1f}")
    
    # for param_name, param_grad in get_gradients_per_parameter(rl.policy, param_type='weight'):
    #     print(f'{param_name + ".grad":<50}: ' + format_summary_statics(
    #         param_grad,
    #         mean_format=' 8.5f',
    #         std_format='.5f',
    #         min_value_format=' 8.5f',
    #         max_value_format='7.5f',
    #     ))
    # 
    # print('\n')

device = set_default_torch_device("cuda:0") if True else set_default_torch_device('cpu')
print(f'using device {device}')

def create_env(render_mode: str | None):
    return gym.make(env_name, render_mode=render_mode, **env_kwargs)

def wrap_env(_env: Env):
    # _env = NormalizeRewardWrapper(_env, gamma=gamma)
    # _env = TransformObservation(_env, lambda _obs: _obs / 255)
    # _env = TransformReward(_env, lambda _reward: 0.01 * _reward) 
    _env = RescaleAction(_env, min_action=-1.0, max_action=1.0)
    _env = ClipAction(_env)
    return _env

env_name = 'Humanoid-v4'
env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
num_envs = 16
    
# policy_db = ModelTinyDB(base_path=f'saved_models/rl/{env_name}')
policy_db = DummyModelDB()

parent_policy_id=None
policy_action_std=0.15

policy_id, policy = get_policy(create_new_if_exists=True)
print(f'{count_parameters(policy) = }')

gamma = 0.995

env = parallelize_env_async(lambda: create_env(render_mode=None), num_envs)
try:
    env = wrap_env(env)
    print(f'{env = }, {num_envs = }')
    
    PPO(
        env=env,
        policy=policy.to(device),
        policy_optimizer=lambda pol: optim.Adam(pol.parameters(), lr=1e-5),
        buffer_size=2500,
        gamma=gamma,
        gae_lambda=1.0,
        normalize_rewards=None,
        normalize_advantages=NormalizationType.Std,
        weigh_actor_objective=lambda obj: 1.0 * obj,
        weigh_critic_objective=lambda obj: 0.5 * obj,
        ppo_max_epochs=10,
        ppo_kl_target=0.01,
        ppo_batch_size=500,
        action_ratio_clip_range=0.02,
        callback=Callback(on_rollout_done=on_rollout_done, on_optimization_done=on_optimization_done),
        logging_config=PPOLoggingConfig(log_returns=True, log_advantages=True, 
                                        log_rollout_infos=True, log_actor_kl_divergence=True,
                                        actor_objective=ObjectiveLoggingConfig(log_raw=True), 
                                        critic_objective=ObjectiveLoggingConfig(log_weighted=True), )
    ).train(5_000_000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(2.5)
    env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')

print('done')

using device cuda:0
New policy 2024-04-29_22.58.42 created
count_parameters(policy) = 1639186
env = <ClipAction<RescaleAction<AsyncVectorEnv instance>>>, num_envs = 16
step =    2500, scores =   16.4 ±  2.3 [   1.4,  25.4], score_ema =   16.4, advantages =  2.056 ± 1.0 [ -0.122,  5.132], abs_actor_obj = -2.098 ± 1.020, critic_obj = 49.084 ± 2.442, expl_var = -0.009, resets = 103.06 ≥ 99, kl_div = 0.0011 ± 0.00108 ≤ 0.0036, ppo_epochs = 10, time = 13.9
step =    5000, scores =   16.2 ±  2.3 [   1.0,  24.8], score_ema =   16.4, advantages =  1.718 ± 1.0 [ -1.165,  4.614], abs_actor_obj = -1.728 ± 1.088, critic_obj = 41.879 ± 1.317, expl_var = -0.140, resets = 104.94 ≥ 102, kl_div = 0.0051 ± 0.00484 ≤ 0.0157, ppo_epochs = 6, time =  9.4
step =    7500, scores =   16.4 ±  2.4 [   0.9,  26.8], score_ema =   16.4, advantages =  1.452 ± 1.0 [ -1.107,  4.478], abs_actor_obj = -1.322 ± 1.487, critic_obj = 39.422 ± 0.934, expl_var = -0.307, resets = 102.75 ≥ 99, kl_div = 0.0053 ± 0.00501 ≤ 0.016

In [3]:
record_env: gym.Env = create_env(render_mode='rgb_array')
try:
    if 'render_fps' not in record_env.metadata:
        record_env.metadata['render_fps'] = 30
    record_env = wrap_env(record_env)
    record_env = AutoResetWrapper(
        RecordVideo(record_env, video_folder=rf'C:\Users\domin\Videos\rl\{get_current_timestamp()}', episode_trigger=lambda ep_nr: True)
    )
    def record(max_steps: int):
        obs, info = record_env.reset()
        for step in range(max_steps):
            actions_dist = policy.predict_actions(obs)
            actions = actions_dist.sample().detach().cpu().numpy()
            obs, reward, terminated, truncated, info = record_env.step(actions)
    
    record(10000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing record_env')
    record_env.close()
    print('record_env closed')

Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-0.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-0.mp4


                                                             

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-0.mp4
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-1.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-1.mp4


                                                             

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-1.mp4
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-2.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-2.mp4


                                                             

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-2.mp4
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-3.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-3.mp4


                                                             

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-3.mp4
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-4.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-4.mp4


                                                             

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-4.mp4
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-5.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-5.mp4


                                                             

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-5.mp4
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-6.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-6.mp4


                                                             

keyboard interrupt
closing record_env
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-6.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-6.mp4


                                                             

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-28_14.27.40\rl-video-episode-6.mp4
record_env closed




In [8]:
torch.save(policy.state_dict(), f'saved_models/rl/{env_name}/{get_current_timestamp()}---6x96_6x64-elu--state_dict.pth')

In [30]:
from stable_baselines3 import A2C
from stable_baselines3.common.env_util import make_vec_env
import gymnasium as gym

# Parallel environments
vec_env = make_vec_env(lambda: gym.make('CartPole-v1', render_mode='rgb_array'), n_envs=4)

model = A2C("MlpPolicy", vec_env, verbose=2)
model.learn(total_timesteps=25000)
model.save("a2c_cartpole")

del model # remove to demonstrate saving and loading

model = A2C.load("a2c_cartpole")

obs = vec_env.reset()
for _ in range(100_000):
    action, _states = model.predict(obs)
    obs, rewards, dones, info = vec_env.step(action)
    vec_env.render("human")

In [11]:
import gymnasium
record_env = gymnasium.make("ALE/Pacman-ram-v5", render_mode='rgb_array')
try:
    record_env = AutoResetWrapper(
        RecordVideo(record_env, video_folder=r'C:\Users\domin\Videos\rl\2024-04-24.1', episode_trigger=lambda ep_nr: True)
    )
    def record(max_steps: int):
        obs, info = record_env.reset()
        for step in range(max_steps):
            # actions_dist = policy.predict_actions(obs)
            # actions = actions_dist.sample().detach().cpu().numpy()
            
            actions = 2
            
            obs, reward, terminated, truncated, info = record_env.step(actions)
            
            if terminated or truncated:
                break
    
    record(10000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing record_env')
    record_env.close()
    print('record_env closed')

Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-24.1\rl-video-episode-0.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-24.1\rl-video-episode-0.mp4


                                                                

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-24.1\rl-video-episode-0.mp4
closing record_env
Moviepy - Building video C:\Users\domin\Videos\rl\2024-04-24.1\rl-video-episode-1.mp4.
Moviepy - Writing video C:\Users\domin\Videos\rl\2024-04-24.1\rl-video-episode-1.mp4


                                                  

Moviepy - Done !
Moviepy - video ready C:\Users\domin\Videos\rl\2024-04-24.1\rl-video-episode-1.mp4
record_env closed




<OrderEnforcing<PassiveEnvChecker<AtariEnv<ALE/Pacman-ram-v5>>>>
