# PGC 2025 Inference

Run inference on production CSV files using multiple models:
- WeightedCox, RankCox, FullAll, MSE, CE (Transformer-based)
- LGBM (LightGBM)
- Rule-based (squad_total_health, dist_from_whitezone_v2, dist_from_bluezone_v2)


## 1. Setup


In [0]:
%pip install tqdm lightgbm

In [0]:
import os
import sys
import json
import glob
import re
import ast
from typing import Dict, List, Tuple, Any, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import lightgbm as lgb
from scipy.special import softmax
from tqdm import tqdm

# Add deployment_v2_1 to path
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.getcwd()))
DEPLOYMENT_PATH = os.path.join(PROJECT_ROOT, "deployment_v2_1")
sys.path.insert(0, DEPLOYMENT_PATH)

from deployment_v2_1.src.data.continuous_features import CONTINUOUS_FEATURES
from deployment_v2_1.src.data.dataset import PGCDataset
from src.models import TransformerBackbone
from src.models.heads import get_head

print(f"Project root: {PROJECT_ROOT}")
print(f"Deployment path: {DEPLOYMENT_PATH}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")


In [0]:
# Configuration
DATA_PATH = "/Volumes/main_dev/dld_ml_anticheat_test/anticheat_test_volume/pgc_wwcd/pgc_features/prod/"
CHECKPOINT_DIR = "/Volumes/main_dev/dld_ml_anticheat_test/anticheat_test_volume/pgc_wwcd/pgc_results/checkpoints/"
LGBM_MODEL_PATH = "/Volumes/main_dev/dld_ml_anticheat_test/anticheat_test_volume/pgc_wwcd/pgc_results/lgbm/leaves31_lr0.1/model.txt"
OUTPUT_DIR = "/Volumes/main_dev/dld_ml_anticheat_test/anticheat_test_volume/pgc_wwcd/pgc_results/pgc2025_predictions/"

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Constants
POSITION_SCALE = 80000.0
ZONE_SCALE = 800000.0

print(f"Data path: {DATA_PATH}")
print(f"Checkpoint dir: {CHECKPOINT_DIR}")
print(f"LGBM model: {LGBM_MODEL_PATH}")
print(f"Output dir: {OUTPUT_DIR}")
print(f"Device: {DEVICE}")


In [0]:
# Load CSV files list
csv_files = sorted(glob.glob(os.path.join(DATA_PATH, "*.csv")))
print(f"Found {len(csv_files)} CSV files")

# Preview first few files
for i, f in enumerate(csv_files[:5]):
    print(f"  {i+1}. {os.path.basename(f)}")
if len(csv_files) > 5:
    print(f"  ... and {len(csv_files) - 5} more")

# Show columns of the first CSV file
if csv_files:
    import pandas as pd
    df_sample = pd.read_csv(csv_files[0], nrows=1)
    print(f"\nColumns in {os.path.basename(csv_files[0])}:")
    print(list(df_sample.columns))


In [0]:
len(set(CONTINUOUS_FEATURES) & set(list(df_sample.columns)))

## 2. Utility Functions


In [0]:
# Parsing functions
_ZONE_FLOAT_RE = re.compile(r'[-+]?(?:\d*\.\d+|\d+\.?)(?:[eE][-+]?\d+)?')


def parse_zone_info(zone_str: str) -> Tuple[float, float, float]:
    """Parse bluezone_info or whitezone_info string to (x, y, radius) tuple."""
    if pd.isna(zone_str) or zone_str == '':
        return 400000.0, 400000.0, 500000.0
    
    if isinstance(zone_str, (tuple, list)):
        if len(zone_str) >= 3:
            return float(zone_str[0]), float(zone_str[1]), float(zone_str[2])
        return 400000.0, 400000.0, 500000.0
    
    if isinstance(zone_str, str):
        cleaned = re.sub(r'np\.float\d+\s*\(\s*([^)]+)\s*\)', r'\1', zone_str)
        nums = _ZONE_FLOAT_RE.findall(cleaned)
        if len(nums) >= 3:
            return float(nums[0]), float(nums[1]), float(nums[2])
    
    return 400000.0, 400000.0, 500000.0


