In [None]:
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
import torch
import ale_py
from ale_py import ALEInterface
from torch import nn
ale = ALEInterface()
from PIL import Image
from tqdm import tqdm

gym.register_envs(ale_py)

In [None]:
env = gym.make("ALE/BeamRider-v5", render_mode="rgb_array")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
def preprocess(rgb_image):
    grayscale_image = np.dot(rgb_image[..., :3], [0.2989, 0.5870, 0.1140])
    grayscale_image = np.clip(grayscale_image, 0, 255)
    grayscale_image = grayscale_image.astype(np.uint8)
    img = Image.fromarray(grayscale_image)
    img = img.resize((110, 84))
    img = img.crop((15, 15, 99, 99))
    # print(np.shape(img))
    return np.array(img)

def render_rgbarray(env):
    plt.imshow(preprocess(env.render()))
    plt.axis('off')
    plt.show()

# def preprocess(img):
#     img = Image.fromarray(img)
#     # Convert to grayscale, resize, and crop
#     img = img.convert('L').resize((84, 110)).crop((0, 26, 84, 110))
#     return np.array(img)

In [None]:
env.reset()
for i in range(10000):
    action = env.action_space.sample()
    observation, reward, terminated, truncated, info = env.step(action=action)
    done = terminated or truncated
    if i % 500 == 0:
        render_rgbarray(env)
    if done:
        break
env.close()


In [None]:
from pyparsing import deque
import random 

class DQN(nn.Module):
    def __init__(self, input_shape, num_actions, hidden_depth, image_stack):
        super(DQN, self).__init__()

        self.input_shape = input_shape  # input image shape
        self.num_actions = num_actions  # no. of actions
        self.hidden_depth = hidden_depth  # no. of filters in first conv layer
        self.image_stack = image_stack  # no. of previous frames to stack

        # self.replay_memory = torch.zeros((100000, image_stack, *input_shape), dtype=torch.uint8)

        self.dqn = nn.Sequential(
            nn.Conv2d(self.image_stack, hidden_depth, kernel_size=8, stride=4),
            nn.ReLU(),
            nn.Conv2d(hidden_depth, hidden_depth * 2, kernel_size=4, stride=2),
            nn.ReLU(),
            nn.Conv2d(hidden_depth * 2, hidden_depth * 2, kernel_size=3, stride=1),
            nn.ReLU(), 
            nn.Flatten(),
            nn.Linear(3136, 256),
            nn.ReLU(),
            nn.Linear(256, num_actions)
        )

    def forward(self, x):
        return self.dqn(x)

