In [None]:
import argparse
import json
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

# Import your policy network
ML_ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(ML_ROOT / 'src'))
from policy_net import Connect4PolicyNet  # type: ignore


def find_file(filename: str) -> Path:
    cwd = Path.cwd()
    for ancestor in [cwd] + list(cwd.parents):
        candidate = ancestor / 'backend' / 'src' / 'ml' / 'data' / filename
        if candidate.is_file():
            return candidate
    raise FileNotFoundError(f"Could not locate {filename} under any ancestor directory")


def load_json(path: Path):
    with open(path, 'r') as f:
        return json.load(f)


def convert_feature(feat_list):
    # feat_list: list of 42 ints (0=Empty,1=Red,2=Yellow)
    arr = np.array(feat_list, dtype=np.int64).reshape(6, 7)
    red = (arr == 1).astype(np.float32)
    yellow = (arr == 2).astype(np.float32)
    return torch.from_numpy(np.stack([red, yellow], axis=0))  # shape [2,6,7]


def prepare_dataset(json_path: Path):
    examples = load_json(json_path)
    data_list, label_list = [], []
    for ex in examples:
        feat = ex.get('features')
        label = ex.get('label')
        if feat is None or label is None:
            continue
        data_list.append(convert_feature(feat))
        label_list.append(int(label))
    if not data_list:
        raise RuntimeError(f"No valid examples found in {json_path}")
    data = torch.stack(data_list)      # [N,2,6,7]
    labels = torch.tensor(label_list, dtype=torch.long)
    return TensorDataset(data, labels)


def evaluate(model, loader, device):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            logits = model(x)
            preds = logits.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return correct / total * 100


def main():
    parser = argparse.ArgumentParser(description="Supervised policy training for Connect4")
    parser.add_argument('--train-json', type=Path,
                        help="Path to train.json file")
    parser.add_argument('--test-data', type=Path,
                        help="Path to test_data.pt file for evaluation")
    parser.add_argument('--epochs', type=int, default=10,
                        help="Number of training epochs")
    parser.add_argument('--batch-size', type=int, default=64,
                        help="Batch size")
    parser.add_argument('--lr', type=float, default=1e-3,
                        help="Learning rate")
    args = parser.parse_args()

    # Locate train.json
    if args.train_json:
        train_path = args.train_json
    else:
        try:
            train_path = find_file('train.json')
        except FileNotFoundError as e:
            sys.stderr.write(str(e) + "\n")
            sys.exit(1)

    # Prepare dataset & loader
    train_ds = prepare_dataset(train_path)
    train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)

    # Optionally prepare test loader
    test_loader = None
    if args.test_data:
        chk = torch.load(args.test_data, map_location='cpu')
        test_ds = TensorDataset(chk['data'], chk['labels'])
        test_loader = DataLoader(test_ds, batch_size=args.batch_size)

    # Device and model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Connect4PolicyNet().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    criterion = nn.CrossEntropyLoss()

    # Paths for export
    models_dir = ML_ROOT / 'models'
    models_dir.mkdir(parents=True, exist_ok=True)
    ckpt_path = models_dir / 'best_policy_net.pt'
    ts_path = models_dir / 'policy_net_ts.pt'
    onnx_path = models_dir / 'policy_net.onnx'

    best_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        model.train()
        running_loss = 0.0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            logits = model(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * y.size(0)
        avg_loss = running_loss / len(train_ds)

        # Evaluate on test set if available
        if test_loader:
            acc = evaluate(model, test_loader, device)
            print(f"Epoch {epoch}: Loss={avg_loss:.4f}, TestAcc={acc:.2f}%")
            if acc > best_acc:
                best_acc = acc
                torch.save(model.state_dict(), ckpt_path)
        else:
            print(f"Epoch {epoch}: Loss={avg_loss:.4f}")
            # Save checkpoint every epoch
            torch.save(model.state_dict(), ckpt_path)

    print(f"Training complete. Best test accuracy: {best_acc:.2f}%")

    # Export TorchScript
    model_cpu = Connect4PolicyNet()
    model_cpu.load_state_dict(torch.load(ckpt_path, map_location='cpu'))
    model_cpu.eval()
    scripted = torch.jit.script(model_cpu)
    scripted.save(ts_path)
    print(f"Saved TorchScript model to {ts_path}")

    # Export ONNX
    dummy = torch.randn(1, 2, 6, 7)
    torch.onnx.export(
        model_cpu, dummy, onnx_path,
        input_names=['input'], output_names=['logits'],
        dynamic_axes={'input':{0:'batch'}, 'logits':{0:'batch'}}
    )
    print(f"Saved ONNX model to {onnx_path}")

if __name__ == '__main__':
    main()