In [None]:
import torch
from omegaconf import OmegaConf
from functools import partial
import gymnasium as gym
import matplotlib.pyplot as plt

import bbrl_utils
from bbrl_utils.notebook import setup_tensorboard
from bbrl.stats import WelchTTest
from bbrl.agents import Agent, Agents, TemporalAgent
from bbrl.agents.gymnasium import ParallelGymAgent, make_env
from bbrl.workspace import Workspace
from bbrl.utils.replay_buffer import ReplayBuffer
from pmind_utils import (
    DQN,
    DDPG,
    TD3,
    dqn_compute_critic_loss,
    ddqn_compute_critic_loss,
    run_dqn,
    run_ddpg,
    run_td3,
    run_td3_offline,
    get_gym_agent,
    get_workspace,
    mix_transitions
)

bbrl_utils.setup()

%load_ext autoreload
%autoreload 2

Load all configurations:

In [None]:
cfg = OmegaConf.load("test_config.yaml")

# Test used algorithms

In [None]:
setup_tensorboard("./outputs/tblogs")

### DQN:

In [None]:
dqn = DQN(OmegaConf.create(cfg.models.dqn))
run_dqn(dqn, dqn_compute_critic_loss)
dqn.visualize_best()

### DDQN:

In [None]:
ddqn = DQN(OmegaConf.create(cfg.models.ddqn))
run_dqn(ddqn, ddqn_compute_critic_loss)
ddqn.visualize_best()

In [None]:
WelchTTest().plot(
    torch.stack(dqn.eval_rewards), torch.stack(ddqn.eval_rewards), save=False
)

### DDPG:

In [None]:
ddpg = DDPG(OmegaConf.create(cfg.models.ddpg))
run_ddpg(ddpg)
ddpg.visualize_best()

### TD3:

In [None]:
# Create hyper-params
td3 = TD3(OmegaConf.create(cfg.models.td3))
run_td3(td3)
td3.visualize_best()

In [None]:
WelchTTest().plot(
    torch.stack(ddpg.eval_rewards),
    torch.stack(td3.eval_rewards),
    legends="ddpg/td3",
    save=False,
)

# SANDBOX PIPELINE

## Best policy:

Get the best policy (to eventually exploit)

In [None]:
cfg_td3_best = OmegaConf.create(cfg.models.td3)

# accelerate for the sake of test:
cfg_td3_best.algorithm.max_epochs = 11000
cfg_td3_best.algorithm.learning_starts = 1000

td3 = TD3(cfg_td3_best)
run_td3(td3)
td3.visualize_best()

In [None]:
best_policy_agent = td3.best_policy

In [None]:
gym_agent = get_gym_agent('CartPoleContinuous-v1', num_envs=10, seed=42)
workspace_best = get_workspace(best_policy_agent ,gym_agent, epoch_size=10_000)
print(workspace_best)

## Uniform policy:

In [None]:
class UniformAgent(Agent):
    '''Agent that explores uniformly a given environment'''

    # TODO: for now it just does a random walk,
    #   need to do jumps instead - random actions in random states
    # I think need to modify ParallelGymAgent._reset() method
    def __init__(self, env_name):
        super().__init__()
        self.env = gym.make(env_name) 

    def forward(self, t: int):
        """An Agent can use self.workspace"""
        # obs = self.get(("env/env_obs", t))
        n_env = self.workspace.batch_size()
        action = torch.tensor([self.env.action_space.sample() for _ in range(n_env) ], dtype=torch.float32)
        self.set(("action", t), action)

In [None]:
workspace_unif = get_workspace(UniformAgent('CartPoleContinuous-v1'), gym_agent, epoch_size=10_000)
print(workspace_unif)

In [None]:
# Poking around with the idea: may be it's easier to implement those 
# random jumps as episodes of length 2?
# for k in range(gym_agent.num_envs):
#     env = gym_agent.envs[k]
#     env.reset()
#     env.state = env.unwrapped.state = env.observation_space.sample()
# gym_agent = get_gym_agent('CartPoleContinuous-v1', num_envs=3, seed=42)
# t_agents = TemporalAgent(Agents(gym_agent,UniformAgent('CartPoleContinuous-v1')))
# workspace = Workspace()
# t_agents(workspace, t=1,n_steps=2)
# workspace_unif.get_transitions()["env/reward"]

## Mix transitions in a buffer

In [None]:
rb_mixed = mix_transitions(workspace_best, 
                           workspace_unif,
                           batch_size=10_000, 
                           proportion=0.7)

In [None]:
cfg_td3_offline = OmegaConf.create(cfg.models.td3)

# accelerate for the sake of test:
cfg_td3_offline.algorithm.max_epochs = 1000

# we don't care about when learning starts for offline:
cfg_td3_offline.algorithm.learning_starts = None

td3_offline = TD3(cfg_td3_offline)
run_td3_offline(td3_offline, rb_mixed)
td3_offline.visualize_best()