In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from collections import deque
import gymnasium as gym
import matplotlib.pyplot as plt
from IPython.display import clear_output
import mlflow
import mlflow.pytorch
import wandb
import os
from dotenv import load_dotenv
from pathlib import Path
from logging_utils import *

In [None]:
# Global hyperparameters
LEARNING_RATE = 0.001
GAMMA = 0.99
EPSILON_START = 1.0
EPSILON_MIN = 0.01
EPSILON_DECAY = 0.99
BATCH_SIZE = 32
BUFFER_SIZE = 50000
TARGET_UPDATE = 1000
MEAN_REWARD_BOUND = 19.5
FRAME_SKIP = 1  # FC version uses regular Pong-v4
INPUT_CHANNELS = 2
FRAME_STACK = 2

# Initialize wandb with global parameters
load_dotenv('/home/bmartins/dev/rl_study/.env')
wandb_key = os.getenv('WANDB_KEY')
wandb.login(key=wandb_key)

wandb.init(
    project="pong-dqn-training",
    name="pong_fc_dqn",
    config={
        "learning_rate": LEARNING_RATE,
        "gamma": GAMMA,
        "epsilon_start": EPSILON_START,
        "epsilon_min": EPSILON_MIN,
        "epsilon_decay": EPSILON_DECAY,
        "batch_size": BATCH_SIZE,
        "buffer_size": BUFFER_SIZE,
        "target_update": TARGET_UPDATE,
        "mean_reward_bound": MEAN_REWARD_BOUND,
        "frame_skip": FRAME_SKIP,
        "architecture": "FC-DQN",
        "input_channels": INPUT_CHANNELS,
        "frame_stack": FRAME_STACK
    }
)

In [None]:
# Logging functions
def log_hyperparameters(artifacts_dir):
    """Log hyperparameters to MLflow"""
    mlflow.log_params({
        "learning_rate": LEARNING_RATE,
        "gamma": GAMMA,
        "epsilon_start": EPSILON_START,
        "epsilon_min": EPSILON_MIN,
        "epsilon_decay": EPSILON_DECAY,
        "batch_size": BATCH_SIZE,
        "buffer_size": BUFFER_SIZE,
        "target_update": TARGET_UPDATE,
        "mean_reward_bound": MEAN_REWARD_BOUND,
        "frame_skip": FRAME_SKIP,
        "input_channels": INPUT_CHANNELS,
        "frame_stack": FRAME_STACK,
        "artifacts_dir": artifacts_dir
    })

def log_model_info(model, device):
    """Log model information to both MLflow and wandb"""
    total_params = sum(p.numel() for p in model.parameters())
    mlflow.log_param("total_parameters", total_params)
    mlflow.log_param("device", str(device))
    wandb.log({"total_parameters": total_params, "device": str(device)})
    return total_params

def log_training_step(loss, episode, step_count):
    """Log training step metrics"""
    mlflow.log_metric("loss", loss, step=episode * 10000 + step_count)
    wandb.log({"loss": loss, "step": episode * 10000 + step_count})

def log_episode_metrics(episode_reward, mean_reward, epsilon, buffer_size, episode_length, episode):
    """Log episode metrics to both MLflow and wandb"""
    metrics = {
        "episode_reward": episode_reward,
        "mean_reward_100": mean_reward,
        "epsilon": epsilon,
        "buffer_size": buffer_size,
        "episode_length": episode_length,
        "episode": episode
    }
    
    mlflow.log_metrics({
        "episode_reward": episode_reward,
        "mean_reward_100": mean_reward,
        "epsilon": epsilon,
        "buffer_size": buffer_size,
        "episode_length": episode_length
    }, step=episode)
    
    wandb.log(metrics)

def log_10_episode_average(avg_reward, episode):
    """Log 10-episode average reward"""
    mlflow.log_metric("avg_reward_10", avg_reward, step=episode)
    wandb.log({"avg_reward_10": avg_reward})

def log_solved_episode(episode):
    """Log when environment is solved"""
    mlflow.log_metric("solved_at_episode", episode)
    wandb.log({"solved_at_episode": episode})

