# Cricket Shot Classification Pipeline with XAI
This notebook covers:
- Training a Random Forest classifier on keypoint data
- Applying SHAP for explainability
- Visualizing keypoint contributions

In [4]:
#Install required packages
!pip install mediapipe opencv-python scikit-learn shap matplotlib pandas --quiet

In [None]:
# Import libraries
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load the dataset
df = pd.read_csv("keypoints_per_frame.csv")
print("Dataset shape:", df.shape)
df.head()

# Prepare features and labels
X = df.drop(columns=["label", "video", "frame"])
y = df["label"]

# Split the dataset
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

# Train Random Forest
rf = RandomForestClassifier(random_state=42)
rf.fit(X_train, y_train)



# Make predictions
y_pred = rf.predict(X_test)

# Evaluate
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy:.2f}")
print("\nClassification Report:\n", classification_report(y_test, y_pred))

# Feature Importance Visualization
importances = rf.feature_importances_
indices = importances.argsort()[::-1]

keypoints = [
    "Nose", "LeftEyeInner", "LeftEye", "LeftEyeOuter", "RightEyeInner", "RightEye",
    "RightEyeOuter", "LeftEar", "RightEar", "LeftShoulder", "RightShoulder",
    "LeftElbow", "RightElbow", "LeftWrist", "RightWrist", "LeftPinky", "RightPinky"
]

coords = ["x", "y", "z"]

# Generate full feature name list
feature_names = [f"{kp}_{c}" for kp in keypoints for c in coords]

plt.figure(figsize=(12, 30))
sns.barplot(x=importances[indices[:51]], y=[feature_names[i] for i in indices[:51]])
plt.title("Feature Importances")
plt.tight_layout()
plt.show()

# Define actual class names
class_names = {
    0: "Bowled",
    1: "Cover Drive",
    2: "Defence",
    3: "Pull Shot",
    4: "Reverse Sweep"
}



## DISPLAY FEATURE IMPORTANCE AND OVERALL SHAP

In [None]:
# Import SHAP and initialize explainer
import shap
import numpy as np
import matplotlib.pyplot as plt

# Initialize JavaScript visualizer
shap.initjs()

# Create SHAP explainer
explainer = shap.TreeExplainer(rf)
shap_values = explainer.shap_values(X_test)
print("Shap calculated")

keypoints = [
    "Nose", "LeftEyeInner", "LeftEye", "LeftEyeOuter", "RightEyeInner", "RightEye",
    "RightEyeOuter", "LeftEar", "RightEar", "LeftShoulder", "RightShoulder",
    "LeftElbow", "RightElbow", "LeftWrist", "RightWrist", "LeftPinky", "RightPinky"
]

coords = ["x", "y", "z"]

# Generate full feature name list
feature_names = [f"{kp}_{c}" for kp in keypoints for c in coords]

# Assign names to feature columns
X.columns = feature_names
print(len(shap_values))



# Create and save the bar plot
shap.summary_plot(shap_values, X_test, feature_names=feature_names, plot_type="bar", max_display=51, show=False)
plt.tight_layout()
plt.savefig("shap_summary_bar.png", dpi=300)
plt.close()

# Global explanation - SHAP summary
shap.summary_plot(shap_values, X_test, max_display=51)





## DISPLAY PER CLASS SHAP

In [None]:
X_test.columns = feature_names
print("shap_values shape:", np.array(shap_values).shape)
print("X_test shape:", X_test.shape)
# Loop over each class
for class_idx, class_name in class_names.items():
    print(f"Generating SHAP summary for class: {class_name}")
    
    # Create and save the bar summary plot
    shap.summary_plot(
        shap_values[:,:,class_idx], X_test,
        feature_names=feature_names,
        plot_type="bar",
        max_display=51,
        show=True
    )
    plt.title(f"SHAP Summary (Bar) - {class_name}")
    plt.tight_layout()
    plt.close()

    # Create and save the dot summary plot
    shap.summary_plot(
        shap_values[:,:, class_idx], X_test,
        feature_names=feature_names,
        max_display=51,
        show=True
    )
    plt.title(f"SHAP Summary (Dot) - {class_name}")
    plt.tight_layout()
    plt.close()

plt.show()
