In [1]:


from sac import init_policy, init_action_selector
from stable_baselines3.common.env_util import make_vec_env
import stable_baselines3 as sb
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
import gymnasium

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.1 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.05 }
env_kwargs = {}
num_envs = 1

def create_env(render_mode: str | None):
    return gymnasium.make(env_name, render_mode=render_mode, **env_kwargs)

# env = parallelize_env_async(lambda: create_env(render_mode=None), num_envs)
env = create_env(render_mode=None)

sb_sac = sb.SAC("MlpPolicy", env, verbose=10, learning_starts=10000, stats_window_size=1) # , seed=594371)

In [2]:
from src.reinforcement_learning.core.polyak_update import polyak_update
from src.reinforcement_learning.core.buffers.replay.base_replay_buffer import ReplayBufferSamples
from src.hyper_parameters import HyperParameters
import torch
from src.reinforcement_learning.core.type_aliases import TensorObs
from typing import Optional
from src.reinforcement_learning.core.policies.components.feature_extractors import FeatureExtractor
import copy
from src.console import print_warning
from src.tags import Tags
from src.reinforcement_learning.core.policies.components.actor import Actor
from src.reinforcement_learning.core.policies.components.q_critic import QCritic
from src.reinforcement_learning.core.policies.base_policy import BasePolicy
import stable_baselines3 as sb


class DebugSACPolicy(BasePolicy):
    
    actor: sb.sac.policies.Actor

    def __init__(
            self,
            actor: Actor,
            critic: QCritic,
            shared_feature_extractor: Optional[FeatureExtractor] = None
    ):
        super().__init__(actor, shared_feature_extractor)
        self.actor = sb_sac.actor
        self.critic = sb_sac.critic

        self._build_target()

        self._check_action_selector()
        
    @property
    def uses_sde(self):
        return False
        
    def act(self, obs: TensorObs) -> torch.Tensor:
        return self.actor(obs, False)
    
    def reset_sde_noise(self, batch_size: int) -> None:
        pass
        

    def collect_hyper_parameters(self) -> HyperParameters:
        return {}

    def collect_tags(self) -> Tags:
        return []

    def _check_action_selector(self):
        # if not isinstance(self.actor.action_selector, (PredictedStdActionSelector, StateDependentNoiseActionSelector)):
        #     print_warning('SAC not being used with PredictedStdAction Selector or gSDE. LogStds should be clamped!')
        pass

    def _build_target(self):
        self.target_critic = copy.deepcopy(self.critic)
        self.target_critic.set_training_mode(False)

        self.target_shared_feature_extractor = copy.deepcopy(self.shared_feature_extractor)
        self.target_shared_feature_extractor.set_trainable(False)

    def forward(self):
        raise NotImplementedError('forward is not used in SACPolicy')

    def compute_target_values(
            self,
            replay_samples: ReplayBufferSamples,
            entropy_coef: torch.Tensor,
            gamma: float,
    ):
        with torch.no_grad():
            next_observations = replay_samples.next_observations

            next_actions, next_actions_log_prob = self.actor.action_log_prob(
                self.shared_feature_extractor(next_observations)
            )

            next_q_values = torch.cat(
                self.target_critic(self.target_shared_feature_extractor(next_observations), next_actions),
                dim=-1
            )
            next_q_values, _ = torch.min(next_q_values, dim=-1, keepdim=True)
            next_q_values = next_q_values - entropy_coef * next_actions_log_prob.reshape(-1, 1)

            target_q_values = replay_samples.rewards + (1 - replay_samples.dones) * gamma * next_q_values

            return target_q_values


    def perform_polyak_update(self, tau: float):
        polyak_update(self.critic.parameters(), self.target_critic.parameters(), tau)
        polyak_update(
            self.shared_feature_extractor.parameters(),
            self.target_shared_feature_extractor.parameters(),
            tau
        )

    def set_train_mode(self, mode: bool) -> None:
        self.actor.set_training_mode(mode)
        self.critic.set_training_mode(mode)
        # Leaving target_critic on train_mode = False

        self.shared_feature_extractor.set_train_mode(mode)
        # Leaving target_shared_feature_extractor on train_mode = False

        self.train_mode = mode


