In [None]:
from __future__ import annotations
from collections import namedtuple, deque
from datetime import datetime as d
import cProfile
import pstats
import random
import matplotlib.pyplot as plt

import numpy as np
import torch.nn as nn
import torch.autograd
import torch.nn.functional as F
import torch.nn.init as init
from torch.utils.tensorboard import SummaryWriter

import gymnasium as gym
from gymnasium.wrappers import AtariPreprocessing, RecordVideo, FrameStack
from stable_baselines3.common.atari_wrappers import (
    ClipRewardEnv,
    EpisodicLifeEnv,
    FireResetEnv,
    MaxAndSkipEnv,
    NoopResetEnv,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
date = d.now()
date = date.strftime("%Y-%m-%d-%H-%M")

In [None]:
env_id = "PongNoFrameskip-v4"  # Select environment here
writer = SummaryWriter(f"Tensorboard_experiments/Atari/dqn-{env_id}-experiments-{date}")
run_name = f"dqn_{env_id}_videos_{date}"
seed = 42

env = gym.make(env_id, render_mode="rgb_array", repeat_action_probability=0)
env = gym.wrappers.RecordVideo(
    env, episode_trigger=lambda x: x % 20 == 0, video_folder=f"RL_Videos/Atari/{run_name}", disable_logger=True
)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env)
if "FIRE" in env.unwrapped.get_action_meanings():
    env = FireResetEnv(env)

env = ClipRewardEnv(env)
env = gym.wrappers.ResizeObservation(env, (84, 84))
env = gym.wrappers.GrayScaleObservation(env)
env = gym.wrappers.FrameStack(env, 4)
env.action_space.seed(seed)

In [None]:
# # Create environment
# env_spec = "PongNoFrameskip-v4"
# env = gym.make(env_spec, obs_type="grayscale", render_mode="rgb_array")
# # Apply some wrappers
# env = AtariPreprocessing(env, screen_size=84, grayscale_obs=True, grayscale_newaxis=False)
# env = FrameStack(env, num_stack=4)
# env = RecordVideo(env, episode_trigger=lambda x: x % 10 == 0, video_folder="ddqn_pong_videos_4-13-24-6-00")

In [None]:
# Create the Replay Buffer class
Transition = namedtuple("Transition", ("state", "action", "next_state", "reward", "done"))


class ReplayMemory(object):

    def __init__(self, capacity):
        self.memory = deque([], maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)

In [None]:
class DQNAgent:
    def __init__(
        self,
        model,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
    ):
        """Initialize a DQN RL agent, get action epsilon-greedily, and manage epsilon

        Args:
            initial_epsilon: The initial epsilon value
            epsilon_decay: The decay for epsilon
            final_epsilon: The final epsilon value
            q_value_model: The DQN which outputs a Q-value for each of the two possible actions
        """
        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon
        self.model = model

    def get_action(self, obs) -> int:
        """
        Returns the best action with probability (1 - epsilon)
        otherwise a random action with probability epsilon to ensure exploration.
        """
        if np.random.random() < self.epsilon:
            return env.action_space.sample()
        else:
            return int(torch.argmax(self.model(obs)))

    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

In [None]:
# May need to adjust for environments other than CartPole
state_shape = env.observation_space.shape
possible_actions = env.action_space.n

In [None]:
print(state_shape)
print(possible_actions)

In [None]:
observation, info = env.reset()

for _ in range(1000):
    action = env.action_space.sample()  # agent policy that uses the observation and info
    observation, reward, terminated, truncated, info = env.step(action)
    print(info["lives"])

    if terminated or truncated:
        observation, info = env.reset()

In [None]:
if torch.cuda.is_available():
    # Get the current memory allocation (in bytes) on the default GPU
    allocated_memory = torch.cuda.memory_allocated()
    print(f"Memory Allocated: {allocated_memory} bytes")

    # Convert bytes to gigabytes
    allocated_memory_mb = allocated_memory / (1024**3)
    print(f"Memory Allocated: {allocated_memory_mb:.2f} GB")

In [None]:
# Create the model
class DQN(nn.Module):
    def __init__(self, possible_actions):
        super().__init__()
        # self.bnorm = torch.nn.BatchNorm2d(num_features=4)
        self.conv1 = nn.Conv2d(4, 32, 8, stride=4)
        init.kaiming_normal_(self.conv1.weight, mode="fan_out", nonlinearity="relu")
        self.conv2 = nn.Conv2d(32, 64, 4, stride=2)
        init.kaiming_normal_(self.conv2.weight, mode="fan_out", nonlinearity="relu")
        self.conv3 = nn.Conv2d(64, 64, 3, stride=1)
        init.kaiming_normal_(self.conv3.weight, mode="fan_out", nonlinearity="relu")
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(3136, 512)
        init.kaiming_normal_(self.fc1.weight, mode="fan_out", nonlinearity="relu")
        self.fc2 = nn.Linear(512, possible_actions)
        init.normal_(self.fc2.weight, 0.01)

    def forward(self, x):
        # x = self.bnorm(x)
        x = x / 255.0
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.flatten(x)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
# Instantiate policy and target models
policy_model = DQN(possible_actions=possible_actions)
policy_model.to(device)
target_model = DQN(possible_actions=possible_actions)
target_model.to(device)
target_model.load_state_dict(policy_model.state_dict())
# Freeze target model; we will update the target model gradually from the policy model during training
for param in target_model.parameters():
    param.requires_grad = False

