In [1]:
from src.reinforcement_learning.core.action_selectors.predicted_std_action_selector import PredictedStdActionSelector
from stable_baselines3.common.torch_layers import FlattenExtractor
from dataclasses import dataclass
from typing import Type, Optional, Any, Literal

import gymnasium
import numpy as np
import stable_baselines3 as sb
import torch
import torch.nn.functional as F
from stable_baselines3.common.policies import ContinuousCritic
from torch import optim

from src.function_types import TorchTensorFn
from src.hyper_parameters import HyperParameters
from src.reinforcement_learning.algorithms.base.base_algorithm import PolicyProvider
from src.reinforcement_learning.algorithms.base.off_policy_algorithm import OffPolicyAlgorithm, ReplayBuf
from src.reinforcement_learning.algorithms.sac.sac_crossq_policy import SACCrossQPolicy
from src.reinforcement_learning.core.action_noise import ActionNoise
from src.reinforcement_learning.core.buffers.replay.base_replay_buffer import BaseReplayBuffer, ReplayBufferSamples
from src.reinforcement_learning.core.buffers.replay.replay_buffer import ReplayBuffer
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.infos import InfoDict, concat_infos
from src.reinforcement_learning.core.logging import LoggingConfig, log_if_enabled
from src.reinforcement_learning.core.loss_config import LossLoggingConfig
from src.reinforcement_learning.core.polyak_update import polyak_update
from src.reinforcement_learning.core.type_aliases import OptimizerProvider, TensorObs, detach_obs
from src.reinforcement_learning.gym.env_analysis import get_single_action_space
from src.torch_device import TorchDevice
from src.torch_functions import identity

SAC_DEFAULT_OPTIMIZER_PROVIDER = lambda params: optim.AdamW(params, lr=3e-4, weight_decay=1e-4)
AUTO_TARGET_ENTROPY = 'auto'

from typing import Optional

import torch
from torch import nn

from src.hyper_parameters import HyperParameters
from src.reinforcement_learning.core.action_selectors.action_selector import ActionSelector
from src.reinforcement_learning.core.action_selectors.state_dependent_noise_action_selector import \
    StateDependentNoiseActionSelector
from src.reinforcement_learning.core.policies.components.base_component import BasePolicyComponent
from src.reinforcement_learning.core.policies.components.feature_extractors import FeatureExtractor, IdentityExtractor
from src.reinforcement_learning.core.type_aliases import TensorObs


class DebugActor(BasePolicyComponent):

    action_selector: ActionSelector
    uses_sde: bool

    def __init__(
            self,
            network: nn.Module,
            action_selector: ActionSelector,
            # feature_extractor: Optional[FeatureExtractor] = None
    ):
        assert isinstance(action_selector, PredictedStdActionSelector)
        super().__init__(IdentityExtractor())
        self.network = network
        self.replace_action_selector(action_selector, copy_action_net_weights=False)

    def collect_hyper_parameters(self) -> HyperParameters:
        return self.update_hps(super().collect_hyper_parameters(), {
            'network': self.get_hps_or_str(self.network),
            'action_selector': self.get_hps_or_str(self.action_selector),
        })

    def forward(self, obs: TensorObs, deterministic: bool = False) -> torch.Tensor:
        obs = self.feature_extractor(obs)
        latent_pi = self.network(obs)
        return self.action_selector.update_latent_features(latent_pi).get_actions(deterministic=debug_actor)

    def get_actions_with_log_probs(
            self,
            obs: TensorObs,
            deterministic: bool = False
    ) -> tuple[torch.Tensor, torch.Tensor]:
        obs = self.feature_extractor(obs)
        latent_pi = self.network(obs)
        return self.action_selector.get_actions_with_log_probs(latent_pi, deterministic=deterministic)

    def replace_action_selector(self, new_action_selector: ActionSelector, copy_action_net_weights: bool) -> None:
        if copy_action_net_weights:
            new_action_selector.action_net.load_state_dict(self.action_selector.action_net.state_dict())
        self.action_selector = new_action_selector
        self.uses_sde = isinstance(self.action_selector, StateDependentNoiseActionSelector)

    def reset_sde_noise(self, batch_size: int = 1) -> None:
        if self.uses_sde:
            raise NotImplemented
    
    def set_training_mode(self, mode: bool):
        self.set_train_mode(mode)



