In [1]:
!pip install gym_super_mario_bros



In [5]:
# wrappers.py
import gym
import cv2
import numpy as np
from gym.spaces import Box
from collections import deque

class SkipFrame(gym.Wrapper):
    """Return every `skip`-th frame and repeat action during skip"""
    def __init__(self, env, skip=4):
        super().__init__(env)
        self._skip = skip

    def step(self, action):
        """Repeat action, and sum reward"""
        total_reward = 0.0
        done = False
        for i in range(self._skip):
            # Accumulate reward and repeat the same action
            obs, reward, done, info = self.env.step(action)
            total_reward += reward
            if done:
                break
        return obs, total_reward, done, info

class GrayScaleObservation(gym.ObservationWrapper):
    """Convert frames to grayscale"""
    def __init__(self, env):
        super().__init__(env)
        obs_shape = self.observation_space.shape[:2]  # (height, width)
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        # Convert RGB to grayscale
        observation = cv2.cvtColor(observation, cv2.COLOR_RGB2GRAY)
        return observation

class ResizeObservation(gym.ObservationWrapper):
    """Resize observation frames to specified size"""
    def __init__(self, env, size=84):
        super().__init__(env)
        if isinstance(size, int):
            self.size = (size, size)
        else:
            self.size = tuple(size)
            
        obs_shape = self.size  # new shape (height, width)
        self.observation_space = Box(low=0, high=255, shape=obs_shape, dtype=np.uint8)

    def observation(self, observation):
        # Resize the observation
        observation = cv2.resize(observation, self.size, interpolation=cv2.INTER_AREA)
        return observation

class NormalizeObservation(gym.ObservationWrapper):
    """Normalize observation values to range [0, 1]"""
    def __init__(self, env):
        super().__init__(env)
        self.observation_space = Box(low=0, high=1.0, 
                                    shape=self.observation_space.shape, 
                                    dtype=np.float32)

    def observation(self, observation):
        # Normalize from [0, 255] to [0, 1]
        return np.array(observation, dtype=np.float32) / 255.0

class FrameStack(gym.Wrapper):
    """Stack n_frames last frames."""
    def __init__(self, env, n_frames=4):
        super().__init__(env)
        self.n_frames = n_frames
        self.frames = deque([], maxlen=n_frames)
        
        # Update observation space to account for stacked frames
        shp = env.observation_space.shape
        obs_shape = (n_frames, *shp) if len(shp) == 2 else (n_frames, *shp[:-1])
        
        # Update observation space
        low = np.min(env.observation_space.low)
        high = np.max(env.observation_space.high)
        self.observation_space = Box(
            low=low, high=high, shape=obs_shape, dtype=env.observation_space.dtype
        )

    def reset(self):
        obs = self.env.reset()
        for _ in range(self.n_frames):
            self.frames.append(obs)
        return self._get_obs()

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.frames.append(obs)
        return self._get_obs(), reward, done, info

    def _get_obs(self):
        # Stack frames along first dimension
        return np.stack(self.frames, axis=0)
    
from IPython import embed
import torch
import torch.nn.functional as F
import torch.nn as nn
from torchrl.modules import NoisyLinear
import numpy as np
import random, os
import pickle
from collections import deque

class DuelingDQN(nn.Module):
    def __init__(self, input_shape, n_actions):
        super(DuelingDQN, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        
        # value stream
        self.value_stream = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, 1)
        )
        # self.value_stream = nn.Sequential(
        #     NoisyLinear(conv_out_size, 512, std_init=2.5),
        #     nn.ReLU(),
        #     NoisyLinear(512, 1, std_init=2.5)
        # )
        
        # advantage stream
        self.advantage_stream = nn.Sequential(
            nn.Linear(conv_out_size, 512),
            nn.ReLU(),
            nn.Linear(512, n_actions)
        )
        # self.advantage_stream = nn.Sequential(
        #     NoisyLinear(conv_out_size, 512, std_init=2.5),
        #     nn.ReLU(),
        #     NoisyLinear(512, n_actions, std_init=2.5)
        # )
        
    def _get_conv_out(self, shape):
        o = self.conv(torch.zeros(1, *shape)) # 1 dummy input through conv
        return int(np.prod(o.size()))
    
    def forward(self, x):
        batch_size = x.size()[0]
        conv_out = self.conv(x).view(batch_size, -1) # flatten 
        
        value = self.value_stream(conv_out) # (B, 1)
        advantage = self.advantage_stream(conv_out) # (B, n_actions)
        
        # combine value and advantage to get Q-values
        # Q(s, a) = V(s) + (A(s, a) - mean(A(s, a')))
        return value + (advantage - advantage.mean(dim=1, keepdim=True))
    
