In [1]:
import sys
import time
from typing import Any, Iterable

import gymnasium as gym
import numpy as np
import torch
from torch import optim

from src.datetime import get_current_timestamp
from src.model_db.tiny_model_db import TinyModelDB
from src.module_analysis import count_parameters
from src.moving_averages import ExponentialMovingAverage
from src.np_functions import softmax
from src.reinforcement_learning.algorithms.policy_mitosis.async_policy_mitosis import AsyncPolicyMitosis
from src.reinforcement_learning.algorithms.policy_mitosis.mitosis_policy_info import MitosisPolicyInfo
from src.reinforcement_learning.algorithms.policy_mitosis.policy_mitosis_base import PolicyWithEnvAndInfo
from src.reinforcement_learning.algorithms.ppo.ppo import PPOLoggingConfig, PPO
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_gae_and_returns
from src.reinforcement_learning.core.normalization import NormalizationType
from src.reinforcement_learning.core.objectives import ObjectiveLoggingConfig
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
from src.stopwatch import Stopwatch
from src.torch_device import get_torch_device
from src.torch_functions import antisymmetric_power

%load_ext autoreload
%autoreload 2

pygame 2.5.2 (SDL 2.28.3, Python 3.11.7)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [None]:



nr_carts = 2

def make_multi_agent_cart_pole_env(render_mode: str | None = None):
    from src.reinforcement_learning.gym.envs.multi_agent_cartpole3d import MultiAgentCartPole3D
    return MultiAgentCartPole3D(
        nr_carts=nr_carts,
        cart_size=0.25,
        force_magnitude=500,
        physics_steps_per_step=25,
        reset_position_radius=0.75,
        reset_randomize_position_angle_offset=True,
        reset_position_randomization_magnitude=0.3,
        reset_hinge_randomization_magnitude=0.05,
        slide_range=2,
        hinge_range=0.8,
        time_limit=120.0,
        step_reward_function=lambda time_, action, state, prev_state: 0.01,
        out_ouf_range_reward_function=lambda time_, action, state: 0.0,# -10 + time_ * 3,
        time_limit_reward_function=lambda time_, action, state: 10,
        render_mode=render_mode,
    )


def init_policy():
    import numpy as np
    import torch
    from torch import nn

    from src.networks.core.net import Net
    from src.networks.core.seq_net import SeqNet
    from src.reinforcement_learning.core.action_selectors.squashed_diag_gaussian_action_selector import \
        SquashedDiagGaussianActionSelector
    from src.reinforcement_learning.core.policies.actor_critic_policy import ActorCriticPolicy
    from src.networks.skip_nets.additive_skip_net import AdditiveSkipNet, FullyConnectedAdditiveSkipNet, \
        FullyConnectedUnweightedAdditiveSkipNet
    from src.networks.skip_nets.additive_skip_connection import AdditiveSkipConnection
    from src.weight_initialization import orthogonal_initialization
    from src.networks.multihead_self_attention import MultiheadSelfAttention
    
    in_size = 8
    action_size = 2
    
    actor_layers = 3
    actor_features = 48
    
    critic_layers = 2
    critic_features = 48

    actor_hidden_activation_function = nn.ELU
    critic_hidden_activation_function = nn.ELU
    
    actor_hidden_initialization = lambda module: orthogonal_initialization(module, gain=np.sqrt(2))
    critic_hidden_initialization = lambda module: orthogonal_initialization(module, gain=np.sqrt(2))

    class A2CNetwork(nn.Module):

        def __init__(self):
            super().__init__()

            self.actor_embedding = nn.Sequential(nn.Linear(in_size, actor_features), actor_hidden_activation_function())
            self.actor = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    AdditiveSkipConnection(MultiheadSelfAttention(
                        embed_dim=in_features,
                        num_heads=4,
                        batch_first=True,
                    )),
                    nn.LayerNorm(in_features),
                    AdditiveSkipConnection(Net.sequential_net(
                        actor_hidden_initialization(nn.Linear(in_features, out_features)),
                        actor_hidden_activation_function(),
                        actor_hidden_initialization(nn.Linear(in_features, out_features)),
                        nn.Tanh() if is_last_layer else actor_hidden_activation_function(),
                    )),
                    nn.LayerNorm(in_features),
                ),
                num_layers=actor_layers,
                num_features=actor_features,
            )

            self.critic_embedding = nn.Sequential(nn.Linear(in_size, critic_features), critic_hidden_activation_function())
            self.critic = SeqNet.from_layer_provider(
                layer_provider=lambda layer_nr, is_last_layer, in_features, out_features: nn.Sequential(
                    AdditiveSkipConnection(MultiheadSelfAttention(
                        embed_dim=in_features,
                        num_heads=4,
                        batch_first=True,
                    )),
                    nn.LayerNorm(in_features),
                    AdditiveSkipConnection(Net.sequential_net(
                        critic_hidden_initialization(nn.Linear(in_features, out_features)),
                        critic_hidden_activation_function(),
                        critic_hidden_initialization(nn.Linear(in_features, out_features)),
                        critic_hidden_activation_function(),
                    )),
                    nn.LayerNorm(in_features),
                ),
                num_layers=critic_layers,
                num_features=critic_features,
            )
            self.critic_regressor = nn.Linear(critic_features, 1)

        def forward(self, x: torch.Tensor):
            *batch_shape, nr_actors, nr_features = x.shape
            x = torch.flatten(x, end_dim=-3)
            
            actor_out: torch.Tensor = self.actor(self.actor_embedding(x))
            critic_out: torch.Tensor = self.critic_regressor(self.critic(self.critic_embedding(x)).sum(dim=-2))
            
            actor_out = actor_out.unflatten(dim=0, sizes=batch_shape)
            critic_out = critic_out.unflatten(dim=0, sizes=batch_shape)
            
            return actor_out, critic_out
    return ActorCriticPolicy(A2CNetwork(), SquashedDiagGaussianActionSelector(
        latent_dim=actor_features,
        action_dim=action_size,
        std=0.15,
        std_learnable=False,
        action_net_initialization=lambda module: orthogonal_initialization(module, gain=0.01),
    ))

