In [1]:
import gymnasium as gym
from gymnasium import spaces
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils import clip_grad_norm_
import numpy as np
import random
import time
import wandb
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

from util.wrappers.add_channel_dimension import AddChannelDimension
from util.wrappers.divide_observation import DivideObservation
from util.scheduler.linear_schedule import linear_schedule
from util.buffer.ReplayBuffer import PrioritizedReplayBuffer

In [12]:
args = {
    'env_id': 'CartPole-v1',
    'algorithm': 'PER_DQN',
    'algorithm_version': 'v1',
    'truncated' : None,
    'seed': 42,
    'cuda': True,
    'learning_rate' : 0.00025,
    'buffer_size' : 10000,
    'total_timesteps' : 300000,
    'start_e' : 1,
    'end_e' : 0.005,
    'exploration_fraction' : 0.1,
    'wandb_entity' : None,
    'learning_starts' : 2000,
    'train_frequency' : 1,
    'batch_size' : 64,
    'target_network_frequency' : 500,
    'gamma' : 0.99,
    'capture_video' : False,
    'loss_function' : 'smooth_l1_loss',
    'grad_clipping' : 10.0,
    'per_alpha' : 0.6,
    'per_beta' : 0.6,
    'per_eps' : 1e-6,
}

In [3]:
class QNetwork(nn.Module):
    def __init__(self, env:gym.Env):
        """Initialization."""
        super(QNetwork, self).__init__()

        self.layers = nn.Sequential(
            nn.Linear(env.observation_space.shape[0], 128), 
            nn.ReLU(),
            nn.Linear(128, 128), 
            nn.ReLU(), 
            nn.Linear(128, env.action_space.n)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward method implementation."""
        return self.layers(x)

In [4]:
def linear_schedule(start_e: float, end_e:float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [5]:
project_path = args['env_id'].split('/')[-1]
device = torch.device("cuda" if torch.cuda.is_available() and args["cuda"] else "cpu")
run_name=f"{args['algorithm']}_{args['algorithm_version']}_{int(time.time())}"
print(f'project_path: {project_path}, device : {device}, run_name : {run_name}')

project_path: CartPole-v1, device : cuda, run_name : PER_DQN_v1_1674589186


In [6]:
env = gym.make(args["env_id"], render_mode='rgb_array')
if args['truncated']:
    env = gym.wrappers.TimeLimit(env, args['truncated'])
env = gym.wrappers.AutoResetWrapper(env)
env = gym.wrappers.RecordEpisodeStatistics(env)

In [7]:
wandb.init(
        # set the wandb project where this run will be logged
        name=run_name,
        project=project_path,
        entity=args['wandb_entity'],
        # sync_tensorboard=True,
        config=args,
        monitor_gym=True,
        save_code=True
)

writer = SummaryWriter(f'runs/{project_path}/{run_name}')
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in args.items()])),
)

[34m[1mwandb[0m: Currently logged in as: [33miamhelpingstar[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
q_network = QNetwork(env).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args['learning_rate'])
target_network = QNetwork(env).to(device)
target_network.load_state_dict(q_network.state_dict())

<All keys matched successfully>

In [15]:
rb = PrioritizedReplayBuffer(
        env.observation_space,
        args['buffer_size'],
        args['batch_size'],
        alpha=args['per_alpha']
)

obs, _ = env.reset()
for global_step in tqdm(range(args['total_timesteps'])):
    epsilon = linear_schedule(args['start_e'], 
                              args['end_e'], 
                              args['exploration_fraction'] * args['total_timesteps'], 
                              global_step)
    if random.random() < epsilon:
        action = env.action_space.sample()
    else:
        q_values = q_network(torch.Tensor(obs).to(device))
        action = torch.argmax(q_values).item()
        
    next_obs, reward, terminate, truncate, info = env.step(action)
    rb.store(obs, action, reward, next_obs, terminate)
    
    obs = next_obs
    
    fraction = min(global_step/args['total_timesteps'], 1.0)
    args['per_beta'] = args['per_beta'] + fraction * (1.0 - args['per_beta'])
    
    if 'episode' in info.keys():
        writer.add_scalar("charts/episodic_return", info['episode']['r'], global_step)
        wandb.log({"charts/episodic_return": info['episode']['r']}, step=global_step)
        
    writer.add_scalar("charts/epsilon", epsilon, global_step)
    wandb.log({"charts/epsilon": epsilon}, step=global_step)
    
    if global_step > args['learning_starts']:
        if global_step % args['train_frequency'] == 0:
            samples = rb.sample_batch()
            states = torch.FloatTensor(samples['obs']).to(device)
            next_states = torch.FloatTensor(samples['next_obs']).to(device)
            actions = torch.LongTensor(samples['acts']).reshape(-1, 1).to(device)
            rewards = torch.FloatTensor(samples['rews']).reshape(-1, 1).to(device)
            dones = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
            weights = torch.FloatTensor(samples["weights"].reshape(-1, 1)).to(device)
            indices = samples["indices"]
            
            with torch.no_grad():
                target_max, _ = target_network(next_states).max(dim=1)
                td_target = rewards.flatten() + args['gamma'] * target_max * (1 - dones.flatten())
            old_val = q_network(states).gather(1, actions).squeeze()
            
            if args['loss_function'] == 'mse_loss':
                elementwise_loss = F.mse_loss(old_val, td_target, reduction="none")
            else:
                elementwise_loss = F.smooth_l1_loss(old_val, td_target, reduction="none")
            loss = torch.mean(elementwise_loss * weights)
            
            if global_step % 100 == 0:
                writer.add_scalar("losses/td_loss", loss, global_step)
                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                wandb.log({"losses/td_loss": loss,
                        "losses/q_values": old_val.mean().item()}, step=global_step)

            optimizer.zero_grad()
            loss.backward()
            if args['grad_clipping'] is not None:
                clip_grad_norm_(q_network.parameters(), args['grad_clipping'])
            optimizer.step()
            
            
            # TODO detach vs torch.no_grad()
            loss_for_prior = elementwise_loss.detach().cpu().numpy()
            new_priorities = loss_for_prior + args['per_eps']
            rb.update_priorities(indices, new_priorities)
            
        if global_step % args['target_network_frequency'] == 0:
            target_network.load_state_dict(q_network.state_dict())
            
env.close()
writer.close()
wandb.finish()

  1%|          | 2626/300000 [00:13<24:39, 200.94it/s] 


KeyboardInterrupt: 