def parse_positions(positions_str: str) -> List[List[float]]:
    """Parse positions string to list of [x, y, z] coordinates."""
    if pd.isna(positions_str) or positions_str == '':
        return [[np.nan, np.nan, np.nan]]
    
    positions_str_clean = positions_str.replace('nan', 'None')
    positions_list = ast.literal_eval(positions_str_clean)
    
    result = []
    for pos in positions_list:
        if pos is None or any(p is None for p in pos):
            result.append([np.nan, np.nan, np.nan])
        elif pos[0] == 0 and pos[1] == 0 and pos[2] == 0:
            result.append([np.nan, np.nan, np.nan])
        else:
            result.append([float(pos[0]), float(pos[1]), float(pos[2])])
    
    return result if result else [[np.nan, np.nan, np.nan]]


def get_squad_center(positions: List[List[float]]) -> np.ndarray:
    """Get center position of alive squad members."""
    positions_arr = np.array(positions)
    alive_mask = ~np.isnan(positions_arr[:, 0])
    if alive_mask.sum() > 0:
        return np.nanmean(positions_arr[alive_mask], axis=0)
    return np.array([np.nan, np.nan, np.nan])


def compute_dist_from_zone_v2(positions_str: str, zone_str: str) -> float:
    """Compute distance from squad center to zone center / radius."""
    positions = parse_positions(positions_str)
    zone = parse_zone_info(zone_str)
    center = get_squad_center(positions)
    
    if np.isnan(center[0]) or zone[2] <= 0:
        return np.nan
    
    dist = np.sqrt((center[0] - zone[0])**2 + (center[1] - zone[1])**2)
    return dist / zone[2]


print("Utility functions defined.")


## 3. Model Loading


In [0]:
def load_transformer_model(checkpoint_path: str, device: str = "cpu") -> Tuple[nn.Module, nn.Module, Dict, Dict, float]:
    """Load Transformer model from checkpoint."""
    checkpoint_dir = os.path.dirname(checkpoint_path)
    config_path = os.path.join(os.path.dirname(checkpoint_path), "config.json")
    
    with open(config_path, "r") as f:
        config = json.load(f)
        
    input_dim = config.get("input_dim", len(CONTINUOUS_FEATURES))
    print(1)
    
    # Create models
    backbone = TransformerBackbone(
        input_dim=input_dim,
        embed_dim=config["embed_dim"],
        num_heads=config["num_heads"],
        num_layers=config["num_layers"],
        dropout=config["dropout"],
    )
    print(2)
    
    loss_type = config.get("loss_type", "weighted_cox")
    head = get_head(loss_type, embed_dim=config["embed_dim"], dropout=config["dropout"])
    print(3)

    # Load weights
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    backbone.load_state_dict(checkpoint["backbone_state_dict"])
    head.load_state_dict(checkpoint["head_state_dict"])
    scaler_params = checkpoint.get("scaler_params", None)
    print(4)
    backbone = backbone.to(device).eval()
    head = head.to(device).eval()
    print(5)
    # Load temperature
    tau_path = os.path.join(checkpoint_dir, "temperature.json")
    temperature = 1.0
    if os.path.exists(tau_path):
        with open(tau_path, "r") as f:
            tau_data = json.load(f)
        temperature = tau_data.get("optimal_tau", 1.0)
    
    return backbone, head, scaler_params, config, temperature


def load_lgbm_model(model_path: str) -> Tuple[lgb.Booster, Dict]:
    """Load LightGBM model and its config."""
    model = lgb.Booster(model_file=model_path)
    
    config_path = os.path.join(os.path.dirname(model_path), "config.json")
    config = {}
    if os.path.exists(config_path):
        with open(config_path, "r") as f:
            config = json.load(f)
    
    return model, config


