<a href="https://colab.research.google.com/github/l3u9/RL-mario-pytorch/blob/main/mario.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%bash
pip install gym-super-mario-bros==7.4.0

Defaulting to user installation because normal site-packages is not writeable


DEPRECATION: distro-info 0.23ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of distro-info or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063
DEPRECATION: python-debian 0.1.36ubuntu1 has a non-standard version number. pip 23.3 will enforce this behaviour change. A possible replacement is to upgrade to a newer version of python-debian or contact the author to suggest that they release a version with a conforming version number. Discussion can be found at https://github.com/pypa/pip/issues/12063


In [2]:
import torch
from torch import nn
from torchvision import transforms as T
from PIL import Image
import numpy as np
from pathlib import Path
from collections import deque
import random, datetime, os, copy

import gym
from gym.spaces import Box
from gym.wrappers import FrameStack

from nes_py.wrappers import JoypadSpace

import gym_super_mario_bros

강화학습의 개념

Environment: 에이전트가 상호작용하며 스스로 배우는 세계

Action a: 에이전트가 환경에 어떻게 응답하는지 행동을 나타낸다. 가능한 모든 행동의 집합을 행동 공간이라고 한다.

State s: 환경의 현재 특성 상태를 통해 나타낸다. 환경이 있을 수 있는 모든 가능한 상태 집합을 상태 공간이라고 한다.

Reward r: 포상은 환경에서 에이전트로 전달되는 피드백이다. 에이전트가 학습하고 향후 행동을 변경하도록 유도하는 것이다. 여러 시간 단계에 걸친 포상의 합을 return이라고 한다.

Action-Value-function Q'(s, a): 상태 s에서 시작하면 예상되는 리턴을 반환하고 임의의 행동 a를 선택한다. 그 다음 각각의 미래의 단계에서 포상의 합을 극대화 하는 행동을 선택하도록 한다. Q는 상태에서 행동의 "품질"을 나타낸다.

In [3]:
# Environment 초기화

if gym.__version__ < '0.26':
  env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0', new_step_api=True)
else:
  env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0', render_mode='rgb', apply_api_compatibility=True)

env = JoypadSpace(env, [["right"], ["right", 'A']])

env.reset()
next_state, reward, done, trunc, info = env.step(action=0)
print(f"{next_state.shape}, \n {reward}, \n {done}, \n {info}")


