# Batch Inference with JSON Output

This notebook runs inference on multiple CSV files and saves detailed results including win probabilities and squad state information for each time point.

In [0]:
%restart_python

In [0]:
%pip install tqdm

In [0]:
import os
import sys
import json
import glob
from tqdm import tqdm

# Add project root to path
PROJECT_ROOT = os.path.dirname(os.getcwd())
sys.path.insert(0, PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")

## 1. Configuration

Set the paths to your checkpoint and test CSV file.


In [0]:
# Configuration
CSV_DIR = "/Volumes/main_dev/dld_ml_anticheat_test/anticheat_test_volume/pgc_wwcd/pgc_features/replay_single_sample_v2/"
CHECKPOINT_PATH = "/Volumes/main_dev/dld_ml_anticheat_test/anticheat_test_volume/pgc_wwcd/pgc_results/checkpoints/run22_emb1024_head2_layer6_drop0.1_lr1e-4_weighted_cox_v11/best.pt"
OUTPUT_DIR = "/Volumes/main_dev/dld_ml_anticheat_test/anticheat_test_volume/pgc_wwcd/pgc_results/predictions_weighted_cox_v2/"

# Check CUDA availability
import torch
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Find all CSV files
csv_files = sorted(glob.glob(os.path.join(CSV_DIR, "*.csv")))

print(f"CSV directory: {CSV_DIR}")
print(f"  Exists: {os.path.exists(CSV_DIR)}")
print(f"  CSV files found: {len(csv_files)}")
print(f"\nCheckpoint: {CHECKPOINT_PATH}")
print(f"  Exists: {os.path.exists(CHECKPOINT_PATH)}")
print(f"\nOutput dir: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")


## 2. Preview CSV Files


In [0]:
# Preview CSV files
print("CSV files to process:")
for i, csv_file in enumerate(csv_files[:10]):
    print(f"  {i+1}. {os.path.basename(csv_file)}")

if len(csv_files) > 10:
    print(f"  ... and {len(csv_files) - 10} more files")


## 3. Run Batch Inference


In [0]:
from src.training.inference_with_json import run_inference_and_save_json

# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Run inference on all CSV files
all_results = []
failed_files = []

for csv_file in tqdm(csv_files, desc="Processing CSV files"):
    csv_basename = os.path.basename(csv_file)
    
    # Check if already processed (JSON exists)
    json_pattern = os.path.join(OUTPUT_DIR, f"predictions_{os.path.splitext(csv_basename)[0]}_*.json")
    existing_jsons = glob.glob(json_pattern)
    
    if existing_jsons:
        print(f"  Skipping {csv_basename} (already processed)")
        continue
    
    result = run_inference_and_save_json(
        csv_path=csv_file,
        checkpoint_path=CHECKPOINT_PATH,
        output_dir=OUTPUT_DIR,
        device=DEVICE,
        temperature=None,
    )
    all_results.append(result)

print(f"\nCompleted: {len(all_results)} files processed")
if failed_files:
    print(f"Failed: {len(failed_files)} files")
    for f in failed_files:
        print(f"  - {f}")


In [0]:
# List generated JSON files
json_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.json")))
print(f"Total JSON files in output directory: {len(json_files)}")

for i, json_file in enumerate(json_files[:10]):
    print(f"  {i+1}. {os.path.basename(json_file)}")

if len(json_files) > 10:
    print(f"  ... and {len(json_files) - 10} more files")


In [0]:
# Examine a sample JSON file
if json_files:
    sample_json_path = json_files[0]
    with open(sample_json_path, 'r') as f:
        sample_data = json.load(f)
    
    print(f"Sample JSON: {os.path.basename(sample_json_path)}")
    print(f"\nMatch Info:")
    for key, value in sample_data['match_info'].items():
        print(f"  {key}: {value}")
    
    # Get first time point data
    match_id = list(sample_data['predictions'].keys())[0]
    time_points = sorted(sample_data['predictions'][match_id].keys(), key=float)
    first_tp = time_points[0]
    tp_data = sample_data['predictions'][match_id][first_tp]
    
    print(f"\nFirst Time Point ({first_tp}):")
    print(f"  Phase: {tp_data['phase']}")
    print(f"  Squads: {len(tp_data['squad_numbers'])}")
    print(f"  Alive counts: {tp_data['alive_cnt']}")
    print(f"  HP values: {tp_data['hp']}")


In [0]:
# Show probabilities for the sample
if json_files:
    print(f"\nProbabilities at first time point:")
    for squad_num, prob in sorted(tp_data['probabilities'].items(), key=lambda x: -x[1]):
        is_alive = tp_data['is_alive'].get(str(squad_num), tp_data['is_alive'].get(int(squad_num), False))
        print(f"  Squad {squad_num}: {prob:.4f} (alive: {is_alive})")


## 4. Summary Statistics


In [0]:
# Summary statistics across all JSON files
if json_files:
    total_matches = len(json_files)
    maps_count = {}
    winner_correct = 0
    
    for json_file in json_files:
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        map_name = data['match_info'].get('map', 'unknown')
        maps_count[map_name] = maps_count.get(map_name, 0) + 1
    
    print(f"Total matches processed: {total_matches}")
    print(f"\nMatches by map:")
    for map_name, count in sorted(maps_count.items(), key=lambda x: -x[1]):
        print(f"  {map_name}: {count}")


In [0]:
# Show output directory contents
print(f"\nOutput directory: {OUTPUT_DIR}")
print(f"Total files: {len(os.listdir(OUTPUT_DIR))}")


In [0]:
print("=" * 60)
print("BATCH INFERENCE COMPLETE")
print("=" * 60)
print(f"CSV files processed: {len(csv_files)}")
print(f"JSON files generated: {len(json_files)}")
print(f"Output directory: {OUTPUT_DIR}")
print("=" * 60)
