# Reinforcement Learning with PyTorch Lightning
### AJ Zerouali, 2023/06/19

Goals of this notebook:
- Get familiar with *pytorch-lightning*'s functionalities.
- Implement an RL algorithm using *pytorch-lightning*.


References:
- Tutorial on DQN (old): https://www.pytorchlightning.ai/blog/en-lightning-reinforcement-learning-building-a-dqn-with-pytorch-lightning

## 1) Deep Q-learning on cartpole

We follow the tutorial here:

https://www.pytorchlightning.ai/blog/en-lightning-reinforcement-learning-building-a-dqn-with-pytorch-lightning

GitHub:

https://github.com/Lightning-AI/lightning/blob/1.9.5/examples/pl_domain_templates/reinforce_learn_Qnet.py


In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import gym
import collections

import torch as th
import torch.nn as nn
import torch.optim as optim
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader
from torch.utils.data.dataset import IterableDataset

import pytorch_lightning as pl
# from pytorch_lightning import cli_lightning_logo, LightningModule, seed_everything, Trainer

In [2]:
import argparse
from collections import deque, namedtuple, OrderedDict
from typing import Iterator, List, Tuple

### 1.a - Usual RL classes

- Replay buffer.
- Main neural net class

In [3]:
# Neural net
class DQN(nn.Module):
    '''
        :param observation_space_dim:
        :param n_actions: No. of actions in discrete action space
        :param n_hidden: No. of units in hidden layer
    '''
    def __init__(self,
                 observation_space_dim: int,
                 n_actions: int,
                 n_hidden: int,
                ):
        
        # Init.
        super(DQN, self).__init__()
        self.net = nn.Sequential(nn.Linear(observation_space_dim, n_hidden),
                                 nn.ReLU(),
                                 nn.Linear(n_hidden, n_actions)
                                ),
        
    def forward(self, x):
        return self.net(x.float())

In [4]:
# Replay buffer
class ReplayBuffer:
    """
    Replay Buffer for storing past experiences allowing the agent to learn from them
    Args:
        capacity: size of the buffer
    """

    def __init__(self, buffer_size: int) -> None:
        self.buffer = collections.deque(maxlen=buffer_size)

    def __len__(self) -> None:
        return len(self.buffer)

    def append(self, state, action, reward, done, state_next) -> None:
        """
        Add experience to the buffer
        Args:
            experience: tuple (state, action, reward, done, new_state)
        """
        transition = (state, action, reward, done, state_next)
        self.buffer.append(transition)

    def sample(self, batch_size: int):
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*[self.buffer[idx] for idx in indices])

        return (np.array(states), np.array(actions), np.array(rewards, dtype=np.float32),
                np.array(dones, dtype=np.bool), np.array(next_states))



##### Test this replay buffer

In [None]:
# old code
'''
        ###############################
        ##### REPLAY BUFFER CLASS #####
        ###############################
'''
class replayBuffer(object):
    def __init__(self, buffer_size, name_buffer=''):
        self.name_buffer = name_buffer
        self.buffer_size = buffer_size  #choose buffer size
        self.num_exp = 0
        self.buffer = deque()

    def add(self, s, a, r, s2, t):
        experience=(s, a, r, s2, t)
        if self.num_exp < self.buffer_size:
            self.buffer.append(experience)
            self.num_exp +=1
        else:
            self.buffer.popleft()
            self.buffer.append(experience)

    def size(self):
        return self.buffer_size

    def count(self):
        return self.num_exp

    def sample(self, batch_size):
        if self.num_exp < batch_size:
            batch = random.sample(self.buffer, self.num_exp)
        else:
            batch = random.sample(self.buffer, batch_size)

        s, a, r, s2, t = map(np.stack, zip(*batch))

        return s, a, r, s2, t

    def clear(self):
        self.buffer = deque()
        self.num_exp=0

In [10]:
# Instantiate
replay_buffer = ReplayBuffer(1024)