def wrap_env(env_):
    return env_

def train_func(policy_with_env_and_info: PolicyWithEnvAndInfo) -> tuple[int, float]:
    policy = policy_with_env_and_info['policy']
    env = policy_with_env_and_info['env']
    
    score = 0.0
    score_ema = ExponentialMovingAverage(0.45)
    rollout_stopwatch = Stopwatch()
    def on_rollout_done(rl: PPO, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):   
        
        if 'raw_rewards' in info['rollout']:
            raw_rewards = info['rollout']['raw_rewards']
            _, gamma_1_returns = compute_gae_and_returns(
                value_estimates=np.zeros_like(rl.buffer.rewards[:len(raw_rewards)]),
                rewards=raw_rewards,
                episode_starts=rl.buffer.episode_starts[:len(raw_rewards)],
                last_values=np.zeros_like(rl.buffer.rewards[0], dtype=float),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_rewards=None,
                normalize_advantages=None,
            )
        else:
            _, gamma_1_returns = rl.buffer.compute_gae_and_returns(
                last_values=torch.zeros_like(rl.buffer.value_estimates[0]),
                last_dones=np.zeros_like(rl.buffer.episode_starts[0], dtype=bool),
                gamma=1.0,
                gae_lambda=1.0,
                normalize_advantages=None,
                normalize_rewards=None,
            )
        
        episode_scores = gamma_1_returns[
            rl.buffer.episode_starts[:rl.buffer.pos]
        ]
        
        nonlocal score, score_ema
        score = episode_scores.mean()
        current_score_ema = score_ema.update(score)
        
        rollout_time = rollout_stopwatch.reset()
        
        resets: np.ndarray = rl.buffer.episode_starts.astype(int).sum(axis=0)
        resets_mean = resets.mean()
        resets_min = resets.min()
        print(f'{step:>6}: '
              f'{score = :9.3f}, '
              f'score_ema = {current_score_ema:9.3f}, '
              f'time = {rollout_time:5.2f}, '
              f'resets = {resets_mean:5.2f} >= {resets_min:5.2f}')
        sys.stdout.flush()
    
    policy_info = policy_with_env_and_info['policy_info']
    policy_info_str = ('('
          f'policy_id = {policy_info["policy_id"]}, '
          f'parent_id = {policy_info["parent_policy_id"]}, '
          f'num_parameters = {count_parameters(policy)}, '
          f'previous_steps = {policy_info["steps_trained"]}, '
          f'previous_score = {policy_info["score"]:9.3f}'
          ')')
    
    print(f'Starting PPO with policy {policy_info_str:s} for {steps_per_iteration:_} steps')
    mitosis_iteration_stopwatch = Stopwatch()
    PPO(
        env=env,
        policy=policy.to(device),
        policy_optimizer=lambda pol: optim.AdamW(pol.parameters(), lr=1e-5),
        buffer_size=2500,
        gamma=0.995,
        gae_lambda=1.0,
        normalize_rewards=None,
        normalize_advantages=NormalizationType.Std,
        reduce_actor_objective=lambda obj: antisymmetric_power(obj, 1.5).mean(),
        weigh_actor_objective=lambda obj: 1.0 * obj,
        weigh_entropy_objective=lambda obj: 1.0 * obj.exp(),
        weigh_critic_objective=lambda obj: 0.5 * obj,
        ppo_max_epochs=10,
        ppo_kl_target=0.025,
        ppo_batch_size=500,
        action_ratio_clip_range=0.1,
        grad_norm_clip_value=0.5,
        # sde_noise_sample_freq=50,
        callback=Callback(
            on_rollout_done=on_rollout_done,
            rollout_schedulers={},
            optimization_schedulers={},
        ),
        logging_config=PPOLoggingConfig(log_returns=True, log_advantages=True, log_grad_norm=True,
                                        log_rollout_infos=True, log_rollout_action_stds=True,
                                        log_actor_kl_divergence=True,
                                        actor_objective=ObjectiveLoggingConfig(log_raw=True),
                                        entropy_objective=ObjectiveLoggingConfig(log_weighted=True),
                                        critic_objective=ObjectiveLoggingConfig(log_weighted=True), ),
            torch_device=device,
        ).train(steps_per_iteration)
    
    
    print(f'Training finished for policy {policy_info_str:s}, end score = {score:9.3f}, time = {mitosis_iteration_stopwatch.time_passed():6.2f}')
    
    return steps_per_iteration, score_ema.get()

