In [1]:
import argparse
import json
import sys
from pathlib import Path

import pandas as pd
import matplotlib.pyplot as plt


def load_data(input_path: Path) -> pd.DataFrame:
    if not input_path.is_file():
        raise FileNotFoundError(f"Raw data file not found: {input_path}")
    with open(input_path, 'r') as f:
        data = json.load(f)
    return pd.DataFrame(data)


def plot_distribution(series: pd.Series, title: str, xlabel: str, ylabel: str, output_file: Path) -> None:
    plt.figure()
    series.plot(kind='bar')
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.tight_layout()

    output_file.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_file)
    plt.close()
    print(f"Saved plot: {output_file}")


def main():
    cwd = Path.cwd()
    default_input = cwd / 'backend' / 'src' / 'ml' / 'data' / 'raw_games.json'
    default_output = cwd / 'backend' / 'src' / 'ml' / 'scripts' / 'plots'

    parser = argparse.ArgumentParser(description="Connect Four gameplay analysis")
    parser.add_argument(
        '-i', '--input', type=Path, default=default_input,
        help="Path to raw self-play JSON data"
    )
    parser.add_argument(
        '-o', '--output-dir', type=Path, default=default_output,
        help="Directory to save output plots"
    )

    # Use parse_known_args to avoid SystemExit in interactive environments
    args, _ = parser.parse_known_args()

    try:
        df = load_data(args.input)
    except FileNotFoundError as e:
        sys.stderr.write(str(e) + "\n")
        sys.exit(1)

    total = len(df)
    win_pct = (df['outcome'] == 'win').mean() * 100
    loss_pct = (df['outcome'] == 'loss').mean() * 100
    draw_pct = (df['outcome'] == 'draw').mean() * 100

    print(f"Total examples: {total}")
    print(f"Win: {win_pct:.2f}%")
    print(f"Loss: {loss_pct:.2f}%")
    print(f"Draw: {draw_pct:.2f}%")

    outcome_counts = df['outcome'].value_counts().sort_index()
    plot_distribution(
        outcome_counts,
        title='Outcome Distribution',
        xlabel='Outcome',
        ylabel='Count',
        output_file=args.output_dir / 'outcome_distribution.png'
    )

    move_counts = df['move'].value_counts().sort_index()
    plot_distribution(
        move_counts,
        title='Move Distribution by Column',
        xlabel='Column Index',
        ylabel='Count',
        output_file=args.output_dir / 'move_distribution.png'
    )

    print("Analysis complete.")


if __name__ == '__main__':
    main()

Raw data file not found: /Users/derekjrussell/Documents/repos/ConnectFourGame/backend/src/ml/notebooks/backend/src/ml/data/raw_games.json


AssertionError: 