In [23]:
import gymnasium as gym
import torch
import wandb
import torch.nn as nn
import numpy as np
from collections import deque
import random

In [5]:
bs = 16
gamma = 0.99
epsilon = 1
lr=1e-4
ENV_NAME =  "CartPole-v1"
replay_memory_max_size = 10000
number_of_episodes = 500
sync_every_n_steps = 500
max_episode_length = 500
epsilon_annealing_steps = 1000
loss_fn = nn.SmoothL1Loss()

In [6]:
config = {
    "learning_rate": lr,
    "architecture": "DQN",
    "environment": ENV_NAME,
    "epsilon": epsilon,
    "gamma":gamma,
    "bs":bs,
    "replay_memory_max_size":replay_memory_max_size,
    "number_of_episodes":number_of_episodes,
    "max_episode_length":max_episode_length,
    "sync_every_n_steps": sync_every_n_steps,
    "epsilon_annealing_steps":epsilon_annealing_steps,
    "loss": str(loss_fn),
    }
config

{'learning_rate': 0.0001,
 'architecture': 'DQN',
 'environment': 'CartPole-v1',
 'epsilon': 1,
 'gamma': 0.99,
 'bs': 16,
 'replay_memory_max_size': 10000,
 'number_of_episodes': 500,
 'max_episode_length': 500,
 'sync_every_n_steps': 500,
 'epsilon_annealing_steps': 1000,
 'loss': 'SmoothL1Loss()'}

In [7]:
wandb.init(
    # set the wandb project where this run will be logged
    project="cartpole",
    
    # track hyperparameters and run metadata
    config=config
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mgarethmd[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
class TorchEnv:
    def __init__(self, env):
        self.env = env
        self.n_observations = self.env.observation_space.shape[0]
        self.n_actions = self.env.action_space.n
        
    def step(self, a):
        s, r, terminated, truncated, info = self.env.step(a)
        return torch.tensor(s), torch.tensor(r), terminated, truncated, info
    
    def reset(self, *args, **kwargs):
        s, info = self.env.reset(*args, **kwargs)
        return torch.tensor(s), info
    
    def close(self):
        return self.env.close()
    
env = TorchEnv(gym.make(ENV_NAME))

In [11]:
class DQN(nn.Module):
    def __init__(self, in_dim: int, hidden_dim: int, n_actions: int) -> None:
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, n_actions)
        )
        self.n_actions = n_actions
    
    def forward(self, s: torch.tensor) -> torch.tensor:
        return self.net(s)
    
    def select_next_action(self, s: torch.tensor, epsilon: float) -> int:
        with torch.no_grad(): # no need to track gradients selecting next action
            use_greedy = np.random.binomial(1, 1-epsilon)
            if use_greedy:
                a = self(s).argmax().item()
            else:
                a = np.random.randint(self.n_actions)
            return a

In [13]:
class ExperienceReplay:
    def __init__(self, maxlen: int) -> None:
        self.deque = deque(maxlen=maxlen)
        
    def append(self, x: tuple) -> None:
        self.deque.append(x)
        
    def sample(self, bs: int) -> list:
        return random.sample(self.deque, min(len(self), bs))
        
    def __len__(self) -> int:
        return len(self.deque)

In [14]:
def fill(replay_memory: ExperienceReplay, env: TorchEnv) -> None:
    while len(replay_memory)<replay_memory_max_size:
        s, info = env.reset()
        terminated=False
        while terminated == False:
            a = np.random.randint(env.n_actions)
            s_prime, r, terminated, *_ = env.step(a)
            replay_memory.append((s, a, r, s_prime, terminated))
            s = s_prime

In [19]:
def pole_collate(batch: list) -> tuple:
    s_j, a_j, r_j, s_prime_j, terminated_j = list(zip(*batch))
    return torch.stack(s_j), torch.tensor(a_j), torch.tensor(r_j), torch.stack(s_prime_j), (~torch.tensor(terminated_j)).float()

def train_batch(self, batch: list, target_net:DQN=None, collate_fn:callable=pole_collate) -> tuple:
    if target_net is None:
        target_net = self

    s, a, r, s_prime, not_terminated = collate_fn(batch)
    y_hat = self(s).gather(1, a.unsqueeze(1)).squeeze() # gather the values at the indices given by the actions a 
    
    with torch.no_grad():
        next_values = target_net(s_prime).max(dim=1).values.clone().detach()
        y_j = r.detach().clone() + gamma * next_values * not_terminated # if terminated then not_terminated is set to zero (y_j = r)
    return y_hat, y_j