def select_policy_selection_probs(policy_infos: Iterable[MitosisPolicyInfo]) -> np.ndarray:
    scores = np.array([policy_info['score'] for policy_info in policy_infos])
    scores = scores / scores.mean()
    scores = softmax(scores, temperature=0.5 / len(scores)**0.5)
    return scores

device = get_torch_device("cuda:0") if True else get_torch_device('cpu')
print(f'using device {device}')

steps_per_iteration = 100_000

num_envs = 32

# mitosis_id = get_current_timestamp()
mitosis_id = '2024-05-27_19.00.00'
policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'E:/saved_models/rl/MultiAgentCartPole/{nr_carts}/mitosis-{mitosis_id}')
# policy_db = TinyModelDB[PolicyInfo](base_path=f'C:/Users/domin/git/pytorch-starter/saved_models/rl/{env_name}/mitosis-{mitosis_id}')

try:
    print(f'Starting mitosis with id {mitosis_id}')
    AsyncPolicyMitosis(
        num_workers=3,
        policy_db=policy_db,
        train_policy_function=train_func,
        create_env=lambda: parallelize_env_async(lambda: make_multi_agent_cart_pole_env(None), num_envs),
        new_init_policy_function=init_policy,
        new_wrap_env_function=wrap_env,
        select_policy_selection_probs=select_policy_selection_probs,
        min_base_ancestors=5,
        rng_seed=None,
        initialization_delay=15,
        delay_between_workers=10,
    ).train_with_mitosis(1000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(2.5)
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

using device cuda:0
Starting mitosis with id 2024-05-27_19.00.00
Starting worker 0 with delay = 0
Started training iteration for policy: 2024-05-27_19.08.26~zC9AVO, parent policy id: None
Starting worker 1 with delay = 15
Started training iteration for policy: 2024-05-27_19.08.41~aOPPI5, parent policy id: None
Starting worker 2 with delay = 30
Started training iteration for policy: 2024-05-27_19.08.56~xYsZyL, parent policy id: None
