In [1]:
!pip install nes-py==0.2.6
!pip install gym-super-mario-bros
!apt-get update
!apt-get install ffmpeg libsm6 libxext6  -y
!apt install -y libgl1-mesa-glx
!pip install opencv-python

Collecting nes-py==0.2.6
  Using cached nes_py-0.2.6-cp39-cp39-macosx_10_9_x86_64.whl
Installing collected packages: nes-py
  Attempting uninstall: nes-py
    Found existing installation: nes-py 8.2.1
    Uninstalling nes-py-8.2.1:
      Successfully uninstalled nes-py-8.2.1
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gym-super-mario-bros 7.4.0 requires nes-py>=8.1.4, but you have nes-py 0.2.6 which is incompatible.[0m
Successfully installed nes-py-0.2.6
Collecting nes-py>=8.1.4
  Using cached nes_py-8.2.1-cp39-cp39-macosx_10_9_x86_64.whl
Installing collected packages: nes-py
  Attempting uninstall: nes-py
    Found existing installation: nes-py 0.2.6
    Uninstalling nes-py-0.2.6:
      Successfully uninstalled nes-py-0.2.6
Successfully installed nes-py-8.2.1
zsh:1: command not found: apt-get
zsh:1: command not found: apt-get
The operation couldn’t b

In [2]:
import torch
import torch.nn as nn
import random
import gym_super_mario_bros
from nes_py.wrappers import JoypadSpace
from tqdm import tqdm
import pickle 
from gym_super_mario_bros.actions import RIGHT_ONLY
import gym
import numpy as np
import collections 
import cv2
import matplotlib.pyplot as plt

In [3]:
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env=None, skip=4):
        """Return only every `skip`-th frame"""
        super(MaxAndSkipEnv, self).__init__(env)
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = collections.deque(maxlen=2)
        self._skip = skip

    def step(self, action):
        total_reward = 0.0
        done = None
        for _ in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            self._obs_buffer.append(obs)
            total_reward += reward
            if done:
                break
        max_frame = np.max(np.stack(self._obs_buffer), axis=0)
        return max_frame, total_reward, done, info

    def reset(self):
        """Clear past frame buffer and init to first obs"""
        self._obs_buffer.clear()
        obs = self.env.reset()
        self._obs_buffer.append(obs)
        return obs


class ProcessFrame84(gym.ObservationWrapper):
    """
    Downsamples image to 84x84
    Greyscales image

    Returns numpy array
    """
    def __init__(self, env=None):
        super(ProcessFrame84, self).__init__(env)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)

    def observation(self, obs):
        return ProcessFrame84.process(obs)

    @staticmethod
    def process(frame):
        if frame.size == 240 * 256 * 3:
            img = np.reshape(frame, [240, 256, 3]).astype(np.float32)
        else:
            assert False, "Unknown resolution."
        img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114
        resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA)
        x_t = resized_screen[18:102, :]
        x_t = np.reshape(x_t, [84, 84, 1])
        return x_t.astype(np.uint8)


class ImageToPyTorch(gym.ObservationWrapper):
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]),
                                                dtype=np.float32)

    def observation(self, observation):
        return np.moveaxis(observation, 2, 0)


class ScaledFloatFrame(gym.ObservationWrapper):
    """Normalize pixel values in frame --> 0 to 1"""
    def observation(self, obs):
        return np.array(obs).astype(np.float32) / 255.0


class BufferWrapper(gym.ObservationWrapper):
    def __init__(self, env, n_steps, dtype=np.float32):
        super(BufferWrapper, self).__init__(env)
        self.dtype = dtype
        old_space = env.observation_space
        self.observation_space = gym.spaces.Box(old_space.low.repeat(n_steps, axis=0),
                                                old_space.high.repeat(n_steps, axis=0), dtype=dtype)

    def reset(self):
        self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype)
        return self.observation(self.env.reset())

    def observation(self, observation):
        self.buffer[:-1] = self.buffer[1:]
        self.buffer[-1] = observation
        return self.buffer


def make_env(env):
    env = MaxAndSkipEnv(env)
    env = ProcessFrame84(env)
    env = ImageToPyTorch(env)
    env = BufferWrapper(env, 4)
    env = ScaledFloatFrame(env)
    return JoypadSpace(env, RIGHT_ONLY)