In [20]:
replay_memory = ExperienceReplay(replay_memory_max_size)
fill(replay_memory, env) 

In [21]:
dqn = DQN(in_dim=env.n_observations, hidden_dim=64, n_actions=env.n_actions)
target_net = DQN(in_dim=env.n_observations, hidden_dim=64, n_actions=env.n_actions)
target_net.load_state_dict(dqn.state_dict())
optimizer = torch.optim.Adam(dqn.parameters(),  lr=lr)

### Training the model

In [24]:
dqn.train()
step = 0

# Magic
wandb.watch(dqn, log_freq=100)

for i in range(number_of_episodes):
    terminated = False
    s, info = env.reset(seed=42)
    episode_loss, episode_reward, episode_length, k  = 0, 0, 0, 0
    while terminated == False and k < max_episode_length:
        a = dqn.select_next_action(s, epsilon)
        s_prime, r, terminated, *_ = env.step(a)
        
        replay_memory.append((s, a, r, s_prime, terminated))
        batch = replay_memory.sample(bs)
        
        optimizer.zero_grad()
        
        y_hat, y = train_batch(dqn, batch, target_net=target_net)
        
        loss = loss_fn(y_hat, y)
        loss.backward()
        torch.nn.utils.clip_grad_value_(dqn.parameters(), 100)
        optimizer.step()
        if epsilon > 0.05 :
            epsilon -= (1 / epsilon_annealing_steps)
        
        if step % sync_every_n_steps == 0:
            target_net.load_state_dict(dqn.state_dict())
            
        s = s_prime
        
        episode_loss += loss.item()
        episode_reward += r.item()
        episode_length += 1
        k += 1
        step += 1
            
    if i % 100 == 0:
        wandb.log({"eposide":i,
                    "episode_loss": episode_loss, 
                   "reward": episode_reward,
                   "step":step
                  })
        print({"eposide":i,
                    "episode_loss": episode_loss / k, 
                   "reward": episode_reward,
                   "step":step
                  })
model_path = 'cartpole.pth'
torch.save(dqn.state_dict(), model_path)
wandb.log_model(name=f"cartpole-{wandb.run.id}", path=model_path)
env.close()

wandb.finish()

{'eposide': 0, 'episode_loss': 0.48330011336426987, 'reward': 38.0, 'step': 38}
{'eposide': 100, 'episode_loss': 0.09720437689684332, 'reward': 10.0, 'step': 1331}
{'eposide': 200, 'episode_loss': 0.2295423513278365, 'reward': 11.0, 'step': 2543}
{'eposide': 300, 'episode_loss': 0.3977079735107956, 'reward': 63.0, 'step': 6247}
{'eposide': 400, 'episode_loss': 0.1679787275143836, 'reward': 74.0, 'step': 17047}


0,1
episode_loss,▆▁▁█▄
eposide,▁▃▅▆█
reward,▄▁▁▇█
step,▁▂▂▄█

0,1
episode_loss,12.43043
eposide,400.0
reward,74.0
step,17047.0


### Validation of the model

In [25]:
for seed in range(10):
    MAX_EPISODE_LENGTH = 500
    episode_loss, episode_reward, episode_length, i = 0, 0, 0, 0
    s, info = env.reset(seed=seed)
    terminated = False
    while terminated == False and i < MAX_EPISODE_LENGTH:
        a = dqn.select_next_action(s, 0)
        s_prime, r, terminated, *_ = env.step(a)
        s = s_prime
        episode_reward += r.item()
        episode_length += 1
        i += 1

    print(f'episode_length {episode_length}, reward {episode_reward}')
env.close()

episode_length 24, reward 24.0
episode_length 107, reward 107.0
episode_length 21, reward 21.0
episode_length 105, reward 105.0
episode_length 16, reward 16.0
episode_length 100, reward 100.0
episode_length 107, reward 107.0
episode_length 98, reward 98.0
episode_length 109, reward 109.0
episode_length 99, reward 99.0


In [26]:
import torch.utils.data as data

In [30]:
class ReplayMemoryDataset(data.IterableDataset):
    def __init__(self, replay_memory: ExperienceReplay, collate_fn: callable = pole_collate):
        super(ReplayMemoryDataset).__init__()
        self.replay_memory = replay_memory
        self.collate_fn = collate_fn
        
    def __iter__(self)-> iter:
        yield self.collate_fn(self.replay_memory.sample(bs))

In [31]:
dataset = ReplayMemoryDataset(replay_memory)

In [32]:
batch = next(iter(dataset))
print(batch)