In [3]:
from src.reinforcement_learning.algorithms.sac.sac import SACLoggingConfig, SAC
from dataclasses import dataclass
from typing import Type, Optional, Any, Literal

import gymnasium
import numpy as np
import torch
import torch.nn.functional as F
from torch import optim

from src.function_types import TorchTensorFn
from src.module_analysis import calculate_grad_norm
from src.hyper_parameters import HyperParameters
from src.reinforcement_learning.algorithms.base.base_algorithm import PolicyProvider
from src.reinforcement_learning.algorithms.base.off_policy_algorithm import OffPolicyAlgorithm, ReplayBuf
from src.reinforcement_learning.algorithms.sac.sac_crossq_policy import SACCrossQPolicy
from src.reinforcement_learning.algorithms.sac.sac_policy import SACPolicy
from src.reinforcement_learning.core.action_noise import ActionNoise
from src.reinforcement_learning.core.buffers.replay.base_replay_buffer import BaseReplayBuffer, ReplayBufferSamples
from src.reinforcement_learning.core.buffers.replay.replay_buffer import ReplayBuffer
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.infos import InfoDict, concat_infos
from src.reinforcement_learning.core.logging import LoggingConfig, log_if_enabled
from src.reinforcement_learning.core.loss_config import weigh_and_reduce_loss, LossLoggingConfig
from src.reinforcement_learning.core.type_aliases import OptimizerProvider, TensorObs, detach_obs
from src.reinforcement_learning.gym.env_analysis import get_single_action_space
from src.tags import Tags
from src.torch_device import TorchDevice
from src.torch_functions import identity
from src.repr_utils import func_repr

from typing import Literal

SAC_DEFAULT_OPTIMIZER_PROVIDER = lambda params: optim.AdamW(params, lr=3e-4, weight_decay=1e-4)
AUTO_TARGET_ENTROPY = 'auto'


