In [1]:
from nes_py.wrappers import JoypadSpace
import gym_super_mario_bros
from gym_super_mario_bros.actions import SIMPLE_MOVEMENT
from gym.spaces import Box

import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
# from torchvision import transforms as T
# from tensordict import TensorDict

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import numpy as np
import os
# set up matplotlib
is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display

plt.ion()

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, SIMPLE_MOVEMENT)

  logger.warn(
  logger.warn(


In [2]:

BATCH_SIZE = 128
GAMMA = 0.99
EPS_START = 0.9
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.005
LR = 1e-4

steps_done = 0
episode_durations = []

In [3]:
from gym.wrappers import FrameStack, GrayScaleObservation

env = GrayScaleObservation(env, keep_dim=True)
# if gym.__version__ < '0.26':
#     env = FrameStack(env, num_stack=BATCH_SIZE, new_step_api=True)
# else:
#     env = FrameStack(env, num_stack=BATCH_SIZE)


In [4]:
state = env.reset()
state.shape

(240, 256, 1)

In [5]:
from tensordict import TensorDict
from torchrl.data import TensorDictReplayBuffer, LazyMemmapStorage

In [6]:
class ReplayMemory(object):

    def __init__(self, capacity):
        # self.memory = deque([], maxlen=capacity)
        self.memory = TensorDictReplayBuffer(storage=LazyMemmapStorage(100000, device=torch.device("cpu")))

    def push(self, state, next_state, action, reward, done):
        # print("In Push method :", type(state), type(next_state), type(action), type(reward), type(done))
        # state = torch.tensor(state)
        #print("state shape is ",state.shape)
        #print(type(state))
        # next_state = torch.tensor(next_state)
        #print(next_state.shape)
        #print(type(next_state))
        action = torch.tensor([action])
        reward = torch.tensor([reward])
        done = torch.tensor([done])

        # self.memory.append((state, next_state, action, reward, done))
        # self.memory.append({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done})
        self.memory.add(TensorDict({"state": state, "next_state": next_state, "action": action, "reward": reward, "done": done}, batch_size=[]))
        

    def __len__(self):
        return len(self.memory)
    
    def recall(self, device):
        """
        메모리에서 일련의 경험들을 검색합니다.
        """
        batch = self.memory.sample(BATCH_SIZE).to(device)
        # print(batch.shape)
        state, next_state, action, reward, done = (batch.get(key) for key in ("state", "next_state", "action", "reward", "done"))
        # print(state.squeeze(1).shape, next_state.squeeze(1).shape, action.squeeze().shape, reward.squeeze().shape, done.squeeze().shape)
        return state.squeeze(1), next_state.squeeze(1), action.squeeze(), reward.squeeze(), done.squeeze()

In [7]:
class DQN(nn.Module):
    def __init__(self, input_dim, n_actions):
        super(DQN, self).__init__()
        
        c, h, w = input_dim
        # print(input_dim)
        self.network = self.__build_cnn(c, n_actions)

    def forward(self, x):

        return self.network(x)
    
    def __build_cnn(self, c, output_dim):
        return nn.Sequential(
            nn.Conv2d(in_channels=c, out_channels=32, kernel_size=4, stride=2),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=2, stride=2),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(230144, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim),
        )

In [8]:
class Mario:
    def __init__(self, state_dim, action_dim) -> None:
        
        self.state_dim = state_dim
        self.action_dim = action_dim
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.policy_net = DQN(state_dim, action_dim).to(device=self.device)
        self.target_net = DQN(state_dim, action_dim).to(device=self.device)
        self.target_net.load_state_dict(self.policy_net.state_dict())

        self.optimizer = optim.AdamW(self.policy_net.parameters(), lr=LR, amsgrad=True)
        self.memory = ReplayMemory(10000)
        self.cur_step = 0
        
    def select_action(self, state):

        sample = random.random()
        eps_threshold = EPS_END + (EPS_START - EPS_END) * \
            math.exp(-1. * steps_done / EPS_DECAY)

        self.cur_step +=1

        if sample > eps_threshold:
            with torch.no_grad():
                # t.max(1) will return the largest column value of each row.
                # second column on max result is index of where max element was
                # found, so we pick action with the larger expected reward.
                
                action_values = self.policy_net(state)
                # print(action_values.shape)
                # print(action_values.max(1).indices)
                return np.bincount(action_values.max(1).indices).argmax()
        else:
            return np.random.randint(self.action_dim)

    def _td_estimate(self, state, action):
        current_Q = self.policy_net(state)[
            np.arange(0, BATCH_SIZE), action
        ]  # Q_online(s,a)
        return current_Q

    @torch.no_grad()
    def _td_target(self, reward, next_state, done):
        next_state_Q = self.policy_net(next_state)
        best_action = next_state_Q.max(1).indices
        next_Q = self.target_net(next_state)[
            np.arange(0, BATCH_SIZE), best_action
        ]
        return (reward + (1 - done.float()) * GAMMA * next_Q).float()

    def _update_Q_online(self, td_estimate, td_target):
        criterion = nn.SmoothL1Loss()
        loss = criterion(td_estimate, td_target)
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return loss.item()

    def _sync_Q_target(self):
        self.target_net.load_state_dict(self.policy_net.state_dict())
        
    def optimize_model(self, ):
        if len(self.memory) < BATCH_SIZE:
            return
        
        state, next_state, action, reward, done = self.memory.recall(self.device)

        # TD 추정값을 가져옵니다.
        td_est = self._td_estimate(state, action)

        # TD 목표값을 가져옵니다.
        td_tgt = self._td_target(reward, next_state, done)

        # 실시간 Q(Q_online)을 통해 역전파 손실을 계산합니다.
        loss = self._update_Q_online(td_est, td_tgt)
 
        return (td_est.mean().item(), loss)
    
    def save(self, save_path):
        torch.save(
            self.target_net,
            save_path,
        )
        
    def predict_action(self, state):

        action_values=  self.target_net(state)
        return np.bincount(action_values.max(1).indices).argmax()


In [9]:
if torch.cuda.is_available():
    num_episodes = 600
else:
    num_episodes = 30
    
use_cuda = torch.cuda.is_available()

print(f"Using CUDA: {use_cuda}")
print()
mario = Mario(state_dim=(1, 240, 256), action_dim=env.action_space.n)

for i_episode in range(num_episodes):
    # Initialize the environment and get its state
    state = env.reset()
    
    # state = np.transpose(state, (0, 3, 1, 2))
    state = np.transpose(state, (2, 0, 1))
    # state = torch.tensor(state.copy(), dtype=torch.float)
    state = torch.tensor(state.copy(), dtype=torch.float).unsqueeze(0)
    # print(state.shape)
    for t in count():
        # if t>300:
        #     print(f"This episode can't reach the final")
        #     break
        # print(type(state))
        
        action = mario.select_action(state)
        next_state, reward, done, info = env.step(action)
        
        # next_state = np.transpose(next_state, (0, 3, 1, 2))
        next_state = np.transpose(next_state, (2, 0, 1))
        # next_state = torch.tensor(next_state.copy(), dtype=torch.float)
        
        next_state = torch.tensor(next_state.copy(), dtype=torch.float).unsqueeze(0)
        # print("Type of next state", type(next_state))
        if done:
            next_state = None

        # Store the transition in memory
        mario.memory.push(state, next_state, action, reward, done)

        # Perform one step of the optimization (on the policy network)
        mario.optimize_model()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = mario.target_net.state_dict()
        policy_net_state_dict = mario.policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        mario.target_net.load_state_dict(target_net_state_dict)
        
        # Move to the next state
        state = next_state
        
        if done or info["flag_get"] or t>500:
            break
        
    print(f"{i_episode} is done")

print('Complete')
# plot_durations(show_result=True)
# plt.ioff()
# plt.show()

Using CUDA: False

0 is done
1 is done
2 is done
3 is done
4 is done
5 is done
6 is done
7 is done
8 is done
9 is done
10 is done
11 is done
12 is done
13 is done
14 is done
15 is done
16 is done
17 is done
18 is done
19 is done
20 is done
21 is done
22 is done
23 is done
24 is done
25 is done
26 is done
27 is done
28 is done
29 is done
Complete


In [10]:
# print(os.getcwd())

In [11]:
torch.save(mario.target_net,'model_state_dict.pt' )

In [12]:
# mario = torch.load(path)
mario_load_target_net = torch.load('model_state_dict.pt')
new_model = Mario(state_dim=(1, 240, 256), action_dim=env.action_space.n)
new_model.target_net = mario_load_target_net

In [16]:
state = env.reset()
while True:
    state = np.transpose(state, (2, 0, 1))
    # state = torch.tensor(state.copy(), dtype=torch.float)
    state = torch.tensor(state.copy(), dtype=torch.float).unsqueeze(0)
    action = new_model.predict_action(state)
    state, reward, done, info = env.step(action)
    if done:
        state = env.reset()
    if info["flag_get"]:
        break
    env.render()

KeyboardInterrupt: 

In [None]:

# def plot_durations(show_result=False):
#     plt.figure(1)
#     durations_t = torch.tensor(episode_durations, dtype=torch.float)
#     if show_result:
#         plt.title('Result')
#     else:
#         plt.clf()
#         plt.title('Training...')
#     plt.xlabel('Episode')
#     plt.ylabel('Duration')
#     plt.plot(durations_t.numpy())
#     # Take 100 episode averages and plot them too
#     if len(durations_t) >= 100:
#         means = durations_t.unfold(0, 100, 1).mean(1).view(-1)
#         means = torch.cat((torch.zeros(99), means))
#         plt.plot(means.numpy())

#     plt.pause(0.001)  # pause a bit so that plots are updated
#     if is_ipython:
#         if not show_result:
#             display.display(plt.gcf())
#             display.clear_output(wait=True)
#         else:
#             display.display(plt.gcf())