In [11]:
# Fill buffer with random numbers
action_space_dim = 6
observation_space_dim = 17
n_transitions = 200
# s,a ,r, done, s_
s = np.random.uniform(size = (observation_space_dim,))
for i in range(n_transitions):
    # Generate
    a = np.random.normal(size = (action_space_dim,))
    r = np.random.uniform(low=-10.0, high=10.0)
    done = (i == 200-1)
    s_ = np.random.uniform(size = (observation_space_dim,))
    # Store in buffer
    replay_buffer.append(s, a, r, done, s_)
    # Update s
    s = s_
    

In [12]:
# Get a batch:
s_batch, a_batch, r_batch, done_batch, s_next_batch = replay_buffer.sample(64)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.array(dones, dtype=np.bool), np.array(next_states))


In [13]:
s_batch.shape

(64, 17)

In [14]:
s_batch[0,:]

array([0.64549436, 0.25604908, 0.13039131, 0.75992778, 0.49294839,
       0.02917197, 0.45112308, 0.66362518, 0.96089943, 0.20814165,
       0.23920129, 0.89725009, 0.10262505, 0.35812957, 0.43796378,
       0.07583482, 0.02144812])

**Comment:** The way they implement the agent is different from what we're used to (e.g. stable_baselines3). Notably:
- The network is not one of the agent's attributes.

### 2) Dataset

Lightning forces us to work with pytorch's dataset class. Obviously, this is to enventually instantiate a dataloader for the replay buffer, and the dataset class we will use will have a *ReplayBuffer* attribute.

In [5]:
class RLDataset(IterableDataset):
    """
    Iterable Dataset containing the ReplayBuffer
    which will be updated with new experiences during training
    Args:
        buffer: replay buffer
        sample_size: number of experiences to sample at a time
    """

    def __init__(self, replay_buffer: ReplayBuffer, batch_size: int = 128) -> None:
        self.replay_buffer = replay_buffer
        self.batch_size = batch_size # What does this do?

    def __iter__(self) -> Tuple:
        states, actions, rewards, dones, new_states = self.replay_buffer.sample(self.batch_size)
        for i in range(len(dones)): # AJZ: check that this works properly
            yield states[i], actions[i], rewards[i], dones[i], new_states[i]

##### Test RLDataset

In [32]:
# Instantiate replay buffer
replay_buffer = ReplayBuffer(1024)

# Fill buffer with random numbers
action_space_dim = 6
observation_space_dim = 17
n_transitions = 200
# s,a ,r, done, s_
s = np.random.uniform(size = (observation_space_dim,))
for i in range(n_transitions):
    # Generate
    a = np.random.normal(size = (action_space_dim,))
    r = np.random.uniform(low=-10.0, high=10.0)
    done = (i == 200-1)
    s_ = np.random.uniform(size = (observation_space_dim,))
    # Store in buffer
    replay_buffer.append(s, a, r, done, s_)
    # Update s
    s = s_

In [33]:
# Instantiate RLDataset and dataloader
DQN_dataset = RLDataset(replay_buffer= replay_buffer, batch_size=64)
DQN_dataloader = DataLoader(dataset = DQN_dataset,
                            batch_size = 64,
                            sampler = None)

In [34]:
s, a, r, done, s_ = next(iter(DQN_dataloader))

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.array(dones, dtype=np.bool), np.array(next_states))


In [35]:
print(f"s.shape = {s.shape}")
print(f"a.shape = {a.shape}")
print(f"r.shape = {r.shape}")
print(f"type(done) = {type(done)}")
print(f"s_.shape = {s_.shape}")

s.shape = torch.Size([64, 17])
a.shape = torch.Size([64, 6])
r.shape = torch.Size([64])
type(done) = <class 'torch.Tensor'>
s_.shape = torch.Size([64, 17])


In [36]:
done[0:]

tensor([False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False, False,
        False, False, False, False])

### Lightning module