class SACDebug(SAC):
    
    @property
    def replay_buffer(self):
        return self.buffer

    buffer: BaseReplayBuffer
    target_entropy: float
    log_ent_coef: Optional[torch.Tensor]
    entropy_coef_optimizer: Optional[optim.Optimizer]
    entropy_coef_tensor: Optional[torch.Tensor]
    
    def collect_hyper_parameters(self) -> HyperParameters:
        print(f'{type(self.policy) = }, {type(self.policy.actor) = }, {type(self.policy.critic) = }, {type(self.policy.target_critic) = }, {type(self.buffer) = }')
        return super().collect_hyper_parameters()
    
    
    def _setup_entropy_optimization(
            self,
            entropy_coef: float,
            target_entropy: float | Literal['auto'],
            entropy_coef_optimizer_provider: Optional[OptimizerProvider],
    ):
        if target_entropy == 'auto':
            self.target_entropy = float(-np.prod(get_single_action_space(self.env).shape).astype(np.float32))
        else:
            self.target_entropy = float(target_entropy)

        if entropy_coef_optimizer_provider is not None:
            self.log_ent_coef = torch.log(
                torch.tensor([entropy_coef], device=self.torch_device, dtype=self.torch_dtype)
            ).requires_grad_(True)
            self.entropy_coef_optimizer = entropy_coef_optimizer_provider([self.log_ent_coef])
            self.entropy_coef_tensor = None
        else:
            self.log_ent_coef = None
            self.entropy_coef_optimizer = None
            self.entropy_coef_tensor = torch.tensor(entropy_coef, device=self.torch_device, dtype=self.torch_dtype)

    # def get_and_optimize_entropy_coef(
    #         self,
    #         actions_pi_log_prob: torch.Tensor,
    #         info: InfoDict
    # ) -> torch.Tensor:
    #     if self.entropy_coef_optimizer is not None:
    #         entropy_coef = torch.exp(self.log_ent_coef.detach())
    # 
    #         entropy_coef_loss = weigh_and_reduce_loss(
    #             raw_loss=-self.log_ent_coef * (actions_pi_log_prob + self.target_entropy).detach(),
    #             weigh_and_reduce_function=self.weigh_and_reduce_entropy_coef_loss,
    #             info=info,
    #             loss_name='entropy_coef_loss',
    #             logging_config=self.logging_config.entropy_coef_loss
    #         )
    #         self.entropy_coef_optimizer.zero_grad()
    #         entropy_coef_loss.backward()
    #         self.entropy_coef_optimizer.step()
    # 
    #         return entropy_coef
    #     else:
    #         return self.entropy_coef_tensor
    # 
    # def calculate_critic_loss(
    #         self,
    #         observation_features: TensorObs,
    #         replay_samples: ReplayBufferSamples,
    #         entropy_coef: torch.Tensor,
    #         info: InfoDict,
    # ):
    #     target_q_values = self.policy.compute_target_values(
    #         replay_samples=replay_samples,
    #         entropy_coef=entropy_coef,
    #         gamma=self.gamma,
    #     )
    #     # critic loss should not influence shared feature extractor
    #     current_q_values = self.critic(detach_obs(observation_features), replay_samples.actions)
    # 
    #     # noinspection PyTypeChecker
    #     critic_loss: torch.Tensor = 0.5 * sum(
    #         F.mse_loss(current_q, target_q_values) for current_q in current_q_values
    #     )
    #     critic_loss = weigh_and_reduce_loss(
    #         raw_loss=critic_loss,
    #         weigh_and_reduce_function=self.weigh_critic_loss,
    #         info=info,
    #         loss_name='critic_loss',
    #         logging_config=self.logging_config.critic_loss,
    #     )
    #     return critic_loss
    # 
    # def calculate_actor_loss(
    #         self,
    #         observation_features: TensorObs,
    #         actions_pi: torch.Tensor,
    #         actions_pi_log_prob: torch.Tensor,
    #         entropy_coef: torch.Tensor,
    #         info: InfoDict,
    # ) -> torch.Tensor:
    #     q_values_pi = torch.cat(self.critic(observation_features, actions_pi), dim=-1)
    #     min_q_values_pi, _ = torch.min(q_values_pi, dim=-1, keepdim=True)
    #     actor_loss = entropy_coef * actions_pi_log_prob - min_q_values_pi
    # 
    #     actor_loss = weigh_and_reduce_loss(
    #         raw_loss=actor_loss,
    #         weigh_and_reduce_function=self.weigh_and_reduce_actor_loss,
    #         info=info,
    #         loss_name='actor_loss',
    #         logging_config=self.logging_config.actor_loss,
    #     )
    # 
    #     return actor_loss

    def optimize(self, last_obs: np.ndarray, last_episode_starts: np.ndarray, info: InfoDict) -> None:
        ent_coef_losses, ent_coefs = [], []
        actor_losses, critic_losses = [], []

        for gradient_step in range(self.gradient_steps):
            # Sample replay buffer
            replay_data = self.replay_buffer.sample(self.optimization_batch_size, env=None)  # type: ignore[union-attr]

            # We need to sample because `log_std` may have changed between two gradient steps
            # if self.sde_noise_sample_freq:
            #     self.actor.reset_noise()

            # Action by the current actor for the sampled state
            actions_pi, log_prob = self.actor.action_log_prob(replay_data.observations)
            log_prob = log_prob.reshape(-1, 1)

            ent_coef_loss = None
            if self.entropy_coef_optimizer is not None and self.log_ent_coef is not None:
                # Important: detach the variable from the graph
                # so we don't change it with other losses
                # see https://github.com/rail-berkeley/softlearning/issues/60
                ent_coef = torch.exp(self.log_ent_coef.detach())
                ent_coef_loss = -(self.log_ent_coef * (log_prob + self.target_entropy).detach()).mean()
                ent_coef_losses.append(ent_coef_loss.item())
            else:
                ent_coef = self.entropy_coef_tensor

            ent_coefs.append(ent_coef.item())

            # Optimize entropy coefficient, also called
            # entropy temperature or alpha in the paper
            if ent_coef_loss is not None and self.entropy_coef_optimizer is not None:
                self.entropy_coef_optimizer.zero_grad()
                ent_coef_loss.backward()
                self.entropy_coef_optimizer.step()

            with torch.no_grad():
                # Select action according to policy
                next_actions, next_log_prob = self.actor.action_log_prob(replay_data.next_observations)
                # Compute the next Q values: min over all critics targets
                next_q_values = torch.cat(self.policy.target_critic(replay_data.next_observations, next_actions), dim=1)
                next_q_values, _ = torch.min(next_q_values, dim=1, keepdim=True)
                # add entropy term
                next_q_values = next_q_values - ent_coef * next_log_prob.reshape(-1, 1)
                # td error + entropy term
                target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values

            # Get current Q-values estimates for each critic network
            # using action from the replay buffer
            current_q_values = self.critic(replay_data.observations, replay_data.actions)

            # Compute critic loss
            critic_loss = 0.5 * sum(F.mse_loss(current_q, target_q_values) for current_q in current_q_values)
            assert isinstance(critic_loss, torch.Tensor)  # for type checker
            critic_losses.append(critic_loss.item())  # type: ignore[union-attr]

            # Optimize the critic
            self.critic.optimizer.zero_grad()
            critic_loss.backward()
            self.critic.optimizer.step()

            # Compute actor loss
            # Alternative: actor_loss = torch.mean(log_prob - qf1_pi)
            # Min over all critic networks
            q_values_pi = torch.cat(self.critic(replay_data.observations, actions_pi), dim=1)
            min_qf_pi, _ = torch.min(q_values_pi, dim=1, keepdim=True)
            actor_loss = (ent_coef * log_prob - min_qf_pi).mean()
            actor_losses.append(actor_loss.item())

            # Optimize the actor
            self.actor.optimizer.zero_grad()
            actor_loss.backward()
            self.actor.optimizer.step()

            # Update target networks
            if gradient_step % self.target_update_interval == 0:
                polyak_update(self.critic.parameters(), self.policy.target_critic.parameters(), self.tau)
                # Copy running stats, see GH issue #996
        
        info['entropy_coef'] = np.array(ent_coefs)
        info['final_entropy_coef_loss'] = np.array(ent_coef_losses)
        info['final_actor_loss'] = np.array(actor_losses)
        info['final_critic_loss'] = np.array(critic_losses)





