# NBA-22-1: Modèle de Classification (Gagnant/Perdant)

**Objectif:** Prédire le gagnant d'un match NBA avec PySpark ML

**Target:** Winner (0 = perdant, 1 = gagnant)  
**Objectif de performance:** Accuracy > 65%

**Algorithme:** Random Forest Classifier

## 1. Setup et Imports

In [None]:
import sys
sys.path.insert(0, '../src')

from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler, StandardScaler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder, TrainValidationSplit
from pyspark.ml import Pipeline

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Configuration Spark
spark = SparkSession.builder \
    .appName("NBA-Classification") \
    .config("spark.sql.adaptive.enabled", "true") \
    .getOrCreate()

print(f"Spark version: {spark.version}")

## 2. Chargement des Données

In [None]:
# Charger le dataset ML créé par NBA-21
df = spark.read.parquet("../data/gold/ml_features")

print(f"Shape: ({df.count()}, {len(df.columns)})")
print(f"Colonnes: {df.columns}")

# Aperçu
df.show(5)

## 3. Analyse Exploratoire

In [None]:
# Distribution de la target
df.groupBy("target").count().show()

# Balance des classes
target_dist = df.groupBy("target").count().toPandas()
plt.figure(figsize=(8, 5))
sns.barplot(data=target_dist, x="target", y="count")
plt.title("Distribution: Gagnant vs Perdant")
plt.xlabel("Target (0=Perdant, 1=Gagnant)")
plt.ylabel("Nombre de matchs")
plt.show()

## 4. Préparation des Features

In [None]:
# Définir les features à utiliser
feature_cols = [
    'win_pct_home',
    'win_pct_away',
    'win_pct_last_5_home',
    'win_pct_last_5_away',
    'avg_points_home',
    'avg_points_away',
    'rest_days_home',
    'rest_days_away',
    'is_back_to_back_home',
    'is_back_to_back_away',
    # Ajouter d'autres features ici
]

print(f"Nombre de features: {len(feature_cols)}")
print("Features:", feature_cols)

## 5. Split Train/Test (Temporel !)

In [None]:
# IMPORTANT: Split temporel pour éviter la fuite de données
# Train: saisons 2018-2022
# Validation: saison 2022-2023
# Test: saison 2023-2024

train_df = df.filter(df.season.isin(['2018-19', '2019-20', '2020-21', '2021-22']))
val_df = df.filter(df.season == '2022-23')
test_df = df.filter(df.season == '2023-24')

print(f"Train: {train_df.count()} matchs")
print(f"Validation: {val_df.count()} matchs")
print(f"Test: {test_df.count()} matchs")

## 6. Modèle Baseline: Random Forest

In [None]:
# Pipeline
assembler = VectorAssembler(
    inputCols=feature_cols,
    outputCol="features",
    handleInvalid="skip"
)

scaler = StandardScaler(
    inputCol="features",
    outputCol="scaled_features"
)

rf = RandomForestClassifier(
    labelCol="target",
    featuresCol="scaled_features",
    numTrees=100,
    maxDepth=10,
    seed=42
)

pipeline = Pipeline(stages=[assembler, scaler, rf])

# Entraînement
model = pipeline.fit(train_df)
print("✅ Modèle entraîné")

## 7. Évaluation

In [None]:
# Prédictions
train_pred = model.transform(train_df)
val_pred = model.transform(val_df)
test_pred = model.transform(test_df)

# Évaluateurs
evaluators = {
    'accuracy': MulticlassClassificationEvaluator(
        labelCol="target", predictionCol="prediction", metricName="accuracy"
    ),
    'precision': MulticlassClassificationEvaluator(
        labelCol="target", predictionCol="prediction", metricName="weightedPrecision"
    ),
    'recall': MulticlassClassificationEvaluator(
        labelCol="target", predictionCol="prediction", metricName="weightedRecall"
    ),
    'f1': MulticlassClassificationEvaluator(
        labelCol="target", predictionCol="prediction", metricName="f1"
    ),
    'auc': BinaryClassificationEvaluator(
        labelCol="target", rawPredictionCol="rawPrediction", metricName="areaUnderROC"
    )
}

# Calcul des métriques
results = {}
for split_name, split_df in [('Train', train_pred), ('Validation', val_pred), ('Test', test_pred)]:
    results[split_name] = {}
    for metric_name, evaluator in evaluators.items():
        results[split_name][metric_name] = evaluator.evaluate(split_df)

# Affichage
results_df = pd.DataFrame(results).round(4)
print("\nMétriques:")
print(results_df)

## 8. Feature Importance

In [None]:
# Extraire l'importance des features
rf_model = model.stages[-1]
importances = rf_model.featureImportances.toArray()

# DataFrame pour visualisation
feat_imp = pd.DataFrame({
    'feature': feature_cols,
    'importance': importances
}).sort_values('importance', ascending=False)

# Visualisation
plt.figure(figsize=(10, 6))
sns.barplot(data=feat_imp.head(15), x='importance', y='feature')
plt.title('Top 15 Features - Random Forest')
plt.xlabel('Importance')
plt.tight_layout()
plt.show()

print("\nTop 10 features les plus importantes:")
print(feat_imp.head(10))

## 9. Optimisation avec Grid Search (Optionnel)

In [None]:
# Décommenter pour faire du grid search
# param_grid = ParamGridBuilder() \
#     .addGrid(rf.numTrees, [50, 100, 200]) \
#     .addGrid(rf.maxDepth, [5, 10, 15]) \
#     .addGrid(rf.maxBins, [16, 32]) \
#     .build()
# 
# tvs = TrainValidationSplit(
#     estimator=pipeline,
#     estimatorParamMaps=param_grid,
#     evaluator=evaluators['accuracy'],
#     trainRatio=0.8,
#     seed=42
# )
# 
# best_model = tvs.fit(train_df)
# print("✅ Grid search terminé")

## 10. Sauvegarde du Modèle

In [None]:
# Sauvegarder le modèle
model_path = "../models/classification_baseline"
model.save(model_path)
print(f"✅ Modèle sauvegardé: {model_path}")

# Sauvegarder les métriques
import json
metrics_path = "../models/classification_metrics.json"
with open(metrics_path, 'w') as f:
    json.dump(results, f, indent=2)
print(f"✅ Métriques sauvegardées: {metrics_path}")

## Résumé

**Résultats:**
- Accuracy Train: X.XXXX
- Accuracy Validation: X.XXXX
- **Accuracy Test: X.XXXX** ← Objectif > 65%

**Features importantes:**
1. Feature A (XX.X%)
2. Feature B (XX.X%)
3. Feature C (XX.X%)

**Conclusion:** [À remplir après exécution]