# 04 – Model Training (LSTM)

In this notebook, we train an LSTM model to classify the direction of penalty kicks based on keypoint sequences extracted with YOLOv7-Pose.

 Input: preprocessed dataset (X, y) from notebook 03

 Output: trained model, accuracy, loss curve, and confusion matrix

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout, Masking
from tensorflow.keras.utils import to_categorical

# Charger les données sauvegardées (ou depuis 03 directement)
# Si 03 vient d’être exécuté, on garde X, y sinon on peut charger avec np.load()

In [None]:
# Encodage des labels (g:0, m:1, d:2)
y_cat = to_categorical(y, num_classes=3)
X_train, X_test, y_train, y_test = train_test_split(X, y_cat, test_size=0.2, random_state=42)

X_train.shape, X_test.shape, y_train.shape

In [None]:
# Définir un modèle LSTM simple
model = Sequential()
model.add(Masking(mask_value=0., input_shape=(X.shape[1], X.shape[2])))
model.add(LSTM(64, return_sequences=False))
model.add(Dropout(0.5))
model.add(Dense(32, activation='relu'))
model.add(Dense(3, activation='softmax'))

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
model.summary()

In [None]:
# Entraînement
history = model.fit(
    X_train, y_train,
    epochs=30,
    batch_size=16,
    validation_data=(X_test, y_test),
    verbose=1
)

In [None]:
# Affichage des courbes d'entraînement
plt.plot(history.history['accuracy'], label='train acc')
plt.plot(history.history['val_accuracy'], label='val acc')
plt.title('Accuracy')
plt.legend()
plt.show()

In [None]:
# Évaluation finale
y_pred = model.predict(X_test)
y_true = np.argmax(y_test, axis=1)
y_pred_classes = np.argmax(y_pred, axis=1)

print(classification_report(y_true, y_pred_classes, target_names=['g', 'm', 'd']))

In [None]:
# Confusion matrix
cm = confusion_matrix(y_true, y_pred_classes)
sns.heatmap(cm, annot=True, fmt='d', xticklabels=['g', 'm', 'd'], yticklabels=['g', 'm', 'd'])
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()