In [None]:
#imports (can be imported individually later too, if needed)
import os
import cv2
import numpy as np
import joblib
import mediapipe as mp
from collections import Counter
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report


In [None]:
'''Processes a dataset of images to extract pose landmarks using MediaPipe Pose. It reads images from the folder Swing_events, extracts pose features, and stores them.
It uses MediaPipe to extract pose landmarks from each image and builds a dataset of pose feature vectors (X) with corresponding labels (y).
This code assumes that the dataset is organized in subfolders, where each subfolder corresponds to a different class (swing phases). 
I have used a maximum of 1350 images per class here.'''

import os
import cv2
import mediapipe as mp
import numpy as np

mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=True)

def extract_pose_from_image(image_path):
    img = cv2.imread(image_path)
    if img is None:
        return None
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    result = pose.process(img_rgb)
    if result.pose_landmarks:
        return np.array([[lm.x, lm.y, lm.z] for lm in result.pose_landmarks.landmark]).flatten()
    return None

def process_dataset(base_folder, max_per_class=1350):
    X, y = [], []
    for label, folder in enumerate(sorted(os.listdir(base_folder))):
        folder_path = os.path.join(base_folder, folder)
        if not os.path.isdir(folder_path): continue

        print(f"Processing: {folder}")
        count = 0
        for fname in sorted(os.listdir(folder_path)):
            if count >= max_per_class:
                break
            fpath = os.path.join(folder_path, fname)
            features = extract_pose_from_image(fpath)
            if features is not None:
                X.append(features)
                y.append(folder)  # Using folder name as label
                count += 1
        print(f"Processed {count} images for {folder}")
    return np.array(X), np.array(y)

X, y = process_dataset('Swing_events', max_per_class=1350)


Processing: Address


I0000 00:00:1749221499.472110 13029533 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.4), renderer: Apple M2
W0000 00:00:1749221499.570467 13272788 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1749221499.585886 13272793 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


Processed 1350 images for Address
Processing: Finish
Processed 1350 images for Finish
Processing: Idle
Processed 1350 images for Idle
Processing: Impact
Processed 1350 images for Impact
Processing: Mid-Backswing
Processed 1350 images for Mid-Backswing
Processing: Mid-Downswing
Processed 1350 images for Mid-Downswing
Processing: Mid-Follow-Through
Processed 1350 images for Mid-Follow-Through
Processing: Toe-up
Processed 1350 images for Toe-up
Processing: Top
Processed 1350 images for Top


In [None]:
'''Trains a RandomForestClassifier on pose landmarks to classify golf swing phases.
It evaluates the model using precision, recall, and F1-score, and saves the trained model to disk as pose_classifier.pkl.'''


from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)

clf = RandomForestClassifier(n_estimators=100)
clf.fit(X_train, y_train)

print(classification_report(y_test, clf.predict(X_test)))
import joblib
joblib.dump(clf, 'pose_classifier.pkl')


                    precision    recall  f1-score   support

           Address       0.84      0.83      0.84       270
            Finish       0.93      0.92      0.93       270
              Idle       0.93      0.95      0.94       270
            Impact       0.77      0.82      0.79       270
     Mid-Backswing       0.74      0.77      0.75       270
     Mid-Downswing       0.74      0.70      0.72       270
Mid-Follow-Through       0.87      0.87      0.87       270
            Toe-up       0.92      0.86      0.89       270
               Top       0.90      0.92      0.91       270

          accuracy                           0.85      2430
         macro avg       0.85      0.85      0.85      2430
      weighted avg       0.85      0.85      0.85      2430



['pose_classifier.pkl']

In [None]:
'''Classifies each frame of a video using trained pose classifier.
Extracts pose landmarks with MediaPipe, predicts the swing phase for each frame,
and saves the results (frame_id, predicted_label) to a text file.'''


def classify_video(video_path, model, pose, output_txt='classified_frames.txt'):
    cap = cv2.VideoCapture(video_path)
    results = []

    frame_id = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break

        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        result = pose.process(img_rgb)

        if result.pose_landmarks:
            pose_vec = np.array([[lm.x, lm.y, lm.z] for lm in result.pose_landmarks.landmark]).flatten().reshape(1, -1)
            pred = model.predict(pose_vec)[0]
        else:
            pred = 'NoPose'

        results.append((frame_id, pred))
        frame_id += 1

    cap.release()

    with open(output_txt, 'w') as f:
        for frame_id, label in results:
            f.write(f"{frame_id},{label}\n")

    return results


In [None]:
#TO LABEL FRAMES IN A VIDEO AND SAVE THE OUTPUT VIDEO WITH LABELS

'''Loads the model and has the function to annotate each frame of a video with predicted golf swing phase.
The labeled video is saved to disk with the text overlay of predicted class.'''

# Load model
model = joblib.load("pose_classifier.pkl")

# Initialize MediaPipe pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False)

