In [44]:

import time
from typing import Any

import gymnasium as gym
import numpy as np
import torch
import tqdm
from torch import optim, nn

from src.experiment_logging.experiment_log import ExperimentLogItem
from src.experiment_logging.experiment_logger import ExperimentLogger, log_experiment
from src.module_analysis import count_parameters
from src.reinforcement_learning.algorithms.sac.sac import SAC, SACInfoStashConfig
from src.reinforcement_learning.algorithms.sac.sac_policy import SACPolicy
from src.reinforcement_learning.core.action_selectors.predicted_std_action_selector import PredictedStdActionSelector
from src.reinforcement_learning.core.callback import Callback
from src.reinforcement_learning.core.loss_config import LossInfoStashConfig
from src.reinforcement_learning.core.policies.components.actor import Actor
from src.reinforcement_learning.core.policies.components.q_critic import QCritic
from src.reinforcement_learning.gym.parallelize_env import parallelize_env_async
from src.stopwatch import Stopwatch
from src.summary_statistics import maybe_compute_summary_statistics

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
def get_setup() -> dict[str, str]:
    return {
        'notebook': _ih[1] + '\n\n' + _ih[-1], # first and last cell input (imports and this cell)
    }

step_stopwatch = Stopwatch()
total_stopwatch = Stopwatch()
best_iteration_score = -1e6

def save_experiment_state():
    experiment_id = logger.experiment_log['experiment_id']
    algo.save(
        folder_location=f'models/{env_name}/{experiment_id}', 
        name=experiment_id, 
        latest_log_item=logger.get_latest_log_item()
    )
    

def on_rollout_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):
    if step % 1000 != 0:
        return
    
    episode_scores = rl.buffer.compute_most_recent_episode_scores(rl.num_envs, consider_truncated_as_done=True)
    
    if len(episode_scores) > 0:
    
        global best_iteration_score
        iteration_score = episode_scores.mean()
        if iteration_score >= best_iteration_score:
            pass
    
    info['episode_scores'] = episode_scores
        
def on_optimization_done(rl: SAC, step: int, info: dict[str, Any], scheduler_values: dict[str, Any]):    
    if step % 1000 != 0:
        return
    
    num_env_steps = step * rl.num_envs
    
    step_time = step_stopwatch.reset()
    total_time = total_stopwatch.time_passed()
    
    tail_indices = rl.buffer.tail_indices(1000)
    
    episode_scores = info.get('episode_scores')
    
    log_item: ExperimentLogItem = {
        'step': step,
        'num_env_steps': num_env_steps,
        'scores': maybe_compute_summary_statistics(episode_scores),
        'actor_loss': maybe_compute_summary_statistics(info['final_actor_loss']),
        'critic_loss': maybe_compute_summary_statistics(info['final_critic_loss']),
        'entropy_coef_loss': maybe_compute_summary_statistics(info.get('final_entropy_coef_loss')),
        'entropy_coef': maybe_compute_summary_statistics(info['entropy_coef']),
        'action_stds': maybe_compute_summary_statistics(info['rollout'].get('action_stds')),
        'action_magnitude': maybe_compute_summary_statistics(np.abs(rl.buffer.actions[tail_indices])),
        'num_gradient_steps': rl.gradient_steps_performed,
        'step_time': step_time,
        'total_time': total_time
    }
    print(logger.format_log_item(
        log_item, 
        mean_format='5.3f', 
        std_format='5.3f', 
        step_time='.2f', 
        total_time='.2f',
        scores={
            'mean_format': '5.3f',
            'n_format': 'd',
        }
    ), end='\n\n')
    logger.add_item(log_item)
    if step % 10000 == 0:
        experiment_id = logger.save_experiment_log()["experiment_id"]
        
        if step % 100_000 == 0:
            rl.policy.save(f'saved_models/{env_name}/sac/{experiment_id}.pth')
            
        print()
    print()

device = torch.device("cuda:0") if True else torch.device('cpu')
print(f'using device {device}')

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
env_kwargs = {}
num_envs = 1

def create_env(render_mode: str | None):
    make_single_env = lambda: gym.make(env_name, render_mode=render_mode, **env_kwargs)
    
    if num_envs == 1:
        return make_single_env()
        
    return parallelize_env_async(make_single_env, num_envs)


