## [TRAIN A MARIO-PLAYING RL AGENT](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#train-a-mario-playing-rl-agent)

In [None]:
#%%bash
#pip install gym-super-mario-bros==7.4.0

In [32]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy

In [33]:
import gym
from gym.spaces import Box
from gym.wrappers import FrameStack

In [34]:
from nes_py.wrappers import JoypadSpace

In [35]:
import gym_super_mario_bros

### [RL Definitions](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#rl-definitions)

**Environment** The world that an agent interacts with and learns from.

**Action** $a$ : How the Agent responds to the Environment. The set of all possible Actions is called action-space.

**State** $s$ : The current characteristic of the Environment. The set of all possible States the Environment can be in is called state-space.

**Reward** $r$ : Reward is the key feedback from Environment to Agent. It is what drives the Agent to learn and to change its future action. An aggregation of rewards over multiple time steps is called Return.

**Optimal Action-Value function** $Q^*(s,a)$ : Gives the expected return if you start in state ss, take an arbitrary action aa, and then for each future time step take the action that maximizes returns. QQ can be said to stand for the “quality” of the action in a state. We try to approximate this function.

In [36]:
env = gym_super_mario_bros.make("SuperMarioBros-1-1-v3", render_mode='rgb', apply_api_compatibility=True)

  logger.warn(


In [37]:
env = JoypadSpace(env, [['right'], ['right', 'A']])

In [38]:
env.reset()

(array([[[104, 136, 252],
         [104, 136, 252],
         [104, 136, 252],
         ...,
         [104, 136, 252],
         [104, 136, 252],
         [104, 136, 252]],
 
        [[104, 136, 252],
         [104, 136, 252],
         [104, 136, 252],
         ...,
         [104, 136, 252],
         [104, 136, 252],
         [104, 136, 252]],
 
        [[104, 136, 252],
         [104, 136, 252],
         [104, 136, 252],
         ...,
         [104, 136, 252],
         [104, 136, 252],
         [104, 136, 252]],
 
        ...,
 
        [[228,  92,  16],
         [228,  92,  16],
         [228,  92,  16],
         ...,
         [228,  92,  16],
         [228,  92,  16],
         [228,  92,  16]],
 
        [[228,  92,  16],
         [228,  92,  16],
         [228,  92,  16],
         ...,
         [228,  92,  16],
         [228,  92,  16],
         [228,  92,  16]],
 
        [[228,  92,  16],
         [228,  92,  16],
         [228,  92,  16],
         ...,
         [228,  92,  16],
  

In [39]:
next_state, reward, done, trunc, info = env.step(action=0)

In [40]:
print(f'{next_state.shape} \t {reward} \t {done}\t{info}')

(240, 256, 3) 	 0.0 	 False	{'coins': 0, 'flag_get': False, 'life': 2, 'score': 0, 'stage': 1, 'status': 'small', 'time': 400, 'world': 1, 'x_pos': 40, 'y_pos': 79}


### [Preprocess Environment](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#preprocess-environment)

Environment data is returned to the agent in next_state. As you saw above, each state is represented by a [3, 240, 256] size array. Often that is more information than our agent needs; for instance, Mario’s actions do not depend on the color of the pipes or the sky!

In [41]:
class SkipFrame(gym.Wrapper):
    def __init__(self, env, skip):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        total_rewards = 0.0
        for i in range(self._skip):
            obs, reward, done, trunk, info = self.env.step(action)
            total_rewards += reward
            if done:
                break
        return obs, total_rewards, done, trunk, info


In [42]:
env = SkipFrame(env, skip=4)

In [43]:
class GrayScaleObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)
    
    def permute_orientation(self, observation):
        observation = np.transpose(observation, (2,0,1))
        observation = torch.tensor(observation.copy(), dtype=torch.float)
        return observation

    def observation(self, observation):
        observation = self.permute_orientation(observation)
        transform = T.Grayscale()
        observation = transform(observation)
        return observation

In [44]:
env = GrayScaleObservation(env)

In [45]:
class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env, shape):
        super().__init__(env)
        if isinstance(shape, int):
            self.shape=(shape, shape)
        else:
            self.shape = tuple(shape)
        
        obs_shape = self.shape + self.observation_space.shape[2:]
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype = np.uint8)

    def observation(self, observation):
        transforms = T.compose(
            [T.Resize(self.shape), T.Normalize(0, 255)]
        )

        observation = transforms(observation).squeeze(0)
        return observation


In [46]:
env = ResizeObservation(env, shape=84)

In [47]:
env = FrameStack(env, num_stack=4)

![](https://pytorch.org/tutorials/_images/mario_env.png)

### [Agent](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#agent)

In [48]:
class MarioNet(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        c, h, w = input_dim
        
        if h != 84:
            raise ValueError(f"Expecting input height: 84, got: {h}")
        if w != 84:
            raise ValueError(f"Expecting input width: 84, got: {w}")

        self.online = nn.Sequential(
            nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
            nn.RelU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )

        self.target = copy.deepcopy(self.online)
        
        for p in self.target.parameters():
            p.requires_grad = False

    def forward(self, input, model):
        if model == 'online':
            return self.online(input)
        elif model == 'target':
            return self.target(input)  

In [49]:
class MarioG1:
    def __init__(self, state_dim, action_dim, save_dir):
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.save_dir = save_dir
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.net = MarioNet(self.state_dim, self.action_dim).float()
        self.net = self.net.to(device=self.device)
        
        self.exploration_rate = 1
        self.exploration_rate_decay = 0.9999975
        self.exploration_rate_min = 0
        self.curr_step = 0

        self.save_every = 5e5

    def act(self, state):
        """
        For any given state, an agent can choose to do 
        the most optimal action (exploit) 
        or a random action (explore).
        """
        ### Explore
        if np.random.rand() < self.exploration_rate:
            action_idx = np.random.randint(self.action_dim)

        # Exploit
        else:
            state = state[0].__array__() if isinstance(state, tuple) else state.__array__()
            state = torch.tensor(state, device=self.device).unsqueeze(0)
            action_values = self.net(state, model='online')
            action_idx = torch.argmax(action_values, axis=1).item()
        
        self.exploration_rate *= self.exploration_rate_decay
        self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

        self.curr_step += 1
        return action_idx

    def cache(self, experience):
        """Add the experience to memory"""
        pass
    def recall(self):
        """"Sample experiences from memory"""
        pass
    def learn(self):
        """Update online action value Q function with a batch of experiences"""
        pass

In [50]:
class MarioG2(MarioG1):
    def __init__(self, state_dim, action_dim, save_dir):
        super().__init__(state_dim, action_dim, save_dir)
        self.memory = deque(maxlen=100000)
        self.batch_size = 32

    def cache(self, state, next_state, action, reward, done):
        def first_if_tuple(x):
            return x[0] if isinstance(x, tuple) else x
        
        state = first_if_tuple(state).__array__()
        next_state = first_if_tuple(next_state).__array__()

        state = torch.tensor(state, device=self.device)
        next_state = torch.tensor(next_state, device=self.device)
        action = torch.tensor([action], device=self.device)
        reward = torch.tensor([reward], device=self.device)
        done = torch.tensor([done], device=self.device)

        self.memory.append((state, next_state, action, reward, done,))

    def recall(self):
        batch = random.sample(self.memory, self.batch_size)
        state, next_state, action, reward, done = map(torch.stack, zip(*batch))
        return state, next_state, action.squeeze(), reward.squeeze(),  done.squeeze() 

### [Learn](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#learn)

#### Mario uses the DDQN algorithm under the hood. DDQN uses two ConvNets - $Q_{online}$ and $Q_{target}$ - that independently approximate the optimal action-value function.
#### In our implementation, we share feature generator features across $Q_{online}$ and $Q_{target}$, but maintain separate FC classifiers for each. $\theta_{target}$ (the parameters of $Q_{target}$) is frozen to prevent updation by backprop. Instead, it is periodically synced with $\theta_{online}$ (more on this later).

$TD_e=Q^∗_{online} (s,a)$ 

$a^{'} =\underset{a}{\arg\max} Q_{online}(s^{'},a)$ action and next state

$TD_t=r + \gamma Q^∗_{target} (s^{'},a^{'})$

In [51]:
class MarioG3(MarioG2):
    def __init__(self, state_dim, action_dim, save_dir):
        super().__init__(state_dim, action_dim, save_dir)
        self.gamma = 0.9
    def td_estimate(self, state, action):
        current_Q = self.net(state, model='online')[
            np.arange(0, self.batch_size), action
        ]

    @torch.no_grad()
    def td_target(self, reward, next_state, done):
        next_state_Q = self.net(next_state, model='online')
        best_action = torch.argmax(next_state_Q, axis=1)
        next_Q = self.net(next_state, model='target')[
            np.arange(0, self.batch_size), best_action
        ]
        return (reward + (1 - done.float()) * self.gamma * next_Q).float()

### [Updating the model](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#updating-the-model)

$θ_{online} ← θ_{online} + α∇(TD_e−TD_t)$

In [52]:
class MarioG4(MarioG3):
    def __init__(self, state_dim, action_dim, save_dir):
        super().__init__(state_dim, action_dim, save_dir)
        self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
        self.loss_fn = torch.nn.SmoothL1Loss()

    def update_Q_online(self, td_e, td_t):
        loss = self.loss_fn(td_e, td_t)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def sync_Q_target(self):
        self.net.target.load_state_dict(self.net.online.state_dict())
    

### [Save checkpoint](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#save-checkpoint)

In [66]:
home = Path.home()

In [84]:
home

PosixPath('/home/dulunche')

In [54]:
class MarioG5(MarioG4):
    def save(self):
        save_path = (
           home / self.save_dir / f'mario_net{self.curr_step // self.save_every}.chkpt'
        )
        torch.save(
            dict(model=self.net.state_dict(), exploration_rate = self.exploration_rate),
            save_path,
        )

        print(f"MarioNet saved to {save_path} at step {self.curr_step}")

#### [Final Mario](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#putting-it-all-together)

In [55]:
class Mario(MarioG5):
    def __init__(self, state_dim, action_dim, save_dir):
        super().__init__(state_dim, action_dim, save_dir)
        self.burnin = 1e4
        self.learn_every = 3 
        self.sync_every = 1e4
    
    def learn(self):
        if self.curr_step % self.sync_every == 0:
            self.sync_Q_target()
        if self.curr_step % self.save_every == 0:
            self.save()
        if self.curr_step < self.burnin:
            return None, None
        if self.curr_step % self.learn_every != 0:
            return None, None

        state, next_state, action, reward, done = self.recall()
        td_est = self.td_estimate(state, action)
        td_tgt = self.td_target(reward, next_state, done)

        loss = self.update_Q_online(td_est, td_tgt)

        return (td_est.mean().item(), loss) 

### [Logging](https://pytorch.org/tutorials/intermediate/mario_rl_tutorial.html#logging)

In [56]:
import time, datetime
import matplotlib.pyplot as plt

In [77]:
class MetricLogger:
    def __init__(self, save_dir) -> None:
        
        save_dir.mkdir(parents=True, exist_ok=True)
        self.save_log = save_dir / 'log'
        with open(self.save_log, 'w') as f:
            f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )

In [87]:
save_path = home / 'mario_net' / f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):<20}"