def save_checkpoint(artifacts_dir, episode, model, optimizer, episode_rewards, epsilon, run_id):
    """Save model checkpoint and log to tracking systems"""
    checkpoint_path = os.path.join(artifacts_dir, f'pong_dqn_fc_checkpoint_ep{episode}.pth')
    torch.save({
        'episode': episode,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'avg_reward': np.mean(episode_rewards[-10:]) if len(episode_rewards) >= 10 else np.mean(episode_rewards),
        'epsilon': epsilon,
        'episode_rewards': episode_rewards,
        'run_id': run_id
    }, checkpoint_path)
    
    # Log checkpoint to MLflow and wandb
    mlflow.log_artifact(checkpoint_path)
    wandb.save(checkpoint_path)
    print(f"Checkpoint saved at episode {episode}: {checkpoint_path}")
    return checkpoint_path

def save_training_plot(artifacts_dir, episode_rewards, episode, mean_reward_bound):
    """Save training plot to file and log to tracking systems"""
    if len(episode_rewards) <= 1:
        return
    
    # Calculate mean rewards for plotting (using 10-episode averages)
    plot_episodes = []
    plot_rewards = []
    for i in range(10, len(episode_rewards) + 1, 10):
        plot_episodes.append(i)
        plot_rewards.append(np.mean(episode_rewards[i-10:i]))
    
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(plot_episodes, plot_rewards, 'b-', linewidth=2, marker='o')
    plt.title('FC-DQN Pong Training Progress')
    plt.xlabel('Episode')
    plt.ylabel('Mean Reward (last 10 episodes)')
    plt.grid(True, alpha=0.3)
    plt.axhline(y=mean_reward_bound, color='r', linestyle='--', label=f'Target ({mean_reward_bound})')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(range(len(plot_rewards)), plot_rewards, 'g-', linewidth=2)
    plt.title('Training Progress Detail')
    plt.xlabel('Episodes (x10)')
    plt.ylabel('Average Reward')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    
    # Save plot to file
    plot_path = os.path.join(artifacts_dir, f"fc_training_plot_ep{episode}.png")
    plt.savefig(plot_path)
    mlflow.log_artifact(plot_path)
    wandb.log({"training_plot": wandb.Image(plot_path)})
    plt.close()  # Close to save memory
    
    print(f"Plot saved at episode {episode}: {plot_path}")
    return plot_path

def display_training_plot(episode_rewards, mean_reward_bound):
    """Display training plot to screen"""
    if len(episode_rewards) <= 1:
        return
    
    # Calculate mean rewards for plotting (using 10-episode averages)
    plot_episodes = []
    plot_rewards = []
    for i in range(10, len(episode_rewards) + 1, 10):
        plot_episodes.append(i)
        plot_rewards.append(np.mean(episode_rewards[i-10:i]))
    
    clear_output(wait=True)
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(plot_episodes, plot_rewards, 'b-', linewidth=2, marker='o')
    plt.title('FC-DQN Pong Training Progress')
    plt.xlabel('Episode')
    plt.ylabel('Mean Reward (last 10 episodes)')
    plt.grid(True, alpha=0.3)
    plt.axhline(y=mean_reward_bound, color='r', linestyle='--', label=f'Target ({mean_reward_bound})')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(range(len(plot_rewards)), plot_rewards, 'g-', linewidth=2)
    plt.title('Training Progress Detail')
    plt.xlabel('Episodes (x10)')
    plt.ylabel('Average Reward')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def log_final_results(model, episode, mean_reward, episode_rewards, artifacts_dir, run_id):
    """Log final model and metrics"""
    # Log final model
    mlflow.pytorch.log_model(model, "final_model")
    
    # Log final metrics
    final_metrics = {
        "final_episode": episode,
        "final_mean_reward": mean_reward,
        "total_episodes": len(episode_rewards)
    }
    
    mlflow.log_metrics(final_metrics)
    wandb.log(final_metrics)
    
    # Save final model to wandb (in run-specific folder)
    final_model_path = os.path.join(artifacts_dir, 'final_pong_dqn_fc_model.pth')
    torch.save({
        'model_state_dict': model.state_dict(),
        'run_id': run_id,
        'final_metrics': final_metrics
    }, final_model_path)
    wandb.save(final_model_path)
    
    return final_model_path