@dataclass
class SACLoggingConfig(LoggingConfig):

    log_entropy_coef: bool = False
    entropy_coef_loss: LossLoggingConfig = None
    actor_loss: LossLoggingConfig = None
    critic_loss: LossLoggingConfig = None

    def __post_init__(self):
        if self.actor_loss is None:
            self.actor_loss = LossLoggingConfig()
        if self.entropy_coef_loss is None:
            self.entropy_loss = LossLoggingConfig()
        if self.critic_loss is None:
            self.critic_loss = LossLoggingConfig()

        super().__post_init__()

"""

        Soft Actor-Critic:
        Off-Policy Maximum Entropy Deep Reinforcement
        Learning with a Stochastic Actor
        https://arxiv.org/pdf/1801.01290

"""
class SACDebug(OffPolicyAlgorithm[sb.sac.sac.SACPolicy, ReplayBuf, SACLoggingConfig]):
    
    policy: sb.sac.sac.SACPolicy
    actor: sb.sac.sac.Actor
    critic: ContinuousCritic
    buffer: BaseReplayBuffer
    target_entropy: float
    log_ent_coef: Optional[torch.Tensor]
    entropy_coef_optimizer: Optional[optim.Optimizer]
    entropy_coef_tensor: Optional[torch.Tensor]

    def __init__(
            self,
            env: gymnasium.Env,
            policy: sb.sac.sac.SACPolicy | PolicyProvider[sb.sac.sac.SACPolicy],
            actor_optimizer_provider: OptimizerProvider = SAC_DEFAULT_OPTIMIZER_PROVIDER,
            critic_optimizer_provider: OptimizerProvider = SAC_DEFAULT_OPTIMIZER_PROVIDER,
            weigh_and_reduce_actor_loss: TorchTensorFn = torch.mean,
            weigh_critic_loss: TorchTensorFn = identity,
            buffer_type: Type[ReplayBuf] = ReplayBuffer,
            buffer_size: int = 100_000,
            buffer_kwargs: dict[str, Any] = None,
            gamma: float = 0.99,
            tau: float = 0.005,
            rollout_steps: int = 100,
            gradient_steps: int = 1,
            optimization_batch_size: int = 256,
            target_update_interval: int = 1,
            entropy_coef: float = 1.0,
            target_entropy: float | Literal['auto'] = AUTO_TARGET_ENTROPY,
            entropy_coef_optimizer_provider: Optional[OptimizerProvider] = None,
            weigh_and_reduce_entropy_coef_loss: TorchTensorFn = torch.mean,
            action_noise: Optional[ActionNoise] = None,
            warmup_steps: int = 100,
            learning_starts: int = 100,
            sde_noise_sample_freq: Optional[int] = None,
            callback: Callback['SAC'] = None,
            logging_config: SACLoggingConfig = None,
            torch_device: TorchDevice = 'auto',
            torch_dtype: torch.dtype = torch.float32,
    ):
        super().__init__(
            env=env,
            policy=policy,
            buffer=buffer_type.for_env(env, buffer_size, torch_device, torch_dtype, **(buffer_kwargs or {})),
            gamma=gamma,
            tau=tau,
            rollout_steps=rollout_steps,
            gradient_steps=gradient_steps,
            optimization_batch_size=optimization_batch_size,
            action_noise=action_noise,
            warmup_steps=warmup_steps,
            learning_starts=learning_starts,
            sde_noise_sample_freq=sde_noise_sample_freq,
            callback=callback or Callback(),
            logging_config=logging_config or LoggingConfig(),
            torch_device=torch_device,
            torch_dtype=torch_dtype,
        )

        self.actor = self.policy.actor
        self.critic = self.policy.critic
        # self.shared_feature_extractor = self.policy.shared_feature_extractor

        self.actor_optimizer = actor_optimizer_provider(
            # self.chain_parameters(self.actor, self.shared_feature_extractor)
            self.actor.parameters()
        )
        self.critic_optimizer = critic_optimizer_provider(self.critic.parameters())

        self.weigh_and_reduce_entropy_coef_loss = weigh_and_reduce_entropy_coef_loss
        self.weigh_and_reduce_actor_loss = weigh_and_reduce_actor_loss
        self.weigh_critic_loss = weigh_critic_loss

        self.target_update_interval = target_update_interval
        self.gradient_steps_performed = 0

        self._setup_entropy_optimization(entropy_coef, target_entropy, entropy_coef_optimizer_provider)

        # CrossQ doesn't use a target critic
        if isinstance(self.policy, SACCrossQPolicy):
            self.tau = 0
            self.target_update_interval = 0


    def collect_hyper_parameters(self) -> HyperParameters:
        return self.update_hps(super().collect_hyper_parameters(), {
            'actor_optimizer': str(self.actor_optimizer),
            'critic_optimizer': str(self.critic_optimizer),
            'entropy_coef_optimizer': str(self.entropy_coef_optimizer),
            'weigh_and_reduce_entropy_coef_loss': str(self.weigh_and_reduce_entropy_coef_loss),
            'weigh_and_reduce_actor_loss': str(self.weigh_and_reduce_actor_loss),
            'weigh_critic_loss': str(self.weigh_critic_loss),
            'target_update_interval': self.target_update_interval,
            'target_entropy': self.target_entropy,
            'entropy_coef': self.entropy_coef_tensor.item() if self.entropy_coef_tensor is not None else 'dynamic',
        })

    def _setup_entropy_optimization(
            self,
            entropy_coef: float,
            target_entropy: float | Literal['auto'],
            entropy_coef_optimizer_provider: Optional[OptimizerProvider],
    ):
        if target_entropy == 'auto':
            self.target_entropy = float(-np.prod(get_single_action_space(self.env).shape).astype(np.float32))
        else:
            self.target_entropy = float(target_entropy)

        if entropy_coef_optimizer_provider is not None:
            self.log_ent_coef = torch.log(
                torch.tensor([entropy_coef], device=self.torch_device, dtype=self.torch_dtype)
            ).requires_grad_(True)
            self.entropy_coef_optimizer = entropy_coef_optimizer_provider([self.log_ent_coef])
            self.entropy_coef_tensor = None
        else:
            self.log_ent_coef = None
            self.entropy_coef_optimizer = None
            self.entropy_coef_tensor = torch.tensor(entropy_coef, device=self.torch_device, dtype=self.torch_dtype)

    def get_and_optimize_entropy_coef(
            self,
            actions_pi_log_prob: torch.Tensor,
            info: InfoDict
    ) -> torch.Tensor:
        if self.entropy_coef_optimizer is not None:
            entropy_coef = torch.exp(self.log_ent_coef.detach())

            # TODO!
            # entropy_coef_loss = weigh_and_reduce_loss(
            #     raw_loss=-self.log_ent_coef * (actions_pi_log_prob + self.target_entropy).detach(),
            #     weigh_and_reduce_function=self.weigh_and_reduce_entropy_coef_loss,
            #     info=info,
            #     loss_name='entropy_coef_loss',
            #     logging_config=self.logging_config.entropy_coef_loss
            # )

            entropy_coef_loss = -(self.log_ent_coef * (actions_pi_log_prob + self.target_entropy).detach()).mean()
            info['final_entropy_coef_loss'] = entropy_coef_loss.detach()

            self.entropy_coef_optimizer.zero_grad()
            entropy_coef_loss.backward()
            self.entropy_coef_optimizer.step()

            return entropy_coef
        else:
            return self.entropy_coef_tensor

    def calculate_critic_loss(
            self,
            observation_features: TensorObs,
            replay_samples: ReplayBufferSamples,
            entropy_coef: torch.Tensor,
            info: InfoDict,
    ):
        with torch.no_grad():
                # Select action according to policy
            next_actions, next_log_prob = self.actor.get_actions_with_log_probs(replay_samples.next_observations)
            # Compute the next Q values: min over all critics targets
            next_q_values = torch.cat(self.policy.critic_target(replay_samples.next_observations, next_actions), dim=1)
            next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True)
            # add entropy term
            next_q_values = next_q_values - entropy_coef * next_log_prob.reshape(-1, 1)
            # td error + entropy term
            target_q_values = replay_samples.rewards + (1 - replay_samples.dones) * self.gamma * next_q_values

        # target_q_values = self.policy.compute_target_values(
        #     replay_samples=replay_samples,
        #     entropy_coef=entropy_coef,
        #     gamma=self.gamma,
        # )
        # critic loss should not influence shared feature extractor
        current_q_values = self.critic(detach_obs(observation_features), replay_samples.actions)

        # noinspection PyTypeChecker
        critic_loss: torch.Tensor = 0.5 * sum(
            F.mse_loss(current_q, target_q_values) for current_q in current_q_values
        )
        # TODO!
        # critic_loss = weigh_and_reduce_loss(
        #     raw_loss=critic_loss,
        #     weigh_and_reduce_function=self.weigh_critic_loss,
        #     info=info,
        #     loss_name='critic_loss',
        #     logging_config=self.logging_config.critic_loss,
        # )

        info['final_critic_loss'] = critic_loss.detach()
        return critic_loss

    def calculate_actor_loss(
            self,
            observation_features: TensorObs,
            actions_pi: torch.Tensor,
            actions_pi_log_prob: torch.Tensor,
            entropy_coef: torch.Tensor,
            info: InfoDict,
    ) -> torch.Tensor:
        q_values_pi = torch.cat(self.critic(observation_features, actions_pi), dim=-1)
        min_q_values_pi, _ = torch.min(q_values_pi, dim=-1, keepdim=True)
        actor_loss = (entropy_coef * actions_pi_log_prob - min_q_values_pi).mean()  # TODO!

        # TODO!
        # actor_loss = weigh_and_reduce_loss(
        #     raw_loss=actor_loss,
        #     weigh_and_reduce_function=self.weigh_and_reduce_actor_loss,
        #     info=info,
        #     loss_name='actor_loss',
        #     logging_config=self.logging_config.actor_loss,
        # )

        info['final_actor_loss'] = actor_loss.detach()
        return actor_loss

    def optimize(self, last_obs: np.ndarray, last_episode_starts: np.ndarray, info: InfoDict) -> None:
        gradient_step_infos: list[InfoDict] = []

        for gradient_step in range(self.gradient_steps):
            step_info: InfoDict = {}
            replay_samples = self.buffer.sample(self.optimization_batch_size)

            # self.actor.reset_sde_noise()  # TODO: set batch size?

            # observation_features = self.shared_feature_extractor(replay_samples.observations)
            # observation_features = replay_samples.observations
            actions_pi, actions_pi_log_prob = self.actor.get_actions_with_log_probs(replay_samples.observations)
            actions_pi_log_prob = actions_pi_log_prob.reshape(-1, 1)

            entropy_coef = self.get_and_optimize_entropy_coef(actions_pi_log_prob, step_info)
            log_if_enabled(step_info, 'entropy_coef', entropy_coef, self.logging_config.log_entropy_coef)

            critic_loss = self.calculate_critic_loss(
                observation_features=replay_samples.observations,
                replay_samples=replay_samples,
                entropy_coef=entropy_coef,
                info=step_info
            )

            self.critic_optimizer.zero_grad()
            critic_loss.backward()
            self.critic_optimizer.step()

            actor_loss = self.calculate_actor_loss(
                observation_features=replay_samples.observations,
                actions_pi=actions_pi,
                actions_pi_log_prob=actions_pi_log_prob,
                entropy_coef=entropy_coef,
                info=step_info
            )

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            self.actor_optimizer.step()

            self.gradient_steps_performed += 1
            if self.target_update_interval > 0 and self.gradient_steps_performed % self.target_update_interval == 0:
                # self.policy.perform_polyak_update(self.tau)
                polyak_update(self.critic.parameters(), self.policy.critic_target.parameters(), self.tau)
            gradient_step_infos.append(step_info)
        info.update(concat_infos(gradient_step_infos))





