# Pytorch lightning bolts SAC algo

Original file: 
https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/sac_model.py#L28-L384

Imports:

In [None]:
import argparse
from typing import Dict, List, Tuple

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import Tensor, optim
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

In [None]:
'''
DONE from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset
from pl_bolts.models.rl.common.agents import SoftActorCriticAgent
from pl_bolts.models.rl.common.memory import MultiStepBuffer
from pl_bolts.models.rl.common.networks import MLP, ContinuousMLP
'''

## *Experience* and *ExperienceSourceDataset*

Source: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/datamodules/experience_source.py

This part replaces:

In [None]:
'''
from pl_bolts.datamodules.experience_source import Experience, ExperienceSourceDataset
'''

In [1]:
from abc import ABC
from collections import deque, namedtuple
from typing import Callable, Iterator, List, Tuple

import torch
from torch.utils.data import IterableDataset

Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"])


class ExperienceSourceDataset(IterableDataset):
    """Basic experience source dataset.
    Takes a generate_batch function that returns an iterator. The logic for the experience source and how the batch is
    generated is defined the Lightning model itself
    """

    def __init__(self, generate_batch: Callable) -> None:
        self.generate_batch = generate_batch

    def __iter__(self) -> Iterator:
        iterator = self.generate_batch()
        return iterator

## *MultipStepBuffer*

Source: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/common/memory.py

This part replaces:

In [None]:
'''
from pl_bolts.models.rl.common.memory import MultiStepBuffer
'''

In [2]:
# Named tuple for storing experience steps gathered in training
import collections
from collections import deque, namedtuple
from typing import List, Tuple, Union

import numpy as np

Experience = namedtuple("Experience", field_names=["state", "action", "reward", "done", "new_state"])


class Buffer:
    """Basic Buffer for storing a single experience at a time."""

    def __init__(self, capacity: int) -> None:
        """
        Args:
            capacity: size of the buffer
        """
        self.buffer = deque(maxlen=capacity)

    def __len__(self) -> None:
        return len(self.buffer)

    def append(self, experience: Experience) -> None:
        """Add experience to the buffer.
        Args:
            experience: tuple (state, action, reward, done, new_state)
        """
        self.buffer.append(experience)

    # pylint: disable=unused-argument
    def sample(self, *args) -> Union[Tuple, List[Tuple]]:
        """
        returns everything in the buffer so far it is then reset
        Returns:
            a batch of tuple np arrays of state, action, reward, done, next_state
        """
        states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in range(self.__len__())))

        self.buffer.clear()

        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(dones, dtype=np.bool),
            np.array(next_states),
        )


class ReplayBuffer(Buffer):
    """Replay Buffer for storing past experiences allowing the agent to learn from them."""

    def sample(self, batch_size: int) -> Tuple:
        """Takes a sample of the buffer.
        Args:
            batch_size: current batch_size
        Returns:
            a batch of tuple np arrays of state, action, reward, done, next_state
        """

        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        states, actions, rewards, dones, next_states = zip(*(self.buffer[idx] for idx in indices))

        return (
            np.array(states),
            np.array(actions),
            np.array(rewards, dtype=np.float32),
            np.array(dones, dtype=np.bool),
            np.array(next_states),
        )