In [4]:
class DQNSolver(nn.Module):

    def __init__(self, input_shape, n_actions):
        super(DQNSolver, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )

        conv_out_size = self._get_conv_out(input_shape)
        self.fc = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
    
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape))
        return int(np.prod(o.size()))

    def forward(self, x):
        conv_out = self.conv(x).view(x.size()[0], -1)
        return self.fc(conv_out)
    

class DQNAgent:

    def __init__(self, state_space, action_space, max_memory_size, batch_size, gamma, lr,
                 dropout, exploration_max, exploration_min, exploration_decay, double_dq, pretrained):

        # Define DQN Layers
        self.state_space = state_space
        self.action_space = action_space
        self.double_dq = double_dq
        self.pretrained = pretrained
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        if self.double_dq:  
            self.local_net = DQNSolver(state_space, action_space).to(self.device)
            self.target_net = DQNSolver(state_space, action_space).to(self.device)
            
            if self.pretrained:
                self.local_net.load_state_dict(torch.load("dq1.pt", map_location=torch.device(self.device)))
                self.target_net.load_state_dict(torch.load("dq2.pt", map_location=torch.device(self.device)))
                    
            self.optimizer = torch.optim.Adam(self.local_net.parameters(), lr=lr)
            self.copy = 5000  # Copy the local model weights into the target network every 5000 steps
            self.step = 0
        else:  
            self.dqn = DQNSolver(state_space, action_space).to(self.device)
            
            if self.pretrained:
                self.dqn.load_state_dict(torch.load("dq.pt", map_location=torch.device(self.device)))
            self.optimizer = torch.optim.Adam(self.dqn.parameters(), lr=lr)

        # Create memory
        self.max_memory_size = max_memory_size
        if self.pretrained:
            self.STATE_MEM = torch.load("STATE_MEM.pt")
            self.ACTION_MEM = torch.load("ACTION_MEM.pt")
            self.REWARD_MEM = torch.load("REWARD_MEM.pt")
            self.STATE2_MEM = torch.load("STATE2_MEM.pt")
            self.DONE_MEM = torch.load("DONE_MEM.pt")
            with open("ending_position.pkl", 'rb') as f:
                self.ending_position = pickle.load(f)
            with open("num_in_queue.pkl", 'rb') as f:
                self.num_in_queue = pickle.load(f)
        else:
            self.STATE_MEM = torch.zeros(max_memory_size, *self.state_space)
            self.ACTION_MEM = torch.zeros(max_memory_size, 1)
            self.REWARD_MEM = torch.zeros(max_memory_size, 1)
            self.STATE2_MEM = torch.zeros(max_memory_size, *self.state_space)
            self.DONE_MEM = torch.zeros(max_memory_size, 1)
            self.ending_position = 0
            self.num_in_queue = 0
        
        self.memory_sample_size = batch_size
        
        # Learning parameters
        self.gamma = gamma
        self.l1 = nn.SmoothL1Loss().to(self.device) # Also known as Huber loss
        self.exploration_max = exploration_max
        self.exploration_rate = exploration_max
        self.exploration_min = exploration_min
        self.exploration_decay = exploration_decay

    def remember(self, state, action, reward, state2, done):
        self.STATE_MEM[self.ending_position] = state.float()
        self.ACTION_MEM[self.ending_position] = action.float()
        self.REWARD_MEM[self.ending_position] = reward.float()
        self.STATE2_MEM[self.ending_position] = state2.float()
        self.DONE_MEM[self.ending_position] = done.float()
        self.ending_position = (self.ending_position + 1) % self.max_memory_size  # FIFO tensor
        self.num_in_queue = min(self.num_in_queue + 1, self.max_memory_size)
        
    def recall(self):
        # Randomly sample 'batch size' experiences
        idx = random.choices(range(self.num_in_queue), k=self.memory_sample_size)
        
        STATE = self.STATE_MEM[idx]
        ACTION = self.ACTION_MEM[idx]
        REWARD = self.REWARD_MEM[idx]
        STATE2 = self.STATE2_MEM[idx]
        DONE = self.DONE_MEM[idx]
        
        return STATE, ACTION, REWARD, STATE2, DONE

    def act(self, state):
        # Epsilon-greedy action
        
        if self.double_dq:
            self.step += 1
        if random.random() < self.exploration_rate:  
            return torch.tensor([[random.randrange(self.action_space)]])
        if self.double_dq:
            # Local net is used for the policy
            return torch.argmax(self.local_net(state.to(self.device))).unsqueeze(0).unsqueeze(0).cpu()
        else:
            return torch.argmax(self.dqn(state.to(self.device))).unsqueeze(0).unsqueeze(0).cpu()

    def copy_model(self):
        # Copy local net weights into target net
        
        self.target_net.load_state_dict(self.local_net.state_dict())
    
    def experience_replay(self):
        
        if self.double_dq and self.step % self.copy == 0:
            self.copy_model()

        if self.memory_sample_size > self.num_in_queue:
            return

        STATE, ACTION, REWARD, STATE2, DONE = self.recall()
        STATE = STATE.to(self.device)
        ACTION = ACTION.to(self.device)
        REWARD = REWARD.to(self.device)
        STATE2 = STATE2.to(self.device)
        DONE = DONE.to(self.device)
        
        self.optimizer.zero_grad()
        if self.double_dq:
            # Double Q-Learning target is Q*(S, A) <- r + γ max_a Q_target(S', a)
            target = REWARD + torch.mul((self.gamma * 
                                        self.target_net(STATE2).max(1).values.unsqueeze(1)), 
                                        1 - DONE)

            current = self.local_net(STATE).gather(1, ACTION.long()) # Local net approximation of Q-value
        else:
            # Q-Learning target is Q*(S, A) <- r + γ max_a Q(S', a) 
            target = REWARD + torch.mul((self.gamma * 
                                        self.dqn(STATE2).max(1).values.unsqueeze(1)), 
                                        1 - DONE)
                
            current = self.dqn(STATE).gather(1, ACTION.long())
        
        loss = self.l1(current, target)
        loss.backward() # Compute gradients
        self.optimizer.step() # Backpropagate error

        self.exploration_rate *= self.exploration_decay
        
        # Makes sure that exploration rate is always at least 'exploration min'
        self.exploration_rate = max(self.exploration_rate, self.exploration_min)