In [2]:


from sac import init_policy, init_action_selector
from stable_baselines3.common.env_util import make_vec_env
import stable_baselines3 as sb
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
import gymnasium

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.1 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.05 }
env_kwargs = {}
num_envs = 1

def create_env(render_mode: str | None):
    return gymnasium.make(env_name, render_mode=render_mode, **env_kwargs)

# env = parallelize_env_async(lambda: create_env(render_mode=None), num_envs)
env = create_env(render_mode=None)

from stable_baselines3.sac.sac import SAC

sb_sac = SAC("MlpPolicy", env, verbose=10, learning_starts=10000, stats_window_size=1) # , seed=594371)

Using cuda device
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


In [3]:
import copy
from src.console import print_warning
from src.tags import Tags
from src.reinforcement_learning.core.policies.components.actor import Actor
from src.reinforcement_learning.core.policies.components.q_critic import QCritic
from src.reinforcement_learning.core.policies.base_policy import BasePolicy
import stable_baselines3 as sb


class DebugSACPolicy(BasePolicy):
    
    actor: sb.sac.policies.Actor

    def __init__(
            self,
            actor: Actor,
            critic: QCritic,
            shared_feature_extractor: Optional[FeatureExtractor] = None
    ):
        super().__init__(actor, shared_feature_extractor)
        self.actor = sb.sac.policies.Actor(
            env.observation_space,
            env.action_space,
            [256, 256],
            FlattenExtractor(env.observation_space),
            17
        )

        self.critic = critic

        self._build_target()

        self._check_action_selector()
        
    @property
    def uses_sde(self):
        return False
        
    def act(self, obs: TensorObs) -> torch.Tensor:
        return self.actor(obs, False)
    
    def reset_sde_noise(self, batch_size: int) -> None:
        pass
        

    def collect_hyper_parameters(self) -> HyperParameters:
        return {}

    def collect_tags(self) -> Tags:
        return []

    def _check_action_selector(self):
        # if not isinstance(self.actor.action_selector, (PredictedStdActionSelector, StateDependentNoiseActionSelector)):
        #     print_warning('SAC not being used with PredictedStdAction Selector or gSDE. LogStds should be clamped!')
        pass

    def _build_target(self):
        self.target_critic = copy.deepcopy(self.critic)
        self.target_critic.set_trainable(False)

        self.target_shared_feature_extractor = copy.deepcopy(self.shared_feature_extractor)
        self.target_shared_feature_extractor.set_trainable(False)

    def forward(self):
        raise NotImplementedError('forward is not used in SACPolicy')

    def compute_target_values(
            self,
            replay_samples: ReplayBufferSamples,
            entropy_coef: torch.Tensor,
            gamma: float,
    ):
        with torch.no_grad():
            next_observations = replay_samples.next_observations

            next_actions, next_actions_log_prob = self.actor.action_log_prob(
                self.shared_feature_extractor(next_observations)
            )

            next_q_values = torch.cat(
                self.target_critic(self.target_shared_feature_extractor(next_observations), next_actions),
                dim=-1
            )
            next_q_values, _ = torch.min(next_q_values, dim=-1, keepdim=True)
            next_q_values = next_q_values - entropy_coef * next_actions_log_prob.reshape(-1, 1)

            target_q_values = replay_samples.rewards + (1 - replay_samples.dones) * gamma * next_q_values

            return target_q_values


    def perform_polyak_update(self, tau: float):
        polyak_update(self.critic.parameters(), self.target_critic.parameters(), tau)
        polyak_update(
            self.shared_feature_extractor.parameters(),
            self.target_shared_feature_extractor.parameters(),
            tau
        )

    def set_train_mode(self, mode: bool) -> None:
        self.actor.set_training_mode(mode)
        self.critic.set_train_mode(mode)
        # Leaving target_critic on train_mode = False

        self.shared_feature_extractor.set_train_mode(mode)
        # Leaving target_shared_feature_extractor on train_mode = False

        self.train_mode = mode


