In [2]:
import sys
import time
from typing import Any, Iterable

import numpy as np
import torch
from torch import optim

import gymnasium as gym

from src.datetime import get_current_timestamp
from src.model_db.tiny_model_db import TinyModelDB
from src.module_analysis import count_parameters
from src.moving_averages import ExponentialMovingAverage
from src.np_functions import softmax
from src.reinforcement_learning.algorithms.policy_mitosis.async_policy_mitosis import AsyncPolicyMitosis
from src.reinforcement_learning.algorithms.policy_mitosis.policy_mitosis import PolicyMitosis
from src.reinforcement_learning.algorithms.policy_mitosis.policy_mitosis_base import PolicyWithEnvAndInfo
from src.reinforcement_learning.algorithms.ppo.ppo import PPOLoggingConfig, PPO
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_gae_and_returns
from src.reinforcement_learning.core.normalization import NormalizationType
from src.reinforcement_learning.core.policy_info import PolicyInfo
from src.reinforcement_learning.gym.envs.parallelize_env import parallelize_env_async
from src.stopwatch import Stopwatch
from src.summary_statistics import format_summary_statics
from src.torch_device import set_default_torch_device, get_torch_device

%load_ext autoreload
%autoreload 2

In [None]:

def init_policy():
    import torch
    from torch import nn

    from src.networks.core.seq_net import SeqNet
    from src.reinforcement_learning.core.action_selectors.squashed_diag_gaussian_action_selector import \
        SquashedDiagGaussianActionSelector
    from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy

    in_size = 376
    action_size = 17
    actor_out_sizes = [512, 512, 256, 256, 256, 256, 256, 256]
    critic_out_sizes = [512, 512, 256, 256, 256, 1]

    hidden_activation_function = nn.ELU()

    class A2CNetwork(nn.Module):

        def __init__(self):
            super().__init__()


            self.actor = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    hidden_activation_function
                ),
                in_size=in_size,
                out_sizes=actor_out_sizes
            )

            self.critic = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    nn.Identity() if is_last_layer else hidden_activation_function
                ),
                in_size=in_size,
                out_sizes=critic_out_sizes
            )

        def forward(self, x: torch.Tensor):
            return self.actor(x), self.critic(x)

    return ActorCriticPolicy(A2CNetwork(), SquashedDiagGaussianActionSelector(
        latent_dim=actor_out_sizes[-1],
        action_dim=action_size,
        std=0.1,
        std_learnable=False,
    ))

def wrap_env(_env):
    from src.reinforcement_learning.gym.envs.transform_reward_wrapper import TransformRewardWrapper
    from gymnasium.wrappers import RescaleAction

    _env = TransformRewardWrapper(_env, lambda _reward: 0.01 * _reward)
    _env = RescaleAction(_env, min_action=-1.0, max_action=1.0)

    return _env

