# Discrete Action Space

In [1]:
import torch.nn as nn
import torch
from torch.optim import AdamW
import torch.nn.functional as F
from torch.utils.data import Dataset, IterableDataset,DataLoader
import lightning as L
import gymnasium as gym
from gymnasium.wrappers import NormalizeObservation, NormalizeReward, RecordVideo, RecordEpisodeStatistics
import numpy as np
from lightning.pytorch.loggers import TensorBoardLogger

In [2]:
ENV_ID='CartPole-v1'
VIDEO_DIR ='../videos/'
LOG_DIR = '../tboard/'

In [3]:
NUM_ENVS=25
DISCOUNT_FACTOR = 0.99
MAX_STEP = 5000
MAX_EPOCHS = 100
BATCH_SIZE = 1024
LR = 0.0001

In [4]:
def create_env(env_name, num_envs):
  env = gym.vector.make(env_name, num_envs=num_envs, asynchronous=False)
  env = RecordEpisodeStatistics(env)
  env = NormalizeObservation(env)
  env = NormalizeReward(env)
  return env

## Policy Model

In [5]:
from typing import Any

class Policy(nn.Module):
    def __init__(self, num_features, num_actions,hidden_size=128) -> None:
        super().__init__()
        self.input = nn.Linear(in_features=num_features, out_features=hidden_size)
        self.hidden = nn.Linear(in_features=hidden_size, out_features=hidden_size)
        self.output = nn.Linear(in_features=hidden_size, out_features=num_actions)
        self.actions = np.arange(num_actions)

    def forward(self,x):
        x = x if torch.is_tensor(x) else torch.FloatTensor(x)
        x = self.input(x)
        x = F.relu(x)
        x = self.hidden(x)
        x = F.relu(x)
        x = self.output(x)
        x = F.softmax(x, dim=-1)
        return x

    @torch.no_grad()
    def pi(self,state):
        p = state if torch.is_tensor(state) else torch.FloatTensor(state)
        p = self.forward(p)
        actions = torch.multinomial(p,1)
        actions = actions.squeeze().numpy()
        return actions


    # def __call__(self, state) -> Any:
    #     size = state.shape[0]
    #     action = np.random.choice(self.actions, size=size)
    #     return action

## Dataset

In [6]:
class MyDataset(IterableDataset):
    def __init__(self,env,max_step,policy,discount_factor):
        super().__init__()
        self.env = env
        self.max_step = max_step
        self.policy = policy
        self.discount_factor = discount_factor

    def __iter__(self):
        rewards = []
        states = []
        actions = []
        returns = []
        dones = []
        state,_ = self.env.reset()
        for step in range(self.max_step):
            action = self.policy(state)
            # obs, rews, terminateds, truncateds, infos
            next_state,reward,done, truncated ,infos = self.env.step(action)

            states.append(state)
            actions.append(action)
            rewards.append(reward)
            dones.append(done)
            state = next_state

        next_return =  np.zeros(self.env.unwrapped.num_envs)
        for t in range(self.max_step-1,-1,-1):
            reward = rewards[t]
            return_ = reward + (1 - dones[t])*self.discount_factor*next_return
            returns.insert(0,return_)
            next_return = return_
        
        states =  np.concatenate(states, axis=0).astype(np.float32) 
        returns = np.concatenate(returns, axis=0).astype(np.float32) 
        actions = np.concatenate(actions, axis=0).astype(np.int64) 

        indices = np.arange(returns.shape[0])
        np.random.shuffle(indices)
        
        for i in indices:
            yield states[i],actions[i],returns[i]
            

## Utility Functions

In [7]:
from base64 import b64encode
from IPython.display import HTML

def test_env(env_name, policy, obs_rms):
  env = gym.make(env_name,render_mode='rgb_array')
  env = RecordVideo(env, VIDEO_DIR, episode_trigger=lambda e: True)
  env = NormalizeObservation(env)
  env.obs_rms = obs_rms

  for episode in range(10):
    done = False
    obs,_ = env.reset()
    while not done:
      action = policy(obs)
      obs, _, done, _ ,_= env.step(action)
  env.close()
  del env