In [4]:
import inspect
import time

from gymnasium import Env

from sac import init_action_selector, init_policy, init_optimizer, wrap_env, policy_construction_hyper_parameter
from src.datetime import get_current_timestamp
from src.experiment_logging.experiment_logger import ExperimentLogger, log_experiment
from src.model_db.dummy_model_db import DummyModelDB
from src.reinforcement_learning.algorithms.policy_mitosis.mitosis_policy_info import MitosisPolicyInfo
from src.module_analysis import count_parameters
from src.moving_averages import ExponentialMovingAverage
from src.reinforcement_learning.core.policies.base_policy import BasePolicy
from src.reinforcement_learning.core.policy_construction import PolicyConstruction
from src.stopwatch import Stopwatch
from src.summary_statistics import format_summary_statics
from typing import Any
from src.reinforcement_learning.core.callback import Callback

import torch
from torch import optim
import gymnasium as gym
import numpy as np

%load_ext autoreload
%autoreload 2

In [10]:
from src.summary_statistics import maybe_compute_summary_statistics
from src.reinforcement_learning.core.loss_config import LossLoggingConfig
from src.reinforcement_learning.algorithms.sac.sac import SAC, SACLoggingConfig
def get_setup() -> dict[str, str]:
    import inspect
    import sac
    return {
        'sac.py': inspect.getsource(sac),
        'notebook': _ih[1] + '\n\n' + _ih[-4] + '\n\n' + _ih[-3] + '\n\n' + _ih[-2] + '\n\n' + _ih[-1], # first and last cell input (imports and this cell)
    }

