In [1]:
import sys
from typing import Any, Iterable

import numpy as np

import gymnasium as gym

from notebooks.policy_mitosis import init_action_selector, init_policy, init_optimizer, wrap_env
from src.datetime import get_current_timestamp
from src.model_db.tiny_model_db import TinyModelDB
from src.module_analysis import count_parameters
from src.moving_averages import ExponentialMovingAverage
from src.np_functions import softmax
from src.reinforcement_learning.algorithms.policy_mitosis.async_policy_mitosis import AsyncPolicyMitosis
from src.reinforcement_learning.algorithms.policy_mitosis.mitosis_policy_info import MitosisPolicyInfo
from src.reinforcement_learning.algorithms.policy_mitosis.policy_mitosis_base import TrainResultInfo, TrainInfo
from src.reinforcement_learning.algorithms.ppo.ppo import PPOLoggingConfig, PPO
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.generalized_advantage_estimate import compute_episode_returns
from src.reinforcement_learning.core.normalization import NormalizationType
from src.reinforcement_learning.core.policy_construction import PolicyConstruction
from src.reinforcement_learning.core.policy_evaluation import evaluate_policy
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
from src.stopwatch import Stopwatch
from src.torch_device import get_torch_device
from src.torch_functions import antisymmetric_power
from src.trees import Forest
%load_ext autoreload
%autoreload 2

In [None]:

