In [1]:
import inspect
import os
import time
from pathlib import Path

import wandb
from gymnasium import Env
from gymnasium.vector import VectorEnv

from src.datetime import get_current_timestamp
from src.model_db.model_db import ModelDB
from src.model_db.dummy_model_db import DummyModelDB
from src.reinforcement_learning.algorithms.policy_mitosis.mitosis_policy_info import MitosisPolicyInfo
from src.model_db.tiny_model_db import TinyModelDB
from src.module_analysis import count_parameters, get_gradients_per_parameter
from src.moving_averages import ExponentialMovingAverage
from src.networks.core.tensor_shape import TensorShape
from src.networks.core.torch_wrappers.torch_net import TorchNet
from src.reinforcement_learning.core.action_selectors.predicted_std_action_selector import PredictedStdActionSelector
from src.reinforcement_learning.core.action_selectors.state_dependent_noise_action_selector import \
    StateDependentNoiseActionSelector
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_gae_and_returns
from src.reinforcement_learning.core.objectives import ObjectiveLoggingConfig
from src.reinforcement_learning.core.policies.base_policy import BasePolicy
from src.reinforcement_learning.gym.normalize_reward_wrapper import NormalizeRewardWrapper
from src.networks.core.seq_net import SeqNet
from src.schedulers import FixedValueScheduler, OneStepRecursiveScheduler
from src.stopwatch import Stopwatch
from src.summary_statistics import format_summary_statics
from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy
from typing import Any, SupportsFloat, Optional
from gymnasium.wrappers import RecordVideo, AutoResetWrapper, NormalizeReward, TransformReward, TransformObservation, ClipAction
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.algorithms.a2c.a2c import A2C
from src.reinforcement_learning.algorithms.ppo.ppo import PPO, PPOLoggingConfig
from src.reinforcement_learning.core.normalization import NormalizationType
from src.reinforcement_learning.gym.step_skip_wrapper import StepSkipWrapper
from src.reinforcement_learning.algorithms import policy_optimization_base
from src.torch_device import set_default_torch_device
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
from torch.distributions import Normal, Categorical

import torch
from torch import optim, nn
import torch.distributions as dist
import gymnasium as gym
import numpy as np

from src.torch_functions import antisymmetric_power
from src.weight_initialization import orthogonal_initialization

%load_ext autoreload
%autoreload 2

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:
policy_id: str
policy: Optional[BasePolicy]
steps_trained: int
def get_policy(create_new_if_exists: bool):
    
    global policy_id, policy, steps_trained
    
    policy_in_ram = 'policy' in globals()
    if not policy_in_ram or create_new_if_exists:
        if not policy_in_ram:
            print('No policy in RAM, creating a new one')
        
        policy_id = get_current_timestamp()
        policy = init_policy()
        steps_trained = 0
        print(f'New policy {policy_id} created')
    
    if parent_policy_id is not None:
        model_entry = policy_db.load_model_state_dict(policy, parent_policy_id)
        steps_trained = model_entry['model_info']['steps_trained']
        print(f'Loading state dict from policy {parent_policy_id}')
    
    print(f'Using policy {policy_id} with parent policy {parent_policy_id}')
    return policy_id, policy

def init_policy():
    import torch
    from torch import nn

    from src.networks.core.seq_net import SeqNet
    from src.reinforcement_learning.core.action_selectors.squashed_diag_gaussian_action_selector import \
        SquashedDiagGaussianActionSelector
    from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy
    from src.networks.skip_nets.additive_skip_net import AdditiveSkipNet, FullyConnectedAdditiveSkipNet, \
        FullyConnectedUnweightedAdditiveSkipNet

    # in_size = 376
    # action_size = 17
    # actor_out_sizes = [512, 512, 256, 256, 256, 256, 256, 256]
    # critic_out_sizes = [512, 512, 256, 256, 256, 1]
    
    in_size = 17
    action_size = 6
    
    actor_layers = 3
    actor_features = 96
    
    critic_layers = 2
    critic_features = 96

    # hidden_activation_function = nn.ELU()
    actor_hidden_activation_function = nn.Tanh()
    critic_hidden_activation_function = nn.ReLU()

    class A2CNetwork(nn.Module):

        def __init__(self):
            super().__init__()

            self.actor_embedding = nn.Sequential(nn.Linear(in_size, actor_features), actor_hidden_activation_function)
            self.actor = FullyConnectedUnweightedAdditiveSkipNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    orthogonal_initialization(nn.Linear(in_features, out_features), gain=np.sqrt(2)),
                    nn.Tanh() if is_last_layer else actor_hidden_activation_function,
                    orthogonal_initialization(nn.Linear(in_features, out_features), gain=np.sqrt(2)),
                    nn.Tanh() if is_last_layer else actor_hidden_activation_function,
                ),
                num_layers=actor_layers,
                num_features=actor_features,
                # weights_trainable=True,
                # initial_skip_connection_weight=0.02,
            )

            self.critic_embedding = nn.Sequential(nn.Linear(in_size, critic_features), critic_hidden_activation_function)
            self.critic = FullyConnectedUnweightedAdditiveSkipNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    orthogonal_initialization(nn.Linear(in_features, out_features), gain=np.sqrt(2)),
                    critic_hidden_activation_function,
                    orthogonal_initialization(nn.Linear(in_features, out_features), gain=np.sqrt(2)),
                    critic_hidden_activation_function,
                ),
                num_layers=critic_layers,
                num_features=critic_features,