policy_id: str
policy: BasePolicy
optimizer: optim.Optimizer
wrapped_env: Env
steps_trained: int
def get_policy(create_new_if_exists: bool):
    
    global policy_id, policy, optimizer, wrapped_env, steps_trained
    
    policy_in_ram = 'policy' in globals()
    if not policy_in_ram or create_new_if_exists:
        if not policy_in_ram:
            print('No policy in RAM, creating a new one')
        
        policy_id = get_current_timestamp()
        policy, optimizer, wrapped_env = PolicyConstruction.init_from_info(
            env=env,
            info=PolicyConstruction.create_policy_initialization_info(
                init_action_selector=init_action_selector,
                init_policy=init_policy,
                init_optimizer=init_optimizer,
                wrap_env=wrap_env,
                hyper_parameters=policy_construction_hyper_parameter,
            ),
        )
        steps_trained = 0
        print(f'New policy {policy_id} created')
    
    if parent_policy_id is not None:
        model_entry = policy_db.load_model_state_dict(policy, parent_policy_id)
        steps_trained = model_entry['model_info']['steps_trained']
        print(f'Loading state dict from policy {parent_policy_id}')
    
    print(f'Using policy {policy_id} with parent policy {parent_policy_id}')
    return policy_id, policy, optimizer, wrapped_env, steps_trained