class ReplayMemory:
    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def add(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
BATCH_SIZE = 32
GAMMA = 0.99  
EPS_START = 1.0 
EPS_END = 0.1 
EPS_DECAY = 200000 
TARGET_UPDATE_FREQUENCY = 1000 
LEARNING_RATE = 0.00025
MEMORY_SIZE = 100000 
NUM_EPISODES = 500 

policy_net = DQN((84, 84), env.action_space.n, 32, 4).to(device)
target_net = DQN((84, 84), env.action_space.n, 32, 4).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()

optimizer = torch.optim.Adam(policy_net.parameters(), lr=LEARNING_RATE)
memory = ReplayMemory(MEMORY_SIZE)

In [None]:
def epsilon_greedy_action(state):
    global steps_done
    sample = random.random()
    eps_threshold = EPS_END + (EPS_START - EPS_END) * \
        np.exp(-1. * steps_done / EPS_DECAY)
    steps_done += 1
    if sample > eps_threshold:
        with torch.no_grad():
            state = torch.tensor(state, dtype=torch.float32).unsqueeze(0).to(device)
            return policy_net(state).max(1)[1].item()  # return the argmax action
    else:
        return env.action_space.sample()  # return a random action
    

def optimize():
    if len(memory) < BATCH_SIZE:
        return 
    
    transitions = memory.sample(BATCH_SIZE)
    states, actions, rewards, next_states, dones = zip(*transitions)

    state_batch = torch.tensor(np.array(states), dtype=torch.float32).to(device)
    action_batch = torch.tensor(actions, dtype=torch.int64).unsqueeze(1).to(device)
    reward_batch = torch.tensor(rewards, dtype=torch.float32).unsqueeze(1).to(device)
    next_state_batch = torch.tensor(np.array(next_states), dtype=torch.float32).to(device)
    done_batch = torch.tensor(dones, dtype=torch.float32).unsqueeze(1).to(device)

    q_values = policy_net(state_batch).gather(1, action_batch)    
    with torch.no_grad():
        next_q_values = target_net(next_state_batch).max(1)[0].unsqueeze(1)

    expected_q_values = reward_batch + GAMMA * next_q_values * (1 - done_batch)

    loss = nn.functional.mse_loss(q_values, expected_q_values)
    
    optimizer.zero_grad()
    loss.backward()

    for param in policy_net.parameters():
        param.grad.data.clamp_(-1, 1)
    optimizer.step()
            

In [None]:
for episode in tqdm(range(NUM_EPISODES)):
    obs, info = env.reset()
    frame = preprocess(obs)
    state = np.stack([frame] * 4, axis=0)
    
    done = False
    while not done:
        action = epsilon_greedy_action(state)
        # print(action_tensor)
        # action = action_tensor.item()
        observation, reward, terminated, truncated, info = env.step(action)
        done = terminated or truncated

        next_frame = preprocess(observation)
        next_state = np.roll(state, -1, axis=0)
        next_state[-1] = next_frame

        memory.add(state, action, reward, next_state, done)

        state = next_state

        optimize()

    if episode % TARGET_UPDATE_FREQUENCY == 0:
        target_net.load_state_dict(policy_net.state_dict())

print("Training complete!")
env.close()

In [None]:
import os
save_dir = os.curdir
torch.save(policy_net.state_dict(), os.path.join(save_dir, "best_model.pth"))


In [None]:
import time
import torch
import gymnasium as gym
import numpy as np
import cv2
from collections import deque
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML  # For Jupyter-compatible animation display

# Assuming DQN is defined in your module; adjust import as needed
# from your_dqn_module import DQN  # Uncomment and adjust

# Define device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Preprocessing function for Atari observations (for DQN input)
def preprocess(observation):
    """Convert Atari observation to grayscale, resize to 84x84, and normalize to [0,1]."""
    if len(observation.shape) == 3:
        # RGB to grayscale using standard weights
        gray = np.dot(observation[..., :3], [0.299, 0.587, 0.114]).astype(np.float32)
    else:
        gray = observation.astype(np.float32)
    resized = cv2.resize(gray, (84, 84), interpolation=cv2.INTER_AREA)
    normalized = resized / 255.0
    return normalized

# Create the test environment with rgb_array rendering (no SDL needed)
test_env = gym.make("ALE/BeamRider-v5", render_mode="rgb_array")

# Initialize the Q-network (adjust DQN params based on your implementation)
test_q_network = DQN((84, 84), test_env.action_space.n, 32, 4).to(device)

# Load the trained model
model_path = '/mnt/ML Summer Learning/Paper_Implementation-master/Paper_Implementation-master/RL code/Deep-Q-Learning/best_model.pth'
test_q_network.load_state_dict(torch.load(model_path, map_location=device))
test_q_network.eval()  # Set the network to evaluation mode

def greedy_action_test(stacked_state):
    """Chooses the best action based on the Q-network's prediction (no exploration).
    Assumes stacked_state is (4, 84, 84)."""
    with torch.no_grad():
        state_tensor = torch.FloatTensor(stacked_state).unsqueeze(0).to(device)  # (1, 4, 84, 84)
        q_values = test_q_network(state_tensor)
        return q_values.argmax().item()

def create_animation(frames):
    """Create and return a matplotlib animation from the list of frames as HTML for Jupyter."""
    if not frames:
        return None
    
    fig, ax = plt.subplots(figsize=(5, 4))  # Adjust size as needed for 160x210 aspect
    ax.set_title("BeamRider Episode Replay")
    ax.axis('off')
    
    im = ax.imshow(frames[0])
    
    def animate(i):
        im.set_array(frames[i])
        return [im]
    
    interval = 50  # ms per frame (20 FPS)
    ani = FuncAnimation(fig, animate, frames=len(frames), interval=interval, blit=True, repeat=True)
    plt.close(fig)  # Close the static fig to avoid display
    return ani.to_jshtml()  # Return HTML string for embedding

num_test_episodes = 10
for episode in range(num_test_episodes):
    # Reset environment and initialize frame stack
    state, info = test_env.reset()
    processed = preprocess(state)
    frame_stack = deque([processed] * 4, maxlen=4)  # Stack 4 identical initial frames
    stacked_state = np.stack(frame_stack, axis=0)  # (4, 84, 84)
    total_reward = 0
    done = False
    frames = []  # List to collect rendered frames for animation

    while not done:
        # Render the current frame (rgb_array: returns numpy array)
        frame = test_env.render()
        frames.append(frame.copy())  # Store a copy to avoid reference issues
        
        # Agent chooses the best action (no epsilon)
        action_index = greedy_action_test(stacked_state)
        
        # Environment takes a step
        next_state, reward, terminated, truncated, _ = test_env.step(action_index)
        done = terminated or truncated
        
        # Preprocess next state and update stack
        next_processed = preprocess(next_state)
        frame_stack.append(next_processed)
        stacked_state = np.stack(frame_stack, axis=0)
        
        total_reward += reward
        
        # Add a small delay to match original timing (optional)
        time.sleep(0.02)

    print(f"Test Episode {episode + 1}, Total Reward: {total_reward:.2f}")
    
    # Create and display the animation as HTML in Jupyter
    html_animation = create_animation(frames)
    display(HTML(html_animation))  # This embeds the interactive animation; it will play automatically

test_env.close()
plt.close('all')  # Clean up figures