In [1]:
import torch
import torch.nn as nn
import numpy as np
import gymnasium as gym
import matplotlib.pyplot as plt
from collections import deque
import cv2
import time
from pathlib import Path
from IPython.display import clear_output

In [2]:
import ale_py
gym.register_envs(ale_py)

In [3]:
class ConvDQN_CNN1(nn.Module):
    """Original CNN architecture (32->64->64 filters)"""
    def __init__(self, input_channels, action_size):
        super(ConvDQN_CNN1, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        
        self.feature_size = self._get_conv_output_size()
        
        self.fc = nn.Sequential(
            nn.Linear(self.feature_size, 512),
            nn.ReLU(),
            nn.Linear(512, action_size)
        )
    
    def _get_conv_output_size(self):
        with torch.no_grad():
            dummy_input = torch.zeros(1, 2, 84, 84)
            conv_output = self.conv(dummy_input)
            return conv_output.numel() // conv_output.size(0)
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

class ConvDQN_CNN2(nn.Module):
    """Optimized CNN architecture (32->36->20 filters)"""
    def __init__(self, input_channels, action_size):
        super(ConvDQN_CNN2, self).__init__()
        
        self.conv = nn.Sequential(
            nn.Conv2d(input_channels, 32, kernel_size=8, stride=4, padding=2),
            nn.ReLU(),
            nn.Conv2d(32, 36, kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(36, 20, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )
        
        self.feature_size = self._get_conv_output_size()
        
        self.fc = nn.Sequential(
            nn.Linear(self.feature_size, 512),
            nn.ReLU(),
            nn.Linear(512, action_size)
        )
    
    def _get_conv_output_size(self):
        with torch.no_grad():
            dummy_input = torch.zeros(1, 2, 84, 84)
            conv_output = self.conv(dummy_input)
            return conv_output.numel() // conv_output.size(0)
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

def preprocess_frame(frame):
    gray = np.mean(frame, axis=2).astype(np.uint8)
    cropped = gray[34:194, :]
    resized = cv2.resize(cropped, (84, 84), interpolation=cv2.INTER_AREA)
    return resized.astype(np.float32) / 255.0

class FrameStack:
    def __init__(self, num_frames=2):
        self.num_frames = num_frames
        self.frames = deque(maxlen=num_frames)
    
    def reset(self, frame):
        processed_frame = preprocess_frame(frame)
        for _ in range(self.num_frames):
            self.frames.append(processed_frame)
        return self.get_stacked()
    
    def step(self, frame):
        processed_frame = preprocess_frame(frame)
        self.frames.append(processed_frame)
        return self.get_stacked()
    
    def get_stacked(self):
        return np.stack(list(self.frames), axis=0)

In [4]:
def find_latest_run_checkpoints():
    """Find checkpoints from the most recent training run"""
    artifacts_dir = Path.home() / "dev" / "rl_study" / "artifacts"
    
    run_dirs = [d for d in artifacts_dir.glob("run_*") if d.is_dir()]
    
    if not run_dirs:
        return []
    
    latest_run_dir = max(run_dirs, key=lambda d: d.stat().st_mtime)
    print(f"Latest run: {latest_run_dir.name}")
    
    checkpoint_files = list(latest_run_dir.glob("*checkpoint*.pth"))
    
    checkpoints = []
    for file_path in checkpoint_files:
        checkpoints.append({
            'path': str(file_path),
            'filename': file_path.name,
            'run_dir': latest_run_dir.name
        })
    
    # Sort checkpoints by episode number extracted from filename
    def extract_episode_number(checkpoint):
        # Extract episode number from filename like "pong_dqn_cnn_v2_checkpoint_ep1000.pth"
        filename = checkpoint['filename']
        try:
            # Find the episode number after "ep"
            ep_start = filename.find('ep') + 2
            ep_end = filename.find('.pth')
            episode_num = int(filename[ep_start:ep_end])
            return episode_num
        except (ValueError, AttributeError):
            return 0  # Default for files without episode numbers
    
    # Sort by episode number
    checkpoints.sort(key=extract_episode_number)
    
    return checkpoints

def load_model(checkpoint_path, cnn_type="cnn1"):
    """Load model from checkpoint with specified CNN architecture
    
    Args:
        checkpoint_path: Path to checkpoint file
        cnn_type: Either 'cnn1' for original architecture or 'cnn2' for optimized architecture
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    
    # Select model architecture based on cnn_type
    if cnn_type.lower() == "cnn1":
        model = ConvDQN_CNN1(input_channels=2, action_size=6).to(device)
        arch_name = "CNN1 (32→64→64)"
    elif cnn_type.lower() == "cnn2":
        model = ConvDQN_CNN2(input_channels=2, action_size=6).to(device)
        arch_name = "CNN2 (32→36→20)"
    else:
        raise ValueError(f"Unknown CNN type: {cnn_type}. Use 'cnn1' or 'cnn2'")
    
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    metadata = {
        'episode': checkpoint.get('episode', 'unknown'),
        'avg_reward': checkpoint.get('avg_reward', 'unknown'),
        'epsilon': checkpoint.get('epsilon', 'unknown'),
        'architecture': arch_name
    }
    
    return model, metadata, device

In [5]:
def visualize_game_frame(raw_frame, processed_frame, stacked_frames, q_values, action, reward, step, total_reward, checkpoint_name, architecture):
    """Show the current game state and agent decision"""
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f'{checkpoint_name} ({architecture}) - Step {step} - Action: {action}, Step Reward: {reward:.1f}, Total: {total_reward:.1f}', 
                 fontsize=14, fontweight='bold')
    
    # Raw game frame
    axes[0].imshow(raw_frame)
    axes[0].set_title('Raw Game Frame')
    axes[0].axis('off')
    
    # Processed frame that agent sees
    axes[1].imshow(processed_frame, cmap='gray')
    axes[1].set_title('Processed Frame (84x84)')
    axes[1].axis('off')
    
    # Q-values
    actions = ['NOOP', 'FIRE', 'RIGHT', 'LEFT', 'RIGHTFIRE', 'LEFTFIRE']
    colors = ['red' if i == action else 'blue' for i in range(len(q_values))]
    bars = axes[2].bar(range(len(q_values)), q_values, color=colors, alpha=0.7)
    axes[2].set_title('Q-Values (Selected Action in Red)')
    axes[2].set_xlabel('Actions')
    axes[2].set_ylabel('Q-Value')
    axes[2].set_xticks(range(len(actions)))
    axes[2].set_xticklabels(actions, rotation=45)
    axes[2].grid(True, alpha=0.3)
    
    # Add value labels
    for i, (bar, val) in enumerate(zip(bars, q_values)):
        axes[2].text(bar.get_x() + bar.get_width()/2., bar.get_height() + 0.01,
                        f'{val:.2f}', ha='center', va='bottom', 
                        fontweight='bold' if i == action else 'normal')
    
    # Frame difference (movement detection)
    frame_diff = np.abs(stacked_frames[1] - stacked_frames[0])
    axes[3].imshow(frame_diff, cmap='hot')
    axes[3].set_title('Movement Detection')
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
def play_and_visualize_game(model, device, checkpoint_name, architecture):
    """Play one complete game and show every 10th step"""
    env = gym.make('PongNoFrameskip-v4')
    
    print(f"Playing game with {checkpoint_name} using {architecture}")
    print("Showing every 10th step of the game")
    print("=" * 60)
    
    state, _ = env.reset()
    frame_stack = FrameStack(2)
    stacked_state = frame_stack.reset(state)
    
    total_reward = 0
    step_count = 0
    done = False
    
    while not done:
        # Get action from model
        with torch.no_grad():
            state_tensor = torch.FloatTensor(stacked_state).unsqueeze(0).to(device)
            q_values = model(state_tensor)
            action = q_values.max(1)[1].item()
            q_vals_numpy = q_values.cpu().numpy()[0]
        
        # Execute action with frame skipping
        step_reward = 0
        for _ in range(4):  # Frame skip
            next_state, reward, terminated, truncated, _ = env.step(action)
            step_reward += reward
            if terminated or truncated:
                break
        
        done = terminated or truncated
        total_reward += step_reward
        
        # Show visualization every 10 steps
        if step_count % 10 == 0:
            processed_current = preprocess_frame(next_state)
            clear_output(wait=True)
            visualize_game_frame(
                raw_frame=next_state,
                processed_frame=processed_current,
                stacked_frames=stacked_state,
                q_values=q_vals_numpy,
                action=action,
                reward=step_reward,
                step=step_count,
                total_reward=total_reward,
                checkpoint_name=checkpoint_name,
                architecture=architecture
            )
            
            print(f"Step {step_count}: Action={action}, Reward={step_reward:.1f}, Total={total_reward:.1f}")
            
            # Small delay to make it watchable
            time.sleep(0.2)
        
        # Update state
        stacked_state = frame_stack.step(next_state)
        step_count += 1
    
    env.close()
    
    print(f"\nGame finished!")
    print(f"Final reward: {total_reward}")
    print(f"Game length: {step_count} steps")
    print("=" * 60)
    
    return total_reward, step_count

In [7]:
# Find and display available checkpoints
checkpoints = find_latest_run_checkpoints()

if checkpoints:
    print(f"Found {len(checkpoints)} checkpoints from latest run:")
    for i, cp in enumerate(checkpoints):
        print(f"{i}: {cp['filename']}")
else:
    print("No checkpoints found. Run training first!")

Latest run: run_c2fe49760fe14990ad1b4a0df89b985b
Found 39 checkpoints from latest run:
0: pong_dqn_cnn_v2_checkpoint_ep1000.pth
1: pong_dqn_cnn_v2_checkpoint_ep2000.pth
2: pong_dqn_cnn_v2_checkpoint_ep3000.pth
3: pong_dqn_cnn_v2_checkpoint_ep4000.pth
4: pong_dqn_cnn_v2_checkpoint_ep5000.pth
5: pong_dqn_cnn_v2_checkpoint_ep6000.pth
6: pong_dqn_cnn_v2_checkpoint_ep7000.pth
7: pong_dqn_cnn_v2_checkpoint_ep8000.pth
8: pong_dqn_cnn_v2_checkpoint_ep9000.pth
9: pong_dqn_cnn_v2_checkpoint_ep10000.pth
10: pong_dqn_cnn_v2_checkpoint_ep11000.pth
11: pong_dqn_cnn_v2_checkpoint_ep12000.pth
12: pong_dqn_cnn_v2_checkpoint_ep13000.pth
13: pong_dqn_cnn_v2_checkpoint_ep14000.pth
14: pong_dqn_cnn_v2_checkpoint_ep15000.pth
15: pong_dqn_cnn_v2_checkpoint_ep16000.pth
16: pong_dqn_cnn_v2_checkpoint_ep17000.pth
17: pong_dqn_cnn_v2_checkpoint_ep18000.pth
18: pong_dqn_cnn_v2_checkpoint_ep19000.pth
19: pong_dqn_cnn_v2_checkpoint_ep20000.pth
20: pong_dqn_cnn_v2_checkpoint_ep21000.pth
21: pong_dqn_cnn_v2_checkpoin

In [None]:
# CONFIGURATION: Select model architecture
# Use "cnn1" for pong_dqn_cnn.ipynb (32->64->64 filters)
# Use "cnn2" for pong_dqn_cnn_v2.ipynb (32->36->20 filters)
ARCH = "cnn2"  # Use "cnn1" for original (32->64->64) or "cnn2" for optimized (32->36->20)

# Select which checkpoint to analyze (play one game per checkpoint)
if checkpoints:
    print(f"Using model architecture: {ARCH.upper()}")
    print("Available checkpoints (now sorted by episode number):")
    for i, checkpoint in enumerate(checkpoints):
        print(f"{i}: {checkpoint['filename']}")
    
    # You can modify this to select specific checkpoints or analyze all
    # For detailed step-by-step analysis, let's analyze fewer checkpoints
    selected_indices = list(range(0, len(checkpoints), 10))  # Every 10th checkpoint to avoid too many games
    print(f"\nAnalyzing checkpoints (every 10th): {selected_indices}")
    print("=" * 80)
    
    results = []
    
    for idx in selected_indices:
        checkpoint = checkpoints[idx]
        print(f"\n{'='*80}")
        print(f"ANALYZING CHECKPOINT {idx}: {checkpoint['filename']}")
        print(f"{'='*80}")
        
        try:
            # Load model with selected architecture
            model, metadata, device = load_model(checkpoint['path'], cnn_type=ARCH)
            print(f"Training Episode: {metadata['episode']}")
            print(f"Training Avg Reward: {metadata['avg_reward']}")
            print(f"Training Epsilon: {metadata['epsilon']}")
            print(f"Architecture: {metadata['architecture']}")
            print()
            
            # Play one complete game and visualize EVERY step
            final_reward, game_length = play_and_visualize_game(
                model, device, checkpoint['filename'], metadata['architecture']
            )
            
            # Store results
            results.append({
                'checkpoint': checkpoint['filename'],
                'training_episode': metadata['episode'],
                'training_avg_reward': metadata['avg_reward'],
                'game_reward': final_reward,
                'game_length': game_length
            })
            
        except Exception as e:
            print(f"Error with checkpoint {checkpoint['filename']}: {e}")
            continue
    
    # Summary of results
    if results:
        print(f"\n{'='*80}")
        print("ANALYSIS SUMMARY")
        print(f"{'='*80}")
        for result in results:
            print(f"Checkpoint: {result['checkpoint']}")
            print(f"  Training Ep: {result['training_episode']}, Training Avg: {result['training_avg_reward']}")
            print(f"  Game Reward: {result['game_reward']}, Game Length: {result['game_length']}")
            print()
            
else:
    print("No checkpoints available to analyze!")