In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
import pickle 
from gym_super_mario_bros.actions import RIGHT_ONLY
import gym
import numpy as np
import collections 
import cv2
import torch.optim as optim
import gymnasium as gym2
import imageio
from tqdm.notebook import tqdm
from stable_baselines3.common.buffers import ReplayBuffer
from torch.utils.tensorboard import SummaryWriter

  if not hasattr(tensorboard, "__version__") or LooseVersion(
  _np_version_forbids_neg_powint = LooseVersion(numpy.__version__) >= LooseVersion('1.12.0b1')
  if LooseVersion(module.__version__) < minver:
  other = LooseVersion(other)


In [34]:
device = torch.device("mps")
batch_size = 32
exploration_fraction = .1
tau = 1.
gamma = .99
train_frequency = 4
start_e = .03
end_e = .01

learning_rate = 1e-4
buffer_size = 100000
total_timesteps = 2000000
learning_starts = 1
target_network_frequency = 1000

In [3]:
JoypadSpace.reset = lambda self, **kwargs: self.env.reset(**kwargs)

In [4]:
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, reward, done, info

    def reset(self, **kwargs):
        """Clear past frame buffer and init to first obs"""
        self._obs_buffer.clear()
        obs = self.env.reset(**kwargs)
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    """
    Downsamples image to 84x84
    Greyscales image

    Returns numpy array
    """
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 240 * 256 * 3:
            img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)

    def reset(self, **kwargs):
        return self.observation(self.env.reset(**kwargs))


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
                                                dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)

    def reset(self, **kwargs):
        return self.observation(self.env.reset(**kwargs))


class ScaledFloatFrame(gym.ObservationWrapper):
    """Normalize pixel values in frame --> 0 to 1"""
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0

    def reset(self, **kwargs):
        return self.observation(self.env.reset(**kwargs))


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
                                                old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self, **kwargs):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset(**kwargs))

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer

In [5]:
def make_env(env_id):
    env = gym_super_mario_bros.make(env_id)
    env = MaxAndSkipEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    env = ScaledFloatFrame(env)
    env = gym.wrappers.RecordVideo(env, "../mario")
    return JoypadSpace(env, RIGHT_ONLY)

In [35]:
env = make_env('SuperMarioBros-v0')

In [20]:
writer = SummaryWriter('../mario/runs')

In [11]:
def render_image(env):
    img = env.render()
    imageio.mimsave('../test2.png', [img])

In [7]:
class QNetwork(nn.Module):
    def __init__(self, env):
        super().__init__()
        self.network = nn.Sequential(
            nn.Conv2d(4, 32, 8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(3136, 512),
            nn.ReLU(),
            nn.Linear(512, env.action_space.n),
        )

    def forward(self, x):
        return self.network(x)

In [8]:
def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
    slope = (end_e - start_e) / duration
    return max(slope * t + start_e, end_e)

In [9]:
q_network = QNetwork(env).to(device)
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
target_network = QNetwork(env).to(device)
target_network.load_state_dict(q_network.state_dict())

<All keys matched successfully>

In [10]:
q_network

QNetwork(
  (network): Sequential(
    (0): Conv2d(4, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=3136, out_features=512, bias=True)
    (8): ReLU()
    (9): Linear(in_features=512, out_features=5, bias=True)
  )
)

In [11]:
rb = ReplayBuffer(
    buffer_size,
    env.observation_space,
    env.action_space,
    device
)



In [36]:
def generate_experience(observation, global_step, total_reward):
    epsilon = linear_schedule(start_e, end_e, exploration_fraction * total_timesteps, global_step)
    if random.random() < epsilon:
        action = env.action_space.sample()
    else:
        q_values = q_network(torch.Tensor(observation).to(device))
        action = torch.argmax(q_values, dim=1).cpu().numpy()[0]

    next_obs, reward, done, info = env.step(action)
    next_observation = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0)
    rb.add(observation, next_observation, action, reward, done, [info])
    total_reward += reward

    if done:
        print(info)
        writer.add_scalar("charts/episode_score", info["score"], global_step+6000000)
        writer.add_scalar("charts/reward", total_reward, global_step+6000000)
        total_reward = 0
        env.reset(seed=69)
        next_obs, reward, done, info = env.step(action)
        next_observation = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0)

    return next_observation, reward, done, info, action, total_reward


In [37]:
def train(update_target_network=False, global_step=0):
    data = rb.sample(batch_size)
    with torch.no_grad():
        target_max, _ = target_network(data.next_observations).max(dim=1)
        td_target = data.rewards.flatten() + gamma * target_max * (1 - data.dones.flatten())
    old_val = q_network(data.observations).gather(1, data.actions).squeeze()
    loss = F.mse_loss(td_target, old_val)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    writer.add_scalar('Loss/train', loss, global_step+6000000)

    if update_target_network:
        for target_network_param, q_network_param in zip(target_network.parameters(), q_network.parameters()):
            target_network_param.data.copy_(tau * q_network_param.data + (1.0 - tau) * target_network_param.data)

In [38]:
env.reset(seed=69)
next_obs, *_ = env.step(4)
observation = torch.tensor(next_obs, dtype=torch.float32).unsqueeze(0)
total_reward = 0

for global_step in tqdm(range(total_timesteps)):
    observation, reward, done, info, action, total_reward = generate_experience(observation, global_step, total_reward)
    if global_step > learning_starts:
        if global_step % train_frequency == 0:
            if global_step % target_network_frequency == 0:
                train(update_target_network=True, global_step=global_step)
            else:
                train(update_target_network=False, global_step=global_step)

In [51]:
def record_video(env):
    env.reset(seed=69)
    obs, *_ = env.step(4)
    observation = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) 
    i=0
    done = False
    while not done:
        q_values = q_network(torch.Tensor(observation).to(device))
        action = torch.argmax(q_values, dim=1).cpu().numpy()
        action = action[0]
        obs, reward, done, info = env.step(action)
        observation = torch.tensor(obs, dtype=torch.float32).unsqueeze(0) 
        i += 1
    print(i)
    print(reward)
    print(info)

In [52]:
r = record_video(env)

  logger.deprecation(
  logger.deprecation(
  logger.deprecation(
  logger.warn(
  logger.deprecation(
  if not isinstance(done, (bool, np.bool8)):
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(


84
-15
{'coins': 0, 'flag_get': False, 'life': 255, 'score': 0, 'stage': 1, 'status': 'small', 'time': 395, 'world': 1, 'x_pos': 303, 'y_pos': 79}


In [39]:
model_path = f"../mario/{'mario'}.cleanrl_model7"
torch.save(q_network.state_dict(), model_path)
print(f"model saved to {model_path}")

In [14]:
model_path = f"../mario/{'mario'}.cleanrl_model4"

q_network.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [15]:
optimizer = optim.Adam(q_network.parameters(), lr=learning_rate)
target_network = QNetwork(env).to(device)
target_network.load_state_dict(q_network.state_dict())

<All keys matched successfully>

In [22]:
print('hi')