In [None]:
# Create/reset the replay buffer
buffer_size = 40_000
replay_buffer = ReplayMemory(buffer_size)

In [None]:
print(replay_buffer.__len__())

In [None]:
learning_rate = 0.0001
tau = 0.001

learn_start = 40_000
num_episodes = 100_000
batch_size = 32
final_exploration_frame = 750_000
target_update_freq = 1000

initial_epsilon = 1.0
epsilon_decay = initial_epsilon / final_exploration_frame  # Reach final epsilon at 50% of num_episodes
final_epsilon = 0.1

discount_factor = 0.99

In [None]:
criterion = torch.nn.HuberLoss(delta=1)
optimizer = torch.optim.Adam(params=policy_model.parameters(), lr=learning_rate)

In [None]:
# Used to render and observe model, significantly slows down training
env.close()
env = gym.make(env_id, render_mode="human")

In [None]:
returns = np.zeros(num_episodes)
moving_average_window = 10

In [None]:
# Epsilon-greedy policy, used for training
agent = DQNAgent(
    model=policy_model,
    initial_epsilon=initial_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)

In [None]:
# Training loop
num_frames = 0

policy_model.train()
for episode in range(num_episodes):
    done = False
    current_state, _ = env.reset()
    current_state = torch.tensor(np.array(current_state, dtype=np.single), device=device).unsqueeze(0)
    episode_return = 0
    while not done:
        action = agent.get_action(current_state)


        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated


        next_state = torch.tensor(np.array(next_state, dtype=np.single), device=device).unsqueeze(0)


        replay_buffer.push(current_state, action, next_state, reward, done)
        current_state = next_state
        num_frames += 1

        if num_frames < learn_start:
            continue


        transitions = replay_buffer.sample(batch_size=batch_size)
        batch = Transition(*zip(*transitions))

        reward_array = np.array(batch.reward, dtype=np.float32)
        batch_states = torch.cat(batch.state)
        batch_rewards = torch.tensor(reward_array, device=device).unsqueeze(1)
        batch_actions = torch.tensor(batch.action, device=device).unsqueeze(1)
        batch_next_states = torch.cat(batch.next_state)
        batch_done = torch.tensor(batch.done, device=device).unsqueeze(1)

        with torch.no_grad():
            # Compute return from target network
            next_actions = policy_model(batch_next_states).argmax(dim=1).unsqueeze(1)
            target_values = target_model(batch_next_states)
            best_next_q_values = torch.gather(target_values, 1, next_actions)
            target_q_values = batch_rewards + discount_factor * best_next_q_values * (~batch_done)

        predicted_q_values = policy_model(batch_states)
        predicted_q_values = torch.gather(predicted_q_values, 1, batch_actions)
        loss = criterion(predicted_q_values, target_q_values)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Polyak update target network towards policy network
        # target_model_state_dict = target_model.state_dict()
        # model_state_dict = model.state_dict()
        # for key in model_state_dict:

        #     target_model_state_dict[key] = model_state_dict[key] * tau + target_model_state_dict[key] * (1 - tau)
        # target_model.load_state_dict(target_model_state_dict)
        if num_frames % target_update_freq == 0:
            target_model.load_state_dict(policy_model.state_dict())

        episode_return += reward
        writer.add_scalar("Loss/Train", loss.item(), num_frames)
        writer.add_scalar("Epsilon/Train", agent.epsilon, num_frames)
        agent.decay_epsilon()

    returns[episode] = episode_return
    if num_frames >= learn_start:
        writer.add_scalar("Return/Train", episode_return, episode)


        writer.add_scalar(
            f"Avg Return, Window {moving_average_window}/Train",
            np.mean(returns[episode - moving_average_window + 1 : episode + 1]),
            episode,
        )
    if (episode + 1) % 100 == 0:
        print(
            f"Ep: {episode+1} Average return: {np.mean(returns[episode-moving_average_window+1:episode+1])} Eps: {agent.epsilon:.4f}"
        )
    writer.flush()

