In [1]:

import time
from typing import Any

import gymnasium as gym
import numpy as np
import torch
import torch.distributions as dist
from gymnasium.vector import VectorEnv
from gymnasium.wrappers import RescaleAction, ClipAction
from torch import optim, nn

from src.datetime import get_current_timestamp
from src.model_db.tiny_model_db import TinyModelDB
from src.networks.core.seq_net import SeqNet
from src.networks.core.tensor_shape import TensorShape
from src.networks.core.torch_wrappers.torch_net import TorchNet
from src.reinforcement_learning.algorithms.ppo.ppo import PPO, PPOLoggingConfig
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_gae_and_returns
from src.reinforcement_learning.core.normalization import NormalizationType
from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy
from src.reinforcement_learning.core.policy_info import PolicyInfo
from src.reinforcement_learning.gym.envs.parallelize_env import parallelize_env_async
from src.reinforcement_learning.gym.envs.transform_reward_wrapper import TransformRewardWrapper
from src.torch_device import set_default_torch_device

%load_ext autoreload
%autoreload 2

In [None]:
from src.reinforcement_learning.algorithms.policy_mitosis.policy_mitosis import PolicyMitosis, PolicyWithEnvAndInfo


def init_policy():
    class A2CNetwork(nn.Module):

        def __init__(self):
            super().__init__()
            
            in_size = 376
            shared_out_sizes = []
            actor_out_sizes = [512, 512, 256, 256, 256, 256, 256, 256, 17]
            critic_out_sizes = [512, 512, 256, 256, 256, 1]
            
            hidden_activation_function = nn.ELU()
            actor_out_activation_function = nn.Tanh()
            critic_out_activation_function = nn.Identity()
            
            self.has_shared = len(shared_out_sizes) > 0
            
            if self.has_shared:
                self.shared = SeqNet.from_layer_provider(
                    layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                        nn.Linear(in_features, out_features),
                        hidden_activation_function
                    ),
                    in_size=in_size,
                    out_sizes=shared_out_sizes
                )
            else:
                self.shared = TorchNet(nn.Identity(), in_shape=TensorShape(features=in_size), out_shape=TensorShape(features=in_size))

            self.actor = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    actor_out_activation_function if is_last_layer else hidden_activation_function
                ),
                in_size=self.shared.out_shape.get_definite_features(),
                out_sizes=actor_out_sizes
            )

            self.critic = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    nn.Linear(in_features, out_features),
                    critic_out_activation_function if is_last_layer else hidden_activation_function
                ),
                in_size=self.shared.out_shape.get_definite_features(),
                out_sizes=critic_out_sizes
            )

        def forward(self, x: torch.Tensor):
            if self.has_shared:
                shared_out = self.shared(x)
            else:
                shared_out = x

            return self.actor(shared_out), self.critic(shared_out)

    return ActorCriticPolicy(A2CNetwork(), lambda action_logits: dist.Normal(loc=action_logits, scale=policy_action_std))

def wrap_env(_env: VectorEnv):
    _env = TransformRewardWrapper(_env, lambda _reward: 0.01 * _reward) 
    _env = RescaleAction(_env, min_action=-1.0, max_action=1.0)
    _env = ClipAction(_env)
    return _env

def train_func(policy_with_env_and_info: PolicyWithEnvAndInfo) -> tuple[int, float]:
    policy = policy_with_env_and_info['policy']
    env = policy_with_env_and_info['env']
    
    score = 0.0
    
    def on_rollout_done(rl: PPO, step: int, info: dict[str, Any]):   
        
        if 'raw_rewards' in info['rollout']:
            raw_rewards = info['rollout']['raw_rewards']
            _, gamma_1_returns = compute_gae_and_returns(
                value_estimates=np.zeros_like(rl.buffer.rewards[:len(raw_rewards)]),
                rewards=raw_rewards,
                episode_starts=rl.buffer.episode_starts[:len(raw_rewards)],
                last_values=np.zeros_like(rl.buffer.rewards[0], dtype=float),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_rewards=None,
                normalize_advantages=None,
            )
        else:
            _, gamma_1_returns = rl.buffer.compute_gae_and_returns(
                last_values=torch.zeros_like(rl.buffer.value_estimates[0]),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_advantages=None,
                normalize_rewards=None,
            )
        
        episode_scores = gamma_1_returns[
            rl.buffer.episode_starts[:rl.buffer.pos]
        ]
        
        nonlocal score
        score = episode_scores.mean()
        
        print(f'{step:>6}: {score = :9.4f}')
    
    print(f'Starting PPO with policy (policy_id = {policy_with_env_and_info["policy_id"]}, parent_id = {policy_with_env_and_info["parent_policy_id"]}) for {steps_per_iteration:_} steps')
    PPO(
        env=env,
        policy=policy.to(device),
        policy_optimizer=lambda pol: optim.Adam(pol.parameters(), lr=1e-5),
        buffer_size=2500,
        gamma=0.995,
        gae_lambda=1.0,
        normalize_rewards=None,
        normalize_advantages=NormalizationType.Std,
        weigh_actor_objective=lambda obj: 1.0 * obj,
        weigh_critic_objective=lambda obj: 0.5 * obj,
        ppo_max_epochs=10,
        ppo_kl_target=0.01,
        ppo_batch_size=500,
        action_ratio_clip_range=0.05,
        grad_norm_clip_value=2.0,
        callback=Callback(on_rollout_done=on_rollout_done),
        logging_config=PPOLoggingConfig(log_rollout_infos=True)
    ).train(steps_per_iteration)
    
    return steps_per_iteration, score


device = set_default_torch_device("cuda:0") if True else set_default_torch_device('cpu')

policy_action_std = 0.15
steps_per_iteration = 100_000

env_name = 'Humanoid-v4'
env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
num_envs = 4

policy_db = TinyModelDB[PolicyInfo](base_path=f'saved_models/rl/{env_name}/mitosis-{get_current_timestamp()}')

unwrapped_env = parallelize_env_async(lambda: gym.make(env_name, **env_kwargs), num_envs)
try:
    PolicyMitosis(
        policy_db=policy_db,
        policy_train_function=train_func,
        env=unwrapped_env,
        new_init_policy_function=init_policy,
        new_wrap_env_function=wrap_env,
        _globals=globals(),
        rng_seed=42,
    ).train(1000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(2.5)
    unwrapped_env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

Starting PPO with policy (policy_id = 2024-05-03_13.55.08, parent_id = None) for 100_000 steps
  2500: score =   21.6012
  5000: score =   21.1634
  7500: score =   23.8943
 10000: score =   31.0903
 12500: score =   39.0746
 15000: score =   57.1511
 17500: score =   50.3929
 20000: score =   65.1771
 22500: score =   67.2442
 25000: score =   70.2096
 27500: score =   68.8788
 30000: score =   71.6042
 32500: score =   70.7416
 35000: score =   68.0738
 37500: score =   68.5171
 40000: score =   69.4280
 42500: score =   71.8161
 45000: score =   71.5567
 47500: score =   71.1147
 50000: score =   73.0481
 52500: score =   74.0450
 55000: score =   75.7549
 57500: score =   75.9361
 60000: score =   75.9568
 62500: score =   77.5204
 65000: score =   75.2102
 67500: score =   74.4111
 70000: score =   76.0839