The main part of the implementation, where we construct a subclass of *pl.LightningModule*. A minimal *lightningModule* has the following elements:

#### 1) The constructor:

In the constructor we initialize the environment, the DQN and its target, and the replay buffer.


#### 2) The *forward()* method:

For the case of deep Q-learning, the *forward()* method of the *LightningModule* will simply wrap that of the DQN module.

#### 3) Defining the loss function - *dqn_mse_loss()*:

For convenience, we will add a taget network. Our loss will then be the MSE of the state-action value function approximated by both the target and main DQN. We add a method that outputs the loss in the *LightningModule* class.

#### 4) The *configure_optimizers()* method:

This function will just assign the DQN parameters to an Adam optimizer. The optimizer is not implemented as an attribute however.

#### 5) The *train_dataloader()* method:

One of the key parts of *pytorch-lightning*. The name is a little misleading, because this method will not train the dataloader, it will rather construct the dataloader used for training.

#### 6) *train_step()*

This key method of the *LightningModule* class contains all the instructions of a training iteration. **More comments**

What's peculiar:
- The deep Q-learning algorithm is not written once and for all in one *train()* method.
- The implementation of the tutorial relies on a *play_step()* method of the agent.
- The *grad_zero()* and *step()* instructions are not explicitly called in what we implement. I think they're left in the trainer object training loop, and this is not discussed in this DQN tutorial. For an alternative approach, see https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/sac_model.py#L28-L384.
- The SAC implementation also relies on an agent class: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/common/agents.py. This agent class is called in the algorithm implementation for 2 things: (1) In the constructor for initialization; (2) In the training and deployment methods, to call only the *get_action()* method of the agent. I think *get_action()* can simply be a method of the *LightningModule*.
- The last two pages are from the Lightning-Bolts repo, a separate extra to Lightning.
- For important parts of the *LightningModule* implementation: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html.

I think it's better to implement the agent directly as a *LightningModule*

