In [1]:
import gymnasium as gym
import torch.nn as nn
import wandb
import time
import torch
import torch.optim as optim
import random
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
import gym_snakegame
from util.wrappers.add_channel_dimension import AddChannelDimension
from torch.utils.tensorboard import SummaryWriter
from util.buffer.ReplayBuffer import ReplayBuffer

In [2]:
args = {
    'env_id': 'gym_snakegame/SnakeGame-v0',
    'env_id_short' : 'SnakeGame-v0',
    'seed': 42,
    'cuda': True,
    'learning_rate' : 0.0003,
    'buffer_size' : 30000,
    'total_timesteps' : 20000000,
    'start_e' : 1, 
    'end_e' : 0.1, 
    'exploration_fraction' : 0.5,
    'wandb_project_name' : "dqn-Snakegame",
    'wandb_entity' : None,
    'learning_starts' : 30000,
    'train_frequency' : 1,
    'batch_size' : 128,
    'target_network_frequency' : 500,
    'gamma' : 0.99,
    'capture_video' : False
    }

device = torch.device("cuda" if torch.cuda.is_available() and args["cuda"] else "cpu")
print(device)
run_name=f"{args['env_id_short']}_{args['seed']}_{int(time.time())}"

cuda


In [3]:
# input.shape : (15, 15)
class DQN(nn.Module):
    def __init__(self, env):
        super().__init__()        
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        self.fc = nn.Sequential(
            nn.Linear(in_features=64*7*7 , out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=env.action_space.n)
        )

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0],-1)
        return self.fc(conv_out)

In [4]:
# # input.shape : (15, 15)
# class DQN(nn.Module):
#     def __init__(self, env):
#         super().__init__()
#         self.network = nn.Sequential(
#             nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3),
#             nn.ReLU(),
#             nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3),
#             nn.ReLU(),
#             nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3),
#             nn.ReLU(),
#             nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3),
#             nn.ReLU(),
#             nn.Flatten(),
#             nn.Linear(in_features=64*7*7, out_features=512),
#             nn.ReLU(),
#             nn.Linear(in_features=512, out_features=env.action_space.n)
#         )

#     def forward(self, x):
#         return self.network(x)

In [5]:
def linear_schedule(start_e: float, end_e:float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [6]:
env = gym.make(args['env_id'], size=15, n_target=1, render_mode='rgb_array')
env = gym.wrappers.RecordVideo(env, f"videos/{run_name}", 
                               episode_trigger=lambda x: x % 200 == 0,
                               name_prefix=run_name,
                               disable_logger=True)
env = AddChannelDimension(env)

  logger.warn(


In [7]:
# wandb.tensorboard.patch(root_logdir='runs')
wandb.init(
    # set the wandb project where this run will be logged
    name=run_name,
    project=args['wandb_project_name'],
    entity=args['wandb_entity'],
    # sync_tensorboard=True,
    config=args,
    monitor_gym=True,
    save_code=True
)

writer = SummaryWriter(f'runs/{run_name}')
writer.add_text(
    "hyperparameters",
    "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in args.items()])),
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33miamhelpingstar[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016666666666666666, max=1.0…

In [8]:
q_network = DQN(env).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=args['learning_rate'])
target_network = DQN(env).to(device)
target_network.load_state_dict(q_network.state_dict())

rb = ReplayBuffer(
    env.observation_space,
    args['buffer_size'],
    args['batch_size']
)

start_time = time.time()

obs, _ = env.reset()
score = 0
episode_cnt = 0
for global_step in tqdm(range(args['total_timesteps'])):
    epsilon = linear_schedule(args['start_e'], 
                              args['end_e'], 
                              args['exploration_fraction'] * args['total_timesteps'], 
                              global_step)
    if random.random() < epsilon:
        action = env.action_space.sample()
    else:
        q_values = q_network(torch.Tensor(obs / 5.0).to(device))
        action = torch.argmax(q_values).item()
    
    next_obs, reward, terminate, truncate, info = env.step(action)
    rb.store(obs, action, reward, next_obs, terminate)
    
    obs = next_obs
    score += reward 
    
    if terminate:
        obs, _ = env.reset()
        writer.add_scalar("charts/episodic_return", score, global_step)
        wandb.log({"charts/episodic_return": score}, step=global_step)
        score = 0
        episode_cnt += 1
        
    writer.add_scalar("charts/epsilon", epsilon, global_step)
    wandb.log({"charts/epsilon": epsilon}, step=global_step)
    
    if global_step > args['learning_starts']:
        if global_step % args['train_frequency'] == 0:
            
            samples = rb.sample_batch()
            states = torch.FloatTensor(samples['obs']).to(device)
            next_states = torch.FloatTensor(samples['next_obs']).to(device)
            actions = torch.LongTensor(samples['acts']).reshape(-1, 1).to(device)
            rewards = torch.FloatTensor(samples['rews']).reshape(-1, 1).to(device)
            dones = torch.FloatTensor(samples["done"].reshape(-1, 1)).to(device)
            
            with torch.no_grad():
                target_max, _ = target_network(next_states / 5.0).max(dim=1)
                td_target = rewards.flatten() + args['gamma'] * target_max * (1 - dones.flatten())
            old_val = q_network(states / 5.0).gather(1, actions).squeeze()
            loss = F.mse_loss(td_target, old_val)
            
            if global_step % 100 == 0:
                writer.add_scalar("losses/td_loss", loss, global_step)
                writer.add_scalar("losses/q_values", old_val.mean().item(), global_step)
                writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
                
                wandb.log({"losses/td_loss": loss, "losses/q_values": old_val.mean().item()}, step=global_step)
                wandb.log({"charts/SPS" : int(global_step / (time.time() - start_time))}, step=global_step)
            # optimize the model
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        
        # update the target network
        if global_step % args['target_network_frequency'] == 0:
            target_network.load_state_dict(q_network.state_dict())
env.close()
writer.close()
wandb.finish()

  1%|          | 117185/20000000 [10:39<41:15:27, 133.87it/s]