In [66]:
import os
import gymnasium as gym
import flappy_bird_gymnasium as flappy_bird
import torch
import numpy as np
import matplotlib.pyplot as plt

from collections import deque, defaultdict
from gymnasium.wrappers import RecordVideo

import sys
sys.path.append(os.path.abspath("../"))

from model import DQN_CNN
from save_model.utils import transition
import tqdm


plt.rcParams["figure.figsize"] = (10, 5)


In [67]:
MODEL_PATH = "../../checkpoints/65a798e19ae9839f79aaa52c1bde3cb5"
ENV_NAME = "FlappyBird-v0"
DEVICE = "cpu"

NUM_EPISODES = 100
FRAME_SKIP = 4
MAX_STEPS = 10_000  # safety cap


In [68]:
policy_net = DQN_CNN(FRAME_SKIP).to(DEVICE)
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
policy_net.load_state_dict(checkpoint["policy_net_state_dict"])
policy_net.eval()

print("✅ Model loaded")


✅ Model loaded


In [69]:
def run_single_episode(env, model, frame_skip, device):
    frame_stack = deque(maxlen=frame_skip)

    state, _ = env.reset()
    done = False
    steps = 0

    total_reward = 0.0
    alive_frames = 0
    pipes_passed = 0
    deaths = 0
    top_hits = 0

    current_state, _, _ = transition(0, env, frame_skip)

    while not done and steps < MAX_STEPS:
        with torch.no_grad():
            state_t = torch.tensor(
                current_state, dtype=torch.float32
            ).unsqueeze(0).to(device)
            action = model(state_t).argmax(dim=1).item()

        next_state, reward_sum, done = transition(action, env, frame_skip)
        current_state = next_state

        total_reward += reward_sum
        steps += 1

        # ---------- reward decomposition ----------
        if reward_sum > 0:
            alive_frames += reward_sum / 0.1
            if reward_sum >= 1.0:
                pipes_passed += int(reward_sum // 1.0)

        if reward_sum <= -1.0:
            deaths += 1
        if reward_sum <= -0.5:
            top_hits += 1

    return {
        "total_reward": total_reward,
        "steps": steps,
        "alive_frames": alive_frames,
        "pipes_passed": pipes_passed,
        "deaths": deaths,
        "top_hits": top_hits,
    }


In [None]:
env = gym.make(ENV_NAME, render_mode="rgb_array")


metrics = defaultdict(list)

for episode in tqdm.tqdm(range(NUM_EPISODES)):
    episode_metrics = run_single_episode(
        env, policy_net, FRAME_SKIP, DEVICE
    )

    for k, v in episode_metrics.items():
        metrics[k].append(v)

    print(
        f"Episode {episode+1:03d} | "
        f"Reward: {episode_metrics['total_reward']:.2f} | "
        f"Pipes: {episode_metrics['pipes_passed']}"
    )

env.close()


  logger.warn(f"{pre} is not within the observation space.")
  logger.warn(f"{pre} is not within the observation space.")
  1%|          | 1/100 [00:05<09:35,  5.81s/it]

Episode 001 | Reward: 38.20 | Pipes: 14


  2%|▏         | 2/100 [00:21<18:46, 11.50s/it]

Episode 002 | Reward: 83.70 | Pipes: 27


  3%|▎         | 3/100 [00:25<13:10,  8.15s/it]

Episode 003 | Reward: 31.40 | Pipes: 10


  4%|▍         | 4/100 [00:35<13:56,  8.71s/it]

Episode 004 | Reward: 63.00 | Pipes: 14


  5%|▌         | 5/100 [00:38<10:37,  6.71s/it]

Episode 005 | Reward: 18.50 | Pipes: 2


  6%|▌         | 6/100 [00:44<10:25,  6.65s/it]

Episode 006 | Reward: 40.10 | Pipes: 13


  7%|▋         | 7/100 [00:48<08:47,  5.67s/it]

Episode 007 | Reward: 20.70 | Pipes: 8


  8%|▊         | 8/100 [00:54<08:51,  5.78s/it]

Episode 008 | Reward: 33.20 | Pipes: 13


  9%|▉         | 9/100 [01:09<13:02,  8.60s/it]

Episode 009 | Reward: 94.10 | Pipes: 34


 10%|█         | 10/100 [01:17<12:48,  8.54s/it]

Episode 010 | Reward: 51.30 | Pipes: 15


 11%|█         | 11/100 [01:23<11:26,  7.72s/it]

Episode 011 | Reward: 33.30 | Pipes: 11


 12%|█▏        | 12/100 [01:32<12:07,  8.27s/it]

Episode 012 | Reward: 66.60 | Pipes: 20


 13%|█▎        | 13/100 [01:46<14:14,  9.82s/it]

Episode 013 | Reward: 90.10 | Pipes: 25


 14%|█▍        | 14/100 [01:52<12:34,  8.78s/it]

Episode 014 | Reward: 37.70 | Pipes: 8


In [None]:
def summarize(values):
    return {
        "mean": np.mean(values),
        "std": np.std(values),
        "min": np.min(values),
        "max": np.max(values),
    }

summary = {
    key: summarize(val)
    for key, val in metrics.items()
}



In [None]:
plt.hist(metrics["total_reward"], bins=20)
plt.title("Total Reward Distribution (100 Episodes)")
plt.xlabel("Total Reward")
plt.ylabel("Frequency")
plt.show()


In [None]:
plt.plot(metrics["pipes_passed"])
plt.title("Pipes Passed per Episode")
plt.xlabel("Episode")
plt.ylabel("Pipes Passed")
plt.show()


In [None]:
plt.plot(metrics["alive_frames"])
plt.title("Alive Frames per Episode")
plt.xlabel("Episode")
plt.ylabel("Alive Frames")
plt.show()


In [None]:
labels = ["Deaths", "Top Screen Hits"]
values = [
    sum(metrics["deaths"]),
    sum(metrics["top_hits"]),
]

plt.bar(labels, values)
plt.title("Failure Event Counts (100 Episodes)")
plt.show()


In [None]:
import json
import pandas as pd
from pathlib import Path

OUTPUT_DIR = Path("metrics")
OUTPUT_DIR.mkdir(exist_ok=True)


In [None]:
df = pd.DataFrame(metrics)
df.index.name = "episode"

csv_path = OUTPUT_DIR / "episodes.csv"
df.to_csv(csv_path)

print(f"✅ Episode metrics saved to {csv_path}")