In [5]:
def vectorize_action(action, action_space):
    # Given a scalar action, return a one-hot encoded action
    
    return [0 for _ in range(action)] + [1] + [0 for _ in range(action + 1, action_space)]

In [6]:
def show_state(env, ep=0, info=""):
    plt.figure(3)
    plt.clf()
    plt.imshow(env.render(mode='rgb_array'))
    plt.title("Episode: %d %s" % (ep, info))
    plt.axis('off')

    display.clear_output(wait=True)
    display.display(plt.gcf())

In [7]:
def run(training_mode, pretrained):
   
    env = gym_super_mario_bros.make('SuperMarioBros-1-1-v0')
    env = make_env(env)  # Wraps the environment so that frames are grayscale 
    observation_space = env.observation_space.shape
    action_space = env.action_space.n
    agent = DQNAgent(state_space=observation_space,
                     action_space=action_space,
                     max_memory_size=30000,
                     batch_size=32,
                     gamma=0.90,
                     lr=0.00025,
                     dropout=0.,
                     exploration_max=1.0,
                     exploration_min=0.02,
                     exploration_decay=0.99,
                     double_dq=True,
                     pretrained=pretrained)
    
    num_episodes = 200
    env.reset()
    total_rewards = []
    
    for ep_num in tqdm(range(num_episodes)):
        state = env.reset()
        state = torch.Tensor([state])
        total_reward = 0
        steps = 0
        while True:
            if not training_mode:
                show_state(env, ep_num)
            action = agent.act(state)
            steps += 1
            
            state_next, reward, terminal, info = env.step(int(action[0]))
            total_reward += reward
            state_next = torch.Tensor([state_next])
            reward = torch.tensor([reward]).unsqueeze(0)
            
            terminal = torch.tensor([int(terminal)]).unsqueeze(0)
            
            if training_mode:
                agent.remember(state, action, reward, state_next, terminal)
                agent.experience_replay()
            
            state = state_next
            if terminal:
                break
        
        total_rewards.append(total_reward)

        print("Total reward after episode {} is {}".format(ep_num + 1, total_rewards[-1]))
        num_episodes += 1      
    
    if training_mode:
        with open("ending_position.pkl", "wb") as f:
            pickle.dump(agent.ending_position, f)
        with open("num_in_queue.pkl", "wb") as f:
            pickle.dump(agent.num_in_queue, f)
        with open("total_rewards.pkl", "wb") as f:
            pickle.dump(total_rewards, f)
        if agent.double_dq:
            torch.save(agent.local_net.state_dict(), "dq1.pt")
            torch.save(agent.target_net.state_dict(), "dq2.pt")
        else:
            torch.save(agent.dqn.state_dict(), "dq.pt")  
        torch.save(agent.STATE_MEM,  "STATE_MEM.pt")
        torch.save(agent.ACTION_MEM, "ACTION_MEM.pt")
        torch.save(agent.REWARD_MEM, "REWARD_MEM.pt")
        torch.save(agent.STATE2_MEM, "STATE2_MEM.pt")
        torch.save(agent.DONE_MEM,   "DONE_MEM.pt")
    
    env.close()
    
    if num_episodes > 500:
        plt.title("Episodes trained vs. Average Rewards (per 500 eps)")
        plt.plot([0 for _ in range(500)] + 
                 np.convolve(total_rewards, np.ones((500,))/500, mode="valid").tolist())
        plt.show()