print("Model loading functions defined.")


In [0]:
# Specific checkpoint subdirectory names for each model
CHECKPOINT_PATHS = {
    # "MSE": "emb256_head8_layer4_drop0.1_lr1e-4_mse",
    # "FullAll": "run21_emb1024_head8_layer6_drop0.1_lr1e-4_cox",
    # "RankCox": "run22_emb512_head4_layer6_drop0.1_lr1e-4_rank_cox",
    "WeightedCox": "final_model",
    # "CE": "run25_emb512_head4_layer6_drop0.1_lr1e-4_survival_ce",
}

# Find checkpoints for each model
model_checkpoints = {}
for model_name, subdir in CHECKPOINT_PATHS.items():
    best_path = os.path.join(CHECKPOINT_DIR, subdir, "best.pt")
    if os.path.exists(best_path):
        model_checkpoints[model_name] = best_path
        print(f"{model_name}: {subdir}")
    else:
        print(f"{model_name}: NOT FOUND ({best_path})")

print(f"\nFound {len(model_checkpoints)} Transformer model checkpoints")


## 4. Inference Functions


In [0]:
def run_transformer_inference(
    csv_path: str,
    backbone: nn.Module,
    head: nn.Module,
    scaler_params: Dict,
    config: Dict,
    device: str = "cpu",
    temperature: float = 1.0,
) -> Dict[str, Any]:
    """Run Transformer model inference on a single CSV file."""
    use_dataset_v2 = config.get("use_dataset_v2", True)
    
    csv_path = os.path.abspath(csv_path)
    folder_path = os.path.dirname(csv_path)
    file_name = os.path.basename(csv_path)
    
    if use_dataset_v2:
        dataset = PGCDataset(
            folder_path=folder_path,
            file_list=[file_name],
            continuous_features=CONTINUOUS_FEATURES,
            scaler_params=scaler_params,
        )
    else:
        dataset = PGCDataset(
            folder_path=folder_path,
            file_list=[file_name],
            continuous_features=CONTINUOUS_FEATURES,
            scaler_params=scaler_params,
        )
    
    raw_df = pd.read_csv(csv_path)
    raw_df['time_point_key'] = raw_df['time_point'].round(6)
    time_point_groups = {}
    for (mid, tp_key), group in raw_df.groupby(['match_id', 'time_point_key']):
        time_point_groups[(mid, tp_key)] = group.sort_values('squad_number')
    
    predictions = {}
    
    with torch.no_grad():
        for idx in range(len(dataset)):
            match_id, time_point = dataset.group_keys[idx]
            sample = dataset[idx]
            
            x = sample["x_continuous"].unsqueeze(0).to(device)
            positions = sample["positions"].unsqueeze(0).to(device)
            bluezone = sample["bluezone_info"].unsqueeze(0).to(device)
            whitezone = sample["whitezone_info"].unsqueeze(0).to(device)
            map_idx = sample["map_idx"].unsqueeze(0).to(device)
            num_squads = sample["num_squads"]
            
            time_point_key = round(time_point, 6)
            time_data = time_point_groups.get((match_id, time_point_key))
            if time_data is None:
                continue
            
            time_data = time_data.head(num_squads)
            squad_numbers = time_data['squad_number'].tolist()
            
            alive_cnt = [int(x) for x in time_data['squad_alive_count'].tolist()] if 'squad_alive_count' in time_data.columns else [None] * len(squad_numbers)
            hp = [int(x) for x in time_data['squad_total_health'].tolist()] if 'squad_total_health' in time_data.columns else [None] * len(squad_numbers)
            
            positions_list = []
            if 'positions' in time_data.columns:
                for pos_str in time_data['positions'].tolist():
                    positions_list.append(parse_positions(pos_str))
            
            # Build alive mask
            squad_alive_mask = torch.zeros(num_squads, dtype=torch.bool)
            for i in range(len(squad_numbers)):
                alive_count = alive_cnt[i] if alive_cnt[i] is not None else 0
                health = hp[i] if hp[i] is not None else 0
                all_positions_nan = True
                if i < len(positions_list):
                    for pos in positions_list[i]:
                        if not (np.isnan(pos[0]) and np.isnan(pos[1]) and np.isnan(pos[2])):
                            all_positions_nan = False
                            break
                squad_alive_mask[i] = (alive_count > 0) and (health > 0) and (not all_positions_nan)
            
            # Forward pass
            embeddings = backbone(x, positions, bluezone, whitezone, map_idx)
            raw_scores = head(embeddings).squeeze(0).cpu()
            scaled_scores = raw_scores / temperature
            
            masked_scores = scaled_scores.clone()
            masked_scores[~squad_alive_mask] = float("-inf")
            win_probs = torch.softmax(masked_scores[:num_squads], dim=0)
            
            # Parse zone info
            bluezone_parsed = parse_zone_info(time_data['bluezone_info'].iloc[0]) if 'bluezone_info' in time_data.columns else None
            whitezone_parsed = parse_zone_info(time_data['whitezone_info'].iloc[0]) if 'whitezone_info' in time_data.columns else None
            phase = int(time_data['phase'].iloc[0]) if 'phase' in time_data.columns else None
            
            # Build probabilities dict
            probabilities = {}
            is_alive = {}
            for i, squad_num in enumerate(squad_numbers):
                if i < len(win_probs):
                    probabilities[int(squad_num)] = float(win_probs[i].item()) if squad_alive_mask[i] else 0.0
                is_alive[int(squad_num)] = bool(squad_alive_mask[i])
            
            predictions[str(time_point)] = {
                'phase': phase,
                'squad_numbers': [int(sn) for sn in squad_numbers],
                'alive_cnt': alive_cnt,
                'hp': hp,
                'bluezone_info': {'x': bluezone_parsed[0], 'y': bluezone_parsed[1], 'radius': bluezone_parsed[2]} if bluezone_parsed else None,
                'whitezone_info': {'x': whitezone_parsed[0], 'y': whitezone_parsed[1], 'radius': whitezone_parsed[2]} if whitezone_parsed else None,
                'positions': positions_list,
                'probabilities': probabilities,
                'is_alive': is_alive,
            }
    
    return {'match_id': match_id, 'predictions': predictions}