class ICM(nn.Module):
    def __init__(self, input_shape, embed_dim, n_actions):
        super().__init__()
        self.n_actions = n_actions
        
        self.conv_net = nn.Sequential(
            nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1),
            nn.ReLU()
        )
        
        conv_out_size = self._get_conv_out(input_shape)
        
        self.encoder = nn.Sequential(
            self.conv_net,
            nn.Flatten(), # flatten will only flatten from dim 1 onwards (excluding the batch dimension)
            nn.Linear(conv_out_size, embed_dim),
            nn.ReLU()
        )
        
        # input: [phi(s_t), action (one-hot)] -> phi(s_t+1)
        self.forward_model = nn.Sequential(
            nn.Linear(embed_dim + n_actions, 256),
            nn.ReLU(),
            nn.Linear(256, embed_dim)
        )
                
        # input: [phi(s_t), phi(s_t+1)] -> action a
        self.inverse_model = nn.Sequential(
            nn.Linear(embed_dim * 2, 256), 
            nn.ReLU(),
            nn.Linear(256, n_actions)
        )
                
    def forward(self, state, next_state, action):
        phi = self.encoder(state) # (B, embed_dim)
        phi_next = self.encoder(next_state) # (B, embed_dim)
        action = action.squeeze(-1)  # (B, )
        action_onehot = F.one_hot(action, self.n_actions).float() # (B, n_actions)
        # detach to stop the gradients of the forward model from flowing into the encoder
        fwd_model_input = torch.cat([phi.detach(), action_onehot], dim=1) # (B, embed_dim + n_actions); 
        inv_model_input = torch.cat([phi, phi_next], dim=1) # (B, embed_dim * 2)
        predicted_phi_next = self.forward_model(fwd_model_input) # (B, embed_dim)
        predicted_action = self.inverse_model(inv_model_input) # (B, n_actions)
        return predicted_phi_next, predicted_action, phi_next
    
    def _get_conv_out(self, shape):
        o = self.conv_net(torch.zeros(1, *shape)) # 1 dummy input through conv
        return int(np.prod(o.size()))
    
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchrl.modules import NoisyLinear
from torch.cuda.amp import autocast
import numpy as np
import random
import os
from collections import deque
from tensordict import TensorDict
from torchrl.data import LazyMemmapStorage, TensorDictPrioritizedReplayBuffer, TensorDictReplayBuffer
from model import DuelingDQN, ICM
import ipdb