In [4]:
import inspect
import time

from gymnasium import Env

from sac import init_action_selector, init_policy, init_optimizer, wrap_env, policy_construction_hyper_parameter
from src.datetime import get_current_timestamp
from src.experiment_logging.experiment_logger import ExperimentLogger, log_experiment
from src.model_db.dummy_model_db import DummyModelDB
from src.reinforcement_learning.algorithms.policy_mitosis.mitosis_policy_info import MitosisPolicyInfo
from src.module_analysis import count_parameters
from src.moving_averages import ExponentialMovingAverage
from src.reinforcement_learning.core.policies.base_policy import BasePolicy
from src.reinforcement_learning.core.policy_construction import PolicyConstruction
from src.stopwatch import Stopwatch
from src.summary_statistics import format_summary_statics
from typing import Any
from src.reinforcement_learning.core.callback import Callback

import torch
from torch import optim
import gymnasium as gym
import numpy as np

%load_ext autoreload
%autoreload 2

In [5]:
from src.summary_statistics import maybe_compute_summary_statistics
from src.reinforcement_learning.core.loss_config import LossLoggingConfig
from src.reinforcement_learning.algorithms.sac.sac import SAC, SACLoggingConfig
def get_setup() -> dict[str, str]:
    import inspect
    import sac
    return {
        'sac.py': inspect.getsource(sac),
        'notebook': _ih[1] + '\n\n' + _ih[-4] + '\n\n' + _ih[-3] + '\n\n' + _ih[-2] + '\n\n' + _ih[-1], # first and last cell input (imports and this cell)
    }