print("Transformer inference function defined.")


In [0]:
def extract_lgbm_features(df: pd.DataFrame, scaler_params: Dict, all_features: List[str]) -> pd.DataFrame:
    """Extract and scale features for LGBM model."""
    df = df.copy()
    
    # Position features
    pos_features = ['team_center_x', 'team_center_y', 'team_center_z', 
                    'team_std_x', 'team_std_y', 'team_std_z',
                    'bluezone_x', 'bluezone_y', 'whitezone_x', 'whitezone_y']
    for col in pos_features:
        df[col] = 0.0
    
    # Distance features
    df['dist_from_bluezone_v2'] = 0.0
    df['dist_from_whitezone_v2'] = 0.0
    
    for idx in df.index:
        row = df.loc[idx]
        team_center_xy = None
        
        # Parse positions
        try:
            positions = np.array(parse_positions(row['positions']))
            alive_mask = ~np.isnan(positions[:, 0])
            alive_positions = positions[alive_mask]
            
            if len(alive_positions) > 0:
                center = np.nanmean(alive_positions, axis=0)
                team_center_xy = center[:2]
                
                df.loc[idx, 'team_center_x'] = center[0] / POSITION_SCALE
                df.loc[idx, 'team_center_y'] = center[1] / POSITION_SCALE
                df.loc[idx, 'team_center_z'] = center[2] / POSITION_SCALE
                
                if len(alive_positions) > 1:
                    std = np.nanstd(alive_positions, axis=0) / POSITION_SCALE
                    df.loc[idx, 'team_std_x'] = std[0]
                    df.loc[idx, 'team_std_y'] = std[1]
                    df.loc[idx, 'team_std_z'] = std[2]
        except:
            pass
        
        # Parse zone info
        bluezone, whitezone = None, None
        try:
            bluezone = parse_zone_info(row['bluezone_info'])
            df.loc[idx, 'bluezone_x'] = bluezone[0] / ZONE_SCALE
            df.loc[idx, 'bluezone_y'] = bluezone[1] / ZONE_SCALE
        except:
            pass
        
        try:
            whitezone = parse_zone_info(row['whitezone_info'])
            df.loc[idx, 'whitezone_x'] = whitezone[0] / ZONE_SCALE
            df.loc[idx, 'whitezone_y'] = whitezone[1] / ZONE_SCALE
        except:
            pass
        
        # Compute distance features
        if team_center_xy is not None:
            if bluezone is not None and bluezone[2] > 0:
                dist_blue = np.sqrt((team_center_xy[0] - bluezone[0])**2 + (team_center_xy[1] - bluezone[1])**2)
                df.loc[idx, 'dist_from_bluezone_v2'] = dist_blue / bluezone[2]
            
            if whitezone is not None and whitezone[2] > 0:
                dist_white = np.sqrt((team_center_xy[0] - whitezone[0])**2 + (team_center_xy[1] - whitezone[1])**2)
                df.loc[idx, 'dist_from_whitezone_v2'] = dist_white / whitezone[2]
    
    # Apply scaling to continuous features
    if scaler_params:
        mean = scaler_params.get('mean', {})
        std = scaler_params.get('std', {})
        for col in CONTINUOUS_FEATURES:
            if col in df.columns and col in mean and col in std:
                df[col] = (df[col] - mean[col]) / std[col]
    
    return df


