
# Blackjack AI Training & Evaluation

This notebook trains and evaluates Blackjack reinforcement learning agents. It is designed to run top-to-bottom in Google Colab or Jupyter.

**Workflow overview**
1. Run the setup cell to install/runtime-check dependencies and configure the notebook.
2. Execute the smoke tests to validate the Blackjack environment and utilities.
3. Train the DQN agent (vectorised env) and track learning curves.
4. Evaluate the trained policy against a basic-strategy counting baseline, generating plots and CSV artifacts.

Artifacts are saved under `./outputs` with subfolders for models, metrics, and plots.


In [None]:

import importlib
import subprocess
import sys
from pathlib import Path

REQUIRED_PACKAGES = {
    "torch": "torch",
    "matplotlib": "matplotlib",
    "pandas": "pandas",
    "tqdm": "tqdm",
}

for module_name, install_name in REQUIRED_PACKAGES.items():
    try:
        importlib.import_module(module_name)
    except ImportError:
        print(f"Installing {install_name}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", install_name])

import matplotlib
matplotlib.use("Agg")

import torch
print(f"Torch version: {torch.__version__}")

import numpy as np
np.set_printoptions(precision=3)

ROOT = Path.cwd()
if str(ROOT) not in sys.path:
    sys.path.append(str(ROOT))

OUTPUT_DIR = ROOT / "outputs"
for sub in [OUTPUT_DIR, OUTPUT_DIR / "models", OUTPUT_DIR / "plots", OUTPUT_DIR / "metrics"]:
    sub.mkdir(parents=True, exist_ok=True)
print(f"Outputs directory: {OUTPUT_DIR}")


In [None]:

from __future__ import annotations

import json
from dataclasses import asdict
from pathlib import Path

import numpy as np
import pandas as pd

from blackjackai_rl.device import detect_torch_device
from blackjackai_rl.utils import ensure_dir, set_global_seed, to_json
from blackjackai_rl.env import BlackjackEnvConfig
from blackjackai_rl.agents import DQNAgent, DQNConfig
from blackjackai_rl.training import train_dqn
from blackjackai_rl.evaluation import (
    compare_to_baseline,
    evaluate_policy,
    plot_evaluation_results,
    plot_training_curves,
    save_hand_records,
)
from blackjackai_rl.tests import run_all_tests

DEVICE = detect_torch_device()
print(f"Using device: {DEVICE.device}")
set_global_seed(42)
BASE_OUTPUT = ensure_dir("outputs")
ensure_dir(BASE_OUTPUT / "plots")
ensure_dir(BASE_OUTPUT / "metrics")
ensure_dir(BASE_OUTPUT / "models")


In [None]:

results = run_all_tests()
for result in results:
    status = "PASS" if result.passed else "FAIL"
    message = f"{status} :: {result.name}"
    if result.details:
        message += f" | {result.details}"
    print(message)
assert all(r.passed for r in results), "One or more smoke tests failed"


In [None]:

training_steps = 20000
vector_envs = 32
log_interval = 1000
evaluation_hands = 4000

env_config = BlackjackEnvConfig(
    num_decks=6,
    penetration=0.8,
    natural_payout=1.5,
    hit_soft_17=False,
    min_bet=1.0,
    max_bet=8.0,
    bankroll=100.0,
    bankroll_stop_loss=0.0,
    bankroll_target=200.0,
    allow_surrender=True,
    allow_double=True,
    allow_split=True,
    max_splits=1,
    reward_shaping=True,
    shaping_stop_step=60000,
)

agent_config = DQNConfig(
    state_dim=12,
    num_actions=5,
    hidden_sizes=(256, 256),
    gamma=0.99,
    lr=5e-4,
    epsilon_start=1.0,
    epsilon_final=0.05,
    epsilon_decay=200000,
    batch_size=512,
    buffer_size=200000,
    min_buffer_size=4000,
    target_update_interval=2000,
    tau=0.01,
    device=DEVICE.device,
)

print("Training configuration ready")
print(env_config)
print(agent_config)


In [None]:

training_history = train_dqn(
    env_config=env_config,
    agent_config=agent_config,
    total_steps=training_steps,
    vector_envs=vector_envs,
    log_interval=log_interval,
    output_dir=BASE_OUTPUT,
)

best_model_path = Path(training_history["best_model_path"])
print(f"Best checkpoint: {best_model_path}")
print(f"Timings: {[asdict(t) for t in training_history['timings']]}")

metrics_df = pd.DataFrame({
    "reward": training_history["reward_history"],
    "epsilon": training_history["epsilon_history"],
})
metrics_csv = BASE_OUTPUT / "metrics" / "training_batch_metrics.csv"
metrics_df.to_csv(metrics_csv, index_label="batch")
print(f"Saved batch metrics to {metrics_csv}")


In [None]:

plot_paths = plot_training_curves(training_history, BASE_OUTPUT)
print("Training plots:")
for name, path in plot_paths.items():
    print(f"  {name}: {path}")


In [None]:

trained_agent = DQNAgent(agent_config)
trained_agent.load(str(best_model_path), map_location=DEVICE.device)

comparison = compare_to_baseline(env_config, trained_agent, num_hands=evaluation_hands)
trained_eval = comparison["trained"]
baseline_eval = comparison["baseline"]

trained_records = pd.DataFrame([asdict(record) for record in trained_eval["hand_records"]])
baseline_records = pd.DataFrame([asdict(record) for record in baseline_eval["hand_records"]])

trained_records_path = BASE_OUTPUT / "metrics" / "trained_hand_history.csv"
baseline_records_path = BASE_OUTPUT / "metrics" / "baseline_hand_history.csv"
trained_records.to_csv(trained_records_path, index=False)
baseline_records.to_csv(baseline_records_path, index=False)
print(f"Saved trained hand history to {trained_records_path}")
print(f"Saved baseline hand history to {baseline_records_path}")

summary_json = {
    "trained": asdict(trained_eval["summary"]),
    "baseline": asdict(baseline_eval["summary"]),
    "expected_value_gain": comparison["expected_value_gain"],
}
summary_path = BASE_OUTPUT / "metrics" / "evaluation_summary.json"
with open(summary_path, "w", encoding="utf-8") as handle:
    json.dump(summary_json, handle, indent=2)
print(f"Summary saved to {summary_path}")

plot_eval_paths = plot_evaluation_results(trained_eval["hand_records"], BASE_OUTPUT)
print("Evaluation plots:")
for name, path in plot_eval_paths.items():
    print(f"  {name}: {path}")


In [None]:

trained_summary = comparison["trained"]["summary"]
baseline_summary = comparison["baseline"]["summary"]
print("Trained agent summary:")
print(trained_summary)
print("
Baseline summary:")
print(baseline_summary)
print("
Expected value gain per hand:", comparison["expected_value_gain"])


In [None]:

from pprint import pprint

artifacts = {
    "best_model": str(best_model_path),
    "final_model": str(Path(BASE_OUTPUT) / "models" / "final_dqn.pt"),
    "training_metrics": str(metrics_csv),
    "evaluation_summary": str(BASE_OUTPUT / "metrics" / "evaluation_summary.json"),
}
artifacts.update(plot_paths)
artifacts.update(plot_eval_paths)
print("âœ… Training complete | Eval done | Plots saved to ./outputs | Best model:", best_model_path)
pprint(artifacts)
