In [5]:
#!/usr/bin/env python3
# ===============================================
# PP_14 - PLAYER EMBEDDINGS (GNN ULTIMATE GOD SOTA v2)
# VERSION CORRIG√âE - Fix NaN + Stabilit√©
# ===============================================
#
# CORRECTIONS v2:
# ‚úÖ Learning rate r√©duit (0.0001)
# ‚úÖ Edge weights clamp√©s (min 1e-4)
# ‚úÖ NaN detection + handling
# ‚úÖ Gradient clipping plus strict
# ‚úÖ Warmup learning rate
# ‚úÖ Initialization plus conservatrice
#
# Output: features/player_embeddings/
# ===============================================

import numpy as np
import polars as pl
from pathlib import Path
from datetime import datetime
from collections import defaultdict
import warnings
warnings.filterwarnings("ignore")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

try:
    from sklearn.metrics import roc_auc_score
    HAS_SKLEARN = True
except ImportError:
    HAS_SKLEARN = False

# ===============================================
# CONFIGURATION
# ===============================================
ROOT = Path(r"C:\Users\Administrateur\Tennis POLAR v2")
DATA_CLEAN = ROOT / "data_clean"
MATCHES_BASE = DATA_CLEAN / "matches_base"
OUTPUT_DIR = DATA_CLEAN / "features" / "player_embeddings"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Model parameters
EMBEDDING_DIM = 64
HIDDEN_DIM = 128
NUM_LAYERS = 2
DROPOUT = 0.2

# Training parameters - AJUST√âS POUR STABILIT√â
EPOCHS = 200
LEARNING_RATE = 0.0001      # ‚úÖ R√©duit de 0.001 √† 0.0001
SAMPLES_PER_EPOCH = 200_000  # ‚úÖ R√©duit pour plus de stabilit√©
WEIGHT_DECAY = 1e-5
PATIENCE = 40               # ‚úÖ Plus de patience
WARMUP_EPOCHS = 10          # ‚úÖ Warmup