def run_lgbm_inference(csv_path: str, model: lgb.Booster, config: Dict) -> Dict[str, Any]:
    """Run LGBM model inference on a single CSV file."""
    df = pd.read_csv(csv_path)
    match_id = df['match_id'].iloc[0]
    
    all_features = config.get('features', CONTINUOUS_FEATURES)
    scaler_params = None  # LGBM handles scaling internally during training
    
    df = extract_lgbm_features(df, scaler_params, all_features)
    
    # Predict survival time (higher = longer survival)
    features_available = [f for f in all_features if f in df.columns]
    df['pred'] = model.predict(df[features_available])
    
    predictions = {}
    for time_point, tp_df in df.groupby('time_point'):
        tp_df = tp_df.sort_values('squad_number')
        squad_numbers = tp_df['squad_number'].tolist()
        
        alive_cnt = [int(x) for x in tp_df['squad_alive_count'].tolist()] if 'squad_alive_count' in tp_df.columns else [None] * len(squad_numbers)
        hp = [int(x) for x in tp_df['squad_total_health'].tolist()] if 'squad_total_health' in tp_df.columns else [None] * len(squad_numbers)
        
        positions_list = []
        if 'positions' in tp_df.columns:
            for pos_str in tp_df['positions'].tolist():
                positions_list.append(parse_positions(pos_str))
        
        # Build alive mask
        squad_alive_mask = []
        for i in range(len(squad_numbers)):
            alive_count = alive_cnt[i] if alive_cnt[i] is not None else 0
            health = hp[i] if hp[i] is not None else 0
            all_positions_nan = True
            if i < len(positions_list):
                for pos in positions_list[i]:
                    if not (np.isnan(pos[0]) and np.isnan(pos[1]) and np.isnan(pos[2])):
                        all_positions_nan = False
                        break
            squad_alive_mask.append((alive_count > 0) and (health > 0) and (not all_positions_nan))
        
        # Compute probabilities via softmax (higher pred = higher prob)
        preds = tp_df['pred'].values
        masked_preds = np.where(squad_alive_mask, preds, -np.inf)
        probs = softmax(masked_preds)
        
        probabilities = {}
        is_alive = {}
        for i, squad_num in enumerate(squad_numbers):
            probabilities[int(squad_num)] = float(probs[i]) if squad_alive_mask[i] else 0.0
            is_alive[int(squad_num)] = squad_alive_mask[i]
        
        bluezone_parsed = parse_zone_info(tp_df['bluezone_info'].iloc[0]) if 'bluezone_info' in tp_df.columns else None
        whitezone_parsed = parse_zone_info(tp_df['whitezone_info'].iloc[0]) if 'whitezone_info' in tp_df.columns else None
        phase = int(tp_df['phase'].iloc[0]) if 'phase' in tp_df.columns else None
        
        predictions[str(time_point)] = {
            'phase': phase,
            'squad_numbers': [int(sn) for sn in squad_numbers],
            'alive_cnt': alive_cnt,
            'hp': hp,
            'bluezone_info': {'x': bluezone_parsed[0], 'y': bluezone_parsed[1], 'radius': bluezone_parsed[2]} if bluezone_parsed else None,
            'whitezone_info': {'x': whitezone_parsed[0], 'y': whitezone_parsed[1], 'radius': whitezone_parsed[2]} if whitezone_parsed else None,
            'positions': positions_list,
            'probabilities': probabilities,
            'is_alive': is_alive,
        }
    
    return {'match_id': match_id, 'predictions': predictions}