run(training_mode=True, pretrained=False)

  logger.warn(
  state = torch.Tensor([state])
  0%|▏                                          | 1/200 [00:13<43:57, 13.25s/it]

Total reward after episode 1 is 807.0


  1%|▍                                          | 2/200 [00:30<51:00, 15.46s/it]

Total reward after episode 2 is 752.0


  2%|▌                                        | 3/200 [00:58<1:09:57, 21.31s/it]

Total reward after episode 3 is 572.0


  2%|▊                                        | 4/200 [01:36<1:31:05, 27.89s/it]

Total reward after episode 4 is 1007.0


  2%|█                                        | 5/200 [01:39<1:01:17, 18.86s/it]

Total reward after episode 5 is 252.0


  3%|█▎                                         | 6/200 [01:52<54:19, 16.80s/it]

Total reward after episode 6 is 627.0


  4%|█▌                                         | 7/200 [01:54<39:21, 12.24s/it]

Total reward after episode 7 is 248.0


  4%|█▋                                         | 8/200 [01:57<29:33,  9.24s/it]

Total reward after episode 8 is 247.0


  4%|█▉                                         | 9/200 [02:00<23:00,  7.23s/it]

Total reward after episode 9 is 252.0


  5%|██                                        | 10/200 [02:03<18:41,  5.90s/it]

Total reward after episode 10 is 251.0


  6%|██▎                                       | 11/200 [02:06<15:44,  5.00s/it]

Total reward after episode 11 is 252.0


  6%|██▌                                       | 12/200 [02:09<13:42,  4.37s/it]

Total reward after episode 12 is 251.0


  6%|██▋                                       | 13/200 [02:12<12:10,  3.91s/it]

Total reward after episode 13 is 248.0


  7%|██▉                                       | 14/200 [02:15<11:05,  3.58s/it]

Total reward after episode 14 is 248.0


  8%|███▏                                      | 15/200 [02:18<10:25,  3.38s/it]

Total reward after episode 15 is 251.0


  8%|███▎                                      | 16/200 [02:20<09:57,  3.25s/it]

Total reward after episode 16 is 251.0


  8%|███▌                                      | 17/200 [02:23<09:37,  3.15s/it]

Total reward after episode 17 is 251.0


  9%|███▊                                      | 18/200 [02:26<09:15,  3.05s/it]

Total reward after episode 18 is 252.0


 10%|███▉                                      | 19/200 [02:29<08:54,  2.95s/it]

Total reward after episode 19 is 252.0


 10%|████▏                                     | 20/200 [02:32<08:46,  2.92s/it]

Total reward after episode 20 is 252.0


 10%|████▍                                     | 21/200 [02:35<08:33,  2.87s/it]

Total reward after episode 21 is 247.0


 11%|████▌                                     | 22/200 [02:37<08:31,  2.87s/it]

Total reward after episode 22 is 251.0


 12%|████▊                                     | 23/200 [02:40<08:19,  2.82s/it]

Total reward after episode 23 is 252.0


 12%|█████                                     | 24/200 [02:43<08:13,  2.81s/it]

Total reward after episode 24 is 251.0


 12%|█████▎                                    | 25/200 [02:46<08:03,  2.76s/it]

Total reward after episode 25 is 252.0


 13%|█████▍                                    | 26/200 [02:48<08:00,  2.76s/it]

Total reward after episode 26 is 251.0


 14%|█████▋                                    | 27/200 [02:59<14:41,  5.10s/it]

Total reward after episode 27 is 629.0


 14%|█████▉                                    | 28/200 [03:02<12:47,  4.46s/it]

Total reward after episode 28 is 251.0


 14%|██████                                    | 29/200 [03:05<11:31,  4.05s/it]

Total reward after episode 29 is 250.0


 15%|██████▎                                   | 30/200 [03:08<10:38,  3.76s/it]

Total reward after episode 30 is 251.0


 16%|██████▌                                   | 31/200 [03:11<10:03,  3.57s/it]

Total reward after episode 31 is 252.0


 16%|██████▋                                   | 32/200 [03:14<09:45,  3.49s/it]

Total reward after episode 32 is 251.0


 16%|██████▉                                   | 33/200 [03:18<09:28,  3.41s/it]

Total reward after episode 33 is 252.0


 17%|███████▏                                  | 34/200 [03:21<09:22,  3.39s/it]

Total reward after episode 34 is 250.0


 18%|███████▎                                  | 35/200 [03:24<09:12,  3.35s/it]

Total reward after episode 35 is 252.0


 18%|███████▌                                  | 36/200 [03:28<09:06,  3.33s/it]

Total reward after episode 36 is 252.0


 18%|███████▊                                  | 37/200 [03:31<09:13,  3.40s/it]

Total reward after episode 37 is 252.0


 19%|███████▉                                  | 38/200 [03:35<09:11,  3.41s/it]

Total reward after episode 38 is 252.0


 20%|████████▏                                 | 39/200 [03:38<09:17,  3.46s/it]

Total reward after episode 39 is 252.0


 20%|████████▍                                 | 40/200 [03:50<16:09,  6.06s/it]

Total reward after episode 40 is 626.0


 20%|████████▌                                 | 41/200 [03:54<14:04,  5.31s/it]

Total reward after episode 41 is 252.0


 21%|████████▊                                 | 42/200 [04:34<41:13, 15.66s/it]

Total reward after episode 42 is 782.0


 22%|█████████                                 | 43/200 [04:37<31:25, 12.01s/it]

Total reward after episode 43 is 248.0


 22%|█████████▏                                | 44/200 [04:41<24:34,  9.45s/it]

Total reward after episode 44 is 252.0


 22%|█████████▍                                | 45/200 [04:44<19:47,  7.66s/it]

Total reward after episode 45 is 248.0


 23%|█████████▋                                | 46/200 [04:54<21:30,  8.38s/it]

Total reward after episode 46 is 637.0


 24%|█████████▊                                | 47/200 [04:58<17:41,  6.94s/it]

Total reward after episode 47 is 250.0


 24%|██████████                                | 48/200 [05:01<15:06,  5.97s/it]

Total reward after episode 48 is 251.0


 24%|██████████▎                               | 49/200 [05:05<13:21,  5.31s/it]

Total reward after episode 49 is 252.0


 25%|██████████▌                               | 50/200 [05:16<17:21,  6.94s/it]

Total reward after episode 50 is 628.0


 26%|██████████▋                               | 51/200 [05:26<19:55,  8.02s/it]

Total reward after episode 51 is 639.0


 26%|██████████▉                               | 52/200 [05:30<16:37,  6.74s/it]

Total reward after episode 52 is 252.0


 26%|███████████▏                              | 53/200 [05:34<14:17,  5.83s/it]

Total reward after episode 53 is 249.0


 26%|███████████▏                              | 53/200 [05:39<15:41,  6.41s/it]


KeyboardInterrupt: 

In [None]:
def show_policy(agent, env, steps=100):
    observation = env.reset()
    done = False
    for step in range(steps): 
        action = agent.eval_strategy.select_action(agent.target_model, observation)
        action = agent.act(state)
        observation, reward, done, info = env.step(action)
        if done: 
            observation = env.reset()
        clear_output(wait=True)
        plt.axis('off')
        plt.imshow(env.render(mode='rgb_array') )
        plt.show()   
    return