In [9]:
import torch
import torch.nn as nn
import gymnasium as gym
import time
import wandb
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter

In [3]:
args = {
    'env_id': 'CartPole-v1',
    'algorithm': 'DQN',
    'algorithm_version': 'v1_1',
    'seed': 42,
    'cuda': True,
    'learning_rate' : 0.0003,
    'buffer_size' : 10000,
    'total_timesteps' : 300000,
    'start_e' : 1, 
    'end_e' : 0.01, 
    'exploration_fraction' : 0.5,
    'wandb_entity' : None,
    'learning_starts' : 10000,
    'train_frequency' : 1,
    'batch_size' : 128,
    'target_network_frequency' : 500,
    'gamma' : 0.99,
    'capture_video' : False
}

project_path = args['env_id'].split('/')[-1]
device = torch.device("cuda" if torch.cuda.is_available() and args["cuda"] else "cpu")
run_name=f"{args['algorithm']}_{args['algorithm_version']}_{int(time.time())}"

print(f'project_path: {project_path}, device : {device}, run_name : {run_name}')

project_path: CartPole-v1, device : cuda, run_name : DQN_v1_1_1673772691


In [2]:
class DQN(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Linear(env.observation_space.shape[0], 120),
            nn.ReLU(),
            nn.Linear(120, 84),
            nn.ReLU(),
            nn.Linear(84, env.action_space.n)
        )
    
    def forward(self, x):
        return self.network(x)

In [4]:
def linear_schedule(start_e: float, end_e:float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [5]:
env = gym.make(args["env_id"], render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, video_folder=f'videos/{project_path}/{run_name}', disable_logger=True)

In [7]:
# wandb.tensorboard.patch(root_logdir='runs')
wandb.init(
    # set the wandb project where this run will be logged
    name=run_name,
    project=project_path,
    entity=args['wandb_entity'],
    # sync_tensorboard=True,
    config=args,
    monitor_gym=True,
    save_code=True
)

writer = SummaryWriter(f'runs/{project_path}/{run_name}')
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in args.items()])),
)

[34m[1mwandb[0m: Currently logged in as: [33miamhelpingstar[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
q_network = DQN(env)
q_network = torch.load(f"weights\{project_path}\DQN_v1_1673770014_q_network.pt")

In [11]:
obs, _ = env.reset()
score = 0
for global_step in tqdm(range(args['total_timesteps'])):
    q_values = q_network(torch.Tensor(obs).to(device))
    action = torch.argmax(q_values).item()
    
    next_obs, reward, terminate, truncate, info = env.step(action)
    score += reward
    obs = next_obs
    
    if terminate:
        obs, _ = env.reset()
        writer.add_scalar("charts/episodic_return", score, global_step)
        wandb.log({"charts/episodic_return": score}, step=global_step)
        score = 0

 32%|███▏      | 96590/300000 [01:52<03:56, 858.71it/s] 


KeyboardInterrupt: 