Source

* https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py

In [1]:
import os
import random
import time
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import wandb
from util.buffer.ReplayBuffer import ReplayBuffer
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [3]:
args = {
    'env_id': 'CartPole-v1',
    'seed': 42,
    'cuda': True,
    'learning_rate' : 0.0003,
    'buffer_size' : 10000,
    'total_timesteps' : 500000,
    'start_e' : 1, 
    'end_e' : 0.05, 
    'exploration_fraction' : 0.5,
    'wandb_project_name' : "dqn-Cartpole-vector",
    'wandb_entity' : None,
    'learning_starts' : 10000,
    'train_frequency' : 10,
    'batch_size' : 128,
    'target_network_frequency' : 500,
    'gamma' : 0.99,
    'capture_video' : False,
    }

device = torch.device("cuda" if torch.cuda.is_available() and args["cuda"] else "cpu")
print(device)
run_name=f"{args['env_id']}_{args['seed']}_{int(time.time())}"

cpu


In [4]:
def make_env(env_id, seed, idx, capture_video, run_name):
    def thunk():
        env = gym.make(env_id)
        env = gym.wrappers.RecordEpisodeStatistics(env)
        if capture_video:
            if idx == 0:
                env = gym.wrappers.RecordVideo(env, f"videos/{run_name}")
        # env.seed(seed)
        env.action_space.seed(seed)
        env.observation_space.seed(seed)
        return env

    return thunk

In [5]:
class DQN(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(np.array(env.single_observation_space.shape).prod(), 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.single_action_space.n),
        )

    def forward(self, x):
        return self.network(x)

In [6]:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [7]:
# wandb.init(
#     # set the wandb project where this run will be logged
#     name=run_name,
#     project=args['wandb_project_name'],
#     entity=args['wandb_entity'],
#     sync_tensorboard=True,
#     # track hyperparameters and run metadata
#     config=args,
#     monitor_gym=True,
#     save_code=True
# )

# writer = SummaryWriter(f'runs/{run_name}')
# writer.add_text(
#     "hyperparameters",
#     "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in args.items()])),
# )

In [8]:
envs = gym.vector.SyncVectorEnv([make_env(args['env_id'], args['seed'], 0, args['capture_video'], run_name) for i in range(2)])

In [None]:
q_network = DQN(envs).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args['learning_rate'])
target_network = DQN(envs).to(device)
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
    envs.single_observation_space,
    args['buffer_size'],
    args['batch_size']
)

start_time = time.time()

obs, _ = envs.reset()
for global_step in tqdm(range(args['total_timesteps'])):
    epsilon = linear_schedule(args['start_e'], 
                              args['end_e'], 
                              args['exploration_fraction'] * args['total_timesteps'], 
                              global_step)
    if random.random() < epsilon:
        actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)])
    else:
        q_values = q_network(torch.Tensor(obs).to(device))
        actions = torch.argmax(q_values, dim=1).cpu().numpy()
    
    next_obs, reward, terminate, truncate, info = env.step(action)
    rb.store(obs, action, reward, next_obs, terminate)
    
    obs = next_obs
    score += reward 
    
    if terminate:
        obs, _ = env.reset()
        writer.add_scalar("charts/episodic_return", score, global_step)
        wandb.log({"charts/episodic_return": score}, step=global_step)
        score = 0
        episode_cnt += 1
        
    writer.add_scalar("charts/epsilon", epsilon, global_step)
    wandb.log({"charts/epsilon": epsilon}, step=global_step)
    
    if global_step > args['learning_starts']:
        if global_step % args['train_frequency'] == 0:
            
            samples = rb.sample_batch()
            states = torch.FloatTensor(samples['obs']).to(device)
            next_states = torch.FloatTensor(samples['next_obs']).to(device)
            actions = torch.LongTensor(samples['acts']).reshape(-1, 1).to(device)
            rewards = torch.FloatTensor(samples['rews']).reshape(-1, 1).to(device)
            dones = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
            
            with torch.no_grad():
                target_max, _ = target_network(next_states).max(dim=1)
                td_target = rewards.flatten() + args['gamma'] * target_max * (1 - dones.flatten())
            old_val = q_network(states).gather(1, actions).squeeze()
            loss = F.mse_loss(td_target, old_val)
            
            if global_step % 100 == 0:
                writer.add_scalar("losses/td_loss", loss, global_step)
                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
            # optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # update the target network
        if global_step % args['target_network_frequency'] == 0:
            target_network.load_state_dict(q_network.state_dict())
env.close()
writer.close()
wandb.finish()