In [None]:
class SimpleDQN(nn.Module):
    def __init__(self, state_size, action_size):
        super(SimpleDQN, self).__init__()
        # Updated to use downscaled frame dimensions: 2 * 84 * 84 (2 frames at 84x84)
        self.fc = nn.Sequential(
            nn.Linear(2 * 84 * 84, 512),
            nn.ReLU(),
            nn.Linear(512, action_size)
        )
    
    def forward(self, x):
        # Flatten stacked frames: (batch, 2, 84, 84) -> (batch, 2*84*84)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = zip(*batch)
        return np.array(state), action, reward, np.array(next_state), done
    
    def __len__(self):
        return len(self.buffer)

def preprocess_frame(frame):
    """Convert frame to grayscale, crop, and downscale to 84x84"""
    import cv2
    
    # Convert to grayscale
    gray = np.mean(frame, axis=2).astype(np.uint8)
    
    # Crop the frame to remove score area and focus on game area
    # Original frame is 210x160, crop to 160x160 by removing top 34 pixels and bottom 16 pixels
    cropped = gray[34:194, :]  # Keep rows 34-193 (160 rows total), all columns (160)
    
    # Downscale to 84x84 using OpenCV
    resized = cv2.resize(cropped, (84, 84), interpolation=cv2.INTER_AREA)
    
    # Normalize to [0, 1]
    return resized.astype(np.float32) / 255.0

class FrameStack:
    """Stack multiple frames for temporal information"""
    def __init__(self, num_frames=2):
        self.num_frames = num_frames
        self.frames = deque(maxlen=num_frames)
    
    def reset(self, frame):
        """Reset with initial frame repeated num_frames times"""
        processed_frame = preprocess_frame(frame)
        for _ in range(self.num_frames):
            self.frames.append(processed_frame)
        return self.get_stacked()
    
    def step(self, frame):
        """Add new frame and return stacked frames"""
        processed_frame = preprocess_frame(frame)
        self.frames.append(processed_frame)
        return self.get_stacked()
    
    def get_stacked(self):
        """Return stacked frames as numpy array with shape (2, 84, 84)"""
        return np.stack(list(self.frames), axis=0)

In [None]:
import ale_py
gym.register_envs(ale_py)