#                 weights_trainable=True,
#                 initial_skip_connection_weight=0.02,
            )
            self.critic_regressor = nn.Linear(critic_features, 1)

        def forward(self, x: torch.Tensor):
            actor_out = self.actor(self.actor_embedding(x))
            critic_out = self.critic_regressor(self.critic(self.critic_embedding(x)))
            return actor_out, critic_out

    # return ActorCriticPolicy(A2CNetwork(), SquashedDiagGaussianActionSelector(
    #         latent_dim=actor_out_sizes[-1],
    #         action_dim=action_size,
    #         std=0.1,
    #         std_learnable=False,
    #         action_net_initialization=lambda module: orthogonal_initialization(module, gain=0.01),
    #     ))
    # return ActorCriticPolicy(A2CNetwork(), StateDependentNoiseActionSelector(
    #     latent_dim=actor_out_sizes[-1],
    #     action_dim=action_size,
    #     initial_std=0.03,
    #     squash_output=True,
    #     use_full_stds=False,
    #     learn_sde_features=False,
    #     action_net_initialization=lambda module: orthogonal_initialization(module, gain=0.01),
    # ))
    return ActorCriticPolicy(A2CNetwork(), PredictedStdActionSelector(
        latent_dim=actor_features,
        action_dim=action_size,
        base_std=0.2,
        squash_output=True,
        action_net_initialization=lambda module: orthogonal_initialization(module, gain=0.01),
        log_std_net_initialization=lambda module: orthogonal_initialization(module, gain=0.1),
    ))
score_mean_ema = ExponentialMovingAverage(alpha=0.25)
stopwatch = Stopwatch()
best_iteration_score = -1e6