policy_id: str
policy: BasePolicy
optimizer: optim.Optimizer
wrapped_env: Env
steps_trained: int
def get_policy(create_new_if_exists: bool):
    
    global policy_id, policy, optimizer, wrapped_env, steps_trained
    
    policy_in_ram = 'policy' in globals()
    if not policy_in_ram or create_new_if_exists:
        if not policy_in_ram:
            print('No policy in RAM, creating a new one')
        
        policy_id = get_current_timestamp()
        policy, optimizer, wrapped_env = PolicyConstruction.init_from_info(
            env=env,
            info=PolicyConstruction.create_policy_initialization_info(
                init_action_selector=init_action_selector,
                init_policy=init_policy,
                init_optimizer=init_optimizer,
                wrap_env=wrap_env,
                hyper_parameters=policy_construction_hyper_parameter,
            ),
        )
        steps_trained = 0
        print(f'New policy {policy_id} created')
    
    if parent_policy_id is not None:
        model_entry = policy_db.load_model_state_dict(policy, parent_policy_id)
        steps_trained = model_entry['model_info']['steps_trained']
        print(f'Loading state dict from policy {parent_policy_id}')
    
    print(f'Using policy {policy_id} with parent policy {parent_policy_id}')
    return policy_id, policy, optimizer, wrapped_env, steps_trained

score_mean_ema = ExponentialMovingAverage(alpha=0.25)
step_stopwatch = Stopwatch()
total_stopwatch = Stopwatch()
best_iteration_score = -1e6

def on_rollout_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    
    if step % 1000 != 0:
        return
    
    # tail_indices = rl.buffer.tail_indices(1000)
    
    # rewards = rl.buffer.rewards[tail_indices]
    # if 'raw_rewards' in info['rollout']:
    #     rewards = info['rollout']['raw_rewards']
    
    # episode_scores = compute_episode_returns(
    #     rewards=rewards,
    #     episode_starts=np.repeat(np.arange(len(tail_indices)).reshape(-1, 1), num_envs, axis=1) % 1000 == 0,
    #     last_episode_starts=info['last_episode_starts'],
    #     gamma=1.0,
    #     gae_lambda=1.0,
    #     normalize_rewards=None,
    #     remove_unfinished_episodes=True,
    # )
    
    # episode_scores = rl.buffer.compute_most_recent_episode_scores(rl.num_envs)
    # 
    # if len(episode_scores) > 0:
    # 
    #     global best_iteration_score
    #     iteration_score = episode_scores.mean()
    #     score_moving_average = score_mean_ema.update(iteration_score)
    #     if iteration_score >= best_iteration_score:
    #         best_iteration_score = iteration_score
    #         policy_db.save_model_state_dict(
    #             model_id=policy_id,
    #             parent_model_id=parent_policy_id,
    #             model_info={
    #                 'score': iteration_score.item(),
    #                 'steps_trained': steps_trained,
    #                 'wrap_env_source_code': wrap_env_source_code_source,
    #                 'init_policy_source_code': init_policy_source
    #             },
    #             model=policy,
    #             optimizer=optimizer,
    #         )
    #     info['score_moving_average'] = score_moving_average
    # 
    # info['episode_scores'] = episode_scores
        
