In [5]:

import time
from typing import Any, Iterable

import gymnasium as gym
import numpy as np
import torch
import torch.distributions as dist
from gymnasium.vector import VectorEnv
from gymnasium.wrappers import RescaleAction, ClipAction
from torch import optim, nn

from src.datetime import get_current_timestamp
from src.model_db.tiny_model_db import TinyModelDB
from src.module_analysis import count_parameters
from src.networks.core.seq_net import SeqNet
from src.networks.core.tensor_shape import TensorShape
from src.networks.core.torch_wrappers.torch_net import TorchNet
from src.reinforcement_learning.algorithms.ppo.ppo import PPO, PPOLoggingConfig
from src.reinforcement_learning.algorithms.policy_mitosis.policy_mitosis import PolicyMitosis, PolicyWithEnvAndInfo
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_gae_and_returns
from src.reinforcement_learning.core.normalization import NormalizationType
from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy
from src.reinforcement_learning.core.policy_info import PolicyInfo
from src.reinforcement_learning.gym.envs.parallelize_env import parallelize_env_async
from src.reinforcement_learning.gym.envs.transform_reward_wrapper import TransformRewardWrapper
from src.stopwatch import Stopwatch
from src.summary_statistics import format_summary_statics
from src.torch_device import set_default_torch_device

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:



def init_policy():
    class A2CNetwork(nn.Module):

        def __init__(self):
            super().__init__()
            
            in_size = 376
            shared_out_sizes = []
            actor_out_sizes = [512, 512, 256, 256, 256, 256, 256, 256, 17]
            critic_out_sizes = [512, 512, 256, 256, 256, 1]
            
            hidden_activation_function = nn.ELU()
            actor_out_activation_function = nn.Tanh()
            critic_out_activation_function = nn.Identity()
            
            self.has_shared = len(shared_out_sizes) > 0
            
            if self.has_shared:
                self.shared = SeqNet.from_layer_provider(
                    layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                        nn.Linear(in_features, out_features),
                        hidden_activation_function
                    ),
                    in_size=in_size,
                    out_sizes=shared_out_sizes
                )
            else:
                self.shared = TorchNet(nn.Identity(), in_shape=TensorShape(features=in_size), out_shape=TensorShape(features=in_size))

            self.actor = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    actor_out_activation_function if is_last_layer else hidden_activation_function
                ),
                in_size=self.shared.out_shape.get_definite_features(),
                out_sizes=actor_out_sizes
            )

            self.critic = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    critic_out_activation_function if is_last_layer else hidden_activation_function
                ),
                in_size=self.shared.out_shape.get_definite_features(),
                out_sizes=critic_out_sizes
            )

        def forward(self, x: torch.Tensor):
            if self.has_shared:
                shared_out = self.shared(x)
            else:
                shared_out = x

            return self.actor(shared_out), self.critic(shared_out)

    return ActorCriticPolicy(A2CNetwork(), lambda action_logits: dist.Normal(loc=action_logits, scale=policy_action_std))

def wrap_env(_env: VectorEnv):
    _env = TransformRewardWrapper(_env, lambda _reward: 0.01 * _reward) 
    _env = RescaleAction(_env, min_action=-1.0, max_action=1.0)
    _env = ClipAction(_env)
    return _env