print("LGBM inference function defined.")


In [0]:
def run_rule_based_inference(csv_path: str, feature_name: str) -> Dict[str, Any]:
    """
    Run rule-based inference using a single feature.
    
    Args:
        csv_path: Path to CSV file.
        feature_name: Feature to use for prediction.
            - 'squad_total_health': higher is better
            - 'dist_from_whitezone_v2': lower is better (closer to next safe zone)
            - 'dist_from_bluezone_v2': lower is better (closer to current safe zone)
    """
    df = pd.read_csv(csv_path)
    match_id = df['match_id'].iloc[0]
    
    # Compute distance features if needed
    if feature_name in ['dist_from_whitezone_v2', 'dist_from_bluezone_v2']:
        for idx in df.index:
            row = df.loc[idx]
            if feature_name == 'dist_from_bluezone_v2':
                df.loc[idx, 'dist_from_bluezone_v2'] = compute_dist_from_zone_v2(
                    row['positions'], row['bluezone_info']
                )
            elif feature_name == 'dist_from_whitezone_v2':
                df.loc[idx, 'dist_from_whitezone_v2'] = compute_dist_from_zone_v2(
                    row['positions'], row['whitezone_info']
                )
    
    predictions = {}
    for time_point, tp_df in df.groupby('time_point'):
        tp_df = tp_df.sort_values('squad_number')
        squad_numbers = tp_df['squad_number'].tolist()
        
        alive_cnt = [int(x) for x in tp_df['squad_alive_count'].tolist()] if 'squad_alive_count' in tp_df.columns else [None] * len(squad_numbers)
        hp = [int(x) for x in tp_df['squad_total_health'].tolist()] if 'squad_total_health' in tp_df.columns else [None] * len(squad_numbers)
        
        positions_list = []
        if 'positions' in tp_df.columns:
            for pos_str in tp_df['positions'].tolist():
                positions_list.append(parse_positions(pos_str))
        
        # Build alive mask
        squad_alive_mask = []
        for i in range(len(squad_numbers)):
            alive_count = alive_cnt[i] if alive_cnt[i] is not None else 0
            health = hp[i] if hp[i] is not None else 0
            all_positions_nan = True
            if i < len(positions_list):
                for pos in positions_list[i]:
                    if not (np.isnan(pos[0]) and np.isnan(pos[1]) and np.isnan(pos[2])):
                        all_positions_nan = False
                        break
            squad_alive_mask.append((alive_count > 0) and (health > 0) and (not all_positions_nan))
        
        # Get feature values and apply rule
        feature_values = tp_df[feature_name].values.copy()
        
        # For distance features, lower is better (use negative)
        # For squad_total_health, higher is better (keep as is)
        if feature_name in ['dist_from_whitezone_v2', 'dist_from_bluezone_v2']:
            feature_values = -feature_values  # Negate so higher = better
        
        # Handle NaN values
        feature_values = np.nan_to_num(feature_values, nan=-np.inf)
        
        # Mask dead squads
        masked_values = np.where(squad_alive_mask, feature_values, -np.inf)
        probs = softmax(masked_values)
        
        probabilities = {}
        is_alive = {}
        for i, squad_num in enumerate(squad_numbers):
            probabilities[int(squad_num)] = float(probs[i]) if squad_alive_mask[i] else 0.0
            is_alive[int(squad_num)] = squad_alive_mask[i]
        
        bluezone_parsed = parse_zone_info(tp_df['bluezone_info'].iloc[0]) if 'bluezone_info' in tp_df.columns else None
        whitezone_parsed = parse_zone_info(tp_df['whitezone_info'].iloc[0]) if 'whitezone_info' in tp_df.columns else None
        phase = int(tp_df['phase'].iloc[0]) if 'phase' in tp_df.columns else None
        
        predictions[str(time_point)] = {
            'phase': phase,
            'squad_numbers': [int(sn) for sn in squad_numbers],
            'alive_cnt': alive_cnt,
            'hp': hp,
            'bluezone_info': {'x': bluezone_parsed[0], 'y': bluezone_parsed[1], 'radius': bluezone_parsed[2]} if bluezone_parsed else None,
            'whitezone_info': {'x': whitezone_parsed[0], 'y': whitezone_parsed[1], 'radius': whitezone_parsed[2]} if whitezone_parsed else None,
            'positions': positions_list,
            'probabilities': probabilities,
            'is_alive': is_alive,
        }
    
    return {'match_id': match_id, 'predictions': predictions}