(tensor([[ 1.2041,  1.6328,  0.0747, -0.0289],
        [ 0.9637,  1.6179,  0.1783,  0.3031],
        [-0.0178, -0.2128,  0.1242,  0.5707],
        [ 1.0619,  1.6325,  0.0864, -0.0222],
        [-0.0130, -0.0131,  0.1059,  0.1751],
        [ 1.7256,  1.8057,  0.1470,  0.1708],
        [ 0.0033, -0.2045,  0.0754,  0.3856],
        [ 0.1661,  1.0959,  0.1886, -0.2241],
        [ 1.2761,  2.1833,  0.1693, -0.1414],
        [ 0.4107,  1.2739,  0.1410, -0.1383],
        [ 0.3332,  1.0948,  0.0983, -0.1981],
        [-0.0510,  0.1639,  0.2072,  0.2859],
        [ 1.6384,  1.8019,  0.1644,  0.2542],
        [ 0.0812,  0.9240,  0.1505, -0.4461],
        [-0.0252, -0.0163,  0.1310,  0.2475],
        [ 1.4565,  2.1706,  0.1929,  0.1448]]), tensor([1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 1]), tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]), tensor([[ 1.2367e+00,  1.8267e+00,  7.4113e-02, -2.9711e-01],
        [ 9.9602e-01,  1.8101e+00,  1.8436e-01,  7.1548e-02],
     

In [33]:
import lightning as L

In [57]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, dqn, target_net, env, replay_memory, config):
        super().__init__()
        self.dqn = dqn
        if target_net is None:
            target_net = self
        self.target_net = target_net
        self.target_net.load_state_dict(self.dqn.state_dict())
        self.env = env
        self.replay_memory = replay_memory
        self.step = 0
        self.save_hyperparameters(config)
        

    def dqn_loss(self, batch):
        s, a, r, s_prime, not_terminated = batch
        y_hat = self.dqn(s).gather(1, a.unsqueeze(1)).squeeze() # gather the values at the indices given by the actions a 
        
        with torch.no_grad():
            next_values = self.target_net(s_prime).max(dim=1).values.clone().detach()
            y_j = r.detach().clone() + gamma * next_values * not_terminated # if terminated then not_terminated is set to zero (y_j = r)
        
        loss = nn.functional.smooth_l1_loss(y_hat, y_j)
        return loss

    def play_episode(self):
        terminated = False
        s, info = self.env.reset()
        episode_loss, episode_reward, episode_length, k  = 0, 0, 0, 0
        while terminated == False and k < self.hparams.max_episode_length:
            a = self.dqn.select_next_action(s, self.hparams.epsilon)
            s_prime, r, terminated, *_ = self.env.step(a)

            self.replay_memory.append((s, a, r, s_prime, terminated))
            if self.hparams.epsilon > 0.05 :
                self.hparams.epsilon -= (1 / self.hparams.epsilon_annealing_steps)

            if self.step % self.hparams.sync_every_n_steps == 0:
                self.target_net.load_state_dict(dqn.state_dict())
            self.step += 1
            s = s_prime

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        self.play_episode()
        loss = self.dqn_loss(batch)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.dqn.parameters(), lr=self.hparams.learning_rate)
        return optimizer


In [58]:
model = LitAutoEncoder(dqn, target_net, env, replay_memory, config)
model.hparams.learning_rate

0.0001

In [59]:
trainer = L.Trainer(max_epochs=500)
trainer.fit(model, ReplayMemoryDataset(replay_memory))

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name       | Type | Params
------------------------------------
0 | dqn        | DQN  | 4.6 K 
1 | target_net | DQN  | 4.6 K 
------------------------------------
9.2 K     Trainable params
0         Non-trainable params
9.2 K     Total params
0.037     Total estimated model params size (MB)


Epoch 0: |          | 0/? [28:09<?, ?it/s], v_num=5]        
Epoch 0: |          | 0/? [12:23<?, ?it/s]
Epoch 0: |          | 0/? [03:38<?, ?it/s]
Epoch 0: |          | 0/? [02:50<?, ?it/s]


AttributeError: 'NoneType' object has no attribute '_log'

In [38]:
config

{'learning_rate': 0.0001,
 'architecture': 'DQN',
 'environment': 'CartPole-v1',
 'epsilon': 1,
 'gamma': 0.99,
 'bs': 16,
 'replay_memory_max_size': 10000,
 'number_of_episodes': 500,
 'max_episode_length': 500,
 'sync_every_n_steps': 500,
 'epsilon_annealing_steps': 1000,
 'loss': 'SmoothL1Loss()'}