class MarioAgent:
    def __init__(self, state_size=(4, 84, 84), action_size=12, batch_size=32, lr=2.5e-4, gamma=0.99, 
                 capacity=100000, update_target_freq=10000, tau=1.0, eps_start=1.0, eps_min=0.1, 
                 eps_fraction=500_000, alpha=0.6, beta=0.4, beta_increment=0.0001, num_envs=4, eps=1e-6):
        
        # Initialize parameters, networks, optimizer, replay buffer, etc.
        self.state_size = state_size
        self.action_size = action_size
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
        # hyperparameters
        self.gamma = gamma  # discount factor
        self.epsilon = eps_start  # exploration rate
        self.epsilon_start = eps_start
        self.epsilon_min = eps_min
        self.epsilon_fraction = eps_fraction
        self.learning_rate = lr
        self.update_target_freq = update_target_freq
        self.batch_size = batch_size
        self.capacity = capacity
        self.tau = tau
        self.alpha = alpha
        self.beta = beta
        self.beta_increment = beta_increment
        self.eps = eps
        self.burn_in = 5000 # min. experiences before training
        
        # Neural Networks - using Dueling architecture
        self.policy_net = DuelingDQN(state_size, action_size).to(self.device)
        self.target_net = DuelingDQN(state_size, action_size).to(self.device)
        # self.policy_net = torch.compile(self.policy_net)
        # self.target_net = torch.compile(self.target_net)
        self.target_net.load_state_dict(self.policy_net.state_dict())
        self.target_net.eval()
        
        # ICM
        self.icm = ICM(state_size, embed_dim=512, n_actions=action_size).to(self.device)
        self.icm_optimizer = torch.optim.Adam(self.icm.parameters(), lr=self.learning_rate)
        # scale of intrinsic reward
        self.icm_beta = 0.01
        # losses weighting
        self.icm_lambda_fwd = 0.8
        self.icm_lambda_inv = 0.2
        
        # Optimizer
        self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)
        
        # PrioritizedReplayBuffer from torchrl
        scratch_dir = "/data1/b10902078/replay_buffer"
        # scratch_dir = "./replay_buffer"
        storage = LazyMemmapStorage(max_size=self.capacity, scratch_dir=scratch_dir, device=torch.device("cpu"))
        # self.memory = TensorDictPrioritizedReplayBuffer(storage=storage, batch_size=batch_size, alpha=self.alpha, beta=self.beta, eps=eps, priority_key="td_error")
        self.memory = TensorDictReplayBuffer(storage=storage, batch_size=batch_size)
        
        # For updating target network
        self.learn_count = 0
        self.total_steps = 0
        
    def act(self, state, deterministic=False):
        """Select action using epsilon-greedy policy"""
        # Convert state to torch tensor if it's not already
        if isinstance(state, np.ndarray):
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device) # create batch dimension, (1, stack_size, 84, 84)
        
        # Epsilon-greedy action selection
        if not deterministic and random.random() < self.epsilon: # exploration
            return random.randrange(self.action_size)
        else: # exploitation
            # for m in self.policy_net.modules():
            #     if isinstance(m, NoisyLinear):
            #         m.reset_noise()
            
            with torch.no_grad():
                # with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
                q_values = self.policy_net(state)
            return q_values.argmax(dim=1).item()
        
    def train(self):
        """Train the network with a batch from replay memory"""
        # Check if buffer has enough samples
        # if self.total_steps < self.burn_in or len(self.memory) < self.batch_size:
        #     return
        if len(self.memory) < self.batch_size:
            return
        
        # Sample batch from replay buffer
        td_batch, info = self.memory.sample(self.batch_size, return_info=True)
        states = td_batch["state"].to(self.device) # (B, 4, 84, 84)
        actions = td_batch["action"].long().to(self.device) # (B, 1)
        rewards = td_batch["reward"].to(self.device) # (B, 1)
        next_states = td_batch["next_state"].to(self.device) # (B, 4, 84, 84)
        dones = td_batch["done"].to(self.device) # (B, 1)
        # indices = td_batch["index"] # (B,)
        # weights = td_batch["_weight"].to(self.device) # (B,)
        
        # —— compute ICM losses & intrinsic reward —— 
        pred_phi_next, pred_action_logits, phi_next = self.icm(states, next_states, actions)
        # forward loss per sample
        fwd_loss_sample = F.mse_loss(pred_phi_next, phi_next.detach(), reduction='none').mean(dim=1, keepdim=True) # (B, 1)
        # inverse loss per sample
        inv_loss_sample = F.cross_entropy(pred_action_logits, actions.squeeze(-1), reduction='none').unsqueeze(1) # (B, 1)
        # intrinsic reward signal
        intrinsic_reward = self.icm_beta * fwd_loss_sample # (B, 1)
        # augment external reward
        rewards = rewards + intrinsic_reward # (B, 1)
        # total ICM loss
        icm_loss = self.icm_lambda_fwd * fwd_loss_sample.mean() + self.icm_lambda_inv * inv_loss_sample.mean()
        
        # ipdb.set_trace()
        # for net in [self.policy_net, self.target_net]:
        #     for m in net.modules():
        #         if isinstance(m, NoisyLinear):
        #             m.reset_noise()
        
        # Compute current Q values
        q_values = self.policy_net(states).gather(1, actions)  # (batch_size, 1)
        
        # Double DQN: use online network to select action and target network to evaluate it
        with torch.no_grad():
            # with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16):
            # Select actions using the online policy network
            best_actions = self.policy_net(next_states).argmax(1, keepdim=True)  # (batch_size, 1)
            # Evaluate those actions using the target network
            next_q_values = self.target_net(next_states).gather(1, best_actions)  # (batch_size, 1)
        
        # Compute expected Q values
        expected_q_values = rewards + self.gamma * next_q_values * (1.0 - dones.float())  # (batch_size, 1)

        # # Calculate TD errors for updating priorities
        # td_errors = (q_values - expected_q_values).abs().detach().cpu().numpy().flatten() + self.eps # (B,)

        # # Update priorities in buffer
        # self.memory.update_priority(indices, td_errors)

        # # Apply importance sampling weights
        # weights = weights.unsqueeze(1)  # (batch_size, 1)
        
        # # Calculate loss using Huber loss (smooth L1)
        # loss = (weights * F.smooth_l1_loss(q_values, expected_q_values, reduction="none")).mean()
        
        # loss = F.smooth_l1_loss(q_values, expected_q_values)
        q_loss = F.smooth_l1_loss(q_values, expected_q_values)
        total_loss = q_loss + icm_loss

        # Gradient descent
        self.optimizer.zero_grad()
        self.icm_optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10.0)  # Gradient clipping
        torch.nn.utils.clip_grad_norm_(self.icm.parameters(), 10.0)
        self.optimizer.step()
        self.icm_optimizer.step()
        
        # Update target network periodically
        self.learn_count += 1
        if self.learn_count % self.update_target_freq == 0:
            self.update_target()
            
    def update_target(self):
        """Update target network with policy network weights"""
        if self.tau < 1.0:
            # Soft update
            for target_param, policy_param in zip(self.target_net.parameters(), self.policy_net.parameters()):
                target_param.data.copy_(self.tau * policy_param.data + (1.0 - self.tau) * target_param.data)
        else:
            # Hard update
            self.target_net.load_state_dict(self.policy_net.state_dict())
            
    def decay_epsilon(self):
        self.epsilon = max(self.epsilon_start * 10 ** (-self.total_steps / self.epsilon_fraction), self.epsilon_min)
    
    def store_transition(self, state, action, reward, next_state, done):
        """Store transition in replay buffer"""
        td = TensorDict({
            "state": torch.tensor(state, device="cpu"),
            "action": torch.tensor([action], device="cpu"),
            "reward": torch.tensor([reward], device="cpu"),
            "next_state": torch.tensor(next_state, device="cpu"),
            "done": torch.tensor([done], device="cpu")
        })
        self.memory.add(td)
        self.total_steps += 1
            
    def save(self, path):
        """Save model to disk"""
        torch.save({
            'policy_net': self.policy_net.state_dict(),
            'target_net': self.target_net.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'epsilon': self.epsilon,
            'learn_count': self.learn_count,
            'total_steps': self.total_steps,
            'icm': self.icm.state_dict(),           # Add ICM model state
            'icm_optimizer': self.icm_optimizer.state_dict()  # Add ICM optimizer state
        }, path)

    def load(self, path):
        """Load model from disk"""
        checkpoint = torch.load(path, map_location=self.device)
        self.policy_net.load_state_dict(checkpoint['policy_net'])
        self.target_net.load_state_dict(checkpoint['target_net'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.epsilon = checkpoint['epsilon']
        self.learn_count = checkpoint['learn_count']
        self.total_steps = checkpoint['total_steps']
        self.icm.load_state_dict(checkpoint['icm'])           # Load ICM model state
        self.icm_optimizer.load_state_dict(checkpoint['icm_optimizer'])  # Load ICM optimizer state

In [6]:
import random
import gym
import gym_super_mario_bros
from gym_super_mario_bros.actions import RIGHT_ONLY, SIMPLE_MOVEMENT, COMPLEX_MOVEMENT
from nes_py.wrappers import JoypadSpace
import numpy as np
import torch
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm
from mario_agent import MarioAgent
from wrappers import SkipFrame, GrayScaleObservation, ResizeObservation, NormalizeObservation, FrameStack
from IPython.display import clear_output

In [None]:
import time

env = gym_super_mario_bros.make('SuperMarioBros-v0')
env = JoypadSpace(env, COMPLEX_MOVEMENT)

# Apply preprocessing wrappers (same as training)
env = SkipFrame(env, skip=4)
env = GrayScaleObservation(env)
env = ResizeObservation(env, size=84)
env = NormalizeObservation(env)
env = FrameStack(env, n_frames=4)

# State shape should match the environment's observation space
state_shape = env.observation_space.shape
action_size = env.action_space.n

# Create agent
agent = MarioAgent(state_size=state_shape, action_size=action_size)

# Load trained model
model_path = "models/mario_dqn_ep8500.pth"
agent.load(model_path)
agent.epsilon = 0.1  # No exploration during evaluation

total_rewards = []
episodes = 100

# Evaluation loop
for episode in range(1, episodes + 1):
    state = env.reset()
    episode_reward = 0
    done = False
    first_step_random = True
    
    # Episode loop
    while not done:
        # frame = env.render(mode='rgb_array')
        # clear_output(wait=True)
        # plt.imshow(frame)
        # plt.axis('off')
        # plt.show()
        if first_step_random:
            action = random.randint(0, env.action_space.n - 1)
            first_step_random = False
        else:
            # Select action (no exploration)
            action = agent.act(state, deterministic=True)
        
        # Take action
        next_state, reward, done, info = env.step(action)
        
        # Update state and tracking info
        state = next_state
        episode_reward += reward
    
    # Track metrics
    total_rewards.append(episode_reward)
    print(f"Episode {episode}/{episodes}: Reward = {episode_reward}")

# Summary
avg_reward = np.mean(total_rewards)
# plt.imshow(frame)
print(f"\nAverage Reward: {avg_reward:.2f}")
env.close()

  logger.warn(
  checkpoint = torch.load(path, map_location=self.device)


Episode 1/20: Reward = 7565.0
Episode 2/20: Reward = 6182.0
Episode 3/20: Reward = 6182.0
Episode 4/20: Reward = 7565.0
Episode 5/20: Reward = 6473.0
Episode 6/20: Reward = 6600.0
Episode 7/20: Reward = 7565.0
Episode 8/20: Reward = 6182.0
Episode 9/20: Reward = 6600.0
Episode 10/20: Reward = 7904.0
Episode 11/20: Reward = 7565.0
Episode 12/20: Reward = 6473.0
Episode 13/20: Reward = 7565.0
Episode 14/20: Reward = 7899.0
Episode 15/20: Reward = 6600.0
Episode 16/20: Reward = 7899.0
Episode 17/20: Reward = 6182.0
Episode 18/20: Reward = 7899.0
Episode 19/20: Reward = 7565.0
Episode 20/20: Reward = 7565.0

Average Reward: 7101.50


In [29]:
x = torch.randn(32, 64, 7, 7)   # (batch, channels, height, width)
flatten = nn.Flatten()
out = flatten(x)
print(out.shape)

torch.Size([32, 3136])


In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dummy encoder and forward model
encoder = nn.Linear(10, 5)    # encoder: input 10 -> feature 5
forward_model = nn.Linear(5, 5)  # forward_model: input 5 -> predict 5

# Input
state = torch.randn(2, 10)   # batch of 2
next_state = torch.randn(2, 10)

# Get features
phi = encoder(state)
phi_next = encoder(next_state)

# Predict next feature
pred_phi_next = forward_model(phi.detach())

# Loss: MSE between pred_phi_next and detached phi_next
loss = F.mse_loss(pred_phi_next, phi_next.detach())
loss.backward()

print(encoder.weight.grad)   # <= look here
print(forward_model.weight.grad)


None
tensor([[ 0.0041,  0.0017, -0.0068, -0.0446, -0.0543],
        [-0.0109, -0.0046,  0.0185,  0.1210,  0.1472],
        [-0.0955,  0.0279, -0.0008, -0.0244,  0.0438],
        [-0.0370,  0.0157, -0.0118, -0.0863, -0.0713],
        [-0.0560,  0.0147,  0.0036,  0.0126,  0.0565]])


In [65]:
import torch.nn.functional as F
pred_phi_next = torch.randn(32, 256)
phi_next = torch.randn(32, 256)
fwd_loss_sample = F.mse_loss(pred_phi_next, phi_next.detach(), reduction='none').mean(dim=1, keepdim=True)
fwd_loss_sample.shape

torch.Size([32, 1])

In [60]:
pred_action_logits = torch.randn(32, 12)
actions = torch.randint(0, 12, (32,))
print(actions.shape)
F.cross_entropy(pred_action_logits, actions, reduction='none').unsqueeze(1).shape

torch.Size([32])


torch.Size([32, 1])

In [70]:
test = torch.arange(5, dtype=torch.float).view(5, 1)
print(test.shape)
test.mean()

torch.Size([5, 1])


tensor(2.)

In [56]:
x = torch.tensor([0, 1, 2])     # shape (3,) → OK
F.one_hot(x, num_classes=4).shape

torch.Size([3, 4])

In [None]:
x = torch.tensor([[0], [1], [2]])  # shape (3,1) → ❌ error
print(x.shape)
F.one_hot(x, num_classes=4).shape  # fix it by .squeeze(1)

torch.Size([3, 1])


torch.Size([3, 1, 4])

In [62]:
1e-3, 1e-2

(0.001, 0.01)