# Temporal split
TRAIN_RATIO = 0.9

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("=" * 70)
print("   PP_14 - PLAYER EMBEDDINGS (GNN ULTIMATE GOD SOTA v2)")
print("=" * 70)
print(f"   {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"   Device: {DEVICE}")
print(f"   Architecture: GCN + Ranking + STABILIZED")
print(f"   Learning rate: {LEARNING_RATE}")
print("=" * 70)


# ===============================================
# GCN MODEL (avec initialisation conservatrice)
# ===============================================

class PlayerGCN(nn.Module):
    def __init__(self, num_players: int, in_channels: int, hidden_channels: int, 
                 out_channels: int, num_layers: int = 2, dropout: float = 0.2):
        super().__init__()
        
        self.dropout = dropout
        
        # ‚úÖ Initialization plus petite
        self.player_emb = nn.Embedding(num_players, hidden_channels // 2)
        nn.init.normal_(self.player_emb.weight, mean=0, std=0.01)  # Petit std
        
        self.input_proj = nn.Linear(in_channels + hidden_channels // 2, hidden_channels)
        nn.init.xavier_uniform_(self.input_proj.weight, gain=0.5)
        nn.init.zeros_(self.input_proj.bias)
        
        self.convs = nn.ModuleList()
        self.norms = nn.ModuleList()
        for _ in range(num_layers):
            conv = GCNConv(hidden_channels, hidden_channels)
            self.convs.append(conv)
            self.norms.append(nn.LayerNorm(hidden_channels))
        
        self.output_proj = nn.Linear(hidden_channels, out_channels)
        nn.init.xavier_uniform_(self.output_proj.weight, gain=0.5)
        nn.init.zeros_(self.output_proj.bias)
    
    def forward(self, x, edge_index, edge_weight, player_indices):
        player_emb = self.player_emb(player_indices)
        x = torch.cat([x, player_emb], dim=-1)
        
        x = self.input_proj(x)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        
        for conv, norm in zip(self.convs, self.norms):
            x_res = x
            x = conv(x, edge_index, edge_weight)
            x = norm(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = x + x_res
        
        out = self.output_proj(x)
        
        # ‚úÖ Clamp output pour √©viter valeurs extr√™mes
        out = torch.clamp(out, -10, 10)
        
        return out


class RankingPredictor(nn.Module):
    def __init__(self, embedding_dim: int):
        super().__init__()
        self.score_net = nn.Sequential(
            nn.Linear(embedding_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(64, 32),
            nn.ReLU(),
            nn.Linear(32, 1)
        )
        
        # ‚úÖ Initialization conservatrice
        for m in self.score_net:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                nn.init.zeros_(m.bias)
    
    def forward(self, emb_a, emb_b):
        score_a = self.score_net(emb_a).squeeze(-1)
        score_b = self.score_net(emb_b).squeeze(-1)
        logit = score_a - score_b
        # ‚úÖ Clamp logits pour stabilit√©
        return torch.clamp(logit, -20, 20)


# ===============================================
# DATA PREPARATION
# ===============================================

def prepare_data_zero_leakage(matches_df: pl.DataFrame, min_matches: int = 10):
    print("\n[1/4] Splitting data chronologically...")
    
    matches_sorted = matches_df.sort("tourney_date_ta")
    matches_list = [
        m for m in matches_sorted.to_dicts()
        if m["winner_id"] and m["loser_id"] and m["tourney_date_ta"]
    ]
    
    split_idx = int(len(matches_list) * TRAIN_RATIO)
    train_matches_raw = matches_list[:split_idx]
    val_matches_raw = matches_list[split_idx:]
    
    print(f"  Total: {len(matches_list):,} | Train: {len(train_matches_raw):,} | Val: {len(val_matches_raw):,}")
    
    # ===== PLAYERS FROM TRAIN =====
    print("\n[2/4] Building player set...")
    
    player_count = defaultdict(int)
    for m in train_matches_raw:
        player_count[m["winner_id"]] += 1
        player_count[m["loser_id"]] += 1
    
    valid_players = {p for p, c in player_count.items() if c >= min_matches}
    player_to_idx = {p: i for i, p in enumerate(sorted(valid_players))}
    idx_to_player = {i: p for p, i in player_to_idx.items()}
    num_players = len(valid_players)
    
    print(f"  Players: {num_players:,}")
    
    # ===== GRAPH WITH CLAMPED WEIGHTS =====
    print("\n[3/4] Building graph (clamped weights)...")
    
    edge_list = []
    edge_weights = []
    max_train_date = max(m["tourney_date_ta"] for m in train_matches_raw)
    
    for m in train_matches_raw:
        w_id, l_id = m["winner_id"], m["loser_id"]
        if w_id not in valid_players or l_id not in valid_players:
            continue
        
        w_idx, l_idx = player_to_idx[w_id], player_to_idx[l_id]
        
        days_ago = (max_train_date - m["tourney_date_ta"]).days
        # ‚úÖ Clamp pour √©viter exp(-inf) et valeurs trop petites
        days_ago = min(days_ago, 3650)  # Max 10 ans
        weight = np.exp(-days_ago / 3650)
        weight = max(weight, 1e-3)  # ‚úÖ Minimum weight
        
        edge_list.extend([[w_idx, l_idx], [l_idx, w_idx]])
        edge_weights.extend([weight, weight])
    
    edge_index = torch.tensor(edge_list, dtype=torch.long).t().contiguous()
    edge_weight = torch.tensor(edge_weights, dtype=torch.float)
    
    # ‚úÖ Normalize edge weights
    edge_weight = edge_weight / edge_weight.max()
    
    print(f"  Edges: {edge_index.shape[1]:,}")
    print(f"  Weight range: [{edge_weight.min():.4f}, {edge_weight.max():.4f}]")
    
    # ===== FEATURES =====
    print("\n[4/4] Computing features...")
    
    player_stats = defaultdict(lambda: {
        "wins": 0, "losses": 0, "ace_rate": [], "first_won": [],
        "surfaces": defaultdict(int)
    })
    
    for m in train_matches_raw:
        w_id, l_id = m["winner_id"], m["loser_id"]
        surface = m.get("tourney_surface_ta") or m.get("surface", "Hard")
        
        for pid, is_winner in [(w_id, True), (l_id, False)]:
            if pid not in valid_players:
                continue
            stats = player_stats[pid]
            stats["wins" if is_winner else "losses"] += 1
            stats["surfaces"][surface] += 1
            
            prefix = "w_" if is_winner else "l_"
            if m.get(f"{prefix}s_ace_p") is not None:
                stats["ace_rate"].append(m[f"{prefix}s_ace_p"])
            if m.get(f"{prefix}s_1stWon_p") is not None:
                stats["first_won"].append(m[f"{prefix}s_1stWon_p"])
    
    X = np.zeros((num_players, 8), dtype=np.float32)
    for pid, idx in player_to_idx.items():
        s = player_stats[pid]
        total = s["wins"] + s["losses"]
        if total > 0:
            X[idx, 0] = s["wins"] / total
            X[idx, 1] = np.log1p(total)
            X[idx, 2] = np.mean(s["ace_rate"]) if s["ace_rate"] else 0.05
            X[idx, 3] = np.mean(s["first_won"]) if s["first_won"] else 0.65
            X[idx, 4] = s["surfaces"].get("Hard", 0) / total
            X[idx, 5] = s["surfaces"].get("Clay", 0) / total
            X[idx, 6] = s["surfaces"].get("Grass", 0) / total
            X[idx, 7] = s["surfaces"].get("Carpet", 0) / total
    
    # ‚úÖ Robust normalization
    for i in range(X.shape[1]):
        col = X[:, i]
        std = col.std()
        if std > 1e-6:
            X[:, i] = (col - col.mean()) / std
        else:
            X[:, i] = 0  # Constant column
    
    # ‚úÖ Clip extreme values
    X = np.clip(X, -5, 5)
    
    node_features = torch.tensor(X, dtype=torch.float)
    print(f"  Features shape: {node_features.shape}")
    
    # ===== LABELS =====
    def make_label_arrays(matches):
        labels_list = []
        for m in matches:
            w_id, l_id = m["winner_id"], m["loser_id"]
            if w_id in valid_players and l_id in valid_players:
                w_idx, l_idx = player_to_idx[w_id], player_to_idx[l_id]
                labels_list.append((w_idx, l_idx, 1))
                labels_list.append((l_idx, w_idx, 0))
        
        n = len(labels_list)
        if n == 0:
            return np.array([], dtype=np.int64), np.array([], dtype=np.int64), np.array([], dtype=np.int64)
        
        arr_a = np.fromiter((x[0] for x in labels_list), dtype=np.int64, count=n)
        arr_b = np.fromiter((x[1] for x in labels_list), dtype=np.int64, count=n)
        arr_y = np.fromiter((x[2] for x in labels_list), dtype=np.int64, count=n)
        
        return arr_a, arr_b, arr_y
    
    train_a, train_b, train_y = make_label_arrays(train_matches_raw)
    val_a, val_b, val_y = make_label_arrays(val_matches_raw)
    
    print(f"\n  Train labels: {len(train_a):,}")
    print(f"  Val labels: {len(val_a):,}")
    
    return {
        "edge_index": edge_index,
        "edge_weight": edge_weight,
        "node_features": node_features,
        "player_to_idx": player_to_idx,
        "idx_to_player": idx_to_player,
        "num_players": num_players,
        "train_a": train_a, "train_b": train_b, "train_y": train_y,
        "val_a": val_a, "val_b": val_b, "val_y": val_y,
    }


# ===============================================
# TRAINING (STABILIZED)
# ===============================================

def compute_auc(y_true, y_proba):
    if HAS_SKLEARN:
        try:
            return roc_auc_score(y_true, y_proba)
        except:
            return 0.5
    return 0.5


def train_gnn(data):
    print("\n" + "=" * 50)
    print("  TRAINING (STABILIZED)")
    print("=" * 50)
    
    edge_index = data["edge_index"].to(DEVICE)
    edge_weight = data["edge_weight"].to(DEVICE)
    node_features = data["node_features"].to(DEVICE)
    num_players = data["num_players"]
    player_indices = torch.arange(num_players, device=DEVICE)
    
    train_a, train_b, train_y = data["train_a"], data["train_b"], data["train_y"]
    val_a, val_b, val_y = data["val_a"], data["val_b"], data["val_y"]
    
    n_train = len(train_a)
    n_val = len(val_a)
    
    samples_per_epoch = min(SAMPLES_PER_EPOCH, n_train)
    replace = samples_per_epoch >= n_train
    
    print(f"  Sampling: {samples_per_epoch:,} / {n_train:,}")
    
    # Models
    gcn = PlayerGCN(
        num_players=num_players,
        in_channels=node_features.shape[1],
        hidden_channels=HIDDEN_DIM,
        out_channels=EMBEDDING_DIM,
        num_layers=NUM_LAYERS,
        dropout=DROPOUT
    ).to(DEVICE)
    
    predictor = RankingPredictor(EMBEDDING_DIM).to(DEVICE)
    
    criterion = nn.BCEWithLogitsLoss()
    params = list(gcn.parameters()) + list(predictor.parameters())
    
    # ‚úÖ Optimizer avec LR plus bas
    optimizer = torch.optim.AdamW(params, lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
    
    # ‚úÖ Scheduler avec warmup
    def lr_lambda(epoch):
        if epoch < WARMUP_EPOCHS:
            return (epoch + 1) / WARMUP_EPOCHS
        return 1.0
    
    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
    
    print(f"  Params: {sum(p.numel() for p in params):,}")
    print(f"  Warmup epochs: {WARMUP_EPOCHS}")
    
    best_val_auc = 0
    patience_counter = 0
    best_state = None
    nan_count = 0
    
    for epoch in range(EPOCHS):
        # ===== TRAINING =====
        gcn.train()
        predictor.train()
        
        sample_idx = np.random.choice(n_train, size=samples_per_epoch, replace=replace)
        
        batch_a = torch.tensor(train_a[sample_idx], dtype=torch.long, device=DEVICE)
        batch_b = torch.tensor(train_b[sample_idx], dtype=torch.long, device=DEVICE)
        batch_y = torch.tensor(train_y[sample_idx], dtype=torch.float, device=DEVICE)
        
        embeddings = gcn(node_features, edge_index, edge_weight, player_indices)
        
        # ‚úÖ Check for NaN in embeddings
        if torch.isnan(embeddings).any():
            nan_count += 1
            if nan_count > 5:
                print(f"  ‚ö†Ô∏è Too many NaN embeddings, stopping training")
                break
            print(f"  ‚ö†Ô∏è NaN in embeddings at epoch {epoch+1}, skipping...")
            continue
        
        emb_a = embeddings[batch_a]
        emb_b = embeddings[batch_b]
        
        logits = predictor(emb_a, emb_b)
        
        # ‚úÖ Check for NaN in logits
        if torch.isnan(logits).any():
            nan_count += 1
            if nan_count > 5:
                print(f"  ‚ö†Ô∏è Too many NaN logits, stopping training")
                break
            print(f"  ‚ö†Ô∏è NaN in logits at epoch {epoch+1}, skipping...")
            continue
        
        loss = criterion(logits, batch_y)
        
        # ‚úÖ Check for NaN loss
        if torch.isnan(loss):
            nan_count += 1
            if nan_count > 5:
                print(f"  ‚ö†Ô∏è Too many NaN losses, stopping training")
                break
            print(f"  ‚ö†Ô∏è NaN loss at epoch {epoch+1}, skipping...")
            continue
        
        optimizer.zero_grad()
        loss.backward()
        
        # ‚úÖ Gradient clipping plus strict
        torch.nn.utils.clip_grad_norm_(params, 0.5)
        
        optimizer.step()
        scheduler.step()
        
        train_loss = loss.item()
        train_preds = (logits > 0).float()
        train_acc = (train_preds == batch_y).float().mean().item()
        
        # ===== VALIDATION =====
        gcn.eval()
        predictor.eval()
        
        with torch.no_grad():
            embeddings = gcn(node_features, edge_index, edge_weight, player_indices)
            
            if torch.isnan(embeddings).any():
                print(f"  ‚ö†Ô∏è NaN in val embeddings at epoch {epoch+1}")
                continue
            
            all_logits = []
            chunk_size = 50_000
            
            for i in range(0, n_val, chunk_size):
                va = torch.tensor(val_a[i:i+chunk_size], dtype=torch.long, device=DEVICE)
                vb = torch.tensor(val_b[i:i+chunk_size], dtype=torch.long, device=DEVICE)
                
                logits_chunk = predictor(embeddings[va], embeddings[vb])
                all_logits.append(logits_chunk.cpu())
            
            all_logits = torch.cat(all_logits)
            
            if torch.isnan(all_logits).any():
                print(f"  ‚ö†Ô∏è NaN in val logits at epoch {epoch+1}")
                continue
            
            val_y_tensor = torch.tensor(val_y, dtype=torch.float)
            
            val_loss = criterion(all_logits, val_y_tensor).item()
            val_preds = (all_logits > 0).float()
            val_acc = (val_preds == val_y_tensor).float().mean().item()
            
            val_proba = torch.sigmoid(all_logits).numpy()
            val_auc = compute_auc(val_y, val_proba)
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            lr = optimizer.param_groups[0]['lr']
            print(f"  Epoch {epoch+1:3d}: loss={train_loss:.4f}, acc={train_acc:.4f} | val_loss={val_loss:.4f}, val_acc={val_acc:.4f}, AUC={val_auc:.4f}, lr={lr:.6f}")
        
        # Early stopping
        if val_auc > best_val_auc + 0.001:  # Am√©lioration significative
            best_val_auc = val_auc
            patience_counter = 0
            best_state = {'gcn': gcn.state_dict(), 'predictor': predictor.state_dict()}
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print(f"  Early stopping at epoch {epoch+1}")
                break
    
    if best_state:
        gcn.load_state_dict(best_state['gcn'])
        predictor.load_state_dict(best_state['predictor'])
    
    print(f"\n  ‚úÖ Best val AUC: {best_val_auc:.4f}")
    
    return gcn, predictor


# ===============================================
# EXTRACT EMBEDDINGS
# ===============================================

def extract_and_create_features(gcn, data, matches_df):
    print("\n[Embeddings] Extracting...")
    
    gcn.eval()
    
    edge_index = data["edge_index"].to(DEVICE)
    edge_weight = data["edge_weight"].to(DEVICE)
    node_features = data["node_features"].to(DEVICE)
    num_players = data["num_players"]
    idx_to_player = data["idx_to_player"]
    player_indices = torch.arange(num_players, device=DEVICE)
    
    with torch.no_grad():
        embeddings = gcn(node_features, edge_index, edge_weight, player_indices)
        embeddings_np = embeddings.cpu().numpy()
    
    # ‚úÖ Replace NaN with zeros
    embeddings_np = np.nan_to_num(embeddings_np, nan=0.0)
    
    player_embeddings = {idx_to_player[i]: embeddings_np[i] for i in range(num_players)}
    print(f"  Players: {len(player_embeddings):,}")
    
    print("\n[Features] Creating match-level features...")
    
    results = []
    for row in matches_df.iter_rows(named=True):
        result = {"custom_match_id": row["custom_match_id"]}
        
        emb_w = player_embeddings.get(row["winner_id"])
        emb_l = player_embeddings.get(row["loser_id"])
        
        if emb_w is not None and emb_l is not None:
            norm_w = np.linalg.norm(emb_w)
            norm_l = np.linalg.norm(emb_l)
            
            if norm_w > 1e-6 and norm_l > 1e-6:
                result["emb_cosine_sim"] = float(np.dot(emb_w, emb_l) / (norm_w * norm_l))
            else:
                result["emb_cosine_sim"] = 0.0
            
            result["emb_l2_distance"] = float(np.linalg.norm(emb_w - emb_l))
            
            diff = emb_w - emb_l
            for i in range(min(8, EMBEDDING_DIM)):
                result[f"emb_diff_{i}"] = float(diff[i])
            
            result["emb_norm_winner"] = float(norm_w)
            result["emb_norm_loser"] = float(norm_l)
            result["has_embeddings"] = 1
        else:
            result["emb_cosine_sim"] = None
            result["emb_l2_distance"] = None
            for i in range(min(8, EMBEDDING_DIM)):
                result[f"emb_diff_{i}"] = None
            result["emb_norm_winner"] = None
            result["emb_norm_loser"] = None
            result["has_embeddings"] = 0
        
        results.append(result)
    
    features_df = pl.DataFrame(results, infer_schema_length=None)
    print(f"  Coverage: {features_df['has_embeddings'].mean():.1%}")
    
    return player_embeddings, features_df


# ===============================================
# MAIN
# ===============================================

def main():
    t0 = datetime.now()
    
    print("\n[Loading] Matches...")
    matches_df = pl.read_parquet(MATCHES_BASE)
    print(f"  Total: {len(matches_df):,}")
    
    data = prepare_data_zero_leakage(matches_df, min_matches=10)
    
    gcn, predictor = train_gnn(data)
    
    player_embeddings, features_df = extract_and_create_features(gcn, data, matches_df)
    
    # Save
    features_df.write_parquet(OUTPUT_DIR / "embedding_features.parquet")
    print(f"\n  ‚úÖ Saved: {OUTPUT_DIR / 'embedding_features.parquet'}")
    
    np.save(OUTPUT_DIR / "player_embeddings.npy", 
            np.array([player_embeddings[p] for p in sorted(player_embeddings.keys())]))
    
    import json
    with open(OUTPUT_DIR / "player_mapping.json", "w") as f:
        json.dump({p: i for i, p in enumerate(sorted(player_embeddings.keys()))}, f)
    
    torch.save({
        'gcn': gcn.state_dict(),
        'predictor': predictor.state_dict(),
        'config': {'embedding_dim': EMBEDDING_DIM, 'hidden_dim': HIDDEN_DIM, 'num_layers': NUM_LAYERS}
    }, OUTPUT_DIR / "gnn_model.pt")
    
    elapsed = (datetime.now() - t0).total_seconds()
    
    print("\n" + "=" * 70)
    print("   ‚úÖ PP_14 GNN ULTIMATE GOD SOTA v2 COMPLETE!")
    print("=" * 70)
    print(f"   ‚è±Ô∏è  Time: {elapsed:.1f}s")
    print(f"   üìä Players: {len(player_embeddings):,}")


if __name__ == "__main__":
    main()

   PP_14 - PLAYER EMBEDDINGS (GNN ULTIMATE GOD SOTA v2)
   2025-12-16 20:08:19
   Device: cuda
   Architecture: GCN + Ranking + STABILIZED
   Learning rate: 0.0001

[Loading] Matches...
  Total: 544,245

[1/4] Splitting data chronologically...
  Total: 544,245 | Train: 489,820 | Val: 54,425

[2/4] Building player set...
  Players: 6,245

[3/4] Building graph (clamped weights)...
  Edges: 925,538
  Weight range: [0.3679, 1.0000]

[4/4] Computing features...
  Features shape: torch.Size([6245, 8])

  Train labels: 925,538
  Val labels: 85,180

  TRAINING (STABILIZED)
  Sampling: 200,000 / 925,538
  Params: 457,089
  Warmup epochs: 10
  Epoch  10: loss=0.6918, acc=0.5236 | val_loss=0.6914, val_acc=0.5708, AUC=0.5955, lr=0.000100
  Epoch  20: loss=0.6874, acc=0.5667 | val_loss=0.6877, val_acc=0.5874, AUC=0.6159, lr=0.000100
  Epoch  30: loss=0.6810, acc=0.5877 | val_loss=0.6829, val_acc=0.5910, AUC=0.6197, lr=0.000100
  Epoch  40: loss=0.6717, acc=0.6020 | val_loss=0.6770, val_acc=0.5933, 