In [None]:
import gymnasium as gym
import torch
import numpy as np
import wandb
from collections import deque
from src.utils.util import ShellColor as sc

print(f"{sc.COLOR_PURPLE}Gym version:{sc.ENDC} {gym.__version__}")
print(f"{sc.COLOR_PURPLE}Pytorch version:{sc.ENDC} {torch.__version__}")

from src.agents.dqn_agent import DQNAgent
from src.utils import util as rl_util

In [None]:
env_name = "CartPole-v1"
env = gym.make(env_name)
rl_util.print_env_info(env=env)

In [None]:
config = rl_util.create_config()
config["batch_size"] = 32
config["buffer_size"] = 50000
config["gamma"] = 0.99
config["target_update_frequency"] = 1000
config["learning_starts"]=1000
config["learning_frequency"] = 1
config["lr"] = 1e-4
config["start_training_step"] = 1000
config["mean_reward_bound"] = 490
config["print_frequency"] = 100

In [None]:
wandb.init(project="DQN-cartpole", config=config)

In [None]:
agent = DQNAgent(
    obs_space_shape=env.observation_space.shape,
    action_space_dims=env.action_space.n,
    is_atari=False,
    config=config,
)

print(agent.config)
print(type(agent.memory))

In [None]:
save_dir = "result/DQN/cartpole/"
rl_util.create_directory(save_dir)
save_model_name = save_dir + env_name + "_mean_score.pt"

In [None]:

total_rewards = []
losses = []
frame_idx = 0
for i_episode in range(agent.config.num_steps):
    obs, info = env.reset()
    avg_loss = 0
    len_game_progress = 0
    cur_reward = 0
    mean_rewards = deque([], maxlen=5)
    while True:
        # env.render()
        frame_idx += 1
        eps = agent.decay_epsilon(frame_idx)
        action = agent.select_action(obs, eps)
        next_obs, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated
        agent.store_transition(obs, action, reward, next_obs, done)

        if len(agent.memory.replay_buffer) > 1000:
            loss = agent.update()
            avg_loss += loss

        obs = next_obs
        len_game_progress += 1
        cur_reward += reward
        if done:
            break
        
    agent.update_target_network()

    avg_loss /= len_game_progress
    losses.append(avg_loss)

    mean_rewards.append(cur_reward)
    total_rewards.append(cur_reward)
    mean_rewards = np.mean(mean_rewards)

    if (i_episode) % 20 == 0:
        print(
            f"episode: {i_episode} | mean_rewards: {mean_rewards:.4f} | loss: {avg_loss:.4f} | epsilon: {eps:.4f}"
        )
        wandb.log({
            "episode": i_episode,
            "mean_rewards": mean_rewards,
            "cur_rewards" : cur_reward,
            "loss": avg_loss,
            "epsilon": eps
        })
        
    if mean_rewards > 490:
        current_time = rl_util.get_current_time_string()
        save_model_name = save_dir + "checkpoint_" + current_time + ".pt"
        print(f"Save model {save_model_name} | episode is {(i_episode)}")
        torch.save(agent.policy_network.state_dict(), save_model_name)
        break

env.close()

In [None]:
fig, ax = rl_util.init_2d_figure("Reward")
rl_util.plot_graph(
    ax,
    total_rewards,
    title="reward",
    ylabel="reward",
    save_dir_name=save_dir,
    is_save=True,
)
rl_util.show_figure()
fig, ax = rl_util.init_2d_figure("Loss")
rl_util.plot_graph(
    ax, losses, title="loss", ylabel="loss", save_dir_name=save_dir, is_save=True
)
rl_util.show_figure()

In [None]:
env = gym.make(env_name, render_mode="human")

test_agent = DQNAgent(
    obs_space_shape=env.observation_space.shape,
    action_space_dims=env.action_space.n,
    is_atari=False,
    config=config
)

file_name = save_dir + "CartPole-v1_mean_score.pt"
test_agent.policy_network.load_state_dict(torch.load(file_name))

for i_episode in range(1):
    state, _ = env.reset()
    test_reward = 0
    while True:
        env.render()
        action = test_agent.select_action(state, 0.)
        next_state, reward, terminated, truncated, _ = env.step(action)
        test_reward += reward
        state = next_state
        done = terminated or truncated
        if done:
            break
    print(f"{i_episode} episode Total Reward: {test_reward}")
env.close()