def on_optimization_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    # global steps_trained
    # steps_trained += rl.buffer.pos
    
    if step % 1000 != 0:
        return
    num_env_steps = step * rl.num_envs
    
    step_time = step_stopwatch.reset()
    total_time = total_stopwatch.time_passed()
    
    # TODO!!
    # tail_indices = rl.buffer.tail_indices(1000)
    
    # episode_scores = info.get('episode_scores')
    score_moving_average = info.get('score_moving_average') or 0.0
    
    tail_indices = np.arange(rl.buffer.pos - 1000, rl.buffer.pos)
    episode_scores = rl.buffer.rewards[tail_indices].sum(axis=0)
    
    scores = format_summary_statics(
        episode_scores, 
        mean_format=' 6.3f',
        std_format='4.3f',
        min_value_format=' 6.3f',
        max_value_format='5.3f',
        n_format='>2'
    )
    # scores2 = format_summary_statics(
    #     rl.buffer.compute_most_recent_episode_scores(rl.num_envs, lambda r: 1 * r), 
    #     mean_format=' 6.3f',
    #     std_format='4.3f',
    #     min_value_format=' 6.3f',
    #     max_value_format='5.3f',
    #     n_format='>2'
    # )
    # advantages = format_summary_statics(
    #     rl.buffer.advantages, 
    #     mean_format=' 6.3f',
    #     std_format='.1f',
    #     min_value_format=' 7.3f',
    #     max_value_format='6.3f',
    # )
    actor_loss = format_summary_statics(
        info['final_actor_loss'],  
        mean_format=' 5.3f',
        # std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    # actor_loss_raw = format_summary_statics(
    #     info['raw_actor_loss'],  
    #     mean_format=' 5.3f',
    #     std_format='5.3f',
    #     min_value_format=None,
    #     max_value_format=None,
    # )
    entropy_coef_loss = None if 'final_entropy_coef_loss' not in info else format_summary_statics(
        info['final_entropy_coef_loss'], 
        mean_format='5.3f',
#         std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    critic_loss = format_summary_statics(
        info['final_critic_loss'], 
        mean_format='5.3f',
#         std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    entropy_coef = format_summary_statics(
        info['entropy_coef'],
        mean_format='5.3f',
#         std_format='5.3f',
        std_format=None,
        min_value_format=None,
        max_value_format=None,
    )
    # resets = format_summary_statics(
    #     rl.buffer.dones.astype(int).sum(axis=0), 
    #     mean_format='.2f',
    #     std_format=None,
    #     min_value_format='1d',
    #     max_value_format=None,
    # )
    # kl_div = info['actor_kl_divergence'][-1]
    # grad_norm = format_summary_statics(
    #     info['grad_norm'], 
    #     mean_format=' 6.3f',
    #     std_format='.1f',
    #     min_value_format=' 7.3f',
    #     max_value_format='6.3f',
    # )
    action_stds = info['rollout'].get('action_stds')
    if action_stds is not None:
        rollout_action_stds = format_summary_statics(
            action_stds,
            mean_format='5.3f',
            std_format='5.3f',
            min_value_format=None,
            max_value_format=None,
        )
    else:
        rollout_action_stds = 'N/A'
    action_magnitude = format_summary_statics(
        np.abs(rl.buffer.actions[tail_indices]),
        mean_format='5.3f',
        std_format='5.3f',
        min_value_format=None,
        max_value_format=None,
    )
    # ppo_epochs = info['nr_ppo_epochs']
    # ppo_updates = info['nr_ppo_updates']
    # expl_var = rl.buffer.compute_critic_explained_variance()
    print(f"{step = : >7}, "
          f"{num_env_steps = : >7}, "
          f"{scores = :s}, "
          # f"{scores2 = :s}, "
          f'score_ema = {score_moving_average: 6.3f}, '
          # f"{advantages = :s}, "
          f"{actor_loss = :s}, "
          # f"{actor_loss_raw = :s}, "
          f"{critic_loss = :s}, "
          +(f"{entropy_coef_loss = :s}, " if entropy_coef_loss is not None else '')+
          f"{entropy_coef = :s}, "
          f"rollout_stds = {rollout_action_stds:s}, "
          f"{action_magnitude = :s}, "
          # f"{expl_var = :.3f}, "
          # f"{kl_div = :.4f}, "
          # f"{ppo_epochs = }, "
          # f"{ppo_updates = }, "
          # f"{grad_norm = :s}, "
          f"n_updates = {rl.gradient_steps_performed}, "
          # f"{resets = :s}, "
          f"time = {step_time:4.1f}, "
          f"total_time = {total_time:4.1f} \n"
          )
    logger.add_item({
        'step': step,
        'num_env_steps': num_env_steps,
        'scores': maybe_compute_summary_statistics(episode_scores),
        'actor_loss': maybe_compute_summary_statistics(info['final_actor_loss']),
        'entropy_coef_loss': maybe_compute_summary_statistics(info.get('final_entropy_coef_loss')),
        'critic_loss': maybe_compute_summary_statistics(info['final_critic_loss']),
        'entropy_coef': maybe_compute_summary_statistics(info['entropy_coef']),
        'action_stds': maybe_compute_summary_statistics(action_stds),
        'action_magnitude': maybe_compute_summary_statistics(np.abs(rl.buffer.actions[tail_indices])),
        'num_gradient_steps': rl.gradient_steps_performed,
        'step_time': step_time,
        'total_time': total_time
    })
    if step % 10000 == 0:
        logger.save_experiment_log()
        print()
    print()
    
    # if episode_scores is not None and len(episode_scores) > 0 and episode_scores.mean().item() < -500:
    #     logger.save_experiment_log()
    #     raise ValueError('Score too low, policy probably fucked :(')

device = torch.device("cuda:0") if True else torch.device('cpu')
print(f'using device {device}')

def create_env(render_mode: str | None):
    return gym.make(env_name, render_mode=render_mode, **env_kwargs)

wrap_env_source_code_source = inspect.getsource(wrap_env)
init_policy_source = inspect.getsource(init_policy)

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.1 }
# env_kwargs = {'forward_reward_weight': 1.25, 'ctrl_cost_weight': 0.05 }
env_kwargs = {}
num_envs = 1
    