class MultiStepBuffer(ReplayBuffer):
    """N Step Replay Buffer."""

    def __init__(self, capacity: int, n_steps: int = 1, gamma: float = 0.99) -> None:
        """
        Args:
            capacity: max number of experiences that will be stored in the buffer
            n_steps: number of steps used for calculating discounted reward/experience
            gamma: discount factor when calculating n_step discounted reward of the experience being stored in buffer
        """
        super().__init__(capacity)

        self.n_steps = n_steps
        self.gamma = gamma
        self.history = deque(maxlen=self.n_steps)
        self.exp_history_queue = deque()

    def append(self, exp: Experience) -> None:
        """Add experience to the buffer.
        Args:
            exp: tuple (state, action, reward, done, new_state)
        """
        self.update_history_queue(exp)  # add single step experience to history
        while self.exp_history_queue:  # go through all the n_steps that have been queued
            experiences = self.exp_history_queue.popleft()  # get the latest n_step experience from queue

            last_exp_state, tail_experiences = self.split_head_tail_exp(experiences)

            total_reward = self.discount_rewards(tail_experiences)

            n_step_exp = Experience(
                state=experiences[0].state,
                action=experiences[0].action,
                reward=total_reward,
                done=experiences[0].done,
                new_state=last_exp_state,
            )

            self.buffer.append(n_step_exp)  # add n_step experience to buffer

    def update_history_queue(self, exp) -> None:
        """Updates the experience history queue with the lastest experiences. In the event of an experience step is
        in the done state, the history will be incrementally appended to the queue, removing the tail of the
        history each time.
        Args:
            env_idx: index of the environment
            exp: the current experience
            history: history of experience steps for this environment
        """
        self.history.append(exp)

        # If there is a full history of step, append history to queue
        if len(self.history) == self.n_steps:
            self.exp_history_queue.append(list(self.history))

        if exp.done:
            if 0 < len(self.history) < self.n_steps:
                self.exp_history_queue.append(list(self.history))

            # generate tail of history, incrementally append history to queue
            while len(self.history) > 2:
                self.history.popleft()
                self.exp_history_queue.append(list(self.history))

            # when there are only 2 experiences left in the history,
            # append to the queue then update the env stats and reset the environment
            if len(self.history) > 1:
                self.history.popleft()
                self.exp_history_queue.append(list(self.history))

            # Clear that last tail in the history once all others have been added to the queue
            self.history.clear()

    def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tuple[Experience]]:
        """Takes in a tuple of experiences and returns the last state and tail experiences based on if the last
        state is the end of an episode.
        Args:
            experiences: Tuple of N Experience
        Returns:
            last state (Array or None) and remaining Experience
        """
        last_exp_state = experiences[-1].new_state
        tail_experiences = experiences

        if experiences[-1].done and len(experiences) <= self.n_steps:
            tail_experiences = experiences

        return last_exp_state, tail_experiences

    def discount_rewards(self, experiences: Tuple[Experience]) -> float:
        """Calculates the discounted reward over N experiences.
        Args:
            experiences: Tuple of Experience
        Returns:
            total discounted reward
        """
        total_reward = 0.0
        for exp in reversed(experiences):
            total_reward = (self.gamma * total_reward) + exp.reward
        return total_reward


## *MLP* and *ContinuousMLP*

Source: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/common/networks.py

Imports a custom MultiVariateNormal: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/common/distributions.py

This part replaces:

In [None]:
'''
from pl_bolts.models.rl.common.networks import MLP, ContinuousMLP
'''

In [59]:
import math
from typing import Tuple

import numpy as np
import torch
from torch import FloatTensor, Tensor, nn
from torch.distributions import Categorical, Normal, MultivariateNormal
from torch.nn import functional as F

class TanhMultivariateNormal(torch.distributions.MultivariateNormal):
    """The distribution of X is an affine of tanh applied on a normal distribution.
    X = action_scale * tanh(Z) + action_bias
    Z ~ Normal(mean, variance)
    
    AJ Zerouali, 23/06/21: They forgot about the devices
    """

    def __init__(self, action_bias, action_scale, **kwargs):
        super().__init__(**kwargs)

        self.action_bias = action_bias
        self.action_scale = action_scale

    def rsample_with_z(self, sample_shape=torch.Size()):
        """Samples X using reparametrization trick with the intermediate variable Z.
        Returns:
            Sampled X and Z
        """
        z = super().rsample()
        '''
        # DEBUG
        print(f"z.device = {z.device}")
        print(f"type(z) = {type(z)}")
        print(f"self.action_scale.device = {self.action_scale.device}")
        print(f"self.action_bias.device = {self.action_bias.device}")
        #print(f"next(self.parameters()).is_cuda = {next(self.parameters()).is_cuda}")
        '''
        
        action_scale = torch.Tensor(self.action_scale).to(z.device)
        action_bias = torch.Tensor(self.action_bias).to(z.device)
        
        output = (action_scale * torch.tanh(z) + action_bias, z)
        
        return output

    def log_prob_with_z(self, value, z):
        """Computes the log probability of a sampled X.
        Refer to the original paper of SAC for more details in equation (20), (21)
        Args:
            value: the value of X
            z: the value of Z
        Returns:
            Log probability of the sample
        """
        action_scale = torch.Tensor(self.action_scale).to(z.device)
        action_bias = torch.Tensor(self.action_bias).to(z.device)
        
        value = (value - action_bias) / action_scale
        z_logprob = super().log_prob(z)
        correction = torch.log(action_scale * (1 - value ** 2) + 1e-7).sum(1)
        return z_logprob - correction

    def rsample_and_log_prob(self, sample_shape=torch.Size()):
        """Samples X and computes the log probability of the sample.
        Returns:
            Sampled X and log probability
        """
        
        z = super().rsample()
        z_logprob = super().log_prob(z)
        value = torch.tanh(z)
        
        action_scale = torch.Tensor(self.action_scale).to(z.device)
        action_bias = torch.Tensor(self.action_bias).to(z.device)        
        
        correction = torch.log(action_scale * (1 - value ** 2) + 1e-7).sum(1)
        return action_scale * value + action_bias, z_logprob - correction

    def rsample(self, sample_shape=torch.Size()):
        fz, z = self.rsample_with_z(sample_shape)
        return fz

    def log_prob(self, value):
        
        action_scale = torch.Tensor(self.action_scale).to(value.device)
        action_bias = torch.Tensor(self.action_bias).to(value.device)
        
        value = (value - action_bias) / action_scale
        z = torch.log(1 + value) / 2 - torch.log(1 - value) / 2
        return self.log_prob_with_z(value, z)



