In [None]:
import random

import gymnasium as gym
import numpy as np
import torch
import torch.optim as optim

from stable_baselines3.common.buffers import ReplayBuffer

import sys
sys.path.append('../..')

from src.rlmcmc.agent import Actor, QNetwork
from src.rlmcmc.env import RLMHEnvV31
from src.rlmcmc.utils import Args, MCMCAnimation, Toolbox
from src.rlmcmc.learning import LearningDDPG

In [None]:
log_p = Toolbox.make_log_target_pdf(
    "test-banana-test-banana",
    "../../posteriordb/posterior_database"
)

In [None]:
# env setup
args = Args()
args.seed = 1234
args.log_target_pdf = log_p
args.total_timesteps = 10_000
args.batch_size = 32
args.learning_starts = args.batch_size
# args.learning_starts = 5_000
args.gamma = 0.5
# args.buffer_size = args.total_timesteps
args.learning_rate = 1e-5
args.policy_frequency = 2

random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.backends.cudnn.deterministic = args.torch_deterministic

device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")

envs = gym.vector.SyncVectorEnv(
    [
        Toolbox.make_env(

            env_id=args.env_id,
            seed=args.seed,
            log_target_pdf=args.log_target_pdf,
            sample_dim=args.sample_dim,
            total_timesteps=args.total_timesteps
        )
    ]
)
predicted_envs = gym.vector.SyncVectorEnv(
    [
        Toolbox.make_env(

            env_id=args.env_id,
            seed=args.seed,
            log_target_pdf=args.log_target_pdf,
            sample_dim=args.sample_dim,
            total_timesteps=args.total_timesteps
        )
    ]
)
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"

actor = Actor(envs).to(device)
actor = actor.double()
actor = torch.compile(actor)
qf1 = QNetwork(envs).to(device)
qf1 = qf1.double()
qf1 = torch.compile(qf1)
qf1_target = QNetwork(envs).to(device)
qf1_target = qf1_target.double()
qf1_target = torch.compile(qf1_target)
target_actor = Actor(envs).to(device)
target_actor = target_actor.double()
target_actor = torch.compile(target_actor)
target_actor.load_state_dict(actor.state_dict())
qf1_target.load_state_dict(qf1.state_dict())
q_optimizer = optim.Adam(list(qf1.parameters()), lr=args.learning_rate)
actor_optimizer = optim.Adam(list(actor.parameters()), lr=args.learning_rate)

envs.single_observation_space.dtype = np.float64
rb = ReplayBuffer(
    args.buffer_size,
    envs.single_observation_space,
    envs.single_action_space,
    device,
    handle_timeout_termination=False
)

In [None]:
learning = LearningDDPG(
    env=envs,
    actor=actor,
    target_actor=target_actor,
    critic=qf1,
    target_critic=qf1_target,
    actor_optimizer=actor_optimizer,
    critic_optimizer=q_optimizer,
    replay_buffer=rb,
    total_timesteps=args.total_timesteps,
    learning_starts=args.learning_starts,
    batch_size=args.batch_size,
    gamma=args.gamma,
    policy_frequency=args.policy_frequency,
    tau=args.tau,
    seed=args.seed,
    device=device
)

In [None]:
learning.train()

In [None]:
predict_func = learning.predict(predicted_envs, 5_000)
predict_func.plot()

In [None]:
predict_func.dataframe().to_csv("save/data/predict.csv", index=False)

In [None]:
mcmc_animation = MCMCAnimation(
    log_target_pdf=log_p,
    dataframe=predict_func.dataframe(),
    xlim=(-5, 5),
    ylim=(-1, 20)
    )

In [None]:
anim_file_path = "./save/data/banana.mp4"
mcmc_animation.make().save(anim_file_path, writer='ffmpeg')

In [None]:
learning.save("save/model")

In [None]:
predict_func.dataframe().head()