score_mean_ema = ExponentialMovingAverage(alpha=0.25)
step_stopwatch = Stopwatch()
total_stopwatch = Stopwatch()
best_iteration_score = -1e6

def on_rollout_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    
    if step % 1000 != 0:
        return
    
    # tail_indices = rl.buffer.tail_indices(1000)
    
    # rewards = rl.buffer.rewards[tail_indices]
    # if 'raw_rewards' in info['rollout']:
    #     rewards = info['rollout']['raw_rewards']
    
    # episode_scores = compute_episode_returns(
    #     rewards=rewards,
    #     episode_starts=np.repeat(np.arange(len(tail_indices)).reshape(-1, 1), num_envs, axis=1) % 1000 == 0,
    #     last_episode_starts=info['last_episode_starts'],
    #     gamma=1.0,
    #     gae_lambda=1.0,
    #     normalize_rewards=None,
    #     remove_unfinished_episodes=True,
    # )
    
    # episode_scores = rl.buffer.compute_most_recent_episode_scores(rl.num_envs)
    # 
    # if len(episode_scores) > 0:
    # 
    #     global best_iteration_score
    #     iteration_score = episode_scores.mean()
    #     score_moving_average = score_mean_ema.update(iteration_score)
    #     if iteration_score >= best_iteration_score:
    #         best_iteration_score = iteration_score
    #         policy_db.save_model_state_dict(
    #             model_id=policy_id,
    #             parent_model_id=parent_policy_id,
    #             model_info={
    #                 'score': iteration_score.item(),
    #                 'steps_trained': steps_trained,
    #                 'wrap_env_source_code': wrap_env_source_code_source,
    #                 'init_policy_source_code': init_policy_source
    #             },
    #             model=policy,
    #             optimizer=optimizer,
    #         )
    #     info['score_moving_average'] = score_moving_average
    # 
    # info['episode_scores'] = episode_scores
        