def train_func(policy_with_env_and_info: PolicyWithEnvAndInfo) -> tuple[int, float]:
    policy = policy_with_env_and_info['policy']
    env = policy_with_env_and_info['env']
    
    
    # todo: ems
    score = 0.0
    rollout_stopwatch = Stopwatch()
    def on_rollout_done(rl: PPO, step: int, info: dict[str, Any]):   
        
        if 'raw_rewards' in info['rollout']:
            raw_rewards = info['rollout']['raw_rewards']
            _, gamma_1_returns = compute_gae_and_returns(
                value_estimates=np.zeros_like(rl.buffer.rewards[:len(raw_rewards)]),
                rewards=raw_rewards,
                episode_starts=rl.buffer.episode_starts[:len(raw_rewards)],
                last_values=np.zeros_like(rl.buffer.rewards[0], dtype=float),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_rewards=None,
                normalize_advantages=None,
            )
        else:
            _, gamma_1_returns = rl.buffer.compute_gae_and_returns(
                last_values=torch.zeros_like(rl.buffer.value_estimates[0]),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_advantages=None,
                normalize_rewards=None,
            )
        
        episode_scores = gamma_1_returns[
            rl.buffer.episode_starts[:rl.buffer.pos]
        ]
        
        nonlocal score
        score = episode_scores.mean()
        
        rollout_time = rollout_stopwatch.reset()
        
        resets = format_summary_statics(
            rl.buffer.episode_starts.astype(int).sum(axis=0), 
            mean_format='.2f',
            std_format=None,
            min_value_format='1d',
            max_value_format=None,
        )
        print(f'{step:>6}: {score = :9.3f}, time = {rollout_time:5.2f}, {resets = :s}')
    
    policy_info_str = ('('
          f'policy_id = {policy_with_env_and_info["policy_id"]}, '
          f'parent_id = {policy_with_env_and_info["parent_policy_id"]}, '
          f'num_parameters = {count_parameters(policy)}, '
          f'previous_steps = {policy_with_env_and_info["policy_info"]["steps_trained"]}, '
          f'previous_score = {policy_with_env_and_info["policy_info"]["score"]:9.3f}'
          ')')
    
    print(f'Starting PPO with policy {policy_info_str:s} for {steps_per_iteration:_} steps')
    mitosis_iteration_stopwatch = Stopwatch()
    PPO(
        env=env,
        policy=policy.to(device),
        policy_optimizer=lambda pol: optim.Adam(pol.parameters(), lr=1e-5),
        buffer_size=2500,
        gamma=0.995,
        gae_lambda=1.0,
        normalize_rewards=None,
        normalize_advantages=NormalizationType.Std,
        weigh_actor_objective=lambda obj: 1.0 * obj,
        weigh_critic_objective=lambda obj: 0.5 * obj,
        ppo_max_epochs=10,
        ppo_kl_target=0.01,
        ppo_batch_size=500,
        action_ratio_clip_range=0.05,
        grad_norm_clip_value=2.0,
        callback=Callback(on_rollout_done=on_rollout_done),
        logging_config=PPOLoggingConfig(log_rollout_infos=True)
    ).train(steps_per_iteration)
    
    
    print(f'Training finished for policy {policy_info_str:s}, end score = {score:9.3f}, time = {mitosis_iteration_stopwatch.time_passed():6.2f}')
    
    return steps_per_iteration, score

def select_policy_selection_scores(policy_infos: Iterable[PolicyInfo]) -> np.ndarray:
    scores = np.array([policy_info['score'] for policy_info in policy_infos])
    return scores / scores.std()

device = set_default_torch_device("cuda:0") if True else set_default_torch_device('cpu')

policy_action_std = 0.15
steps_per_iteration = 100_000

env_name = 'Humanoid-v4'
env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
num_envs = 32

# mitosis_id = get_current_timestamp()
mitosis_id = '2024-05-03_19.31.53'
policy_db = TinyModelDB[PolicyInfo](base_path=f'E:/saved_models/rl/{env_name}/mitosis-{mitosis_id}')

unwrapped_env = parallelize_env_async(lambda: gym.make(env_name, **env_kwargs), num_envs)
try:
    print(f'Starting mitosis with id {mitosis_id}')
    PolicyMitosis(
        policy_db=policy_db,
        policy_train_function=train_func,
        env=unwrapped_env,
        new_init_policy_function=init_policy,
        new_wrap_env_function=wrap_env,
        select_policy_selection_scores=select_policy_selection_scores,
        policy_selection_temperature=1.1,
        min_base_ancestors=5,
        _globals=globals(),
        rng_seed=None,
    ).train(1000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(2.5)
    unwrapped_env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

Starting mitosis with id 2024-05-03_19.31.53
policy selection probs = 
	2024-05-03_19.31.56~YOi2Kr: p = 0.030967, score = 9.324106646821708, steps = 100000
	2024-05-03_19.41.26~ONGcgY: p = 0.023469, score = 9.019118288658124, steps = 100000
	2024-05-03_19.50.38~rpp1TN: p = 0.029827, score = 9.28283808761468, steps = 100000
	2024-05-03_19.59.45~iOUmuV: p = 0.009948, score = 8.074939831743997, steps = 100000
	2024-05-03_20.08.31~qVIc0m: p = 0.028977, score = 9.251017367165009, steps = 100000
	2024-05-03_20.23.19~jkSCqb: p = 0.054721, score = 9.950350163271155, steps = 200000
	2024-05-03_20.32.46~UETzfs: p = 0.102401, score = 10.639669612283704, steps = 300000
	2024-05-03_20.42.03~AAbsq9: p = 0.092334, score = 10.525839112620446, steps = 300000
	2024-05-03_20.51.29~B2PjVU: p = 0.021362, score = 8.915683225407209, steps = 100000
	2024-05-03_21.00.17~O81lZo: p = 0.046211, score = 9.764420465595782, steps = 200000
	2024-05-03_21.09.57~l028kP: p = 0.194829, score = 11.347217090925172, steps =