(240, 256, 3), 
 0.0, 
 False, 
 {'coins': 0, 'flag_get': False, 'life': 2, 'score': 0, 'stage': 1, 'status': 'small', 'time': 400, 'world': 1, 'x_pos': 40, 'y_pos': 79}


  logger.warn(
  logger.warn(


In [4]:
# 전처리
class SkipFrame(gym.Wrapper):
  def __init__(self, env, skip):
    super().__init__(env)
    self._skip = skip

  def step(self, action):
    total_reward = 0.0
    for i in range(self._skip):
      obs, reward, done, trunk, info = self.env.step(action)
      total_reward += reward
      if done:
        break
    return obs, total_reward, done, trunk, info


class GrayScaleObservation(gym.ObservationWrapper):
  def __init__(self, env):
    super().__init__(env)
    obs_shape = self.observation_space.shape[:2]
    self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

  def permute_orientation(self, observation):
    observation = np.transpose(observation, (2, 0, 1))
    observation = torch.tensor(observation.copy(), dtype=torch.float)
    return observation

  def observation(self, observation):
    observation = self.permute_orientation(observation)
    transform = T.Grayscale()
    observation = transform(observation)
    return observation

class ResizeObservation(gym.ObservationWrapper):
  def __init__(self, env, shape):
    super().__init__(env)
    if isinstance(shape, int):
      self.shape = (shape, shape)
    else:
      self.shape = tuple(shape)

    obs_shape = self.shape + self.observation_space.shape[2:]

    self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

  def observation(self, observation):
    transforms = T.Compose([T.Resize(self.shape), T.Normalize(0, 255)])
    observation = transforms(observation).squeeze(0)
    return observation


env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, shape=84)

if gym.__version__ < '0.26':
  env = FrameStack(env, num_stack=4, new_step_api=True)
else:
  env = FrameStack(env, num_stack=4)



In [5]:
from torch.nn.modules.conv import Conv2d
class MarioNet(nn.Module):

  def __init__(self, input_dim, output_dim):
    super().__init__()
    c,h,w = input_dim

    if h != 84:
      raise ValueError(f"Expecting input height: 84, got: {h}")
    if w!= 84:
      raise ValueError(f"Expecting input width: 84, got: {w}")

    self.online = nn.Sequential(
        nn.Conv2d(in_channels=c, out_channels=32, kernel_size=8, stride=4),
        nn.ReLU(),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
        nn.ReLU(),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
        nn.ReLU(),
        nn.Flatten(),
        nn.Linear(3136,512),
        nn.ReLU(),
        nn.Linear(512, output_dim),
    )

    self.target = copy.deepcopy(self.online)

  def forward(self, input, model):
    if model == "online":
      return self.online(input)
    elif model == "target":
      return self.target(input)

In [6]:
class Mario:
  def __init__(self, state_dim, action_dim, save_dir):
    self.state_dim = state_dim
    self.action_dim = action_dim
    self.save_dir = save_dir

    self.device = "cuda" if torch.cuda.is_available() else "cpu"

    self.net = MarioNet(self.state_dim, self.action_dim).float()
    self.net = self.net.to(device=self.device)

    self.exploration_rate = 1
    self.exploration_rate_decay = 0.99999975
    self.exploration_rate_min = 0.1

    self.curr_step = 0

    self.save_every = 5e5

  def act(self, state):
    if np.random.rand() < self.exploration_rate:
      action_idx = np.random.randint(self.action_dim)
    else:
      state = state[0].__array__() if isinstance(state, tuple) else state.__array__()
      state = torch.tensor(state, device=self.device).unsqueeze(0)
      action_values = self.net(state, model="online")
      action_idx = torch.argmax(action_values, axis=1).item()

    self.exploration_rate *= self.exploration_rate_decay
    self.exploration_rate = max(self.exploration_rate_min, self.exploration_rate)

    self.curr_step += 1
    return action_idx




In [7]:
class Mario(Mario):
  def __init__(self, state_dim, action_dim, save_dir):
    super().__init__(state_dim, action_dim, save_dir)
    self.memory = deque(maxlen=100000)
    self.batch_size = 32

  def cache(self, state, next_state, action, reward, done):

    def first_if_tuple(x):
      return x[0] if isinstance(x, tuple) else x

    state = first_if_tuple(state).__array__()
    next_state = first_if_tuple(next_state).__array__()

    state = torch.tensor(state, device=self.device)
    next_state = torch.tensor(next_state, device=self.device)
    action = torch.tensor([action], device=self.device)
    reward = torch.tensor([reward], device=self.device)
    done = torch.tensor([done], device=self.device)

    self.memory.append((state, next_state, action, reward, done,))

  def recall(self):
    batch = random.sample(self.memory, self.batch_size)
    state, next_state, action, reward, done = map(torch.stack, zip(*batch))
    return state, next_state, action.squeeze(), reward.squeeze(), done.squeeze()



In [8]:
class Mario(Mario):
  def __init__(self, state_dim, action_dim, save_dir):
    super().__init__(state_dim, action_dim, save_dir)
    self.gamma = 0.9

  def td_estimate(self, state, action):
    current_Q = self.net(state, model="online")[np.arange(0, self.batch_size), action]
    return current_Q

  @torch.no_grad()
  def td_target(self, reward, next_state, done):
    next_state_Q = self.net(next_state, model="online")
    best_action = torch.argmax(next_state_Q, axis=1)
    next_Q = self.net(next_state, model="target")[
        np.arange(0, self.batch_size), best_action
        ]
    return (reward + (1 - done.float()) * self.gamma * next_Q).float()

In [9]:
class Mario(Mario):
  def __init__(self, state_dim, action_dim, save_dir):
    super().__init__(state_dim, action_dim, save_dir)
    self.optimizer = torch.optim.Adam(self.net.parameters(), lr=0.00025)
    self.loss_fn = torch.nn.SmoothL1Loss()

  def update_Q_online(self, td_estimate, td_target):
    loss = self.loss_fn(td_estimate, td_target)
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()
    return loss.item()

  def sync_Q_target(self):
    self.net.target.load_state_dict(self.net.online.state_dict())

In [10]:
class Mario(Mario):
  def save(self):
    save_path = (
        self.save_dir / f"mario_net_{int(self.curr_step // self.save_every)}.chkpt"
    )
    torch.save(
        dict(model=self.net.state_dict(), exploration_rate=self.exploration_rate),
        save_path,
    )

    print(f"MarioNet saved to {save_path} at step {self.curr_step}")


In [11]:
class Mario(Mario):
  def __init__(self, state_dim, action_dim, save_dir):
    super().__init__(state_dim, action_dim, save_dir)
    self.burnin = 1e4
    self.learn_every = 3
    self.sync_every = 1e4

  def learn(self):
    if self.curr_step % self.sync_every == 0:
      self.sync_Q_target()

    if self.curr_step % self.save_every == 0:
      self.save()

    if self.curr_step < self.burnin:
      return None, None

    state, next_state, action, reward, done = self.recall()

    td_est = self.td_estimate(state, action)

    td_tgt = self.td_target(reward, next_state, done)

    loss = self.update_Q_online(td_est, td_tgt)

    return (td_est.mean().item(), loss)

In [12]:
import numpy as np
import time, datetime
import matplotlib.pyplot as plt

class MetricLogger:
  def __init__(self, save_dir):
    self.save_log = save_dir / "log"
    with open(self.save_log, "w") as f:
      f.write(
                f"{'Episode':>8}{'Step':>8}{'Epsilon':>10}{'MeanReward':>15}"
                f"{'MeanLength':>15}{'MeanLoss':>15}{'MeanQValue':>15}"
                f"{'TimeDelta':>15}{'Time':>20}\n"
            )
    self.ep_rewards_plot = save_dir / "reward_plot.jpg"
    self.ep_lengths_plot = save_dir / "length_plot.jpg"
    self.ep_avg_losses_plot = save_dir / "loss_plot.jpg"
    self.ep_avg_qs_plot = save_dir / "q_plot.jpg"


    self.ep_rewards = []
    self.ep_lengths = []
    self.ep_avg_losses = []
    self.ep_avg_qs = []

    self.moving_avg_ep_rewards = []
    self.moving_avg_ep_lengths = []
    self.moving_avg_ep_avg_losses = []
    self.moving_avg_ep_avg_qs = []

    self.init_episode()

    self.record_time = time.time()

  def log_step(self, reward, loss, q):
    self.curr_ep_reward += reward
    self.curr_ep_length += 1
    if loss:
      self.curr_ep_loss += loss
      self.curr_ep_q += q
      self.curr_ep_loss_length += 1

  def log_episode(self):
    self.ep_rewards.append(self.curr_ep_reward)
    self.ep_lengths.append(self.curr_ep_length)
    if self.curr_ep_loss_length == 0:
      ep_avg_loss = 0
      ep_avg_q = 0

    else:
      ep_avg_loss = np.round(self.curr_ep_loss / self.curr_ep_loss_length, 5)
      ep_avg_q = np.round(self.curr_ep_q / self.curr_ep_loss_length, 5)
    self.ep_avg_losses.append(ep_avg_loss)
    self.ep_avg_qs.append(ep_avg_q)

    self.init_episode()
  def init_episode(self):
        self.curr_ep_reward = 0.0
        self.curr_ep_length = 0
        self.curr_ep_loss = 0.0
        self.curr_ep_q = 0.0
        self.curr_ep_loss_length = 0

  def record(self, episode, epsilon, step):
    mean_ep_reward = np.round(np.mean(self.ep_rewards[-100:]), 3)
    mean_ep_length = np.round(np.mean(self.ep_lengths[-100:]), 3)
    mean_ep_loss = np.round(np.mean(self.ep_avg_losses[-100:]), 3)
    mean_ep_q = np.round(np.mean(self.ep_avg_qs[-100:]), 3)
    self.moving_avg_ep_rewards.append(mean_ep_reward)
    self.moving_avg_ep_lengths.append(mean_ep_length)
    self.moving_avg_ep_avg_losses.append(mean_ep_loss)
    self.moving_avg_ep_avg_qs.append(mean_ep_q)

    last_record_time = self.record_time
    self.record_time = time.time()
    time_since_last_record = np.round(self.record_time - last_record_time, 3)

    print(
        f"Episode {episode} - "
        f"Step {step} - "
        f"Epsilon {epsilon} - "
        f"Mean Reward {mean_ep_reward} - "
        f"Mean Length {mean_ep_length} - "
        f"Mean Loss {mean_ep_loss} - "
        f"Mean Q Value {mean_ep_q} - "
        f"Time Delta {time_since_last_record} - "
        f"Time {datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')}"
    )

    with open(self.save_log, "a") as f:
        f.write(
            f"{episode:8d}{step:8d}{epsilon:10.3f}"
            f"{mean_ep_reward:15.3f}{mean_ep_length:15.3f}{mean_ep_loss:15.3f}{mean_ep_q:15.3f}"
            f"{time_since_last_record:15.3f}"
            f"{datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S'):>20}\n"
        )

    for metric in ["ep_rewards", "ep_lengths", "ep_avg_losses", "ep_avg_qs"]:
        plt.plot(getattr(self, f"moving_avg_{metric}"))
        plt.savefig(getattr(self, f"{metric}_plot"))
        plt.clf()


In [None]:
use_cuda = torch.cuda.is_available()
print(f"Using CUDA: {use_cuda}")
print()

save_dir = Path("checkpoints") / datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
save_dir.mkdir(parents=True)

mario = Mario(state_dim=(4, 84, 84), action_dim=env.action_space.n, save_dir=save_dir)

logger = MetricLogger(save_dir)

episodes = 40000
for e in range(episodes):

    state = env.reset()

    # 게임을 실행시켜봅시다!
    while True:

        # 현재 상태에서 에이전트 실행하기
        action = mario.act(state)

        # 에이전트가 액션 수행하기
        next_state, reward, done, trunc, info = env.step(action)

        # 기억하기
        mario.cache(state, next_state, action, reward, done)

        # 배우기
        q, loss = mario.learn()

        # 기록하기
        logger.log_step(reward, loss, q)

        # 상태 업데이트하기
        state = next_state

        # 게임이 끝났는지 확인하기
        if done or info["flag_get"]:
            break

    logger.log_episode()

    if e % 20 == 0:
        logger.record(episode=e, epsilon=mario.exploration_rate, step=mario.curr_step)

Using CUDA: True





Episode 0 - Step 40 - Epsilon 0.9999900000487484 - Mean Reward 231.0 - Mean Length 40.0 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 0.183 - Time 2023-07-20T16:05:15




Episode 20 - Step 4536 - Epsilon 0.9988666425932761 - Mean Reward 692.762 - Mean Length 216.0 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 25.172 - Time 2023-07-20T16:05:40




Episode 40 - Step 8606 - Epsilon 0.9978508125484896 - Mean Reward 651.927 - Mean Length 209.902 - Mean Loss 0.0 - Mean Q Value 0.0 - Time Delta 21.611 - Time 2023-07-20T16:06:02




Episode 60 - Step 12592 - Epsilon 0.9968569493639886 - Mean Reward 642.475 - Mean Length 206.426 - Mean Loss 0.172 - Mean Q Value 0.71 - Time Delta 30.667 - Time 2023-07-20T16:06:32




Episode 80 - Step 16559 - Epsilon 0.9958688064371667 - Mean Reward 629.222 - Mean Length 204.432 - Mean Loss 0.217 - Mean Q Value 1.319 - Time Delta 34.567 - Time 2023-07-20T16:07:07




Episode 100 - Step 21494 - Epsilon 0.9946409107574661 - Mean Reward 641.03 - Mean Length 214.54 - Mean Loss 0.239 - Mean Q Value 1.845 - Time Delta 45.598 - Time 2023-07-20T16:07:53




Episode 120 - Step 24839 - Epsilon 0.9938094898781065 - Mean Reward 609.09 - Mean Length 203.03 - Mean Loss 0.313 - Mean Q Value 3.037 - Time Delta 32.207 - Time 2023-07-20T16:08:25




Episode 140 - Step 27963 - Epsilon 0.9930336275830849 - Mean Reward 600.92 - Mean Length 193.57 - Mean Loss 0.386 - Mean Q Value 4.259 - Time Delta 29.916 - Time 2023-07-20T16:08:55




Episode 160 - Step 33684 - Epsilon 0.9916143562579732 - Mean Reward 627.05 - Mean Length 210.92 - Mean Loss 0.368 - Mean Q Value 5.392 - Time Delta 50.854 - Time 2023-07-20T16:09:46




Episode 180 - Step 38805 - Epsilon 0.9903456541208467 - Mean Reward 632.24 - Mean Length 222.46 - Mean Loss 0.377 - Mean Q Value 6.435 - Time Delta 47.407 - Time 2023-07-20T16:10:33




Episode 200 - Step 42633 - Epsilon 0.9893983465702657 - Mean Reward 621.81 - Mean Length 211.39 - Mean Loss 0.409 - Mean Q Value 7.608 - Time Delta 34.764 - Time 2023-07-20T16:11:08




Episode 220 - Step 47273 - Epsilon 0.98825130975468 - Mean Reward 654.9 - Mean Length 224.34 - Mean Loss 0.431 - Mean Q Value 8.528 - Time Delta 44.701 - Time 2023-07-20T16:11:52




Episode 240 - Step 50989 - Epsilon 0.9873336504918101 - Mean Reward 666.05 - Mean Length 230.26 - Mean Loss 0.456 - Mean Q Value 9.514 - Time Delta 32.745 - Time 2023-07-20T16:12:25




Episode 260 - Step 55028 - Epsilon 0.9863371933844758 - Mean Reward 647.43 - Mean Length 213.44 - Mean Loss 0.48 - Mean Q Value 10.506 - Time Delta 36.659 - Time 2023-07-20T16:13:02




Episode 280 - Step 61338 - Epsilon 0.9847824728755668 - Mean Reward 653.37 - Mean Length 225.33 - Mean Loss 0.512 - Mean Q Value 11.483 - Time Delta 57.964 - Time 2023-07-20T16:14:00




Episode 300 - Step 65708 - Epsilon 0.9837071853721409 - Mean Reward 660.04 - Mean Length 230.75 - Mean Loss 0.532 - Mean Q Value 12.393 - Time Delta 41.99 - Time 2023-07-20T16:14:42




Episode 320 - Step 69617 - Epsilon 0.9827463269808977 - Mean Reward 645.54 - Mean Length 223.44 - Mean Loss 0.546 - Mean Q Value 13.138 - Time Delta 33.889 - Time 2023-07-20T16:15:16




Episode 340 - Step 73636 - Epsilon 0.9817594083716243 - Mean Reward 648.31 - Mean Length 226.47 - Mean Loss 0.584 - Mean Q Value 14.054 - Time Delta 37.321 - Time 2023-07-20T16:15:53




Episode 360 - Step 76937 - Epsilon 0.9809495455341349 - Mean Reward 625.98 - Mean Length 219.09 - Mean Loss 0.596 - Mean Q Value 14.677 - Time Delta 30.312 - Time 2023-07-20T16:16:23




Episode 380 - Step 80580 - Epsilon 0.9800565523322959 - Mean Reward 623.28 - Mean Length 192.42 - Mean Loss 0.613 - Mean Q Value 15.208 - Time Delta 30.755 - Time 2023-07-20T16:16:54




Episode 400 - Step 82932 - Epsilon 0.9794804483985858 - Mean Reward 589.14 - Mean Length 172.24 - Mean Loss 0.641 - Mean Q Value 15.838 - Time Delta 19.888 - Time 2023-07-20T16:17:14




Episode 420 - Step 87246 - Epsilon 0.9784246480455467 - Mean Reward 600.46 - Mean Length 176.29 - Mean Loss 0.665 - Mean Q Value 16.493 - Time Delta 35.783 - Time 2023-07-20T16:17:50




Episode 440 - Step 90489 - Epsilon 0.9776317116429445 - Mean Reward 590.28 - Mean Length 168.53 - Mean Loss 0.658 - Mean Q Value 16.934 - Time Delta 27.111 - Time 2023-07-20T16:18:17




Episode 460 - Step 95051 - Epsilon 0.9765173581172566 - Mean Reward 624.18 - Mean Length 181.14 - Mean Loss 0.688 - Mean Q Value 17.585 - Time Delta 39.386 - Time 2023-07-20T16:18:56




Episode 480 - Step 100274 - Epsilon 0.9752431025280999 - Mean Reward 630.85 - Mean Length 196.94 - Mean Loss 0.712 - Mean Q Value 18.234 - Time Delta 56.917 - Time 2023-07-20T16:19:53




Episode 500 - Step 104724 - Epsilon 0.9741587477250411 - Mean Reward 674.98 - Mean Length 217.92 - Mean Loss 0.733 - Mean Q Value 18.723 - Time Delta 37.878 - Time 2023-07-20T16:20:31




Episode 520 - Step 109416 - Epsilon 0.9730167292958025 - Mean Reward 684.61 - Mean Length 221.7 - Mean Loss 0.753 - Mean Q Value 19.183 - Time Delta 38.915 - Time 2023-07-20T16:21:10




Episode 540 - Step 113994 - Epsilon 0.9719037485345752 - Mean Reward 705.72 - Mean Length 235.05 - Mean Loss 0.796 - Mean Q Value 19.798 - Time Delta 37.653 - Time 2023-07-20T16:21:48




Episode 560 - Step 117429 - Epsilon 0.9710694843495766 - Mean Reward 675.06 - Mean Length 223.78 - Mean Loss 0.803 - Mean Q Value 20.128 - Time Delta 28.335 - Time 2023-07-20T16:22:16




Episode 580 - Step 121485 - Epsilon 0.9700853188255355 - Mean Reward 676.44 - Mean Length 212.11 - Mean Loss 0.815 - Mean Q Value 20.589 - Time Delta 33.549 - Time 2023-07-20T16:22:49




Episode 600 - Step 124054 - Epsilon 0.9694624814816254 - Mean Reward 638.03 - Mean Length 193.3 - Mean Loss 0.824 - Mean Q Value 21.031 - Time Delta 20.957 - Time 2023-07-20T16:23:10




Episode 620 - Step 127720 - Epsilon 0.968574376042694 - Mean Reward 617.75 - Mean Length 183.04 - Mean Loss 0.838 - Mean Q Value 21.513 - Time Delta 30.079 - Time 2023-07-20T16:23:41




Episode 640 - Step 131449 - Epsilon 0.9676718432261584 - Mean Reward 612.7 - Mean Length 174.55 - Mean Loss 0.857 - Mean Q Value 21.909 - Time Delta 30.68 - Time 2023-07-20T16:24:11




Episode 660 - Step 136324 - Epsilon 0.9664932113943966 - Mean Reward 630.84 - Mean Length 188.95 - Mean Loss 0.874 - Mean Q Value 22.428 - Time Delta 39.471 - Time 2023-07-20T16:24:51




Episode 680 - Step 139605 - Epsilon 0.9657007702829216 - Mean Reward 626.35 - Mean Length 181.2 - Mean Loss 0.889 - Mean Q Value 22.88 - Time Delta 27.433 - Time 2023-07-20T16:25:18


