In [12]:
import gymnasium as gym
import torch
import wandb
import torch.nn as nn
import numpy as np
from collections import deque

In [5]:
bs = 16
gamma = 0.99
epsilon = 1
lr=1e-4
ENV_NAME =  "CartPole-v1"
replay_memory_max_size = 10000
number_of_episodes = 500
sync_every_n_steps = 500
max_episode_length = 500
epsilon_annealing_steps = 1000
loss_fn = nn.SmoothL1Loss()

In [6]:
config = {
    "learning_rate": lr,
    "architecture": "DQN",
    "environment": ENV_NAME,
    "epsilon": epsilon,
    "gamma":gamma,
    "bs":bs,
    "replay_memory_max_size":replay_memory_max_size,
    "number_of_episodes":number_of_episodes,
    "max_episode_length":max_episode_length,
    "sync_every_n_steps": sync_every_n_steps,
    "epsilon_annealing_steps":epsilon_annealing_steps,
    "loss": str(loss_fn),
    }
config

{'learning_rate': 0.0001,
 'architecture': 'DQN',
 'environment': 'CartPole-v1',
 'epsilon': 1,
 'gamma': 0.99,
 'bs': 16,
 'replay_memory_max_size': 10000,
 'number_of_episodes': 500,
 'max_episode_length': 500,
 'sync_every_n_steps': 500,
 'epsilon_annealing_steps': 1000,
 'loss': 'SmoothL1Loss()'}

In [7]:
wandb.init(
    # set the wandb project where this run will be logged
    project="cartpole",
    
    # track hyperparameters and run metadata
    config=config
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgarethmd[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
class TorchEnv:
    def __init__(self, env):
        self.env = env
        self.n_observations = self.env.observation_space.shape[0]
        self.n_actions = self.env.action_space.n
        
    def step(self, a):
        s, r, terminated, truncated, info = self.env.step(a)
        return torch.tensor(s), torch.tensor(r), terminated, truncated, info
    
    def reset(self, *args, **kwargs):
        s, info = self.env.reset(*args, **kwargs)
        return torch.tensor(s), info
    
    def close(self):
        return self.env.close()
    
env = TorchEnv(gym.make(ENV_NAME))

In [11]:
class DQN(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, n_actions: int) -> None:
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
        self.n_actions = n_actions
    
    def forward(self, s: torch.tensor) -> torch.tensor:
        return self.net(s)
    
    def select_next_action(self, s: torch.tensor, epsilon: float) -> int:
        with torch.no_grad(): # no need to track gradients selecting next action
            use_greedy = np.random.binomial(1, 1-epsilon)
            if use_greedy:
                a = self(s).argmax().item()
            else:
                a = np.random.randint(self.n_actions)
            return a

In [13]:
class ExperienceReplay:
    def __init__(self, maxlen: int) -> None:
        self.deque = deque(maxlen=maxlen)
        
    def append(self, x: tuple) -> None:
        self.deque.append(x)
        
    def sample(self, bs: int) -> list:
        return random.sample(self.deque, min(len(self), bs))
        
    def __len__(self) -> int:
        return len(self.deque)

In [14]:
def fill(replay_memory: ExperienceReplay, env: TorchEnv) -> None:
    while len(replay_memory)<replay_memory_max_size:
        s, info = env.reset()
        terminated=False
        while terminated == False:
            a = np.random.randint(env.n_actions)
            s_prime, r, terminated, *_ = env.step(a)
            replay_memory.append((s, a, r, s_prime, terminated))
            s = s_prime

In [16]:
def pole_collate(batch: list) -> tuple:
    s_j, a_j, r_j, s_prime_j, terminated_j = list(zip(*batch))
    return torch.stack(s_j), torch.tensor(a_j), torch.tensor(r_j), torch.stack(s_prime_j), (~torch.tensor(terminated_j)).float()

def get_batch(self, batch: list, target_net:DQN=None, collate_fn:callable=pole_collate) -> tuple:
    if target_net is None:
        target_net = self

    s, a, r, s_prime, not_terminated = collate_fn(batch)
    y_hat = self(s).gather(1, a.unsqueeze(1)).squeeze() # gather the values at the indices given by the actions a 
    
    with torch.no_grad():
        next_values = target_net(s_prime).max(dim=1).values.clone().detach()
        y_j = r.detach().clone() + gamma * next_values * not_terminated # if terminated then not_terminated is set to zero (y_j = r)
    return y_hat, y_j

In [17]:
replay_memory = ExperienceReplay(replay_memory_max_size)
fill(replay_memory, env) 

In [18]:
dqn = DQN(in_dim=env.n_observations, hidden_dim=64, n_actions=env.n_actions)
target_net = DQN(in_dim=env.n_observations, hidden_dim=64, n_actions=env.n_actions)
target_net.load_state_dict(dqn.state_dict())
optimizer = torch.optim.Adam(dqn.parameters(),  lr=lr)

In [None]:
step = 0

# Magic
#wandb.watch(dqn, log_freq=100)

for i in range(number_of_episodes):
    terminated = False
    s, info = env.reset(seed=42)
    episode_loss, episode_reward, episode_length, k  = 0, 0, 0, 0
    while terminated == False and k < max_episode_length:
        a = dqn.select_next_action(s, epsilon)
        s_prime, r, terminated, *_ = env.step(a)
        
        replay_memory.append((s, a, r, s_prime, terminated))
        batch = replay_memory.sample(bs)
        
        optimizer.zero_grad()
        
        y_hat, y = get_batch_efficient(dqn, batch, target_net=target_net)
        
        loss = loss_fn(y_hat, y)
        loss.backward()
        torch.nn.utils.clip_grad_value_(dqn.parameters(), 100)
        optimizer.step()
        if epsilon > 0.05 :
            epsilon -= (1 / epsilon_annealing_steps)
        
        if step % sync_every_n_steps == 0:
            target_net.load_state_dict(dqn.state_dict())
            
        s = s_prime
        
        episode_loss += loss.item()
        episode_reward += r.item()
        episode_length += 1
        k += 1
        step += 1
            
    if i % 100 == 0:
        wandb.log({"eposide":i,
                    "episode_loss": episode_loss, 
                   "reward": episode_reward,
                   "step":step
                  })
        print({"eposide":i,
                    "episode_loss": episode_loss / k, 
                   "reward": episode_reward,
                   "step":step
                  })
model_path = 'cartpole.pth'
torch.save(dqn.state_dict(), model_path)
wandb.log_model(name=f"cartpole-{wandb.run.id}", path=model_path)
env.close()

wandb.finish()