def create_policy():
    in_size = 17
    action_size = 6
    
    actor_net = nn.Sequential(
        nn.Linear(in_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
    )

    critic = QCritic(
        n_critics=2,
        create_q_network=lambda: nn.Sequential(
            nn.Linear(in_size + action_size, 256),
            nn.ReLU(),
            # BatchRenorm(256),
            nn.Linear(256, 256),
            nn.ReLU(),
            # BatchRenorm(256),
            nn.Linear(256, 1)
        )
    )

    return SACPolicy(
        actor=Actor(actor_net, PredictedStdActionSelector(
            latent_dim=256,
            action_dim=action_size,
            base_std=1.0,
            squash_output=True,
        )),
        critic=critic
    )


env = create_env(render_mode=None)
policy = create_policy()
logger = ExperimentLogger(f'experiment_logs/{env_name}/sac/')

try:
    print(f'{count_parameters(policy) = }')
    print(f'{env = }, {num_envs = }')
        
    with ((torch.autograd.set_detect_anomaly(False))):
        algo = SAC(
            env=env,
            policy=policy,
            actor_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            critic_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            buffer_size=1_000_000,
            reward_scale=1,
            gamma=0.99,
            tau=0.005,
            entropy_coef_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),
            entropy_coef=1.0,
            rollout_steps=1,
            gradient_steps=1,
            warmup_steps=10_000,
            optimization_batch_size=256,
            target_update_interval=1,
            callback=Callback(
                on_rollout_done=on_rollout_done,
                rollout_schedulers={},
                on_optimization_done=on_optimization_done,
                optimization_schedulers={},
            ),
            stash_config=SACInfoStashConfig(stash_rollout_infos=True, stash_rollout_action_stds=True,
                                            stash_entropy_coef=True,
                                            entropy_coef_loss=LossInfoStashConfig(stash_final=True),
                                            actor_loss=LossInfoStashConfig(stash_final=True),
                                            critic_loss=LossInfoStashConfig(stash_final=True)),
            torch_device=device,
        )
        total_stopwatch.reset()
        with log_experiment(
            logger,
            experiment_tags=algo.collect_tags(),
            hyper_parameters=algo.collect_hyper_parameters(),
            setup=get_setup(),
        ) as x:
            logger.save_experiment_log()
            print('\nStarting Training\n\n')
            # import cProfile
            # pr = cProfile.Profile()
            # pr.enable()
            algo.learn(5_000_000)
            # pr.disable()  
            # pr.dump_stats('profile_stats.pstat')
except KeyboardInterrupt as ki:
    print('keyboard interrupt')
    raise ki
finally:
    print('closing envs')
    time.sleep(0.5)
    env.close()
    print('envs closed')
    

print('done')

using device cuda:0
count_parameters(policy) = 217870
env = <TimeLimit<OrderEnforcing<PassiveEnvChecker<HalfCheetahEnv<HalfCheetah-v4>>>>>, num_envs = 1
Grabbing system information... done!
saved experiment log 2024-10-21_09-39-59_555014~4J6lFW at experiment_logs/HalfCheetah-v4/sac/2024-10-21_09-39-59_555014~4J6lFW.json

Starting Training


