# Yahtzee Deep Q-Learning Training

This notebook implements a Dueling DQN agent to play Yahtzee using prioritized experience replay.

## Setup and Configuration

In [None]:
%pip install -r requirements.txt

In [None]:
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

# save_path = '/content/drive'
# checkpoint_path = '/content/drive'

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import random
import pickle
from tqdm import tqdm
import csv
import os

"""
COMPUTE DEVICE SETUP:
This whole program would not have been possible if it weren't for cloud-based GPUs.
More details in the READMEs, but I ran this on an NVIDIA L4 GPU via a Google Colab notebook.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

# Define random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

# Define Yahtzee categories
CATEGORIES = [
    'ones', 'twos', 'threes', 'fours', 'fives', 'sixes',
    'three_of_kind', 'four_of_kind', 'full_house',
    'small_straight', 'large_straight', 'yahtzee', 'chance'
]

## Yahtzee Game Environment

In [None]:
# Create Yahtzee gameplay env
class YahtzeeGame:

    # The next few functions are to reset the state, this first one is to automatically do that at the start
    def __init__(self):
        self.reset()

    def reset(self):
        self.dice = np.zeros(5, dtype=int)
        self.scorecard = {cat: None for cat in CATEGORIES}
        self.roll_count = 0
        self.turn = 0
        self.total_score = 0
        self.yahtzee_bonus_count = 0
        self.first_roll_of_turn = True
        self.game_log = []  # Track for analysis
        return self.get_state()

    # Define the action to roll the dice; this is done up to 3 times per turn
    def roll_dice(self, keep_mask=None):

        if keep_mask is None:
            keep_mask = [0, 0, 0, 0, 0]

        for i in range(5):
            if not keep_mask[i]:
                self.dice[i] = random.randint(1, 6)

        self.roll_count += 1
        self.first_roll_of_turn = False
        self.game_log.append(('roll', keep_mask, self.dice.copy()))
        return self.dice.copy()

    # In order for Q-learning to work, we must calculate what the current dice would score in a given category
    def calculate_score(self, category):
        dice = self.dice
        counts = np.bincount(dice, minlength=7)[1:]

        # Define scoring rules for each category
        if category == 'ones':
            return np.sum(dice == 1)
        elif category == 'twos':
            return np.sum(dice == 2) * 2
        elif category == 'threes':
            return np.sum(dice == 3) * 3
        elif category == 'fours':
            return np.sum(dice == 4) * 4
        elif category == 'fives':
            return np.sum(dice == 5) * 5
        elif category == 'sixes':
            return np.sum(dice == 6) * 6
        elif category == 'three_of_kind':
            return np.sum(dice) if np.max(counts) >= 3 else 0
        elif category == 'four_of_kind':
            return np.sum(dice) if np.max(counts) >= 4 else 0
        elif category == 'full_house':
            return 25 if sorted(counts[counts > 0]) == [2, 3] else 0
        elif category == 'small_straight':
            dice_set = set(dice)
            straights = [{1,2,3,4}, {2,3,4,5}, {3,4,5,6}]
            return 30 if any(s.issubset(dice_set) for s in straights) else 0
        elif category == 'large_straight':
            return 40 if set(dice) in [{1,2,3,4,5}, {2,3,4,5,6}] else 0
        elif category == 'yahtzee':
            return 50 if np.max(counts) == 5 else 0
        elif category == 'chance':
            return np.sum(dice)
        return 0

    # Using the above parameters for each category, we can now score the current dice in a given category
    # Returns -1 if invalid move
    def score_category(self, category):
        if self.scorecard[category] is not None:
            return -1  # Category already used
        if self.first_roll_of_turn:
            return -1  # Must roll before scoring

        score = self.calculate_score(category)

        # Yahtzee bonus: +100 for each additional Yahtzee after first
        is_yahtzee = np.max(np.bincount(self.dice, minlength=7)[1:]) == 5
        if is_yahtzee and self.scorecard['yahtzee'] is not None and self.scorecard['yahtzee'] > 0:
            score += 100
            self.yahtzee_bonus_count += 1

        self.scorecard[category] = score
        self.total_score += score
        self.game_log.append(('score', category, score))

        # Upper section bonus at game end (+35 if 63+ points)
        if self.is_game_over():
            upper_score = sum(self.scorecard[cat] for cat in CATEGORIES[:6] if self.scorecard[cat] is not None)
            if upper_score >= 63:
                self.total_score += 35
                self.game_log.append(('bonus', 'upper', 35))

        # Reset for next turn
        self.turn += 1
        self.roll_count = 0
        self.first_roll_of_turn = True

        return score

    def is_game_over(self):
        return all(score is not None for score in self.scorecard.values())

    def get_state(self):
        # Get current game state as feature vector.

        # Dice one-hot encoding (see READMEs for explanation) - basically just a way to turn categories into numbers
        dice_onehot = np.zeros(30)
        for i, die in enumerate(self.dice):
            if die > 0:
                dice_onehot[i * 6 + die - 1] = 1

        # Scorecard status
        scorecard_filled = np.array([1.0 if self.scorecard[cat] is not None else 0.0 for cat in CATEGORIES])

        """
        I like this next part so I want to explain it here:
        The program will create an array of placeholders for each category, then for each category
        that is still available it will calculate what the current roll would score. So let's sat that
        you roll a [2,2,5,5,6]. It will calculate the categories as this:
        [0, 4, 0, 0, 10, 6, 0, 0, 0, 0, 0, 0, 20]

        But then it will normalize it to the highest normal score in Yahtzee (50):
        [0, 0.08, 0, 0, 0.2, 0.12, 0, 0, 0, 0, 0, 0, 0.4]

        and use that to figure out which categories are most valuable to go for with the current roll!
        """

        # Potential scores (normalized)
        potential_scores = np.zeros(13) # placeholders for each category
        if not self.first_roll_of_turn: # Don't need to calculate if no roll yet
            for i, cat in enumerate(CATEGORIES):
                if self.scorecard[cat] is None: # if category is still available
                    potential_scores[i] = self.calculate_score(cat) / 50.0 # Normalize to between 0 and 1

        # Upper section progress (critical for 63-point bonus strategy)
        upper_scores = np.array([
            (self.scorecard[cat] if self.scorecard[cat] is not None else 0) / 18.0
            for cat in CATEGORIES[:6]
        ])

        # Turn and roll info
        roll_count = np.array([self.roll_count / 3.0])
        turn_progress = np.array([self.turn / 13.0])
        upper_sum = sum(self.scorecard[cat] if self.scorecard[cat] is not None else 0
                       for cat in CATEGORIES[:6])
        upper_progress = np.array([min(upper_sum / 63.0, 1.5)])  # Can exceed 1.0
        has_rolled = np.array([0.0 if self.first_roll_of_turn else 1.0])
        score_normalized = np.array([self.total_score / 400.0])
        yahtzee_bonuses = np.array([self.yahtzee_bonus_count / 3.0])

        # return the features the model is going to evaulate
        return np.concatenate([
            dice_onehot,           # 30
            scorecard_filled,      # 13
            potential_scores,      # 13
            upper_scores,          # 6
            roll_count,            # 1
            turn_progress,         # 1
            upper_progress,        # 1
            has_rolled,            # 1
            score_normalized,      # 1
            yahtzee_bonuses        # 1
        ])

    def get_valid_actions(self):
        #n Get all valid actions from current state
        actions = []

        # Must roll at start of turn
        if self.first_roll_of_turn:
            actions.append(('roll', tuple([0, 0, 0, 0, 0])))
            return actions

        # Can roll if under 3 rolls this turn
        if self.roll_count < 3:
            for i in range(32):
                keep_mask = tuple((i >> j) & 1 for j in range(5))
                actions.append(('roll', keep_mask))

        # Can score in any available category
        for cat in CATEGORIES:
            if self.scorecard[cat] is None:
                actions.append(('score', cat))

        return actions

## Neural Network Architecture

In [None]:
"""
This is where most of the magic happens. An explanation of Dueling DQN is in the READMEs
"""
class DuelingDQN(nn.Module):
    def __init__(self, state_size, action_size, hidden_size=512):
        super(DuelingDQN, self).__init__()

        self.feature = nn.Sequential( # shared network
            nn.Linear(state_size, hidden_size), # fully connected layers
            nn.ReLU(), # best activation function, according to my professors in grad school
            nn.LayerNorm(hidden_size), # the internet told me to add this
            nn.Dropout(0.1), # apparently reduces overfitting
            ### Repeat for hidden layers
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
            nn.LayerNorm(hidden_size),
            nn.Dropout(0.1),
        )

        # Value stream: estimates state value
        # V(s) = how good the state is
        self.value = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, 1)
        )

        # Advantage stream: estimates action advantages
        # This is the benefit of Dueling DQN over Q-learning - it separates state value from action advantage
        self.advantage = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Linear(hidden_size // 2, action_size)
        )

    # Dueling architecture combines streams to get Q-values
    def forward(self, x):
        features = self.feature(x)
        value = self.value(features)
        advantage = self.advantage(features)

        # Combine: Q = V + (A - mean(A))
        q_values = value + (advantage - advantage.mean(dim=1, keepdim=True))
        return q_values

## Prioritized Experience Replay

In [None]:
class PrioritizedReplayBuffer:

    # Initialize buffer with memory capacity and prioritization alpha (how strongly to priorize high-error samples)
    def __init__(self, capacity=300000, alpha=0.7): # High alpha so it can take risks to hopefully beat mathematical strategies
        self.capacity = capacity
        self.alpha = alpha
        self.buffer = []
        self.priorities = []
        self.position = 0

    # Add experiences to memory buffer
    def add(self, state, action, reward, next_state, done):
        max_priority = max(self.priorities) if self.priorities else 1.0

        experience = (state, action, reward, next_state, done)

        # Replace old memory if at capacity
        if len(self.buffer) < self.capacity:
            self.buffer.append(experience)
            self.priorities.append(max_priority)
        else:
            self.buffer[self.position] = experience
            self.priorities[self.position] = max_priority

        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size, beta=0.4):
        if len(self.buffer) == 0:
            return None

        # Meat and potatoes of PER right here
        priorities = np.array(self.priorities[:len(self.buffer)])
        probs = priorities ** self.alpha
        probs /= probs.sum()

        indices = np.random.choice(len(self.buffer), batch_size, p=probs, replace=False)

        # Importance sampling weights
        total = len(self.buffer)
        weights = (total * probs[indices]) ** (-beta)
        weights /= weights.max()

        experiences = [self.buffer[idx] for idx in indices]

        # Convert to tensors for PyTorch digestion
        states = torch.FloatTensor([e[0] for e in experiences]).to(device)
        actions = torch.LongTensor([e[1] for e in experiences]).to(device)
        rewards = torch.FloatTensor([e[2] for e in experiences]).to(device)
        next_states = torch.FloatTensor([e[3] for e in experiences]).to(device)
        dones = torch.FloatTensor([e[4] for e in experiences]).to(device)
        weights_tensor = torch.FloatTensor(weights).to(device)

        return states, actions, rewards, next_states, dones, weights_tensor, indices

    def update_priorities(self, indices, priorities):
        for idx, priority in zip(indices, priorities):
            self.priorities[idx] = priority

    def __len__(self):
        return len(self.buffer)

## DQN Agent

In [None]:
class YahtzeeAgent:

    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size

        # Initialize networks
        self.q_network = DuelingDQN(state_size, action_size, hidden_size=512).to(device)
        self.target_network = DuelingDQN(state_size, action_size, hidden_size=512).to(device)
        self.target_network.load_state_dict(self.q_network.state_dict())

        # Optimizer with multi-step Learning Rate decay
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=0.0001)
        self.scheduler = optim.lr_scheduler.MultiStepLR(
            self.optimizer,
            milestones=[10000, 20000, 30000],  # Drop LR at these episodes
            gamma=0.5  # Multiply LR by 0.5 at each milestone
        )

        # Replay buffer
        self.memory = PrioritizedReplayBuffer(capacity=300000, alpha=0.6)

        # Hyperparameters
        self.epsilon = 1.0 # Initial rate of random exploration
        self.epsilon_min = 0.01 # Target minimum exploration rate
        self.epsilon_decay = 0.9999  # Slow decay for better exploration
        self.gamma = 0.99
        self.batch_size = 256 # Could probably go higher, but I don't want to toast the GPUs
        self.update_target_every = 500
        self.train_start = 10000  # Warmup period
        self.steps = 0

        # Action mapping
        self.action_map = self._create_action_map()

        # Training metrics
        self.training_metrics = {
            'q_values': [],
            'td_errors': [],
            'gradient_norms': []
        }

    # Figure out every possible action and map it to an index - this was annoying to figure out
    def _create_action_map(self):
        action_map = []
        for i in range(32):
            keep_mask = tuple((i >> j) & 1 for j in range(5))
            action_map.append(('roll', keep_mask))
        for cat in CATEGORIES:
            action_map.append(('score', cat))
        return action_map

    def get_action_index(self, action):
        try:
            return self.action_map.index(action)
        except ValueError:
            return 0

    def select_action(self, state, valid_actions, training=True):
        # Epsilon-greedy action selection
        if training and random.random() < self.epsilon:
            return random.choice(valid_actions)

        state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)

        with torch.no_grad():
            q_values = self.q_network(state_tensor).cpu().numpy()[0]

        # Mask invalid actions
        valid_indices = [self.get_action_index(a) for a in valid_actions]
        masked_q = np.full(self.action_size, -1e10)
        masked_q[valid_indices] = q_values[valid_indices]

        best_action_idx = np.argmax(masked_q)
        return self.action_map[best_action_idx]

    def train_step(self, beta=0.4):
        # Single training step with prioritized experience replay
        if len(self.memory) < self.train_start:
            return 0.0

        batch = self.memory.sample(self.batch_size, beta)
        if batch is None:
            return 0.0

        states, actions, rewards, next_states, dones, weights, indices = batch

        # Current Q values
        current_q = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze()

        # Double DQN target to prevent overestimation
        with torch.no_grad():
            next_actions = self.q_network(next_states).max(1)[1]
            next_q = self.target_network(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
            target_q = rewards + (1 - dones) * self.gamma * next_q

        # TD errors for priority updates
        td_errors = torch.abs(current_q - target_q).detach().cpu().numpy()

        # Weighted Huber loss (stablizes training)
        loss = (weights * nn.SmoothL1Loss(reduction='none')(current_q, target_q)).mean()

        # Optimize and implement gradient clipping
        self.optimizer.zero_grad()
        loss.backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(self.q_network.parameters(), 1.0)
        self.optimizer.step()

        # Update priorities (with small epsilon to avoid zero priority)
        self.memory.update_priorities(indices, td_errors + 1e-6)

        # Update target network
        self.steps += 1
        if self.steps % self.update_target_every == 0:
            self.target_network.load_state_dict(self.q_network.state_dict())

        # Decay epsilon
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

        # Track metrics
        self.training_metrics['td_errors'].append(td_errors.mean())
        self.training_metrics['gradient_norms'].append(grad_norm.item())

        return loss.item()

    # Save and load model functions
    def save(self, filepath):
        torch.save({
            'q_network': self.q_network.state_dict(),
            'target_network': self.target_network.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'scheduler': self.scheduler.state_dict(),
            'epsilon': self.epsilon,
            'steps': self.steps,
            'training_metrics': self.training_metrics
        }, filepath)
        print(f"Model saved to {filepath}")

    def load(self, filepath):
        checkpoint = torch.load(filepath, map_location=device)
        self.q_network.load_state_dict(checkpoint['q_network'])
        self.target_network.load_state_dict(checkpoint['target_network'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.scheduler.load_state_dict(checkpoint['scheduler'])
        self.epsilon = checkpoint['epsilon']
        self.steps = checkpoint['steps']
        self.training_metrics = checkpoint.get('training_metrics', self.training_metrics)
        print(f"Model loaded from {filepath}")

## Reward Function

In [None]:
"""This part is incredibly important - the reward function defines how the agent learns!
I didn't know exactly how to do it at first, so I iterated by running it about 20% of the way several times.
However, I eventually just asked ChatGPT to rewrite it to be accurate to the probability of Yahtzee and to
reward the agent accordingly. The result is a much better reward function that leads to far superior play.
"""
def calculate_reward(game, action_type, score, prev_score, category_used):

    # ---- 1. Reward for rolling ----
    if action_type == 'roll':
        return 0.0  # rolling shouldn't bias the value function

    # ---- 2. If the score is invalid ----
    if score < 0:
        return -10.0  # keep penalty small but discouraging

    reward = 0.0

    # Category expected values (approx across optimal play)
    EV = {
        "ones":   2.8,
        "twos":   5.6,
        "threes": 8.5,
        "fours":  11.3,
        "fives":  14.0,
        "sixes":  16.8,
        "three_of_kind": 12,
        "four_of_kind":  7,
        "full_house": 9.2,
        "small_straight": 10,
        "large_straight": 7.5,
        "yahtzee": 2.5,
        "chance": 23
    }

    # Reward based on improvement over the EV for that category
    expected = EV.get(category_used, 0)
    reward += (score - expected) / 10.0
    #       ^ normalization so the agent sees consistent gradients

    if category_used in CATEGORIES[:6]:
        upper_sum = sum(game.scorecard[cat] if game.scorecard[cat] is not None else 0
                       for cat in CATEGORIES[:6])
        before = upper_sum - score   # previous total
        after  = upper_sum           # new total

        # Reward progress normalized by 63-point target
        upper_progress = (after - before) / 63.0
        reward += upper_progress * 3.0  # strong encouragement

    if score == 0:

        # Zeroing Yahtzee or 1s is fine
        if category_used in ("yahtzee", "ones"):
            reward -= 0.1  # almost neutral

        # Zeroing Chance early is bad
        elif category_used == "chance":
            reward -= 2.0

        # Zeroing upper categories mid-game is costly
        elif category_used in CATEGORIES[:6]:
            reward -= 1.5

        else:
            reward -= 0.5  # reasonable penalty

    # Add a small rarity reward (helpful early)
    rarity_bonus = {
        "yahtzee": 1.5,
        "large_straight": 0.9,
        "small_straight": 0.4,
        "full_house": 0.15,
    }

    if category_used in rarity_bonus and score > 0:
        reward += rarity_bonus[category_used]

    if game.is_game_over():
        total = game.total_score

        # Nonlinear shaping—smooth and RL-friendly
        if total >= 300:
            reward += 20.0
        elif total >= 250:
            reward += 12.0
        elif total >= 200:
            reward += 6.0
        elif total >= 150:
            reward += 2.0
        else:
            reward -= 5.0


    return reward

## Training and Testing Functions

In [None]:
def play_episode(game, agent, training=True, verbose=False):
    # Play one complete Yahtzee game (this was fun to write)
    state = game.reset()
    total_reward = 0
    steps = 0

    while not game.is_game_over() and steps < 250: # steps < 250 not needed, but a safety net for infinite loops
        valid_actions = game.get_valid_actions()
        if not valid_actions:
            break

        action = agent.select_action(state, valid_actions, training)
        action_type, action_value = action

        prev_score = game.total_score

        if action_type == 'roll':
            game.roll_dice(keep_mask=action_value)
            score = 0
            done = False
            category_used = None
        else:
            score = game.score_category(action_value)
            done = game.is_game_over()
            category_used = action_value
            if score < 0:
                done = True

        next_state = game.get_state()
        reward = calculate_reward(game, action_type, score, prev_score, category_used)

        if training:
            action_idx = agent.get_action_index(action)
            agent.memory.add(state, action_idx, reward, next_state, done)

        state = next_state
        total_reward += reward
        steps += 1

        if done:
            break

    return game.total_score, total_reward, steps


def train_agent(episodes=50000, save_path='yahtzee_model.pth', checkpoint_interval=1000):

    game = YahtzeeGame()
    state_size = len(game.get_state())
    action_size = 32 + 13

    agent = YahtzeeAgent(state_size, action_size)

    # Try to load existing checkpoint
    checkpoint_path = 'checkpoint.pth'
    start_episode = 0
    log_rows = []

    if os.path.exists(checkpoint_path):
        try:
            checkpoint = torch.load(checkpoint_path)
            # Load agent state
            agent.q_network.load_state_dict(checkpoint['q_network'])
            agent.target_network.load_state_dict(checkpoint['target_network'])
            agent.optimizer.load_state_dict(checkpoint['optimizer'])
            agent.scheduler.load_state_dict(checkpoint['scheduler'])
            agent.epsilon = checkpoint['epsilon']
            agent.steps = checkpoint['steps']
            agent.training_metrics = checkpoint.get('training_metrics', agent.training_metrics)
            # Load training progress
            start_episode = checkpoint.get('episode', 0)
            log_rows = checkpoint.get('log_rows', [])
            print(f"Resumed from episode {start_episode}")
        except Exception as e:
            print(f"Could not load checkpoint: {e}")
            print("Starting from scratch...")

    recent_scores = deque(maxlen=100)
    losses = []

    print("\nStarting training...\n")

    for episode in range(start_episode, episodes):

        score, reward, steps_taken = play_episode(game, agent, training=True)
        recent_scores.append(score)

        beta = min(1.0, 0.4 + 0.6 * episode / episodes)
        loss = agent.train_step(beta)
        if loss > 0:
            losses.append(loss)

        agent.scheduler.step()

        if (episode + 1) % 100 == 0:
            avg_score = np.mean(recent_scores)
            avg_loss = (np.mean(losses[-500:])
                        if len(losses) >= 500
                        else (np.mean(losses) if losses else 0))
            lr = agent.optimizer.param_groups[0]['lr']

            print(f"[Episode {episode+1}] "
                  f"AvgScore={avg_score:.1f}  "
                  f"Loss={avg_loss:.4f}  "
                  f"Eps={agent.epsilon:.4f}  "
                  f"LR={lr:.6f}  "
                  f"Steps={agent.steps}")

        lr = agent.optimizer.param_groups[0]['lr']
        log_rows.append({
            'episode': episode + 1,
            'score': score,
            'reward': reward,
            'steps': steps_taken,
            'epsilon': agent.epsilon,
            'learning_rate': lr,
            'loss': float(loss if loss > 0 else 0.0)
        })

        # Save checkpoint periodically
        if (episode + 1) % checkpoint_interval == 0:
            checkpoint = {
                'q_network': agent.q_network.state_dict(),
                'target_network': agent.target_network.state_dict(),
                'optimizer': agent.optimizer.state_dict(),
                'scheduler': agent.scheduler.state_dict(),
                'epsilon': agent.epsilon,
                'steps': agent.steps,
                'training_metrics': agent.training_metrics,
                'episode': episode + 1,
                'log_rows': log_rows
            }
            torch.save(checkpoint, checkpoint_path)

            # Also save CSV incrementally
            csv_path = "training_log.csv"
            with open(csv_path, 'w', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=list(log_rows[0].keys()))
                writer.writeheader()
                writer.writerows(log_rows)

            print(f"Checkpoint saved at episode {episode+1}")

    # Save final model
    agent.save(save_path)

    # Write final CSV
    csv_path = "training_log.csv"
    with open(csv_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=list(log_rows[0].keys()))
        writer.writeheader()
        writer.writerows(log_rows)

    print(f"\n✓ Training complete. CSV saved to {csv_path}")
    print(f"✓ Model saved to {save_path}\n")

    return agent


def test_agent(agent, num_games=100, verbose=False):
    # Run evaluation episodes with epsilon=0 and return scores.
    game = YahtzeeGame()
    original_epsilon = agent.epsilon
    agent.epsilon = 0

    scores = []
    rewards = []

    for _ in range(num_games):
        score, reward, _ = play_episode(game, agent, training=False, verbose=verbose)
        scores.append(score)
        rewards.append(reward)

    agent.epsilon = original_epsilon

    # Summary statistics
    avg = np.mean(scores)
    std = np.std(scores)
    med = np.median(scores)

    print(f"\n=== TEST RESULTS ({num_games} games) ===")
    print(f"Average: {avg:.1f}")
    print(f"Std Dev: {std:.1f}")
    print(f"Median: {med:.1f}")
    print(f"Min: {min(scores)}")
    print(f"Max: {max(scores)}")
    print(f">200: {sum(s >= 200 for s in scores)}/{num_games}")
    print(f">250: {sum(s >= 250 for s in scores)}/{num_games}")

    return scores

## Train the Agent

In [None]:
# Train
agent = train_agent(
    episodes=50000,
    save_path='yahtzee_model.pth'
)

In [None]:
def resume_training(checkpoint_path, csv_path, total_episodes=50000, checkpoint_interval=1000):
    """
    Resume training from a saved checkpoint.

    Args:
        checkpoint_path: Path to the .pth checkpoint file
        csv_path: Path to the training_log.csv file
        total_episodes: Total number of episodes to train to (default 50000)
        checkpoint_interval: How often to save checkpoints (default 1000)
    """
    import pandas as pd
    checkpoint_path = '/content/checkpoint.pth'
    csv_path = '/content/training_log.csv'
    # Initialize game and agent
    game = YahtzeeGame()
    state_size = len(game.get_state())
    action_size = 32 + 13
    agent = YahtzeeAgent(state_size, action_size)

    # Load checkpoint (weights_only=False for compatibility with numpy objects)
    print(f"Loading checkpoint from {checkpoint_path}...")
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)

    # Load agent state
    agent.q_network.load_state_dict(checkpoint['q_network'])
    agent.target_network.load_state_dict(checkpoint['target_network'])
    agent.optimizer.load_state_dict(checkpoint['optimizer'])
    agent.scheduler.load_state_dict(checkpoint['scheduler'])
    agent.epsilon = checkpoint['epsilon']
    agent.steps = checkpoint['steps']
    agent.training_metrics = checkpoint.get('training_metrics', agent.training_metrics)

    # Load training progress
    start_episode = checkpoint.get('episode', 0)

    # Load existing log rows from CSV
    log_rows = []
    if os.path.exists(csv_path):
        df = pd.read_csv(csv_path)
        log_rows = df.to_dict('records')
        print(f"Loaded {len(log_rows)} existing log entries from CSV")
    else:
        log_rows = checkpoint.get('log_rows', [])
        print(f"Loaded {len(log_rows)} log entries from checkpoint")

    print(f"Resuming training from episode {start_episode}")
    print(f"Current epsilon: {agent.epsilon:.4f}")
    print(f"Current steps: {agent.steps}")
    print(f"Episodes remaining: {total_episodes - start_episode}\n")

    recent_scores = deque(maxlen=100)
    losses = []

    # Continue training
    for episode in range(start_episode, total_episodes):

        score, reward, steps_taken = play_episode(game, agent, training=True)
        recent_scores.append(score)

        beta = min(1.0, 0.4 + 0.6 * episode / total_episodes)
        loss = agent.train_step(beta)
        if loss > 0:
            losses.append(loss)

        agent.scheduler.step()

        # Progress updates
        if (episode + 1) % 100 == 0:
            avg_score = np.mean(recent_scores)
            avg_loss = (np.mean(losses[-500:])
                        if len(losses) >= 500
                        else (np.mean(losses) if losses else 0))
            lr = agent.optimizer.param_groups[0]['lr']

            print(f"[Episode {episode+1}] "
                  f"AvgScore={avg_score:.1f}  "
                  f"Loss={avg_loss:.4f}  "
                  f"Eps={agent.epsilon:.4f}  "
                  f"LR={lr:.6f}  "
                  f"Steps={agent.steps}")

        # Log this episode
        lr = agent.optimizer.param_groups[0]['lr']
        log_rows.append({
            'episode': episode + 1,
            'score': score,
            'reward': reward,
            'steps': steps_taken,
            'epsilon': agent.epsilon,
            'learning_rate': lr,
            'loss': float(loss if loss > 0 else 0.0)
        })

        # Save checkpoint periodically
        if (episode + 1) % checkpoint_interval == 0:
            checkpoint = {
                'q_network': agent.q_network.state_dict(),
                'target_network': agent.target_network.state_dict(),
                'optimizer': agent.optimizer.state_dict(),
                'scheduler': agent.scheduler.state_dict(),
                'epsilon': agent.epsilon,
                'steps': agent.steps,
                'training_metrics': agent.training_metrics,
                'episode': episode + 1,
                'log_rows': log_rows
            }
            torch.save(checkpoint, checkpoint_path)

            # Save CSV incrementally
            with open(csv_path, 'w', newline='') as f:
                writer = csv.DictWriter(f, fieldnames=list(log_rows[0].keys()))
                writer.writeheader()
                writer.writerows(log_rows)

            print(f"✓ Checkpoint saved at episode {episode+1}")

    # Save final model
    final_model_path = 'yahtzee_model_final.pth'
    agent.save(final_model_path)

    # Write final CSV
    with open(csv_path, 'w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=list(log_rows[0].keys()))
        writer.writeheader()
        writer.writerows(log_rows)

    print(f"\n✓ Training complete!")
    print(f"✓ CSV saved to {csv_path}")
    print(f"✓ Final model saved to {final_model_path}\n")

    return agent


# Usage example:
# Resume from episode 43000 and continue to 50000
agent = resume_training(
    checkpoint_path=checkpoint_path,  # '/content/drive/MyDrive/Yathzee/checkpoint.pth'
    csv_path='training_log.csv',
    total_episodes=50000,
    checkpoint_interval=1000
)

## Test the Agent

In [None]:
# Test
test_scores = test_agent(agent, num_games=100, verbose=False)

# Save training results for external plotting
with open("test_scores.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["game", "score"])
    for i, s in enumerate(test_scores):
        writer.writerow([i, s])