def train_func(policy_with_env_and_info: PolicyWithEnvAndInfo) -> tuple[int, float]:
    policy = policy_with_env_and_info['policy']
    env = policy_with_env_and_info['env']
    
    score = 0.0
    score_ema = ExponentialMovingAverage(0.45)
    rollout_stopwatch = Stopwatch()
    def on_rollout_done(rl: PPO, step: int, info: dict[str, Any]):   
        
        if 'raw_rewards' in info['rollout']:
            raw_rewards = info['rollout']['raw_rewards']
            _, gamma_1_returns = compute_gae_and_returns(
                value_estimates=np.zeros_like(rl.buffer.rewards[:len(raw_rewards)]),
                rewards=raw_rewards,
                episode_starts=rl.buffer.episode_starts[:len(raw_rewards)],
                last_values=np.zeros_like(rl.buffer.rewards[0], dtype=float),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_rewards=None,
                normalize_advantages=None,
            )
        else:
            _, gamma_1_returns = rl.buffer.compute_gae_and_returns(
                last_values=torch.zeros_like(rl.buffer.value_estimates[0]),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_advantages=None,
                normalize_rewards=None,
            )
        
        episode_scores = gamma_1_returns[
            rl.buffer.episode_starts[:rl.buffer.pos]
        ]
        
        nonlocal score, score_ema
        score = episode_scores.mean()
        current_score_ema = score_ema.update(score)
        
        rollout_time = rollout_stopwatch.reset()
        
        resets: np.ndarray = rl.buffer.episode_starts.astype(int).sum(axis=0)
        resets_mean = resets.mean()
        resets_min = resets.min()
        print(f'{step:>6}: '
              f'{score = :9.3f}, '
              f'score_ema = {current_score_ema:9.3f}, '
              f'time = {rollout_time:5.2f}, '
              f'resets = {resets_mean:5.2f} >= {resets_min:5.2f}')
    
    policy_info = policy_with_env_and_info['policy_info']
    policy_info_str = ('('
          f'policy_id = {policy_info["policy_id"]}, '
          f'parent_id = {policy_info["parent_policy_id"]}, '
          f'num_parameters = {count_parameters(policy)}, '
          f'previous_steps = {policy_info["steps_trained"]}, '
          f'previous_score = {policy_info["score"]:9.3f}'
          ')')
    
    print(f'Starting PPO with policy {policy_info_str:s} for {steps_per_iteration:_} steps')
    mitosis_iteration_stopwatch = Stopwatch()
    PPO(
        env=env,
        policy=policy.to(device),
        policy_optimizer=lambda pol: optim.Adam(pol.parameters(), lr=1e-5),
        buffer_size=2500,
        gamma=0.995,
        gae_lambda=1.0,
        normalize_rewards=None,
        normalize_advantages=NormalizationType.Std,
        weigh_actor_objective=lambda obj: 1.0 * obj,
        weigh_critic_objective=lambda obj: 0.5 * obj,
        ppo_max_epochs=10,
        ppo_kl_target=0.01,
        ppo_batch_size=500,
        action_ratio_clip_range=0.05,
        grad_norm_clip_value=2.0,
        callback=Callback(on_rollout_done=on_rollout_done),
        logging_config=PPOLoggingConfig(log_rollout_infos=True),
        torch_device=device,
    ).train(steps_per_iteration)
    
    
    print(f'Training finished for policy {policy_info_str:s}, end score = {score:9.3f}, time = {mitosis_iteration_stopwatch.time_passed():6.2f}')
    
    return steps_per_iteration, score_ema.get()

def select_policy_selection_probs(policy_infos: Iterable[PolicyInfo]) -> np.ndarray:
    scores = np.array([policy_info['score'] for policy_info in policy_infos])
    scores = scores / scores.std()
    scores = softmax(scores, temperature=0.9)
    return scores

device = get_torch_device("cuda:0") if True else get_torch_device('cpu')

policy_action_std = 0.15
steps_per_iteration = 100_000

env_name = 'Humanoid-v4'
env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
num_envs = 4

mitosis_id = get_current_timestamp()
# mitosis_id = '2024-05-03_19.31.53'
# policy_db = TinyModelDB[PolicyInfo](base_path=f'E:/saved_models/rl/{env_name}/mitosis-{mitosis_id}')
policy_db = TinyModelDB[PolicyInfo](base_path=f'C:/Users/domin/git/pytorch-starter/saved_models/rl/{env_name}/mitosis-{mitosis_id}')

unwrapped_env = parallelize_env_async(lambda: gym.make(env_name, **env_kwargs), num_envs)
try:
    print(f'Starting mitosis with id {mitosis_id}')
    AsyncPolicyMitosis(
        num_workers=2,
        policy_db=policy_db,
        train_policy_function=train_func,
        create_env=lambda: parallelize_env_async(lambda: gym.make(env_name, **env_kwargs), num_envs),
        new_init_policy_function=init_policy,
        new_wrap_env_function=wrap_env,
        select_policy_selection_probs=select_policy_selection_probs,
        min_base_ancestors=5,
        rng_seed=None,
    ).train_with_mitosis(1000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(2.5)
    unwrapped_env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

Starting mitosis with id 2024-05-16_18.30.57
Started training iteration for policy: 2024-05-16_18.30.58~UeQRBr, parent policy id: None
Started training iteration for policy: 2024-05-16_18.30.58~AWrUA7, parent policy id: None
Finished training iteration for policy: 2024-05-16_18.30.58~UeQRBr
Started training iteration for policy: 2024-05-16_18.31.06~htveIW, parent policy id: None
Finished training iteration for policy: 2024-05-16_18.30.58~AWrUA7
Started training iteration for policy: 2024-05-16_18.31.06~q6nQTS, parent policy id: None
Finished training iteration for policy: 2024-05-16_18.31.06~htveIW
Started training iteration for policy: 2024-05-16_18.31.10~EYR5xp, parent policy id: None
Finished training iteration for policy: 2024-05-16_18.31.06~q6nQTS
Started training iteration for policy: 2024-05-16_18.31.10~iSbh4i, parent policy id: None
Finished training iteration for policy: 2024-05-16_18.31.10~EYR5xp
policy selection probs = 
	2024-05-16_18.30.58~UeQRBr: p = 0.625930, scores =  3