**To do list:**

    class SAC(LightningModule):
        def __init__(
            self,
            env: str,
            eps_start: float = 1.0,
            eps_end: float = 0.02,
            eps_last_frame: int = 150000,
            sync_rate: int = 1,
            gamma: float = 0.99,
            policy_learning_rate: float = 3e-4,
            q_learning_rate: float = 3e-4,
            target_alpha: float = 5e-3,
            batch_size: int = 128,
            replay_size: int = 1000000,
            warm_start_size: int = 10000,
            avg_reward_len: int = 100,
            min_episode_reward: int = -21,
            seed: int = 123,
            batches_per_epoch: int = 10000,
            n_steps: int = 1,
            **kwargs,
        ):
            super().__init__()
            #### ADD RL ATTRIBUTES HERE ####

        def run_n_episodes(self, env, n_epsiodes: int = 1) -> List[int]:
            """Carries out N episodes of the environment with the current agent without exploration.
            Args:
                env: environment to use, either train environment or test environment
                n_epsiodes: number of episodes to run
            """

        def populate(self, warm_start: int) -> None:
            """Populates the buffer with initial experience."""

        def build_networks(self) -> None:
            """Initializes the SAC policy and q networks (with targets)"""

        def soft_update_target(self, q_net, target_net):
            """Update the weights in target network using a weighted sum.

        def forward(self, x: Tensor) -> Tensor:
            """Passes in a state x through the network and gets the q_values of each action as an output.
            Args:
                x: environment state
            Returns:
                q values
            """
            output = self.policy(x).sample()
            return output

        def train_batch(
            self,
        ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
            """Contains the logic for generating a new batch of data to be passed to the DataLoader.
            Returns:
                yields a Experience tuple containing the state, action, reward, done and next_state.
            """

        def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
            """Calculates the loss for SAC which contains a total of 3 losses.
            Args:
                batch: a batch of states, actions, rewards, dones, and next states
            """
        def training_step(self, batch: Tuple[Tensor, Tensor], _):
            """Carries out a single step through the environment to update the replay buffer. Then calculates loss
            based on the minibatch recieved.
            Args:
                batch: current mini batch of replay data
                _: batch number, not used
            """

        def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
            """Evaluate the agent for 10 episodes."""

        def test_epoch_end(self, outputs) -> Dict[str, Tensor]:
            """Log the avg of the test results."""

        def _dataloader(self) -> DataLoader:
            """Initialize the Replay Buffer dataset used for retrieving experiences."""

        def train_dataloader(self) -> DataLoader:
            """Get train loader."""
            return self._dataloader()

        def test_dataloader(self) -> DataLoader:
            """Get test loader."""
            return self._dataloader()

        def configure_optimizers(self) -> Tuple[Optimizer]:
            """Initialize Adam optimizer."""
            policy_optim = optim.Adam(self.policy.parameters(), self.hparams.policy_learning_rate)
            q1_optim = optim.Adam(self.q1.parameters(), self.hparams.q_learning_rate)
            q2_optim = optim.Adam(self.q2.parameters(), self.hparams.q_learning_rate)
            return policy_optim, q1_optim, q2_optim

        @staticmethod
        def add_model_specific_args(
            arg_parser: argparse.ArgumentParser,
        ) -> argparse.ArgumentParser:
            """Adds arguments for DQN model.
            Note:
                These params are fine tuned for Pong env.
            Args:
                arg_parser: parent parser
            """

Now for the agent class, Lightning has the following abstract class:

    class Agent(ABC):
        """Basic agent that always returns 0."""

        def __init__(self, net: nn.Module):
            self.net = net

        def __call__(self, state: Tensor, device: str, *args, **kwargs) -> List[int]:
            """Using the given network, decide what action to carry.
            Args:
                state: current state of the environment
                device: device used for current batch
            Returns:
                action
            """
            return [0]


#### Agent class

#### Algorithm class

In [11]:
class DQN(pl.LightningModule):
    
    # Constructor (CRUCIAL)
    def __init__(self,
                 train_env: gym.Env,
                 gamma: float = 0.99,
                 dqn_lr: float = 1e-4,
                 batch_size: int = 128,
                 buffer_size: int = 1000000,
                 eps_start: float = 1.0,
                 eps_end: float = 0.02,
                 eps_last_frame: int = 150000,
                 sync_rate: int = 1,
                 target_alpha: float = 5e-3,
                 n_warmup_steps: int = 10000,
                 avg_reward_len: int = 100,
                 #min_episode_reward: int = -21, # ?
                 seed: int = 101,
                 batches_per_epoch: int = 10000,
                 n_steps: int = 1,
                 **kwargs,
                ):
        '''
            Explain constructor params...
        '''
        
        # Mandatory torch call
        super(DQN, self).__init__()
        
        # Assign constructor attr
        self.train_env = train_env
        self.test_env = None
        self.observation_space_shape = self.train_env.observation_space.shape
        ### Discrete action space here
        self.n_actions = self.train_env.action_space.n #### INCORRECT
        
        # Model attributes
        ### Don't assign the dataloader as a class attribute
        self.replay_buffer = None
        self.dataset = None
        self.net = None
        self.target_net = None
        ### Build model 
        self._build_model()
        
        # Save mdoel hparams (lightning mandatory call)
        self.save_hyperparameters()
        
        
        # Metrics
        self.total_episode_steps = [0]
        self.total_rewards = [0]
        self.done_episodes = 0
        self.total_steps = 0

        # Average Rewards
        self.avg_reward_len = avg_reward_len

        for _ in range(avg_reward_len):
            self.total_rewards.append(torch.tensor(min_episode_reward, device=self.device))

        self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len :]))

        
        # Transition book-keeping
        self.state = None
        self.replay_buffer = ReplayBuffer(buffer_size = self.hparams.buffer_size)
        
        # See docs (looks important)
        self.automatic_optimization = False
        
        
    # Build the model
    def _build_model(self) -> None:
        '''
            Initializes the DQN, and target net
        '''
        
        
        # Instantiate DQN
        ### Will use this class for box2d or classic control envs,
        ### in which case: observation_space_dim = observation_space_shape[0]
        self.net = DQN(observation_space_dim = self.hparams.observation_space_shape[0], 
                       n_actions = self.hparams.n_actions, 
                       n_hidden = 128)
        self.target_net = DQN(observation_space_dim = self.hparams.observation_space_shape[0],
                              n_actions = self.hparams.n_actions, 
                              n_hidden = 128)
        ### Copy params from DQN
        self.target_net.load_state_dict(self.net.state_dict())
    
    # Get actions from DQN (necessary for RL)
    def get_action(self, state: th.Tensor, device: th.device):
        """
            Computes action from DQN output
        """
        if not isinstance(state, th.Tensor):
            state = th.tensor(state, device=device)

        
        q_val = self.net(state)
        q_val_max, action_star = th.max(q_val, dim=1)
        return action_star.detach().cpu().numpy()
    
    
    # Run n_episodes (for testing)
    def run_n_episodes(self, 
                       test_env: gym.Env,
                       n_episodes: int = 1,
                       max_episode_length: int = 10000,
                      ):
        """
            Runs a number of episodes in a test environment without exploration.
            Actions are obtained from current DQN using self.get_actions().
            Called by the test_step() method.
            
            :param env: environment to use, either train environment or test environment
            :param n_episodes: number of episodes to run
            :param max_episode_length: Maximal num. of steps per episode
            
            :return total_rewards_hist:
        """
        # Init rwrds list
        total_rewards_hist = []
        
        # Main loop
        for i in range(n_episodes):
            
            # Initializations
            state = test_env.reset()
            done = False
            episode_tot_rwrd = 0
            step = 0
            
            # Episodic loop
            while not done and (step<max_episode_length):
                action = self.get_action(state, self.device)
                ## NOTE: gym >= 0.26.2 and gymnasium
                state_next, reward, done, truncated, _ \
                    = test_env.step(action[0])
                episode_tot_rwrd += reward
                state = state_next
                step+=1
            
            # Append episode total reward to hist
            total_rewards_hist.append(episode_tot_rwrd)
        
        # Output
        return total_rewards_hist
    
    # Necessary?
    def populate(self, n_warmup_steps: int) -> None:
        """
            Populates the replay buffer with the specified number
            of warmup transitions in the training environment.
            Resets env if done to continue warmup, and uses epsilon
            greedy policy.
            
            :param n_warmup_steps: Num. of warmup steps to make
        """
        if n_warmup_steps>0:
            self.state = self.train_env.reset()
            
            for i in range(n_warmup_steps):
                
                # Get action following eps-greedy policy
                ### Q: Where do we initialize self.epsilon?
                if np.random.random() < self.epsilon:
                    ### WARNING:Review action shape
                    action = np.array([self.test_env.action_space.sample()])
                else:
                    action = self.get_action(self.state, self.device)
                
                state_next, reward, done, truncated, _ \
                    = self.train_env.step(action[0])
                self.replay_buffer.append(self.state, action, reward, done, state_next)
                self.state = state_next
                
                if done:
                    self.state = self.train_env.reset()
            
            
    
    # Training dataloader assignment (CRUCIAL)
    def train_dataloader(self) -> DataLoader:
        """
            Initialize training dataloader
        """
        self.dataset = RLDataset(self.replay_buffer)
        return DataLoader(dataset = self.dataset, batch_size = self.hparams.batch_size)
    
    # Forward (CRUCIAL)
    def forward(self, x: th.Tensor) -> th.Tensor:
        """
            Passes in a state x through the network and gets the q_values of each action as an output.
            # NOTE: For an actor-critic algorithm with a stochastic policy, this function shou
        """
        output = self.net(x)
        return output
    
    # Optimizers' initialization f'n (CRUCIAL)
    def configure_optimizers(self) -> List[Optimizer]:
        """
            Initialize optimizers for all class networks.
            # NOTE: Should return a list or a tuple of Optimizer objects.
        """
        optimizer = optim.Adam(self.net.parameters(), lr = self.hparams.dqn_lr)
        return [optimizer]
    
    # Compute loss (CRUCIAL)
    def loss(self, batch: Tuple[th.Tensor, th.Tensor, th.Tensor, th.Tensor, th.Tensor])-> th.Tensor:
        """
            Method to calculate the loss value
            # NOTE: Review this part. I hardly recognize what this does.
        """
        state_b, action_b, reward_b, done_b, state_next_b = batch
        
        state_action_values = self.net(state_b).gather(1, 
                                                       action_b.unsqueeze(-1)).squeeze(-1)
        
        # Evaluate target 
        with th.no_grad():
            state_next_q_vals = self.target_net(state_next_b).max(1)[0]
            state_next_q_vals[done_b] = 0.0
            state_next_q_vals = state_next_q_vals.detach()
        
        # Bellman backup
        expected_state_action_vals = self.hparams.gamma*state_next_q_vals+reward_b
        
        # Output
        return nn.MSELoss()(state_action_values, expected_state_action_vals)
    
    # Training step (CRUCIAL)
    def training_step(self, batch):
        """
            Need to add comments here. Seems to be the the crux of the implementation.
            Where is this executed.
        """
        
        # Get optimizer
        dqn_optim = self.optimizers()
                
        # Get training batch
        
        # Compute loss
        loss = self.loss(batch)
        ### Clarify this condition
        #if self.trainer.use_dp or self.trainer.use_ddp2:
        
        # Gradient step
        ## NOTE: You use a manual backward here
        ## Important for actor-critic algos
        dqn_optim.zero_grad()
        self.manual_backward(loss)
        dqn_optim.step()
        
        # Update target net
        if self.global_step % self.hparams.sync_rate == 0:
            self.target_net.load_state_dict(self.net.state_dict())
        
        # Log dict
        self.log_dict(
            {
                "total_reward": self.total_rewards_hist[-1],
                "avg_reward": self.avg_rewards,
                "train_loss": loss,
                "episodes": self.done_episodes,
                "episode_steps": self.total_episode_steps[-1],
            }
        )
        
        # Output
        return OrderedDict({"loss": loss, "avg_reward": self.avg_rewards})

In [None]:
    '''
    # Initialize replay buffer
    def _init_replay_buffer(self):
        '''
            #Initializes the replay buffer
        '''
        
        # Instantiate replay buffer
        self.replay_buffer = ReplayBuffer(self.hparams.buffer_size)
        
        # Add "n_warmup_steps" transitions to the buffer
        self.populate(self.hparams.n_warmup_steps)
    '''

In [7]:
?th.max

[0;31mDocstring:[0m
max(input) -> Tensor

Returns the maximum value of all elements in the ``input`` tensor.

    This function produces deterministic (sub)gradients unlike ``max(dim=0)``

Args:
    input (Tensor): the input tensor.

Example::

    >>> a = torch.randn(1, 3)
    >>> a
    tensor([[ 0.6763,  0.7445, -2.2369]])
    >>> torch.max(a)
    tensor(0.7445)

.. function:: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor)
   :noindex:

Returns a namedtuple ``(values, indices)`` where ``values`` is the maximum
value of each row of the :attr:`input` tensor in the given dimension
:attr:`dim`. And ``indices`` is the index location of each maximum value found
(argmax).

If ``keepdim`` is ``True``, the output tensors are of the same size
as ``input`` except in the dimension ``dim`` where they are of size 1.
Otherwise, ``dim`` is squeezed (see :func:`torch.squeeze`), resulting
in the output tensors having 1 fewer dimension than ``input``.

.. note:: If there are mult