print("Rule-based inference function defined.")


## 5. JSON Saving


In [0]:
def save_predictions_to_json(result: Dict[str, Any], output_dir: str, model_name: str) -> str:
    """
    Save predictions to JSON file.
    
    Args:
        result: Dict with 'match_id' and 'predictions' keys.
        output_dir: Output directory path.
        model_name: Model name for filename.
    
    Returns:
        Path to saved JSON file.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    match_id = result['match_id']
    json_filename = f"{match_id}_{model_name}.json"
    json_path = os.path.join(output_dir, json_filename)
    
    output_data = {
        'match_id': match_id,
        'model_name': model_name,
        'predictions': result['predictions'],
    }
    
    with open(json_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)
    
    return json_path


print("JSON saving function defined.")


## 6. Main Execution Loop


In [0]:
# Load all Transformer models
transformer_models = {}
for model_name, checkpoint_path in model_checkpoints.items():
    print(f"Loading {model_name}...")
    backbone, head, scaler_params, config, temperature = load_transformer_model(checkpoint_path, DEVICE)
    transformer_models[model_name] = {
        'backbone': backbone,
        'head': head,
        'scaler_params': scaler_params,
        'config': config,
        'temperature': temperature,
    }
    print(f"  Loaded: embed_dim={config['embed_dim']}, temperature={temperature:.4f}")

print(f"\nLoaded {len(transformer_models)} Transformer models")


In [0]:
# Load LGBM model
lgbm_model, lgbm_config = None, {}
if os.path.exists(LGBM_MODEL_PATH):
    print("Loading LGBM model...")
    lgbm_model, lgbm_config = load_lgbm_model(LGBM_MODEL_PATH)
    print(f"  Features: {len(lgbm_config.get('features', []))}")
else:
    print(f"LGBM model not found at: {LGBM_MODEL_PATH}")


In [0]:
def run_all_inference(csv_files: List[str], output_dir: str):
    """Run inference with all models on all CSV files."""
    os.makedirs(output_dir, exist_ok=True)
    
    # Define all models to run
    all_models = list(transformer_models.keys())
    if lgbm_model is not None:
        all_models.append('LGBM')
    all_models.extend(['RuleBased_HP', 'RuleBased_WhiteZone', 'RuleBased_BlueZone'])
    
    print(f"Models to run: {all_models}")
    print(f"CSV files: {len(csv_files)}")
    print(f"Output dir: {output_dir}")
    print("=" * 60)
    
    for csv_path in tqdm(csv_files, desc="Processing CSV files"):
        csv_basename = os.path.basename(csv_path)
        
        # Transformer models
        for model_name, model_data in transformer_models.items():
            try:
                result = run_transformer_inference(
                    csv_path=csv_path,
                    backbone=model_data['backbone'],
                    head=model_data['head'],
                    scaler_params=model_data['scaler_params'],
                    config=model_data['config'],
                    device=DEVICE,
                    temperature=model_data['temperature'],
                )
                save_predictions_to_json(result, output_dir, model_name)
            except Exception as e:
                print(f"  Error {model_name} on {csv_basename}: {e}")
        
        # # LGBM model
        # if lgbm_model is not None:
        #     try:
        #         result = run_lgbm_inference(csv_path, lgbm_model, lgbm_config)
        #         save_predictions_to_json(result, output_dir, 'LGBM')
        #     except Exception as e:
        #         print(f"  Error LGBM on {csv_basename}: {e}")
        
        # Rule-based models
        rule_based_configs = [
            ('RuleBased_HP', 'squad_total_health'),
            ('RuleBased_WhiteZone', 'dist_from_whitezone_v2'),
            ('RuleBased_BlueZone', 'dist_from_bluezone_v2'),
        ]
        
        for model_name, feature_name in rule_based_configs:
            try:
                result = run_rule_based_inference(csv_path, feature_name)
                save_predictions_to_json(result, output_dir, model_name)
            except Exception as e:
                print(f"  Error {model_name} on {csv_basename}: {e}")
    
    print("=" * 60)
    print("Inference complete!")
    print(f"Output directory: {output_dir}")


print("Main execution function defined.")


In [0]:
# Run inference on all CSV files
run_all_inference(csv_files, OUTPUT_DIR)


## 7. Verify Output


In [0]:
# List generated JSON files
output_json_files = sorted(glob.glob(os.path.join(OUTPUT_DIR, "*.json")))
print(f"Total JSON files generated: {len(output_json_files)}")

# Group by model
model_counts = {}
for f in output_json_files:
    basename = os.path.basename(f)
    model_name = basename.rsplit('_', 1)[-1].replace('.json', '')
    model_counts[model_name] = model_counts.get(model_name, 0) + 1

print("\nFiles per model:")
for model_name, count in sorted(model_counts.items()):
    print(f"  {model_name}: {count}")


In [0]:
# Examine a sample JSON file
if output_json_files:
    sample_json_path = output_json_files[0]
    with open(sample_json_path, 'r') as f:
        sample_data = json.load(f)
    
    print(f"Sample JSON: {os.path.basename(sample_json_path)}")
    print(f"Match ID: {sample_data['match_id']}")
    print(f"Model: {sample_data['model_name']}")
    print(f"Time points: {len(sample_data['predictions'])}")
    
    # Show first time point structure
    first_tp = list(sample_data['predictions'].keys())[0]
    tp_data = sample_data['predictions'][first_tp]
    print(f"\nFirst time point ({first_tp}):")
    print(f"  Phase: {tp_data['phase']}")
    print(f"  Squads: {len(tp_data['squad_numbers'])}")
    print(f"  Alive squads: {sum(tp_data['is_alive'].values())}")
    print(f"  Keys: {list(tp_data.keys())}")