In [None]:
def train_dqn():
    # End any active runs first
    if mlflow.active_run():
        mlflow.end_run()
    
    # MLflow experiment setup
    mlflow.set_experiment("Pong_FC_DQN_Training")
    
    with mlflow.start_run():
        # Create artifacts directory using MLflow run ID
        run_id = mlflow.active_run().info.run_id
        artifacts_dir = f"/home/bmartins/dev/rl_study/artifacts/run_{run_id}"
        os.makedirs(artifacts_dir, exist_ok=True)
        print(f"Artifacts will be saved to: {artifacts_dir}")
        
        # Environment setup
        env = gym.make('Pong-v4')
        
        # Use global hyperparameters
        action_size = env.action_space.n
        lr = LEARNING_RATE
        gamma = GAMMA
        epsilon = EPSILON_START
        epsilon_min = EPSILON_MIN
        epsilon_decay = EPSILON_DECAY
        batch_size = BATCH_SIZE
        buffer_size = BUFFER_SIZE
        target_update = TARGET_UPDATE
        mean_reward_bound = MEAN_REWARD_BOUND
        frame_skip = FRAME_SKIP
        
        # Log hyperparameters using external function
        log_hyperparameters(artifacts_dir)
        
        # Initialize networks and buffer
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")
        
        main_net = SimpleDQN(None, action_size).to(device)
        target_net = SimpleDQN(None, action_size).to(device)
        target_net.load_state_dict(main_net.state_dict())
        
        # Log model info using external function
        total_params = log_model_info(main_net, device)
        
        optimizer = optim.Adam(main_net.parameters(), lr=lr)
        buffer = ReplayBuffer(buffer_size)
        
        # Training loop
        episode_rewards = []
        mean_rewards = []
        episode = 0
        
        while True:
            state, _ = env.reset()
            frame_stack = FrameStack(FRAME_STACK)
            stacked_state = frame_stack.reset(state)
            total_reward = 0
            done = False
            step_count = 0
            
            while not done:
                # Epsilon-greedy action selection
                if random.random() > epsilon:
                    with torch.no_grad():
                        state_tensor = torch.FloatTensor(stacked_state).unsqueeze(0).to(device)
                        q_values = main_net(state_tensor)
                        action = q_values.max(1)[1].item()
                else:
                    action = env.action_space.sample()
                
                # Take action
                next_state, reward, terminated, truncated, _ = env.step(action)
                done = terminated or truncated
                next_stacked_state = frame_stack.step(next_state)
                
                # Store experience
                buffer.push(stacked_state, action, reward, next_stacked_state, done)
                stacked_state = next_stacked_state
                total_reward += reward
                step_count += 1
                
                # Train if buffer has enough samples
                if len(buffer) > batch_size:
                    # Sample batch
                    states, actions, rewards, next_states, dones = buffer.sample(batch_size)
                    
                    # Convert to tensors
                    states = torch.FloatTensor(states).to(device)
                    actions = torch.LongTensor(actions).to(device)
                    rewards = torch.FloatTensor(rewards).to(device)
                    next_states = torch.FloatTensor(next_states).to(device)
                    dones = torch.BoolTensor(dones).to(device)
                    
                    # Current Q values
                    current_q_values = main_net(states).gather(1, actions.unsqueeze(1))
                    
                    # Next Q values
                    with torch.no_grad():
                        next_q_values = target_net(next_states).max(1)[0]
                        target_q_values = rewards + (gamma * next_q_values * ~dones)
                    
                    # Compute loss
                    loss = nn.MSELoss()(current_q_values.squeeze(), target_q_values)
                    
                    # Optimize
                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(main_net.parameters(), max_norm=10)
                    optimizer.step()
                    
                    # Log loss using external function
                    if step_count % 1000 == 0:
                        log_training_step(loss.item(), episode, step_count)
            
            # Update target network
            if episode % target_update == 0:
                target_net.load_state_dict(main_net.state_dict())
            
            # Decay epsilon
            if epsilon > epsilon_min:
                epsilon *= epsilon_decay

            # Check mean reward, not single episode
            episode_rewards.append(total_reward)
            mean_reward = np.mean(episode_rewards[-100:]) if len(episode_rewards) >= 100 else np.mean(episode_rewards)
            
            # Log episode metrics using external function
            log_episode_metrics(total_reward, mean_reward, epsilon, len(buffer), step_count, episode)
            
            if mean_reward >= mean_reward_bound:
                print(f"Environment solved in {episode} episodes with mean reward: {mean_reward:.2f}")
                log_solved_episode(episode)
                break
            
            # Save checkpoint every 1000 episodes
            if episode > 0 and episode % 1000 == 0:
                save_checkpoint(artifacts_dir, episode, main_net, optimizer, episode_rewards, epsilon, run_id)
                save_training_plot(artifacts_dir, episode_rewards, episode, mean_reward_bound)
            
            # Display plot to screen every 100 episodes
            if episode % 100 == 0 and episode > 0:
                avg_reward = np.mean(episode_rewards[-10:]) if len(episode_rewards) >= 10 else np.mean(episode_rewards)
                mean_rewards.append(avg_reward)
                print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}, Epsilon: {epsilon:.3f}, Buffer: {len(buffer)}")
                
                display_training_plot(episode_rewards, mean_reward_bound)
            
            # Print progress every 10 episodes (no plotting)
            elif episode % 10 == 0 and episode > 0:
                avg_reward = np.mean(episode_rewards[-10:]) if len(episode_rewards) >= 10 else np.mean(episode_rewards)
                print(f"Episode {episode}, Avg Reward: {avg_reward:.2f}, Epsilon: {epsilon:.3f}, Buffer: {len(buffer)}")
                
                # Log 10-episode average using external function
                log_10_episode_average(avg_reward, episode)

            episode += 1
        
        # Log final results using external function
        final_model_path = log_final_results(main_net, episode, mean_reward, episode_rewards, artifacts_dir, run_id)
        
        print(f"Training completed! All artifacts saved in: {artifacts_dir}")
        
        env.close()
        wandb.finish()
        return main_net

