# Catan DQN Training Notebook
This notebook trains a Deep Q-Network (DQN) agent to play a simplified version of Settlers of Catan.

### Setup
1. Ensure you have the Catan-AI code in your python path.
2. This notebook runs in headless mode (no pygame window).

In [None]:
%tensorflow_version 2.x
import sys
import os
import shutil
from google.colab import drive

# Mount Google Drive
drive.mount('/content/drive')

# CONFIG: Change this to the path where you uploaded the 'Catan-AI' folder in your Drive
DRIVE_PATH = '/content/drive/MyDrive/Catan-AI/code'

# Copy code to local runtime for speed (reading from Drive is slow)
if not os.path.exists('code'):
    shutil.copytree(DRIVE_PATH, 'code')

sys.path.append(os.path.abspath('code'))

import random
import numpy as np
import collections
import tensorflow as tf
from tensorflow.keras.layers import Dense, Input, Flatten
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam

# Add code directory to path
if os.path.exists('code'):
    sys.path.append(os.path.abspath('code'))
else:
    # Assuming we are inside 'notebooks'
    sys.path.append(os.path.abspath('../code'))

from catan_dqn_env import CatanEnv

# Set seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

## DQN Agent Implementation

In [None]:
class DQNAgent:
    def __init__(self, state_size, action_size):
        self.state_size = state_size
        self.action_size = action_size
        # Store (s, a, r, s', done, next_mask)
        self.memory = collections.deque(maxlen=2000)
        self.gamma = 0.95    # discount rate
        self.epsilon = 1.0   # exploration rate
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.learning_rate = 0.001
        self.model = self._build_model()
        self.target_model = self._build_model()
        self.update_target_model()

    def _build_model(self):
        model = Sequential()
        model.add(Dense(256, input_dim=self.state_size, activation='relu'))
        model.add(Dense(128, activation='relu'))
        model.add(Dense(64, activation='relu'))
        model.add(Dense(self.action_size, activation='linear'))
        model.compile(loss='mse', optimizer=Adam(learning_rate=self.learning_rate))
        return model

    def update_target_model(self):
        self.target_model.set_weights(self.model.get_weights())

    def remember(self, state, action, reward, next_state, done, next_mask):
        self.memory.append((state, action, reward, next_state, done, next_mask))

    def act(self, state, action_mask):
        """Epsilon-greedy, but never selects masked-out (invalid) actions."""
        valid_actions = np.flatnonzero(action_mask)
        if valid_actions.size == 0:
            return 0  # fallback to END_TURN

        if np.random.rand() <= self.epsilon:
            return int(np.random.choice(valid_actions))

        q = self.model.predict(state.reshape(1, -1), verbose=0)[0]
        q = q.copy()
        # Use dtype-safe negative infinity (avoids float16 overflow under mixed precision)
        mask_value = np.finfo(q.dtype).min
        q[~action_mask] = mask_value
        return int(np.argmax(q))

    def replay(self, batch_size):
        if len(self.memory) < batch_size:
            return

        minibatch = random.sample(self.memory, batch_size)
        states = np.array([i[0] for i in minibatch])
        actions = np.array([i[1] for i in minibatch])
        rewards = np.array([i[2] for i in minibatch])
        next_states = np.array([i[3] for i in minibatch])
        dones = np.array([i[4] for i in minibatch])
        next_masks = np.array([i[5] for i in minibatch])

        target = self.model.predict(states, verbose=0)
        target_next = self.target_model.predict(next_states, verbose=0)

        for i in range(batch_size):
            if dones[i]:
                target[i][actions[i]] = rewards[i]
            else:
                qn = target_next[i].copy()
                mask_value = np.finfo(qn.dtype).min
                qn[~next_masks[i]] = mask_value
                target[i][actions[i]] = rewards[i] + self.gamma * np.max(qn)

        self.model.fit(states, target, batch_size=batch_size, epochs=1, verbose=0)

        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay

    def load(self, name):
        self.model.load_weights(name)

    def save(self, name):
        self.model.save_weights(name)

## Training Loop

In [None]:
env = CatanEnv()
state_size = env.observation_space.shape[0]
action_size = env.action_space.n
agent = DQNAgent(state_size, action_size)

# Parameters from pseudocode
EPISODES = 50
BATCH_SIZE = 32
TRAIN_EVERY_N_ACTIONS = 4  # "every n actions, sample k transitions and train"
UPDATE_TARGET_EVERY_M_EPISODES = 5  # "every m episodes, copy training network to target"

scores = []
total_actions = 0

for e in range(EPISODES):
    state = env.reset()
    score = 0
    done = False
    steps = 0

    while not done:
        # 1. Mask invalid actions
        mask = env.action_mask()

        # 2. Action (masked epsilon-greedy)
        action = agent.act(state, mask)

        # 3. Step
        next_state, reward, done, _ = env.step(action)

        # 4. Next mask (for masked target max)
        next_mask = env.action_mask() if not done else np.ones(action_size, dtype=bool)

        # 5. Remember
        agent.remember(state, action, reward, next_state, done, next_mask)

        state = next_state
        score += reward
        steps += 1
        total_actions += 1

        # 6. Train every N actions
        if total_actions % TRAIN_EVERY_N_ACTIONS == 0 and len(agent.memory) > BATCH_SIZE:
            agent.replay(BATCH_SIZE)

    # 7. Update Target Network every M episodes
    if (e + 1) % UPDATE_TARGET_EVERY_M_EPISODES == 0:
        agent.update_target_model()
        print(f"Target network updated at episode {e+1}")

    print(f"episode: {e}/{EPISODES}, score: {score}, steps: {steps}, e: {agent.epsilon:.3}")
    scores.append(score)

    if e % 10 == 0:
        agent.save(f"catan-dqn-{e}.weights.h5")

In [None]:
import matplotlib.pyplot as plt
plt.plot(scores)
plt.ylabel('Score')
plt.xlabel('Episode')
plt.show()