self.log_ent_coef = tensor([-0.3001], device='cuda:0', requires_grad=True)
actions_pi_log_prob = tensor([[-4.3998],
        [-3.9373],
        [-3.4147],
        [-4.3203],
        [-4.1546],
        [-3.9914],
        [-4.9038],
        [-3.5747],
        [-4.5029],
        [-4.3757],
        [-4.1280],
        [-4.8241],
        [-4.3853],
        [-4.6347],
        [-3.9751],
        [-3.9438],
        [-4.3633],
        [-4.4005],
        [-5.4730],
        [-3.9581],
        [-4.6724],
        [-4.7304],
        [-4.3266],
        [-3.8694],
        [-4.4207],
        [-3.8882],
        [-4.3190],
        [-3.4612],
        [-4.2863],
        [

KeyboardInterrupt: 

In [1]:
from src.datetime import get_current_timestamp

get_current_timestamp()

'2024-11-02 23:58:50.056338'

In [97]:
from src.reinforcement_learning.core.buffers.replay.ring_with_reservoir_replay_buffer import \
    RingWithReservoirReplayBuffer
from src.reinforcement_learning.core.buffers.replay.reservoir_replay_buffer import ReservoirReplayBuffer
from src.reinforcement_learning.core.buffers.replay.ring_replay_buffer import RingReplayBuffer
from src.reinforcement_learning.gym.singleton_vector_env import SingletonVectorEnv
from src.reinforcement_learning.gym.envs.test_env import TestEnv

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
env_kwargs = {}
num_envs = 1

device = torch.device("cuda:0") if False else torch.device('cpu')
env = SingletonVectorEnv(TestEnv(3))
obs, info = env.reset()

rep: RingReplayBuffer = RingReplayBuffer.for_env(
    env,
    buffer_size=20_000,
    reward_scale=10,
    torch_device=device
)
res = RingWithReservoirReplayBuffer.for_env(
    env,
    buffer_size=20_000,
    reward_scale=10,
    torch_device=device
)



for i in tqdm.tqdm(range(1, 1_000_000)):    
    next_obs, rew, term, trunc, info = env.step(np.array([[i, i*10]]))
    
    rep.add(
        observations=obs,
        next_observations=next_obs,
        actions=np.array([[i, i*10]]),
        rewards=rew,
        terminated=term,
        truncated=trunc,
    )
    res.add(
        observations=obs,
        next_observations=next_obs,
        actions=np.array([[i, i*10]]),
        rewards=rew,
        terminated=term,
        truncated=trunc,
    )
    obs = next_obs
    
    rep_sample = rep.sample(256)
    res_sample = res.sample(256)
    
    # print(i)
    # print(rep_sample.observations, res_sample.observations)
    # print(res.observations[:10], rep.observations[:10])
    
    # assert np.all(rep_sample.observations.cpu().numpy() == res_sample.observations.cpu().numpy())
    # assert np.all(rep_sample.next_observations.cpu().numpy() == res_sample.next_observations.cpu().numpy())
    # assert np.all(rep_sample.actions.cpu().numpy() == res_sample.actions.cpu().numpy())
    # assert np.all(rep_sample.rewards.cpu().numpy() == res_sample.rewards.cpu().numpy())
    # assert np.all(rep_sample.dones.cpu().numpy() == res_sample.dones.cpu().numpy())
    
    assert rep_sample.observations.cpu().numpy().shape == res_sample.observations.cpu().numpy().shape
    assert rep_sample.next_observations.cpu().numpy().shape == res_sample.next_observations.cpu().numpy().shape
    assert rep_sample.actions.cpu().numpy().shape == res_sample.actions.cpu().numpy().shape
    assert rep_sample.rewards.cpu().numpy().shape == res_sample.rewards.cpu().numpy().shape
    assert rep_sample.dones.cpu().numpy().shape == res_sample.dones.cpu().numpy().shape
    
    
    assert np.all(rep.observations.squeeze() == res.observations.squeeze())
    assert np.all(rep.next_observations.squeeze() == res.next_observations.squeeze())
    assert np.all(rep.actions.squeeze() == res.actions.squeeze())
    assert np.all(rep.rewards.squeeze() == res.rewards.squeeze())
    assert np.all(rep.terminated.squeeze() == res.terminated.squeeze())
    assert np.all(rep.truncated.squeeze() == res.truncated.squeeze())
        
    
    

  2%|▏         | 20000/999999 [00:06<05:02, 3239.38it/s]


AssertionError: 

In [95]:
rep_sample = rep.sample(256)
res_sample = res.sample(256)
res.truncated.mean(), rep.truncated.mean(), res.terminated.mean(), rep.terminated.mean(), res_sample.dones.mean(), rep_sample.dones.mean()

(0.1, 0.1, 0.1, 0.1, tensor(0.0820), tensor(0.0938))

In [ ]:
import matplotlib.pyplot as plt
from src.reinforcement_learning.core.buffers.replay.replay_with_reservoir_buffer import ReplayWithReservoirBuffer

env_name = 'HalfCheetah-v4'
# env_kwargs = {'forward_reward_weight': 1.25, 'healthy_reward': 0.5, 'ctrl_cost_weight': 0.001 }
env_kwargs = {}
num_envs = 1

def create_env(render_mode: str | None):
    make_single_env = lambda: gym.make(env_name, render_mode=render_mode, **env_kwargs)
    
    if num_envs == 1:
        return make_single_env()
        
    return parallelize_env_async(make_single_env, num_envs)

def create_policy():
    in_size = 17
    action_size = 6
    
    actor_net = nn.Sequential(
        nn.Linear(in_size, 256),
        nn.ReLU(),
        nn.Linear(256, 256),
        nn.ReLU(),
    )

    critic = QCritic(
        n_critics=2,
        create_q_network=lambda: nn.Sequential(
            nn.Linear(in_size + action_size, 256),
            nn.ReLU(),
            # BatchRenorm(256),
            nn.Linear(256, 256),
            nn.ReLU(),
            # BatchRenorm(256),
            nn.Linear(256, 1)
        )
    )

    return SACPolicy(
        actor=Actor(actor_net, PredictedStdActionSelector(
            latent_dim=256,
            action_dim=action_size,
            base_std=1.0,
            squash_output=True,
        )),
        critic=critic
    )


device = torch.device("cuda:0") if True else torch.device('cpu')
env = create_env(render_mode=None)
policy = create_policy()

# noinspection PyTypeChecker
buf: ReplayWithReservoirBuffer = SAC(
            env=env,
            policy=policy,
            actor_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            critic_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),  # (params, lr=3e-4, betas=(0.5, 0.999)),
            buffer_type=ReplayWithReservoirBuffer,
            buffer_step_size=200_000,
            buffer_kwargs={
                'reservoir_total_size': 300_000,
            },
            reward_scale=10,
            gamma=0.99,
            tau=0.005,
            entropy_coef_optimizer_provider=lambda params: optim.Adam(params, lr=3e-4),
            entropy_coef_clamp_range=(0.001, 1.5),
            # entropy_coef=0.1,
            rollout_steps=1,
            gradient_steps=1,
            warmup_steps=10_000,
            optimization_batch_size=256,
            target_update_interval=1,
            # callback=Callback(
            #     on_rollout_done=on_rollout_done,
            #     rollout_schedulers={},
            #     on_optimization_done=on_optimization_done,
            #     optimization_schedulers={},
            # ),
            # stash_config=SACInfoStashConfig(stash_rollout_infos=True, stash_rollout_action_stds=True,
            #                                 stash_entropy_coef=True,
            #                                 entropy_coef_loss=LossInfoStashConfig(stash_final=True),
            #                                 actor_loss=LossInfoStashConfig(stash_final=True),
            #                                 critic_loss=LossInfoStashConfig(stash_final=True)),
            torch_device=device,
        ).buffer

for i in tqdm.tqdm(range(0, 1_000_000)):
    buf.add(
        observations=np.array([i]),
        next_observations=np.array([i + 1]),
        actions=np.array([i]),
        rewards=np.array([i]),
        terminated=np.array([False]),
        truncated=np.array([False]),
    )
    
    if i >= 200_000 and (i % 1000 == 0):
        fig, ax = plt.subplots(1, figsize=(14, 14))
        
        sample_obs = np.concatenate(tuple(buf.sample(256).observations[:, 0].cpu().numpy() for _ in range(1000)))
        
        buf_obs = buf.observations[:, 0, 0]
        res_obs = buf.reservoir.observations[:, 0]
        
        buf_obs = buf_obs[buf_obs != 0]
        res_obs = res_obs[res_obs != 0]
        
        ax.hist(sample_obs, bins=100, label='samples', alpha=0.5)
        ax.hist(buf_obs, bins=100, label='rep', alpha=0.5)
        ax.hist(res_obs, bins=100, label='res', alpha=0.5)
        fig.legend()
        
        print(i)
        
    
    

In [14]:
buf.observations[:20]

array([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1.]],

       [[2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,
         2.]],

       [[3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3.,
         3.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0.]],

       [[0., 0., 0., 

In [9]:
buf.sample(2)


ReplayBufferSamples(observations=tensor([[25725., 25725., 25725., 25725., 25725., 25725., 25725., 25725., 25725.,
         25725., 25725., 25725., 25725., 25725., 25725., 25725., 25725.],
        [98733., 98733., 98733., 98733., 98733., 98733., 98733., 98733., 98733.,
         98733., 98733., 98733., 98733., 98733., 98733., 98733., 98733.]],
       device='cuda:0'), actions=tensor([[25725., 25725., 25725., 25725., 25725., 25725.],
        [98733., 98733., 98733., 98733., 98733., 98733.]], device='cuda:0'), next_observations=tensor([[25726., 25726., 25726., 25726., 25726., 25726., 25726., 25726., 25726.,
         25726., 25726., 25726., 25726., 25726., 25726., 25726., 25726.],
        [98734., 98734., 98734., 98734., 98734., 98734., 98734., 98734., 98734.,
         98734., 98734., 98734., 98734., 98734., 98734., 98734., 98734.]],
       device='cuda:0'), dones=tensor([[0.],
        [0.]], device='cuda:0'), rewards=tensor([[257250.],
        [987330.]], device='cuda:0'))