# Model Training: SARIMA (Global + Per-Product)

Ce notebook impl√©mente deux approches SARIMA :
1. **Globale** : Mod√®le unique sur les ventes agr√©g√©es.
2. **Par Produit** : Mod√®le individuel pour chaque produit.

Donn√©es utilis√©es : `Data/train_sarimax.csv` et `Data/test_sarimax.csv`.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_squared_error, mean_absolute_error
import joblib
import os
import warnings

warnings.filterwarnings('ignore')

os.makedirs('trained_models', exist_ok=True)
print("‚úÖ Biblioth√®ques import√©es.")

## Chargement des Donn√©es

In [None]:
train_path = 'Data/train_sarimax.csv'
test_path = 'Data/test_sarimax.csv'

if not os.path.exists(train_path) or not os.path.exists(test_path):
    raise FileNotFoundError("Fichiers Data/ introuvables.")

train_df = pd.read_csv(train_path, parse_dates=['Date'])
test_df = pd.read_csv(test_path, parse_dates=['Date'])

print(f"üì¶ Train cols: {list(train_df.columns)}")
print(f"üì¶ Train shape: {train_df.shape}")

## 1. Mod√®le Global
Nous agr√©geons toutes les ventes par date pour obtenir une tendance globale.

In [None]:
print("\nüåç --- Entra√Ænement Mod√®le Global ---")

# Agr√©gation
global_train = train_df.groupby('Date')[['Units Sold', 'Price', 'Discount', 'Holiday/Promotion', 'Competitor Pricing']].mean().reset_index()
global_train['Units Sold'] = train_df.groupby('Date')['Units Sold'].sum().values # Somme pour la cible
global_train = global_train.set_index('Date').asfreq('W')

global_test = test_df.groupby('Date')[['Units Sold', 'Price', 'Discount', 'Holiday/Promotion', 'Competitor Pricing']].mean().reset_index()
global_test['Units Sold'] = test_df.groupby('Date')['Units Sold'].sum().values
global_test = global_test.set_index('Date').asfreq('W')

# Entra√Ænement
target_col = 'Units Sold'
exog_cols = ['Price', 'Discount', 'Holiday/Promotion', 'Competitor Pricing']

model_global = SARIMAX(
    global_train[target_col],
    exog=global_train[exog_cols],
    order=(1, 1, 1),
    seasonal_order=(0, 1, 1, 52),
    enforce_stationarity=False,
    enforce_invertibility=False
)
fit_global = model_global.fit(disp=False)

# √âvaluation
global_preds = fit_global.forecast(steps=len(global_test), exog=global_test[exog_cols])
rmse_global = np.sqrt(mean_squared_error(global_test[target_col], global_preds))
print(f"‚úÖ Global MAE: {mean_absolute_error(global_test[target_col], global_preds):.2f}")
print(f"‚úÖ Global RMSE: {rmse_global:.2f}")

# Sauvegarde
joblib.dump(fit_global, 'trained_models/global_sarima.pkl')
print("üíæ Mod√®le Global sauvegard√©.")

## 2. Mod√®les Par Produit
Nous it√©rons sur chaque `Product ID`.

In [None]:
print("\nüè≠ --- Entra√Ænement Par Produit ---")

products = train_df['Product ID'].unique()
print(f"Produits trouv√©s: {len(products)}")

for pid in products:
    try:
        # Filtrer
        p_train = train_df[train_df['Product ID'] == pid].set_index('Date').asfreq('W')
        p_test = test_df[test_df['Product ID'] == pid].set_index('Date').asfreq('W')
        
        if len(p_train) < 10:
            continue
            
        # Entra√Æner (param√®tres simplifi√©s pour rapidit√©/robustesse)
        # On utilise exog si possible, sinon simple ARIMA si colonnes constantes ou NaN
        exog_train = p_train[exog_cols].fillna(0)
        exog_test = p_test[exog_cols].fillna(0)
        
        model_p = SARIMAX(
            p_train[target_col],
            exog=exog_train,
            order=(1, 1, 1),
            seasonal_order=(0, 0, 0, 0), # Simplification pour √©viter erreurs sur petites s√©ries
            enforce_stationarity=False,
            enforce_invertibility=False
        )
        fit_p = model_p.fit(disp=False)
        
        # Save
        joblib.dump(fit_p, f'trained_models/sarima_{pid}.pkl')
        
    except Exception as e:
        print(f"‚ö†Ô∏è Erreur {pid}: {e}")

print("‚úÖ Tous les mod√®les produits ont √©t√© trait√©s.")