def on_rollout_done(rl: PPO, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    
    if 'raw_rewards' in info['rollout']:
        raw_rewards = info['rollout']['raw_rewards']
        _, gamma_1_returns = compute_gae_and_returns(
            value_estimates=np.zeros_like(rl.buffer.rewards[:len(raw_rewards)]),
            rewards=raw_rewards,
            episode_starts=rl.buffer.episode_starts[:len(raw_rewards)],
            last_values=np.zeros_like(rl.buffer.rewards[0], dtype=float),
            last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
            gamma=1.0,
            gae_lambda=1.0,
            normalize_rewards=None,
            normalize_advantages=None,
        )
    else:
        _, gamma_1_returns = rl.buffer.compute_gae_and_returns(
            last_values=torch.zeros_like(rl.buffer.value_estimates[0]),
            last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
            gamma=1.0,
            gae_lambda=1.0,
            normalize_advantages=None,
            normalize_rewards=None,
        )
    
    episode_scores = gamma_1_returns[
        rl.buffer.episode_starts[:rl.buffer.pos]
    ]
    
    global best_iteration_score
    iteration_score = episode_scores.mean()
    score_moving_average = score_mean_ema.update(iteration_score)
    if iteration_score >= best_iteration_score:
        best_iteration_score = iteration_score
        policy_db.save_model_state_dict(
            model=policy,
            model_id=policy_id,
            parent_model_id=parent_policy_id,
            model_info={
                'score': iteration_score.item(),
                'steps_trained': steps_trained,
                'wrap_env_source_code': wrap_env_source_code_source,
                'init_policy_source_code': init_policy_source
            },
        )
        
    info['episode_scores'] = episode_scores
    info['score_moving_average'] = score_moving_average

def on_optimization_done(rl: PPO, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    time_taken = stopwatch.reset()
    
    global steps_trained
    steps_trained += rl.buffer.pos
    
    episode_scores = info['episode_scores']
    score_moving_average = info['score_moving_average']
    
    scores = format_summary_statics(
        episode_scores, 
        mean_format=' 6.3f',
        std_format='4.3f',
        min_value_format=' 6.3f',
        max_value_format='5.3f',
    )
    advantages = format_summary_statics(
        info['advantages'], 
        mean_format=' 6.3f',
        std_format='.1f',
        min_value_format=' 7.3f',
        max_value_format='6.3f',
    )
    abs_actor_obj = format_summary_statics(
        rl.weigh_actor_objective(torch.abs(info['raw_actor_objective'])),  
        mean_format=' 5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    entropy_obj = None if info['weighted_entropy_objective'] is None else format_summary_statics(
        info['weighted_entropy_objective'], 
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    critic_obj = format_summary_statics(
        info['weighted_critic_objective'], 
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    resets = format_summary_statics(
        rl.buffer.episode_starts.astype(int).sum(axis=0), 
        mean_format='.2f',
        std_format=None,
        min_value_format='1d',
        max_value_format=None,
    )
    kl_div = info['actor_kl_divergence'][-1]
    grad_norm = format_summary_statics(
        info['grad_norm'], 
        mean_format=' 6.3f',
        std_format='.1f',
        min_value_format=' 7.3f',
        max_value_format='6.3f',
    )
    rollout_action_stds = format_summary_statics(
        info['rollout']['action_stds'],
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    ppo_epochs = info['nr_ppo_epochs']
    ppo_updates = info['nr_ppo_updates']
    expl_var = rl.buffer.compute_critic_explained_variance(info['returns'])
    print(f"{step = : >7}, "
          f"{scores = :s}, "
          f'score_ema = {score_moving_average: 6.3f}, '
          f"{advantages = :s}, "
          f"{abs_actor_obj = :s}, "
          +(f"{entropy_obj = :s}, " if entropy_obj is not None else '')+
          f"rollout_stds = {rollout_action_stds:s}, "
          f"{critic_obj = :s}, "
          f"{expl_var = :.3f}, "
          f"{kl_div = :.4f}, "
          f"{ppo_epochs = }, "
          f"{ppo_updates = }, "
          f"{grad_norm = :s}, "
          f"{resets = :s}, "
          f"time = {time_taken:4.1f} \n")
    print()
    if not wandb_run.disabled:
        wandb_run.log({
            'scores': episode_scores,
            'advantages': wandb.Histogram(info['advantages']),
            'actor_obj': wandb.Histogram(rl.weigh_actor_objective(info['raw_actor_objective'])),
            'abs_actor_obj': wandb.Histogram(rl.weigh_actor_objective(torch.abs(info['raw_actor_objective']))),
            'critic_obj': wandb.Histogram(info['weighted_critic_objective']),
            'expl_var': expl_var,
            'kl_div': kl_div,
            'ppo_epochs': ppo_epochs,
            'ppo_updates': ppo_updates,
            'grad_norm': wandb.Histogram(info['grad_norm']),
            'resets': wandb.Histogram(rl.buffer.episode_starts.astype(int).sum(axis=0)),
            'time_taken': time_taken,
        }, step=step)

device = torch.device("cuda:0") if True else torch.device('cpu')
print(f'using device {device}')

def create_env(render_mode: str | None):
    return gym.make(env_name, render_mode=render_mode, **env_kwargs)

def wrap_env(env_):
    from src.reinforcement_learning.gym.transform_reward_wrapper import TransformRewardWrapper
    from gymnasium.wrappers import RescaleAction
    
    env_ = TransformRewardWrapper(env_, lambda reward_: 0.01 * reward_)
    # _env = RescaleAction(_env, min_action=-1.0, max_action=1.0)
    return env_

wrap_env_source_code_source = inspect.getsource(wrap_env)
init_policy_source = inspect.getsource(init_policy)

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.001 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.05 }
num_envs = 32
    
# policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'saved_models/rl/{env_name}')
policy_db = DummyModelDB[MitosisPolicyInfo]()
print(f'{policy_db = }')

parent_policy_id=None  # '2024-04-28_20.57.23'
policy_action_std=0.15
policy_id, policy = get_policy(create_new_if_exists=True)
print(f'{count_parameters(policy) = }')

# wandb.init(project=f'rl-{env_name}', config={'policy_id': policy_id})
wandb_run = wandb.init(mode='disabled')

env = parallelize_env_async(lambda: create_env(render_mode=None), num_envs)
# print(dict(policy.named_parameters()))
try:
    env = wrap_env(env)
    print(f'{env = }, {num_envs = } \n\n')
    with torch.autograd.set_detect_anomaly(True):
        PPO(
            env=env,
            policy=policy.to(device),
            policy_optimizer=lambda pol: optim.AdamW(pol.parameters(), lr=5e-5),
            buffer_size=2500,
            gamma=0.99,
            gae_lambda=0.95,
            normalize_rewards=None,
            normalize_advantages=NormalizationType.Std,
            reduce_actor_objective=lambda obj: antisymmetric_power(obj, 1.5).mean(),
            weigh_actor_objective=lambda obj: 1.0 * obj,
            weigh_entropy_objective=lambda obj: 0.1 * obj.exp(),
            weigh_critic_objective=lambda obj: 0.5 * obj,
            ppo_max_epochs=10,
            ppo_kl_target=0.025,
            ppo_batch_size=500,
            action_ratio_clip_range=0.1,
            grad_norm_clip_value=0.5,
            sde_noise_sample_freq=50,
            callback=Callback(
                on_rollout_done=on_rollout_done,
                rollout_schedulers={},
                on_optimization_done=on_optimization_done,
                optimization_schedulers={},
            ),
            logging_config=PPOLoggingConfig(log_returns=True, log_advantages=True, log_grad_norm=True,
                                            log_rollout_infos=True, log_rollout_action_stds=True,
                                            log_actor_kl_divergence=True,
                                            actor_objective=ObjectiveLoggingConfig(log_raw=True),
                                            entropy_objective=ObjectiveLoggingConfig(log_weighted=True),
                                            critic_objective=ObjectiveLoggingConfig(log_weighted=True), ),
            torch_device=device,
        ).train(5_000_000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(2.5)
    env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

using device cuda:0
policy_db = DummyModelDB()
New policy 2024-05-26_16.49.46 created
Using policy 2024-05-26_16.49.46 with parent policy None
count_parameters(policy) = 97837
env = <TransformRewardWrapper<AsyncVectorEnv instance>>, num_envs = 32 


 SDE noise sample freq is set to 50 despite not using SDE 


step =    2500, scores = -16.125 ± 11.846 [-50.132, 8.898], score_ema = -16.125, advantages = -0.080 ± 1.0 [ -5.780,  4.894], abs_actor_obj =  2.895 ± 3.780, entropy_obj = 0.125 ± 0.001, rollout_stds = 0.201 ± 0.022, critic_obj = 0.016 ± 0.007, expl_var = -35.993, kl_div = 0.0148, ppo_epochs = 10, ppo_updates = 50, grad_norm =  4.348 ± 0.3 [  3.818,  5.119], resets = 2.00 ≥ 2, time = 18.3 

step =    5000, scores = -5.114 ± 12.573 [-41.423, 33.268], score_ema = -13.372, advantages = -0.118 ± 1.0 [ -5.852,  5.258], abs_actor_obj =  3.024 ± 3.789, entropy_obj = 0.125 ± 0.001, rollout_stds = 0.203 ± 0.031, critic_obj = 0.006 ± 0.001, expl_var = -25.058, kl_div = 0.0255, ppo_epochs = 

In [None]:
from src.reinforcement_learning.gym.envs.singleton_vector_env import as_vec_env

record_env: gym.Env = create_env(render_mode='rgb_array')

# policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'saved_models/rl/{env_name}')
# policy_db.load_model_state_dict(policy, model_id='2024-05-24_16.15.39')

try:
    if 'render_fps' not in record_env.metadata:
        record_env.metadata['render_fps'] = 30
    record_env = wrap_env(record_env)
    record_env = AutoResetWrapper(
        RecordVideo(record_env, video_folder=rf'C:\Users\domin\Videos\rl\{get_current_timestamp()}', episode_trigger=lambda ep_nr: True)
    )
    record_env, _ = as_vec_env(record_env)
    
    policy.reset_sde_noise(1)
    
    def record(max_steps: int):
        obs, info = record_env.reset()
        for step in range(max_steps):
            actions_dist, _ = policy.process_obs(torch.tensor(obs, device=device))
            actions = actions_dist.sample().detach().cpu().numpy()
            obs, reward, terminated, truncated, info = record_env.step(actions)
    
    record(5_000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing record_env')
    record_env.close()
    print('record_env closed')