def on_optimization_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    # global steps_trained
    # steps_trained += rl.buffer.pos
    
    if step % 1000 != 0:
        return
    num_env_steps = step * rl.num_envs
    
    step_time = step_stopwatch.reset()
    total_time = total_stopwatch.time_passed()
    
    # TODO!!
    # tail_indices = rl.buffer.tail_indices(1000)
    
    # episode_scores = info.get('episode_scores')
    score_moving_average = info.get('score_moving_average') or 0.0
    
    tail_indices = np.arange(rl.buffer.pos - 1000, rl.buffer.pos)
    episode_scores = rl.buffer.rewards[tail_indices].sum(axis=0)
    
    scores = format_summary_statics(
        episode_scores, 
        mean_format=' 6.3f',
        std_format='4.3f',
        min_value_format=' 6.3f',
        max_value_format='5.3f',
        n_format='>2'
    )
    # scores2 = format_summary_statics(
    #     rl.buffer.compute_most_recent_episode_scores(rl.num_envs, lambda r: 1 * r), 
    #     mean_format=' 6.3f',
    #     std_format='4.3f',
    #     min_value_format=' 6.3f',
    #     max_value_format='5.3f',
    #     n_format='>2'
    # )
    # advantages = format_summary_statics(
    #     rl.buffer.advantages, 
    #     mean_format=' 6.3f',
    #     std_format='.1f',
    #     min_value_format=' 7.3f',
    #     max_value_format='6.3f',
    # )
    actor_loss = format_summary_statics(
        info['final_actor_loss'],  
        mean_format=' 5.3f',
        # std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    # actor_loss_raw = format_summary_statics(
    #     info['raw_actor_loss'],  
    #     mean_format=' 5.3f',
    #     std_format='5.3f',
    #     min_value_format=None,
    #     max_value_format=None,
    # )
    entropy_coef_loss = None if 'final_entropy_coef_loss' not in info else format_summary_statics(
        info['final_entropy_coef_loss'], 
        mean_format='5.3f',
#         std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    critic_loss = format_summary_statics(
        info['final_critic_loss'], 
        mean_format='5.3f',
#         std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    entropy_coef = format_summary_statics(
        info['entropy_coef'],
        mean_format='5.3f',
#         std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    # resets = format_summary_statics(
    #     rl.buffer.dones.astype(int).sum(axis=0), 
    #     mean_format='.2f',
    #     std_format=None,
    #     min_value_format='1d',
    #     max_value_format=None,
    # )
    # kl_div = info['actor_kl_divergence'][-1]
    # grad_norm = format_summary_statics(
    #     info['grad_norm'], 
    #     mean_format=' 6.3f',
    #     std_format='.1f',
    #     min_value_format=' 7.3f',
    #     max_value_format='6.3f',
    # )
    action_stds = info['rollout'].get('action_stds')
    if action_stds is not None:
        rollout_action_stds = format_summary_statics(
            action_stds,
            mean_format='5.3f',
            std_format='5.3f',
            min_value_format=None,
            max_value_format=None,
        )
    else:
        rollout_action_stds = 'N/A'
    action_magnitude = format_summary_statics(
        np.abs(rl.buffer.actions[tail_indices]),
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    # ppo_epochs = info['nr_ppo_epochs']
    # ppo_updates = info['nr_ppo_updates']
    # expl_var = rl.buffer.compute_critic_explained_variance()
    print(f"{step = : >7}, "
          f"{num_env_steps = : >7}, "
          f"{scores = :s}, "
          # f"{scores2 = :s}, "
          f'score_ema = {score_moving_average: 6.3f}, '
          # f"{advantages = :s}, "
          f"{actor_loss = :s}, "
          # f"{actor_loss_raw = :s}, "
          f"{critic_loss = :s}, "
          +(f"{entropy_coef_loss = :s}, " if entropy_coef_loss is not None else '')+
          f"{entropy_coef = :s}, "
          f"rollout_stds = {rollout_action_stds:s}, "
          f"{action_magnitude = :s}, "
          # f"{expl_var = :.3f}, "
          # f"{kl_div = :.4f}, "
          # f"{ppo_epochs = }, "
          # f"{ppo_updates = }, "
          # f"{grad_norm = :s}, "
          f"n_updates = {rl.gradient_steps_performed}, "
          # f"{resets = :s}, "
          f"time = {step_time:4.1f}, "
          f"total_time = {total_time:4.1f} \n"
          )
    logger.add_item({
        'step': step,
        'num_env_steps': num_env_steps,
        'scores': maybe_compute_summary_statistics(episode_scores),
        'actor_loss': maybe_compute_summary_statistics(info['final_actor_loss']),
        'entropy_coef_loss': maybe_compute_summary_statistics(info.get('final_entropy_coef_loss')),
        'critic_loss': maybe_compute_summary_statistics(info['final_critic_loss']),
        'entropy_coef': maybe_compute_summary_statistics(info['entropy_coef']),
        'action_stds': maybe_compute_summary_statistics(action_stds),
        'action_magnitude': maybe_compute_summary_statistics(np.abs(rl.buffer.actions[tail_indices])),
        'num_gradient_steps': rl.gradient_steps_performed,
        'step_time': step_time,
        'total_time': total_time
    })
    if step % 10000 == 0:
        logger.save_experiment_log()
        print()
    print()
    
    # if episode_scores is not None and len(episode_scores) > 0 and episode_scores.mean().item() < -500:
    #     logger.save_experiment_log()
    #     raise ValueError('Score too low, policy probably fucked :(')

device = torch.device("cuda:0") if True else torch.device('cpu')
print(f'using device {device}')

def create_env(render_mode: str | None):
    return gym.make(env_name, render_mode=render_mode, **env_kwargs)

wrap_env_source_code_source = inspect.getsource(wrap_env)
init_policy_source = inspect.getsource(init_policy)

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.1 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.05 }
env_kwargs = {}
num_envs = 1
    
# policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'saved_models/rl/{env_name}')
policy_db = DummyModelDB[MitosisPolicyInfo]()
print(f'{policy_db = }')

parent_policy_id=None  # '2024-04-28_20.57.23'

# TODO
# env = parallelize_env_async(lambda: create_env(render_mode=None), num_envs)
env = create_env(render_mode=None)

logger = ExperimentLogger(f'experiment_logs/{env_name}/sac/')

try:
    policy_id, policy, optimizer, wrapped_env, steps_trained = get_policy(create_new_if_exists=False)
    print(f'{count_parameters(policy) = }')
    print(f'{env = }, {num_envs = }')
        
    with ((torch.autograd.set_detect_anomaly(False))):
        algo = SAC(
            env=wrapped_env,
            policy=DebugSACPolicy(policy.actor, policy.critic),
            actor_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            critic_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            # weigh_and_reduce_actor_loss=lambda l: 1 * l.mean(),
            # weigh_critic_loss=lambda l: 1 * l,
            buffer_size=1_000_000,
            reward_scale=1,
            gamma=0.99,
            tau=0.005,
            entropy_coef_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),
            entropy_coef=1.0,
            rollout_steps=1,
            gradient_steps=1,
            warmup_steps=10_000,
            optimization_batch_size=256,
            target_update_interval=1,
            # sde_noise_sample_freq=50,
            callback=Callback(
                on_rollout_done=on_rollout_done,
                rollout_schedulers={},
                on_optimization_done=on_optimization_done,
                optimization_schedulers={},
            ),
            logging_config=SACLoggingConfig(log_rollout_infos=True, log_rollout_action_stds=True,
                                            log_last_obs=True, log_entropy_coef=True,
                                            entropy_coef_loss=LossLoggingConfig(log_final=True),
                                            actor_loss=LossLoggingConfig(log_final=True, log_raw=True),
                                            critic_loss=LossLoggingConfig(log_final=True)),
            torch_device=device,
        )
        
        # Todo!
        algo.buffer = sb_sac.replay_buffer
        algo.buffer.to_torch = lambda arr: torch.tensor(arr, device='cuda', dtype=torch.float32)
        
        total_stopwatch.reset()
        with log_experiment(
            logger,
            experiment_tags=algo.collect_tags() + ['Debug'],
            hyper_parameters=algo.collect_hyper_parameters(),
            setup=get_setup(),
        ) as x:
            logger.save_experiment_log()
            print('\nStarting Training\n\n')
            # import cProfile
            # pr = cProfile.Profile()
            # pr.enable()
            algo.learn(5_000_000)
            # pr.disable()  
            # pr.dump_stats('profile_stats.pstat')
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(0.5)
    env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

using device cuda:0
policy_db = DummyModelDB()
Using policy 2024-10-08 13:36:31.856491 with parent policy None
count_parameters(policy) = 217870
env = <TimeLimit<OrderEnforcing<PassiveEnvChecker<HalfCheetahEnv<HalfCheetah-v4>>>>>, num_envs = 1
Grabbing system information... done!
saved experiment log 2024-10-08_13-43-48_236302~fXPQM9 at experiment_logs/HalfCheetah-v4/sac/2024-10-08_13-43-48_236302~fXPQM9.json

Starting Training


step =   11000, num_env_steps =   11000, scores = -239.672 (n= 1), score_ema =  0.000, actor_loss = -59.483, critic_loss = 3.322, entropy_coef_loss = -3.001, entropy_coef = 0.741, rollout_stds = N/A, action_magnitude = 0.534 ± 0.286, n_updates = 1000, time = 12.8, total_time = 12.8 


step =   12000, num_env_steps =   12000, scores = -239.406 (n= 1), score_ema =  0.000, actor_loss = -63.789, critic_loss = 3.120, entropy_coef_loss = -5.929, entropy_coef = 0.549, rollout_stds = N/A, action_magnitude = 0.536 ± 0.289, n_updates = 2000, time =  9.5, total_time = 22

In [6]:
logger.experiment_log['experiment_id']

'2024-09-25_16-31-21_748992~egIuot'