In [54]:
class MLP(nn.Module):
    """Simple MLP network."""

    def __init__(self, input_shape: Tuple[int], n_actions: int, hidden_size: int = 128):
        """
        Args:
            input_shape: observation shape of the environment
            n_actions: number of discrete actions available in the environment
            hidden_size: size of hidden layers
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_shape[0], hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, n_actions),
        )

    def forward(self, input_x):
        """Forward pass through network.
        Args:
            x: input to network
        Returns:
            output of network
        """
        return self.net(input_x.float())

class ContinuousMLP(nn.Module):
    """MLP network that outputs continuous value via Gaussian distribution."""

    def __init__(
        self,
        input_shape: Tuple[int],
        n_actions: int,
        hidden_size: int = 128,
        action_bias: int = 0,
        action_scale: int = 1,
    ):
        """
        Args:
            input_shape: observation shape of the environment
            n_actions: dimension of actions in the environment
            hidden_size: size of hidden layers
            action_bias: the center of the action space
            action_scale: the scale of the action space
        """
        super().__init__()
        self.action_bias = action_bias
        self.action_scale = action_scale

        self.shared_net = nn.Sequential(
            nn.Linear(input_shape[0], hidden_size), nn.ReLU(), nn.Linear(hidden_size, hidden_size), nn.ReLU()
        )
        self.mean_layer = nn.Linear(hidden_size, n_actions)
        self.logstd_layer = nn.Linear(hidden_size, n_actions)

    def forward(self, x: FloatTensor) -> TanhMultivariateNormal:
        """Forward pass through network. Calculates the action distribution.
        Args:
            x: input to network
        Returns:
            action distribution
        """
        # DEBUG
        #print(f"x.device = {x.device}")
        #print(f"next(self.parameters()).is_cuda = {next(self.parameters()).is_cuda}")
        
        x = self.shared_net(x.float())
        batch_mean = self.mean_layer(x)
        logstd = torch.clamp(self.logstd_layer(x), -20, 2)
        batch_scale_tril = torch.diag_embed(torch.exp(logstd))
        output = TanhMultivariateNormal(action_bias=self.action_bias, 
                                        action_scale=self.action_scale, 
                                        loc=batch_mean, 
                                        scale_tril=batch_scale_tril,)
        return output

    def get_action(self, x: FloatTensor) -> Tensor:
        """Get the action greedily (without sampling)
        Args:
            x: input to network
        Returns:
            mean action
        """
        x = self.shared_net(x.float())
        batch_mean = self.mean_layer(x)
        return self.action_scale * torch.tanh(batch_mean) + self.action_bias

## *Agent* and *SoftActorCriticAgent*

Source: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/common/agents.py

This part replaces:

In [None]:
'''
from pl_bolts.models.rl.common.agents import SoftActorCriticAgent
'''

In [4]:
from abc import ABC
from typing import List

import numpy as np
import torch
from torch import Tensor, nn
from torch.nn import functional as F


class Agent(ABC):
    """Basic agent that always returns 0."""

    def __init__(self, net: nn.Module):
        self.net = net

    def __call__(self, state: Tensor, device: str, *args, **kwargs) -> List[int]:
        """Using the given network, decide what action to carry.
        Args:
            state: current state of the environment
            device: device used for current batch
        Returns:
            action
        """
        return [0]

class SoftActorCriticAgent(Agent):
    """Actor-Critic based agent that returns a continuous action based on the policy."""

    def __call__(self, states: Tensor, device: str) -> List[float]:
        """Takes in the current state and returns the action based on the agents policy.
        Args:
            states: current state of the environment
            device: the device used for the current batch
        Returns:
            action defined by policy
        """
        if not isinstance(states, list):
            states = [states]

        if not isinstance(states, Tensor):
            states = torch.tensor(states, device=device)

        dist = self.net(states)
        actions = [a for a in dist.sample().cpu().numpy()]

        return actions

    def get_action(self, states: Tensor, device: str) -> List[float]:
        """Get the action greedily (without sampling)
        Args:
            states: current state of the environment
            device: the device used for the current batch
        Returns:
            action defined by policy
        """
        if not isinstance(states, list):
            states = [states]

        if not isinstance(states, Tensor):
            states = torch.tensor(states, device=device)

        actions = [self.net.get_action(states).cpu().numpy()]

        return actions

## Soft Actor-Critic Algo implementation:

Source: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/sac_model.py

In [5]:
import argparse
from typing import Dict, List, Tuple

import numpy as np
import torch
import gym
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import Tensor, optim
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader



In [6]:
class SAC(LightningModule):
    def __init__(
        self,
        env: str,
        eps_start: float = 1.0,
        eps_end: float = 0.02,
        eps_last_frame: int = 150000,
        sync_rate: int = 1,
        gamma: float = 0.99,
        policy_learning_rate: float = 3e-4,
        q_learning_rate: float = 3e-4,
        target_alpha: float = 5e-3,
        batch_size: int = 128,
        replay_size: int = 1000000,
        warm_start_size: int = 10000,
        avg_reward_len: int = 100,
        min_episode_reward: int = -21,
        seed: int = 123,
        batches_per_epoch: int = 10000,
        n_steps: int = 1,
        **kwargs,
    ):
        super().__init__()

        # Environment
        self.env = gym.make(env)
        self.test_env = gym.make(env)

        self.obs_shape = self.env.observation_space.shape
        self.n_actions = self.env.action_space.shape[0]

        # Model Attributes
        self.buffer = None
        self.dataset = None

        self.policy = None
        self.q1 = None
        self.q2 = None
        self.target_q1 = None
        self.target_q2 = None
        self.build_networks()

        self.agent = SoftActorCriticAgent(self.policy)

        # Hyperparameters
        self.save_hyperparameters()

        # Metrics
        self.total_episode_steps = [0]
        self.total_rewards = [0]
        self.done_episodes = 0
        self.total_steps = 0

        # Average Rewards
        self.avg_reward_len = avg_reward_len

        for _ in range(avg_reward_len):
            self.total_rewards.append(torch.tensor(min_episode_reward, device=self.device))

        self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len :]))

        self.state = self.env.reset()

        self.automatic_optimization = False

    def run_n_episodes(self, env, n_epsiodes: int = 1) -> List[int]:
        """Carries out N episodes of the environment with the current agent without exploration.

        Args:
            env: environment to use, either train environment or test environment
            n_epsiodes: number of episodes to run
        """
        total_rewards = []

        for _ in range(n_epsiodes):
            episode_state = env.reset()
            done = False
            episode_reward = 0

            while not done:
                action = self.agent.get_action(episode_state, self.device)
                next_state, reward, done, _ = env.step(action[0])
                episode_state = next_state
                episode_reward += reward

            total_rewards.append(episode_reward)

        return total_rewards

    def populate(self, warm_start: int) -> None:
        """Populates the buffer with initial experience."""
        if warm_start > 0:
            self.state = self.env.reset()

            for _ in range(warm_start):
                action = self.agent(self.state, self.device)
                next_state, reward, done, _ = self.env.step(action[0])
                exp = Experience(state=self.state, action=action[0], reward=reward, done=done, new_state=next_state)
                self.buffer.append(exp)
                self.state = next_state

                if done:
                    self.state = self.env.reset()

    def build_networks(self) -> None:
        """Initializes the SAC policy and q networks (with targets)"""
        action_bias = torch.from_numpy((self.env.action_space.high + self.env.action_space.low) / 2)
        action_scale = torch.from_numpy((self.env.action_space.high - self.env.action_space.low) / 2)
        self.policy = ContinuousMLP(self.obs_shape, self.n_actions, action_bias=action_bias, action_scale=action_scale)

        concat_shape = [self.obs_shape[0] + self.n_actions]
        self.q1 = MLP(concat_shape, 1)
        self.q2 = MLP(concat_shape, 1)
        self.target_q1 = MLP(concat_shape, 1)
        self.target_q2 = MLP(concat_shape, 1)
        self.target_q1.load_state_dict(self.q1.state_dict())
        self.target_q2.load_state_dict(self.q2.state_dict())

    def soft_update_target(self, q_net, target_net):
        """Update the weights in target network using a weighted sum.

        w_target := (1-a) * w_target + a * w_q

        Args:
            q_net: the critic (q) network
            target_net: the target (q) network
        """
        for q_param, target_param in zip(q_net.parameters(), target_net.parameters()):
            target_param.data.copy_(
                (1.0 - self.hparams.target_alpha) * target_param.data + self.hparams.target_alpha * q_param
            )

    def forward(self, x: Tensor) -> Tensor:
        """Passes in a state x through the network and gets the q_values of each action as an output.

        Args:
            x: environment state

        Returns:
            q values
        """
        output = self.policy(x).sample()
        return output

    def train_batch(
        self,
    ) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
        """Contains the logic for generating a new batch of data to be passed to the DataLoader.

        Returns:
            yields a Experience tuple containing the state, action, reward, done and next_state.
        """
        episode_reward = 0
        episode_steps = 0

        while True:
            self.total_steps += 1
            action = self.agent(self.state, self.device)

            next_state, r, is_done, _ = self.env.step(action[0])

            episode_reward += r
            episode_steps += 1

            exp = Experience(state=self.state, action=action[0], reward=r, done=is_done, new_state=next_state)

            self.buffer.append(exp)
            self.state = next_state

            if is_done:
                self.done_episodes += 1
                self.total_rewards.append(episode_reward)
                self.total_episode_steps.append(episode_steps)
                self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len :]))
                self.state = self.env.reset()
                episode_steps = 0
                episode_reward = 0

            states, actions, rewards, dones, new_states = self.buffer.sample(self.hparams.batch_size)

            for idx, _ in enumerate(dones):
                yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx]

            # Simulates epochs
            if self.total_steps % self.hparams.batches_per_epoch == 0:
                break

    def loss(self, batch: Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]) -> Tuple[Tensor, Tensor, Tensor]:
        """Calculates the loss for SAC which contains a total of 3 losses.

        Args:
            batch: a batch of states, actions, rewards, dones, and next states
        """
        states, actions, rewards, dones, next_states = batch
        rewards = rewards.unsqueeze(-1)
        dones = dones.float().unsqueeze(-1)

        # actor
        dist = self.policy(states)
        new_actions, new_logprobs = dist.rsample_and_log_prob()
        new_logprobs = new_logprobs.unsqueeze(-1)

        new_states_actions = torch.cat((states, new_actions), 1)
        new_q1_values = self.q1(new_states_actions)
        new_q2_values = self.q2(new_states_actions)
        new_qmin_values = torch.min(new_q1_values, new_q2_values)

        policy_loss = (new_logprobs - new_qmin_values).mean()

        # critic
        states_actions = torch.cat((states, actions), 1)
        q1_values = self.q1(states_actions)
        q2_values = self.q2(states_actions)

        with torch.no_grad():
            next_dist = self.policy(next_states)
            new_next_actions, new_next_logprobs = next_dist.rsample_and_log_prob()
            new_next_logprobs = new_next_logprobs.unsqueeze(-1)

            new_next_states_actions = torch.cat((next_states, new_next_actions), 1)
            next_q1_values = self.target_q1(new_next_states_actions)
            next_q2_values = self.target_q2(new_next_states_actions)
            next_qmin_values = torch.min(next_q1_values, next_q2_values) - new_next_logprobs
            target_values = rewards + (1.0 - dones) * self.hparams.gamma * next_qmin_values

        q1_loss = F.mse_loss(q1_values, target_values)
        q2_loss = F.mse_loss(q2_values, target_values)

        return policy_loss, q1_loss, q2_loss

    def training_step(self, batch: Tuple[Tensor, Tensor], _):
        """Carries out a single step through the environment to update the replay buffer. Then calculates loss
        based on the minibatch recieved.

        Args:
            batch: current mini batch of replay data
            _: batch number, not used
        """
        policy_optim, q1_optim, q2_optim = self.optimizers()
        policy_loss, q1_loss, q2_loss = self.loss(batch)

        policy_optim.zero_grad()
        self.manual_backward(policy_loss)
        policy_optim.step()

        q1_optim.zero_grad()
        self.manual_backward(q1_loss)
        q1_optim.step()

        q2_optim.zero_grad()
        self.manual_backward(q2_loss)
        q2_optim.step()

        # Soft update of target network
        if self.global_step % self.hparams.sync_rate == 0:
            self.soft_update_target(self.q1, self.target_q1)
            self.soft_update_target(self.q2, self.target_q2)

        self.log_dict(
            {
                "total_reward": self.total_rewards[-1],
                "avg_reward": self.avg_rewards,
                "policy_loss": policy_loss,
                "q1_loss": q1_loss,
                "q2_loss": q2_loss,
                "episodes": self.done_episodes,
                "episode_steps": self.total_episode_steps[-1],
            }
        )

    def test_step(self, *args, **kwargs) -> Dict[str, Tensor]:
        """Evaluate the agent for 10 episodes."""
        test_reward = self.run_n_episodes(self.test_env, 1)
        avg_reward = sum(test_reward) / len(test_reward)
        return {"test_reward": avg_reward}

    def test_epoch_end(self, outputs) -> Dict[str, Tensor]:
        """Log the avg of the test results."""
        rewards = [x["test_reward"] for x in outputs]
        avg_reward = sum(rewards) / len(rewards)
        self.log("avg_test_reward", avg_reward)
        return {"avg_test_reward": avg_reward}

    def _dataloader(self) -> DataLoader:
        """Initialize the Replay Buffer dataset used for retrieving experiences."""
        self.buffer = MultiStepBuffer(self.hparams.replay_size, self.hparams.n_steps)
        self.populate(self.hparams.warm_start_size)

        self.dataset = ExperienceSourceDataset(self.train_batch)
        return DataLoader(dataset=self.dataset, batch_size=self.hparams.batch_size)

    def train_dataloader(self) -> DataLoader:
        """Get train loader."""
        return self._dataloader()

    def test_dataloader(self) -> DataLoader:
        """Get test loader."""
        return self._dataloader()

    def configure_optimizers(self) -> Tuple[Optimizer]:
        """Initialize Adam optimizer."""
        policy_optim = optim.Adam(self.policy.parameters(), self.hparams.policy_learning_rate)
        q1_optim = optim.Adam(self.q1.parameters(), self.hparams.q_learning_rate)
        q2_optim = optim.Adam(self.q2.parameters(), self.hparams.q_learning_rate)
        return policy_optim, q1_optim, q2_optim

    @staticmethod
    def add_model_specific_args(
        arg_parser: argparse.ArgumentParser,
    ) -> argparse.ArgumentParser:
        """Adds arguments for DQN model.

        Note:
            These params are fine tuned for Pong env.

        Args:
            arg_parser: parent parser
        """
        arg_parser.add_argument(
            "--sync_rate",
            type=int,
            default=1,
            help="how many frames do we update the target network",
        )
        arg_parser.add_argument(
            "--replay_size",
            type=int,
            default=1000000,
            help="capacity of the replay buffer",
        )
        arg_parser.add_argument(
            "--warm_start_size",
            type=int,
            default=10000,
            help="how many samples do we use to fill our buffer at the start of training",
        )
        arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch")
        arg_parser.add_argument("--batch_size", type=int, default=128, help="size of the batches")
        arg_parser.add_argument("--policy_lr", type=float, default=3e-4, help="policy learning rate")
        arg_parser.add_argument("--q_lr", type=float, default=3e-4, help="q learning rate")
        arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag")
        arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor")

        arg_parser.add_argument(
            "--avg_reward_len",
            type=int,
            default=100,
            help="how many episodes to include in avg reward",
        )
        arg_parser.add_argument(
            "--n_steps",
            type=int,
            default=1,
            help="how many frames do we update the target network",
        )

        return arg_parser


## Training:

Source: https://github.com/Lightning-Universe/lightning-bolts/blob/0.5.0/pl_bolts/models/rl/sac_model.py

In [7]:
def cli_main():
    parser = argparse.ArgumentParser(add_help=False)

    # trainer args
    parser = Trainer.add_argparse_args(parser)

    # model args
    parser = SAC.add_model_specific_args(parser)
    args = parser.parse_args()

    model = SAC(**args.__dict__)

    # save checkpoints based on avg_reward
    checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", verbose=True)

    seed_everything(123)
    trainer = Trainer.from_argparse_args(args, deterministic=True, callbacks=checkpoint_callback)

    trainer.fit(model)

In [1]:

import argparse
from typing import Dict, List, Tuple

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import Tensor, optim
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

import gym

In [2]:
from RL_lightning_bolts_template.pl_bolts_sac import SAC

In [3]:
sac_model = SAC(env = "HalfCheetah-v4",
                eps_start = 1.0,
                eps_end = 0.02,
                eps_last_frame = 150000,
                sync_rate = 5,
                gamma = 0.98,
                policy_learning_rate = 3e-4,
                q_learning_rate = 3e-4,
                target_alpha = 5e-3,
                batch_size = 128,
                replay_size = 1000000,
                warm_start_size = 1000,
                avg_reward_len = 100,
                seed = 101,
                batches_per_epoch = 1000,
                n_steps = 1,)

In [4]:
sac_model = SAC(env = "HalfCheetah-v4")

In [5]:
# save checkpoints based on avg_reward
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="avg_reward", mode="max", verbose=True)
seed_everything(101)
trainer = Trainer(accelerator="gpu", max_steps=100000,  callbacks = checkpoint_callback)

Global seed set to 101
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Moment de vérité

In [6]:
trainer.fit(sac_model)

You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type          | Params
--------------------------------------------
0 | policy    | ContinuousMLP | 20.4 K
1 | q1        | MLP           | 3.2 K 
2 | q2        | MLP           | 3.2 K 
3 | target_q1 | MLP           | 3.2 K 
4 | target_q2 | MLP           | 3.2 K 
--------------------------------------------
33.2 K    Trainable params
0         Non-trainable params
33.2 K    Total params
0.133     Total estimated model params size (MB)
  states = torch.tensor(states, device=device)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Epoch 0, global step 30000: 'avg_reward' reached -21.00000 (best -21.00000), saving model to '/notebooks/Deep_Forecasting/lightning_logs/version_7/checkpoints/epoch=0-step=30000.ckpt' as top 1
Epoch 1, global step 60000: 'avg_reward' was not in top 1
Epoch 2, global step 90000: 'avg_reward' was not in top 1
Epoch 3, global step 100002: 'avg_reward' was not in top 1
`Trainer.fit` stopped: `max_steps=100000` reached.


Follow the instructions here to get the Tensorboard link:
https://docs.paperspace.com/gradient/notebooks/tensorboard/

In [7]:
torch.save(sac_model.state_dict(), "./sac_lightning_v7_gym0262_230621")

In [16]:
test_model = SAC(env = "HalfCheetah-v4")

  deprecation(
  deprecation(


In [18]:
test_model.load_state_dict(torch.load("./sac_lightning_v6_230621"))

<All keys matched successfully>

## Testing:


In [1]:
import argparse
from typing import Dict, List, Tuple

import numpy as np
import torch
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint
from torch import Tensor, optim
from torch.nn import functional as F
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader

import gym



In [2]:
from RL_lightning_bolts_template.pl_bolts_sac import SAC

In [14]:
from gym import wrappers
from pyvirtualdisplay import Display

In [8]:
test_model = SAC(env = "HalfCheetah-v4")

In [9]:
test_model.load_state_dict(torch.load("./sac_lightning_v7_gym0262_230621"))

<All keys matched successfully>

In [10]:
test_model.to("cuda")

SAC(
  (policy): ContinuousMLP(
    (shared_net): Sequential(
      (0): Linear(in_features=17, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=128, bias=True)
      (3): ReLU()
    )
    (mean_layer): Linear(in_features=128, out_features=6, bias=True)
    (logstd_layer): Linear(in_features=128, out_features=6, bias=True)
  )
  (q1): MLP(
    (net): Sequential(
      (0): Linear(in_features=23, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (q2): MLP(
    (net): Sequential(
      (0): Linear(in_features=23, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (target_q1): MLP(
    (net): Sequential(
      (0): Linear(in_features=23, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=1, bias=True)
    )
  )
  (target_q2): MLP(
    (net): Sequential(
      (0

In [11]:
test_model.device

device(type='cuda', index=0)

In [15]:
s = env.reset()

In [16]:
type(s)

numpy.ndarray

In [17]:
s.shape

(17,)

In [20]:
test_model.agent.get_action(torch.Tensor(s), test_model.device)

ValueError: only one element tensors can be converted to Python scalars

In [22]:
test_model(torch.Tensor(s).to(test_model.device))

tensor([ 0.5478,  0.8676, -0.7957,  0.3067, -0.3930,  0.0500], device='cuda:0')

### gym 0.26.2

In this version, you need to include *render_mode* in the env instantiation to be able to record.

In [12]:
# Trigger for wrapper.RecordVideo() object
def epsd_trigger(episode_id: int) -> bool:
    '''
        Records all episodes
    '''
    if episode_id < 10:
        return True
    else:
        return False

In [15]:
virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

<pyvirtualdisplay.display.Display at 0x7f5176972d00>

In [18]:
env = gym.make("HalfCheetah-v4", render_mode = "rgb_array")
env = wrappers.RecordVideo(env = env, 
                           video_folder="vids/",
                           name_prefix="SAC_pl_gym0262_2306212113",
                           episode_trigger = epsd_trigger)

In [19]:
for episode in range(3):
    state, _ = env.reset()
    step = 0
    total_reward = 0
    done = False
    while not done and step<5001:
        step += 1
        #env.render()
        # Get action
        with torch.no_grad():
            # Convert to torch tensors
            state_ = torch.FloatTensor(np.array(state)).to(test_model.device)
            # Get actions and UPolicy output
            #action_ = actor(control, state_, t_)
            action_ = test_model(state_)
            # Get np arrays
            action = action_.cpu().detach().numpy()
        
        observation, reward, done, trunc, info = env.step(action)
        total_reward += reward
        if done:
            print("Episode: {0},\tSteps: {1},\tscore: {2}"
                  .format(episode, step, total_reward)
            )
            break
env.close()

Moviepy - Building video /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-0.mp4.
Moviepy - Writing video /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-0.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-0.mp4
Moviepy - Building video /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-1.mp4.
Moviepy - Writing video /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-1.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-1.mp4
Moviepy - Building video /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-2.mp4.
Moviepy - Writing video /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-2.mp4



                                                                

Moviepy - Done !
Moviepy - video ready /notebooks/Deep_Forecasting/vids/SAC_pl_gym0262_2306212113-episode-2.mp4


### gym v0.24.0

The code below does not work with gym 0.25.1

In [6]:
# Trigger for wrapper.RecordVideo() object
def epsd_trigger(episode_id: int) -> bool:
    '''
        Records all episodes
    '''
    if episode_id < 10:
        return True
    else:
        return False

In [7]:
virtual_display = Display(visible=0, size=(1400, 900))
virtual_display.start()

<pyvirtualdisplay.display.Display at 0x7fe2a928b340>

In [8]:
env = gym.make("HalfCheetah-v4")
env = wrappers.RecordVideo(env = env, 
                           video_folder="vids/",
                           name_prefix="SAC_pl_2306211908",
                           episode_trigger = epsd_trigger)

  logger.warn(
  logger.warn(
  logger.warn(


In [23]:
for episode in range(3):
    state = env.reset()
    step = 0
    total_reward = 0
    done = False
    while not done and step<5001:
        step += 1
        env.render()
        # Get action
        with torch.no_grad():
            # Convert to torch tensors
            state_ = torch.FloatTensor(np.array(state)).to(test_model.device)
            # Get actions and UPolicy output
            #action_ = actor(control, state_, t_)
            action_ = test_model(state_)
            # Get np arrays
            action = action_.cpu().detach().numpy()
        
        observation, reward, done, info = env.step(action)
        total_reward += reward
        if done:
            print("Episode: {0},\tSteps: {1},\tscore: {2}"
                  .format(episode, step, total_reward)
            )
            break
env.close()

Episode: 0,	Steps: 1000,	score: -338.7753927132449
Episode: 1,	Steps: 1000,	score: -155.0876923827959
Episode: 2,	Steps: 1000,	score: -137.1122364004564