In [None]:
# Greedy policy, used for evaluation
agent = DQNAgent(
    model=policy_model,
    initial_epsilon=0,
    epsilon_decay=0,
    final_epsilon=0,
)

In [None]:
# Model evaluation

# losses = np.zeros(50_000_000)
num_episodes = 1000
returns_test = np.zeros(num_episodes)

policy_model.eval()
for episode in range(num_episodes):
    done = False
    current_state, _ = env.reset()
    current_state = torch.tensor(np.array(current_state, dtype=np.single), device=device).unsqueeze(0)
    episode_return = 0
    while not done:
        action = agent.get_action(current_state)
        next_state, reward, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

        next_state = torch.tensor(np.array(next_state, dtype=np.single), device=device).unsqueeze(0)

        current_state = next_state
        num_frames += 1

        # losses[num_frames] = loss.item()
        episode_return += reward
        agent.decay_epsilon()

    writer.add_scalar("Return/Test", episode_return, episode)
    returns_test[episode] = episode_return
    if episode >= 9:
        writer.add_scalar("Avg Return/Test", np.mean(returns_test[episode - 9 : episode + 1]), episode)
    if (episode + 1) % 100 == 0:
        print(f"Ep: {episode+1} Average return: {np.mean(returns_test[episode-9:episode+1])} Eps: {agent.epsilon:.4f}")
    writer.flush()

In [None]:
print(num_frames)

In [None]:
profiler = cProfile.Profile()
profiler.enable()
# optimize_model()
profiler.disable()

stats = pstats.Stats(profiler)
stats.sort_stats("time")
stats.print_stats()

In [None]:
# Save model parameters and other info
checkpoint_path = f"Model_Checkpoints/Atari/dqn_{env_id}_{date}.ckpt"
additional_info = {
    "model_state_dict": policy_model.state_dict(),
    "target_state_dict": target_model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
}
torch.save(additional_info, checkpoint_path)

In [None]:
# Load model and optimizer parameters
checkpoint_path = ""
checkpoint = torch.load(checkpoint_path)
policy_model.load_state_dict(checkpoint["model_state_dict"])
target_model.load_state_dict(checkpoint["target_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

In [None]:
test_transitions = replay_buffer.sample(batch_size=batch_size)
test_batch = Transition(*zip(*test_transitions))

test_batch_states = torch.cat(test_batch.state)
test_batch_rewards = torch.tensor(test_batch.reward, device=device).unsqueeze(1)
test_batch_actions = torch.tensor(test_batch.action, device=device).unsqueeze(1)
test_batch_next_states = torch.cat(test_batch.next_state)
test_batch_done = torch.tensor(test_batch.done, device=device).unsqueeze(1)

In [None]:
print(f"current_state.shape{current_state.shape}")
print(f"next_state.shape{next_state.shape}")
print(f"reward{reward}")
print(f"done{done}")
print(f"batch_states.shape{batch_states.shape}")
print(f"batch_rewards{batch_rewards}")
print(f"batch_actions{batch_actions}")
print(f"batch_next_states.shape{batch_next_states.shape}")
print(f"batch_done{batch_done}")
print(f"next_actions{next_actions}")
print(f"target_values{target_values}")
print(f"best_next_q_values{best_next_q_values}")
print(f"batch.done{batch.done}")
print(f"batch_done{batch_done}")
print(f"~batch_done{~batch_done}")
print(f"target_q_values{target_q_values}")
print(f"predicted_q_values{predicted_q_values}")

In [None]:
def imshow(example):
    img = example
    plt.imshow(img.numpy(), cmap="gray")
    plt.show()

In [None]:
def plot_images_grid(states, next_states):
    fig, axs = plt.subplots(2, 4, figsize=(15, 6))  # Create a 2x4 grid of subplots

    # Iterate over the first four states and plot them
    for i in range(4):
        img = states[0][i].cpu().numpy()  # Assuming states is a tensor
        axs[0, i].imshow(img, cmap="gray")  # Plot on row 1
        axs[0, i].axis("off")  # Turn off axis

    # Iterate over the first four next states and plot them
    for i in range(4):
        img = next_states[0][i].cpu().numpy()  # Assuming next_states is a tensor
        axs[1, i].imshow(img, cmap="gray")  # Plot on row 2
        axs[1, i].axis("off")  # Turn off axis

    plt.show()

In [None]:
# Get the states and next_states from your replay buffer
view_states = replay_buffer.memory[749].state
view_next_states = replay_buffer.memory[749].next_state

# Call the function with states and next states
plot_images_grid(view_states, view_next_states)

In [None]:
# Get the states and next_states from your replay buffer
view_states = replay_buffer.memory[750].state
view_next_states = replay_buffer.memory[750].next_state

# Call the function with states and next states
plot_images_grid(view_states, view_next_states)