def display_video(episode=0):
  video_file = open(f'{VIDEO_DIR}/rl-video-episode-{episode}.mp4', "r+b").read()
  video_url = f"data:video/mp4;base64,{b64encode(video_file).decode()}"
  return HTML(f"<video width=600 controls><source src='{video_url}'></video>")

## Training Model

In [8]:
class Reinforce(L.LightningModule):
    def __init__(self,env_id, num_envs,lr = 1e-3, entropy_coeff=0.01, hidden_size=64, discount_factor=0.99, max_step=100, batch_size=64):
        super().__init__()
        self.env = create_env(env_name=env_id,num_envs=num_envs)
        num_features = self.env.unwrapped.single_observation_space.shape[0]
        num_actions = self.env.unwrapped.single_action_space.n
        self.model=Policy(num_features, num_actions,hidden_size=hidden_size)
        self.lr = lr
        self.save_hyperparameters()

    def training_step(self, batch, batch_idx):
        state,action,returns = batch
        
        action = action.reshape(-1,1)
        returns = returns.reshape(-1,1)

        p_a = self.model(state) 
        log_p_a = torch.log(p_a+ 1e-6)
        entropy = - torch.sum(p_a * log_p_a, dim=-1, keepdim=True)
        log_p_a = log_p_a.gather(1, action)
        E_g = -returns*log_p_a
        loss = (E_g - self.hparams.entropy_coeff*entropy).mean()
        self.log("episode/Train Loss", loss)
        return loss
    
    def on_train_epoch_end(self):
        self.log("episode/Return", self.env.return_queue[-1])

    def train_dataloader(self):
        train_ds = MyDataset(env=self.env, discount_factor=self.hparams.discount_factor, max_step=self.hparams.max_step,policy=self.model.pi,)
        train_dl = DataLoader(train_ds, batch_size=self.hparams.batch_size)
        return train_dl

    def configure_optimizers(self):
        optimizer = AdamW(self.model.parameters(), lr=self.lr)
        return optimizer

In [9]:
reinforce = Reinforce(env_id=ENV_ID, 
                  lr=LR, 
                  num_envs=NUM_ENVS, 
                  discount_factor=DISCOUNT_FACTOR,
                  batch_size=BATCH_SIZE, 
                  max_step=MAX_STEP)

  gym.logger.warn(
  logger.warn(


In [10]:
reinforce

Reinforce(
  (model): Policy(
    (input): Linear(in_features=4, out_features=64, bias=True)
    (hidden): Linear(in_features=64, out_features=64, bias=True)
    (output): Linear(in_features=64, out_features=2, bias=True)
  )
)

In [11]:
trainer = L.Trainer(
    accelerator='cpu',
    max_epochs=MAX_EPOCHS,
    logger=TensorBoardLogger(save_dir=LOG_DIR)
)

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [12]:
!rm -r ../tboard/
!rm -r ../videos/
!mkdir ../tboard/
!mkdir ../videos/

In [13]:
%load_ext tensorboard
%tensorboard --logdir ../tboard/

Reusing TensorBoard on port 6006 (pid 935941), started 10:35:35 ago. (Use '!kill 935941' to kill it.)

In [14]:
trainer.fit(model=reinforce,)

Missing logger folder: ../tboard/lightning_logs

  | Name  | Type   | Params
---------------------------------
0 | model | Policy | 4.6 K 
---------------------------------
4.6 K     Trainable params
0         Non-trainable params
4.6 K     Total params
0.018     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  logger.warn(
`Trainer.fit` stopped: `max_epochs=100` reached.


In [15]:
test_env(env_name=ENV_ID, policy= reinforce.model.pi, obs_rms=reinforce.env.obs_rms )

  logger.warn(
  logger.warn(


Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-0.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-0.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-0.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-1.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-1.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-1.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-2.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-2.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-2.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-3.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-3.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-3.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-4.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-4.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-4.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-5.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-5.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-5.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-6.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-6.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-6.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-7.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-7.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-7.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-8.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-8.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-8.mp4
Moviepy - Building video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-9.mp4.
Moviepy - Writing video /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-9.mp4



                                                               

Moviepy - Done !
Moviepy - video ready /home/daniel/src/advanced_rl_pg_methods_complete/videos/rl-video-episode-9.mp4




In [17]:
display_video(episode=0)