# Define function to label and save output video
def label_and_save_video(input_video_path, output_video_path, model, pose_model):
    cap = cv2.VideoCapture(input_video_path)

    # Get video properties
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps    = cap.get(cv2.CAP_PROP_FPS)

    # Set up video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))

    frame_id = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break

        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        result = pose_model.process(img_rgb)

        if result.pose_landmarks:
            pose_vec = np.array([[lm.x, lm.y, lm.z] for lm in result.pose_landmarks.landmark]).flatten().reshape(1, -1)
            pred_label = model.predict(pose_vec)[0]
        else:
            pred_label = "NoPose"

        # Put text label on the frame
        cv2.putText(frame, f"{pred_label}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)

        # Write frame to output
        out.write(frame)
        frame_id += 1

    cap.release()
    out.release()
    print(f"Saved annotated video to {output_video_path}")


I0000 00:00:1749223325.892914 13029533 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.4), renderer: Apple M2


W0000 00:00:1749223325.972705 13291538 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1749223325.988769 13291538 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


In [None]:
'''Actual code to execute the labeling and saving of the video, using functions defined in previous cell.'''

label_and_save_video(
    input_video_path='sample_vid1.mp4',
    output_video_path='latest7.mp4',
    model=model,
    pose_model=pose
)


Saved annotated video to latest7.mp4


In [None]:
## MAIN CODE TO EXTRACT SWING CLIPS FROM A VIDEO


video_path = "sample_vid3.mp4"  # update this path
output_dir='swing_clips4'       # directly where you want to save the extracted clips


import cv2
import numpy as np
import joblib
import mediapipe as mp
from collections import Counter
import os

# Load pose classifier
model = joblib.load("pose_classifier.pkl")

# Initialize MediaPipe pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False)

# Define swing classes
swing_classes = [
    "Address", "Takeaway", "Mid-Backswing", "Top",
    "Mid-Downswing", "Impact", "Follow-through", "Finish"
]

# Step 1: Classify each frame
def classify_video(video_path, model, pose):
    cap = cv2.VideoCapture(video_path)
    results = []
    frame_id = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret: break

        img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        result = pose.process(img_rgb)

        if result.pose_landmarks:
            pose_vec = np.array([[lm.x, lm.y, lm.z] for lm in result.pose_landmarks.landmark]).flatten().reshape(1, -1)
            pred = model.predict(pose_vec)[0]
        else:
            pred = 'NoPose'

        results.append((frame_id, pred))
        frame_id += 1

    cap.release()
    return results

# Step 2: Smooth predictions with 4-frame window
def smooth_predictions(results, window_size=4):
    smoothed = []
    labels = [label for _, label in results]
    for i in range(len(labels)):
        window = labels[max(0, i - window_size//2): min(len(labels), i + window_size//2)]
        most_common = Counter(window).most_common(1)[0][0]
        smoothed.append((results[i][0], most_common))
    return smoothed

# Step 3: Detect swings starting with "Address" and ending with "Finish"
def extract_address_to_finish_segments(preds, min_swing_length=20):
    swing_segments = []
    start = None
    temp_segment = []

    for i in range(len(preds)):
        frame_id, label = preds[i]

        if label == "Address":
            # Start new swing or restart if another Address comes before Finish
            start = frame_id
            temp_segment = [(frame_id, label)]

        elif start is not None:
            temp_segment.append((frame_id, label))

            if label == "Finish":
                # Extend to last consecutive "Finish"
                j = i
                while j + 1 < len(preds) and preds[j + 1][1] == "Finish":
                    j += 1
                    temp_segment.append(preds[j])

                end = preds[j][0]

                # Apply filters: length + required labels
                labels_in_segment = [lbl for _, lbl in temp_segment]
                if (
                    end - start >= min_swing_length and
                    "Mid-Backswing" in labels_in_segment and
                    "Mid-Downswing" in labels_in_segment
                ):
                    swing_segments.append((start, end))

                start = None
                temp_segment = []

    return swing_segments

# Step 4: Save swing clips
def save_swing_clips(video_path, segments, output_dir='swing_clips'):
    os.makedirs(output_dir, exist_ok=True)

    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')

    for idx, (start, end) in enumerate(segments):
        cap.set(cv2.CAP_PROP_POS_FRAMES, start)
        out = cv2.VideoWriter(os.path.join(output_dir, f'swing_{idx}.mp4'), fourcc, fps, (width, height))

        for f in range(start, end + 1):
            ret, frame = cap.read()
            if not ret:
                break
            out.write(frame)
        out.release()
        print(f"Saved swing_{idx}.mp4 from frame {start} to {end}")

    cap.release()

# Run all steps
frame_results = classify_video(video_path, model, pose)
smoothed = smooth_predictions(frame_results, window_size=4)
segments = extract_address_to_finish_segments(smoothed, min_swing_length=20)
save_swing_clips(video_path, segments, output_dir)


I0000 00:00:1749223794.969482 13029533 gl_context.cc:369] GL version: 2.1 (2.1 Metal - 89.4), renderer: Apple M2
W0000 00:00:1749223795.041648 13297038 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.
W0000 00:00:1749223795.053636 13297038 inference_feedback_manager.cc:114] Feedback manager requires a model with a single signature inference. Disabling support for feedback tensors.


Saved swing_0.mp4 from frame 1484 to 1536
Saved swing_1.mp4 from frame 1541 to 1608