In [None]:
# Test FC architecture
env = gym.make('Pong-v4')
state, _ = env.reset()

print(f"Original frame shape: {state.shape}")

processed = preprocess_frame(state)
print(f"Processed frame shape: {processed.shape}")

frame_stack = FrameStack(2)
stacked = frame_stack.reset(state)
print(f"Stacked frames shape: {stacked.shape}")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fc_net = SimpleDQN(None, env.action_space.n).to(device)

with torch.no_grad():
    state_tensor = torch.FloatTensor(stacked).unsqueeze(0).to(device)
    q_values = fc_net(state_tensor)
    print(f"Network input shape: {state_tensor.shape}")
    print(f"Flattened input size: {state_tensor.view(1, -1).shape}")
    print(f"Q-values shape: {q_values.shape}")
    print(f"Q-values: {q_values.cpu().numpy()}")

total_params = sum(p.numel() for p in fc_net.parameters())
print(f"\nTotal parameters: {total_params:,}")

# Compare with CNN version
cnn_params_approx = 3354278  # From CNN version
print(f"CNN parameters (approx): {cnn_params_approx:,}")
print(f"FC vs CNN parameter ratio: {total_params / cnn_params_approx:.2f}x")

env.close()

In [None]:
# Ensure no active MLflow runs before training
if mlflow.active_run():
    print("Ending active MLflow run...")
    mlflow.end_run()

# Train the FC-based DQN
trained_model = train_dqn()
print("FC-DQN Training completed!")

In [None]:
# Test frame cropping and downscaling
import gymnasium as gym
import matplotlib.pyplot as plt

# Create environment and get a sample frame
env = gym.make('Pong-v4')
state, _ = env.reset()

print(f"Original frame shape: {state.shape}")

# Test preprocessing
processed = preprocess_frame(state)
print(f"Processed frame shape: {processed.shape}")

# Test frame stack
frame_stack = FrameStack(2)  # Now using 2 frames
stacked = frame_stack.reset(state)
print(f"Stacked frames shape: {stacked.shape}")

# Visualize the processing pipeline
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 10))

# Original frame (convert to grayscale for display)
original_gray = np.mean(state, axis=2).astype(np.uint8)
ax1.imshow(original_gray, cmap='gray')
ax1.set_title(f'Original Frame {original_gray.shape}')
ax1.set_xlabel('Width (160 pixels)')
ax1.set_ylabel('Height (210 pixels)')

# Cropped frame (before downscaling)
cropped = original_gray[34:194, :]
ax2.imshow(cropped, cmap='gray')
ax2.set_title(f'Cropped Frame {cropped.shape}')
ax2.set_xlabel('Width (160 pixels)')
ax2.set_ylabel('Height (160 pixels)')

# Final processed frame
ax3.imshow(processed, cmap='gray')
ax3.set_title(f'Final Processed Frame {processed.shape}')
ax3.set_xlabel('Width (84 pixels)')
ax3.set_ylabel('Height (84 pixels)')

# Show network input size comparison
sizes = ['Original\n(210×160×3)', 'Cropped\n(160×160)', 'Downscaled\n(84×84)', 'Network Input\n(2×84×84)']
params = [210*160*3, 160*160, 84*84, 2*84*84]
colors = ['red', 'orange', 'yellow', 'green']

ax4.bar(sizes, params, color=colors, alpha=0.7)
ax4.set_title('Memory Usage Comparison')
ax4.set_ylabel('Number of Values')
ax4.tick_params(axis='x', rotation=45)

# Add value labels on bars
for i, v in enumerate(params):
    ax4.text(i, v + max(params)*0.01, f'{v:,}', ha='center', va='bottom')

plt.tight_layout()
plt.show()

print(f"\nMemory reduction:")
print(f"Original: {210*160*3:,} values")
print(f"Final: {2*84*84:,} values")
print(f"Reduction factor: {(210*160*3)/(2*84*84):.1f}x smaller")

env.close()