Source

* https://github.com/vwxyzjn/cleanrl/blob/master/cleanrl/dqn.py

* https://github.com/Curt-Park/rainbow-is-all-you-need/blob/master/01.dqn.ipynb

In [20]:
import gymnasium as gym
import torch.nn as nn
import wandb
import time
import torch
import torch.optim as optim
import random
import numpy as np
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from util.buffer.ReplayBuffer import ReplayBuffer

In [21]:
args = {
    'env_id': 'CartPole-v1',
    'seed': 42,
    'cuda': True,
    'learning_rate' : 0.0003,
    'buffer_size' : 10000,
    'total_timesteps' : 500000,
    'start_e' : 1, 
    'end_e' : 0.05, 
    'exploration_fraction' : 0.5,
    'wandb_project_name' : "dqn-Cartpole",
    'wandb_entity' : None,
    'learning_starts' : 10000,
    'train_frequency' : 10,
    'batch_size' : 128,
    'target_network_frequency' : 500,
    'gamma' : 0.99
    }

device = torch.device("cuda" if torch.cuda.is_available() and args["cuda"] else "cpu")
print(device)
run_name=f"{args['env_id']}_{args['seed']}_{int(time.time())}"

cpu


In [5]:
class DQN(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(env.observation_space.shape[0], 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.action_space.n)
        )
    
    def forward(self, x):
        return self.network(x)

In [6]:
def linear_schedule(start_e: float, end_e:float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [7]:
wandb.init(
    # set the wandb project where this run will be logged
    name=run_name,
    project=args['wandb_project_name'],
    entity=args['wandb_entity'],
    sync_tensorboard=True,
    # track hyperparameters and run metadata
    config=args,
    monitor_gym=True,
    save_code=True
)

writer = SummaryWriter(f'runs/{run_name}')
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in args.items()])),
)

[34m[1mwandb[0m: Currently logged in as: [33miamhelpingstar[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [18]:
q_network = DQN(env).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args['learning_rate'])
target_network = DQN(env).to(device)
target_network.load_state_dict(q_network.state_dict())

t1 = env.observation_space
t2 = env.action_space,

rb = ReplayBuffer(
    env.observation_space,
    args['buffer_size'],
    args['batch_size']
)

start_time = time.time()

obs, _ = env.reset()

for global_step in range(args['total_timesteps']):
    epsilon = linear_schedule(args['start_e'], 
                              args['end_e'], 
                              args['exploration_fraction'] * args['total_timesteps'], 
                              global_step)
    if random.random() < epsilon:
        action = env.action_space.sample()
    else:
        q_values = q_network(torch.Tensor(obs).to(device))
        action = torch.argmax(q_values).item()
    
    next_obs, reward, terminate, truncate, info = env.step(action)
    
    if "episode" in info.keys():
        print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
        writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
        writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
        writer.add_scalar("charts/epsilon", epsilon, global_step)
    
    rb.store(obs, action, reward, next_obs, terminate)
    
    obs = next_obs
    
    if global_step > args['learning_starts']:
        if global_step % args['train_frequency'] == 0:
            
            samples = rb.sample_batch()
            states = torch.FloatTensor(samples['obs']).to(device)
            next_states = torch.FloatTensor(samples['next_obs']).to(device)
            actions = torch.LongTensor(samples['acts']).reshape(-1, 1).to(device)
            rewards = torch.FloatTensor(samples['rews']).reshape(-1, 1).to(device)
            dones = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
            
            with torch.no_grad():
                target_max, _ = target_network(next_states).max(dim=1)
                td_target = rewards.flatten() + args['gamma'] * target_max * (1 - dones.flatten())
            old_val = q_network(states).gather(1, actions).squeeze()
            loss = F.mse_loss(td_target, old_val)
            
            if global_step % 100 == 0:
                writer.add_scalar("losses/td_loss", loss, global_step)
                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                print("SPS:", int(global_step / (time.time() - start_time)))
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
            
            # optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # update the target network
        if global_step % args['target_network_frequency'] == 0:
            target_network.load_state_dict(q_network.state_dict())
env.close()
writer.close()

  logger.warn(


SPS: 11126
SPS: 10358
SPS: 9765
SPS: 9230
SPS: 8717
SPS: 8311
SPS: 8018
SPS: 7739
SPS: 7453
SPS: 7234
SPS: 7072
SPS: 6882
SPS: 6669
SPS: 6512
SPS: 6373
SPS: 6215
SPS: 6086
SPS: 5961
SPS: 5866
SPS: 5759
SPS: 5637
SPS: 5531
SPS: 5422
SPS: 5286
SPS: 5185
SPS: 5114
SPS: 5049
SPS: 4974
SPS: 4893
SPS: 4818
SPS: 4749
SPS: 4674
SPS: 4608
SPS: 4533
SPS: 4462
SPS: 4404
SPS: 4361
SPS: 4320
SPS: 4275
SPS: 4215
SPS: 4169
SPS: 4123
SPS: 4082
SPS: 4044
SPS: 4009
SPS: 3964
SPS: 3918
SPS: 3867
SPS: 3839
SPS: 3810
SPS: 3770
SPS: 3741
SPS: 3714
SPS: 3690
SPS: 3664
SPS: 3640
SPS: 3606
SPS: 3580
SPS: 3551
SPS: 3520
SPS: 3497
SPS: 3468
SPS: 3446
SPS: 3427
SPS: 3396
SPS: 3371
SPS: 3351
SPS: 3323
SPS: 3300
SPS: 3279
SPS: 3262
SPS: 3242
SPS: 3211
SPS: 3196
SPS: 3173
SPS: 3157
SPS: 3141
SPS: 3127
SPS: 3108
SPS: 3084
SPS: 3062
SPS: 3050
SPS: 3036
SPS: 3018
SPS: 3004
SPS: 2987
SPS: 2969
SPS: 2950
SPS: 2939
SPS: 2928
SPS: 2918
SPS: 2905
SPS: 2891
SPS: 2881
SPS: 2871
SPS: 2865
SPS: 2852
SPS: 2838
SPS: 2825
SPS: 281

KeyboardInterrupt: 

In [19]:
import random

# simulate training
epochs = 10
offset = random.random() / 5
for epoch in range(2, epochs):
    acc = 1 - 2 ** -epoch - random.random() / epoch - offset
    loss = 2 ** -epoch + random.random() / epoch + offset
    
    # log metrics to wandb
    wandb.log({"acc": acc, "loss": loss})
    time.sleep(5)
# [optional] finish the wandb run, necessary in notebooks
wandb.finish()

KeyboardInterrupt: 