def train_func(train_info: TrainInfo) -> TrainResultInfo:
    policy = train_info['policy']
    optimizer = train_info['optimizer']
    env = train_info['env']
    policy_info = train_info['policy_info']
    
    score = 0.0
    score_ema = ExponentialMovingAverage(0.45)
    rollout_stopwatch = Stopwatch()
    def on_rollout_done(rl: PPO, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):   
        
        rewards = rl.buffer.rewards
        if 'raw_rewards' in info['rollout']:
            rewards = info['rollout']['raw_rewards']
        
        episode_scores = compute_episode_returns(
            rewards=rewards,
            episode_starts=rl.buffer.episode_starts,
            last_episode_starts=info['last_episode_starts'],
        )
        
        nonlocal score, score_ema
        score = episode_scores.mean()
        current_score_ema = score_ema.update(score)
        
        rollout_time = rollout_stopwatch.reset()
        
        resets: np.ndarray = rl.buffer.episode_starts.astype(int).sum(axis=0)
        resets_mean = resets.mean()
        resets_min = resets.min()
        
        print(f'{policy_info["policy_id"]}  {step:>6}: '
              f'{score = :9.3f}, '
              f'score_ema = {current_score_ema or score_ema.get():9.3f}, '
              f'time = {rollout_time:5.2f}, '
              f'resets = {resets_mean:5.2f} >= {resets_min:5.2f}')
        sys.stdout.flush()
        
    optimizations_done = 0
    def on_optimization_done(rl: PPO, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
        nonlocal optimizations_done
        optimizations_done += 1
    
    policy_info_str = ('('
          f'policy_id = {policy_info["policy_id"]}, '
          f'parent_id = {policy_info["parent_policy_id"]}, '
          f'num_parameters = {count_parameters(policy)}, '
          f'previous_steps = {policy_info["steps_trained"]}, '
          f'previous_score = {policy_info["score"]:9.3f}'
          ')')
    
    print(f'Starting PPO with policy {policy_info_str:s} for {steps_per_iteration:_} steps')
    mitosis_iteration_stopwatch = Stopwatch()
    PPO(
        env=env,
        policy=policy,
        policy_optimizer=optimizer,
        buffer_size=5000,
        gamma=0.995,
        gae_lambda=1.0,
        normalize_rewards=None,
        normalize_advantages=NormalizationType.Std,
        weigh_and_reduce_actor_objective=lambda obj: antisymmetric_power(obj, 1.5).mean(),
        weigh_and_reduce_entropy_objective=None,  # lambda obj: 1.0 * obj.mean(),
        weigh_and_reduce_critic_objective=lambda obj: 0.5 * obj.mean(),
        ppo_max_epochs=10,
        ppo_kl_target=0.025,
        ppo_batch_size=500,
        action_ratio_clip_range=0.1,
        grad_norm_clip_value=1.0,
        callback=Callback(
            on_rollout_done=on_rollout_done,
            on_optimization_done=on_optimization_done,
        ),
        logging_config=PPOLoggingConfig(log_rollout_infos=True, log_last_obs=True),
        torch_device=device,
    ).train(steps_per_iteration)
    
    eval_scores = evaluate_policy(
        env=env,
        policy=policy,
        num_steps=10_000,
    )
    eval_score = eval_scores.mean() 
    
    print(f'Training finished for policy {policy_info_str:s}, '
          f'evaluation_score = {eval_score:9.3f}, '
          f'moving average score = {score_ema.get():9.3f}, '
          f'time = {mitosis_iteration_stopwatch.time_passed():6.2f}')
    
    return {
        'steps_trained': steps_per_iteration, 
        'optimizations_done': optimizations_done, 
        'score': eval_score,
        'extra_infos': {
            'score_ema': score_ema.get()
        }
    }

def select_policy_selection_probs(policy_infos: Iterable[MitosisPolicyInfo]) -> np.ndarray:
    # TODO introduce score change momentum factor, average child score
    policy_infos = list(policy_infos)
    policy_info_forest = Forest(
        policy_infos, 
        get_id=lambda pi: pi['policy_id'], 
        get_parent_id=lambda pi: pi['parent_policy_id']
    )
    
    scores = np.array([policy_info['score'] for policy_info in policy_infos], dtype=float)
    score_probs = softmax(scores, temperature=0.5 / np.log(len(scores) + 1), normalize=True)
    
    num_descendants = np.array([
        policy_info_forest.compute_num_descendants(policy_info['policy_id'], discount_factor=0.5) 
        for policy_info in policy_infos
    ], dtype=float)
    num_descendants_probs = softmax(-num_descendants, temperature=0.5)
    
    steps_trained = np.array([policy_info['steps_trained'] for policy_info in policy_infos], dtype=float)
    steps_trained_probs = softmax(-steps_trained, temperature=0.1, normalize=True)
    
    score_weight = 1.0
    num_descendants_weight = 0.5
    steps_trained_weight = 0.1
    
    probs = (
        score_probs**score_weight * 
        num_descendants_probs**num_descendants_weight * 
        steps_trained_probs**steps_trained_weight
    )
    probs /= probs.sum()
    
    print('policy selection probs = \n\t' + '\n\t'.join(
        f'{(policy_id := policy_infos[i]["policy_id"])}: {p = :8.6f}, '
        f'score = {policy_infos[i]["score"]:7.3f}, '
        f'score_prob = {score_probs[i]**score_weight:7.5f}, '
        f'num_children = {len(policy_info_forest[policy_id].children)}, '
        f'num_descendants = {num_descendants[i]:7.3f}, '
        f'descendants_prob = {num_descendants_probs[i]**num_descendants_weight:7.5f}, '
        f'steps = {policy_infos[i]["steps_trained"]}, '
        f'steps_prob = {steps_trained_probs[i]**steps_trained_weight:7.5f}, '
        for i, p
        in enumerate(probs)
    ))
    
    return probs

device = get_torch_device("cuda:0") if True else get_torch_device('cpu')

steps_per_iteration = 50_000

# env_name = 'Humanoid-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
env_name = 'Ant-v4'
env_kwargs = {'healthy_reward': 0.001, 'ctrl_cost_weight': 0.05 }
num_envs = 16

mitosis_id = get_current_timestamp()
# mitosis_id = '2024-06-10_19.43.13'
# policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'E:/saved_models/rl/{env_name}/mitosis-{mitosis_id}')
policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'C:/Users/domin/git/pytorch-starter/saved_models/rl/{env_name}/mitosis-{mitosis_id}')

try:
    print(f'Starting mitosis with id {mitosis_id}')
    AsyncPolicyMitosis(
        num_workers=3,
        policy_db=policy_db,
        train_policy_function=train_func,
        create_env=lambda: parallelize_env_async(lambda: gym.make(env_name, render_mode=None, **env_kwargs), num_envs),
        new_policy_initialization_info=PolicyConstruction.create_policy_initialization_info(
            init_action_selector=init_action_selector,
            init_policy=init_policy,
            init_optimizer=init_optimizer,
            wrap_env=wrap_env,
        ),
        new_policy_prob_function=lambda nr_policies, nr_primordial_ancestors: 0.0,
        modify_policy=None,
        select_policy_selection_probs=select_policy_selection_probs,
        min_primordial_ancestors=1,
        rng_seed=None,
        initialization_delay=5,
        delay_between_workers=5,
        save_optimizer_state_dicts=True,
        load_optimizer_state_dicts=True,
    ).train_with_mitosis(1000)
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    policy_db.close()
    print('model db closed')
    

print('done')