# policy_db = TinyModelDB[MitosisPolicyInfo](base_path=f'saved_models/rl/{env_name}')
policy_db = DummyModelDB[MitosisPolicyInfo]()
print(f'{policy_db = }')

parent_policy_id=None  # '2024-04-28_20.57.23'

# TODO
# env = parallelize_env_async(lambda: create_env(render_mode=None), num_envs)
env = create_env(render_mode=None)

logger = ExperimentLogger(f'experiment_logs/{env_name}/sac/')

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

try:
    policy_id, policy, optimizer, wrapped_env, steps_trained = get_policy(create_new_if_exists=False)
    print(f'{count_parameters(policy) = }')
    print(f'{env = }, {num_envs = }')
        
    with ((torch.autograd.set_detect_anomaly(False))):
        algo = SACDebug(
            env=wrapped_env,
            policy=DebugSACPolicy(policy.actor, policy.critic),
            actor_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            critic_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            # weigh_and_reduce_actor_loss=lambda l: 1 * l.mean(),
            # weigh_critic_loss=lambda l: 1 * l,
            buffer_size=1_000_000,
            reward_scale=1,
            gamma=0.99,
            tau=0.005,
            entropy_coef_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),
            entropy_coef=1.0,
            rollout_steps=1,
            gradient_steps=1,
            warmup_steps=10_000,
            optimization_batch_size=256,
            target_update_interval=1,
            # sde_noise_sample_freq=50,
            callback=Callback(
                on_rollout_done=on_rollout_done,
                rollout_schedulers={},
                on_optimization_done=on_optimization_done,
                optimization_schedulers={},
            ),
            logging_config=SACLoggingConfig(log_rollout_infos=True, log_rollout_action_stds=True,
                                            log_last_obs=True, log_entropy_coef=True,
                                            entropy_coef_loss=LossLoggingConfig(log_final=True),
                                            actor_loss=LossLoggingConfig(log_final=True, log_raw=True),
                                            critic_loss=LossLoggingConfig(log_final=True)),
            torch_device=device,
        )
        
        # Todo!
        algo.buffer = sb_sac.replay_buffer
        algo.buffer.to_torch = lambda arr: torch.tensor(arr, device='cuda', dtype=torch.float32)
        
        total_stopwatch.reset()
        with log_experiment(
            logger,
            experiment_tags=algo.collect_tags() + ['Debug'],
            hyper_parameters=algo.collect_hyper_parameters(),
            setup=get_setup(),
        ) as x:
            logger.save_experiment_log()
            print('\nStarting Training\n\n')
            # import cProfile
            # pr = cProfile.Profile()
            # pr.enable()
            algo.learn(5_000_000)
            # pr.disable()  
            # pr.dump_stats('profile_stats.pstat')
except KeyboardInterrupt:
    print('keyboard interrupt')
finally:
    print('closing envs')
    time.sleep(0.5)
    env.close()
    print('envs closed')
    policy_db.close()
    print('model db closed')
    

print('done')

In [6]:
logger.experiment_log['experiment_id']

In [6]:
algo.buffer.observations.shape