In [1]:
pip install opencv-python mediapipe numpy tqdm ultralytics --quiet

Note: you may need to restart the kernel to use updated packages.


## Import YOLO

In [1]:
import os
import cv2
import csv
import numpy as np
from tqdm import tqdm
from ultralytics import YOLO
import mediapipe as mp

# Initialize MediaPipe Pose and YOLOv8 model
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False)
yolo = YOLO("yolov8n.pt")  # You can use yolov8s.pt for higher accuracy


## Trim Videos

In [2]:
# This function trims the video by 0.5 seconds at the beginning and 1 second at the end.
# Values can be altered through the trim_start_sec and trim_end_sec respectively.
def trim_video(input_path, output_path, trim_start_sec=0.5, trim_end_sec=1.0):
    cap = cv2.VideoCapture(input_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Determine which frames to keep
    start_frame = int(trim_start_sec * fps)
    end_frame = total_frames - int(trim_end_sec * fps)

    # Setup video writer
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    # Write only the selected frames
    frame_idx = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if start_frame <= frame_idx < end_frame:
            out.write(frame)
        frame_idx += 1

    cap.release()
    out.release()



## Get all keypoints

In [7]:
def process_all_frames_to_csv(input_root, output_csv_path):
    rows = []
    labels = []
    video_entries = []

    # Step 1: Collect all .mp4 videos and subfolder names
    for root, _, files in os.walk(input_root):
        for file in files:
            if file.endswith(".mp4"):
                full_input_path = os.path.join(root, file)
                class_label = os.path.basename(root)
                relative_subfolder = os.path.relpath(root, input_root)
                video_entries.append((full_input_path, class_label, relative_subfolder, file))

    # Step 2: Process each video, frame-by-frame
    for input_path, label, subfolder, file in tqdm(video_entries, desc="Processing videos"):
        cap = cv2.VideoCapture(input_path)
        frame_idx = 0

        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            frame_idx += 1

            # Detect people
            detections = yolo(frame)[0].boxes.data.cpu().numpy()
            person_boxes = [d for d in detections if int(d[5]) == 0]
            if not person_boxes:
                continue

            # Use top-most person (smallest y1)
            person_boxes.sort(key=lambda b: b[1])
            x1, y1, x2, y2, *_ = person_boxes[0]
            x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
            cropped = frame[y1:y2, x1:x2]
            if cropped.size == 0:
                continue

            cropped_rgb = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)
            result = pose.process(cropped_rgb)

            if result.pose_landmarks:
                keypoints = []
                for lm in result.pose_landmarks.landmark[:17]:
                    keypoints.extend([lm.x, lm.y, lm.z])

                # Append row: keypoints + metadata
                rows.append(keypoints + [label, file, frame_idx])

        cap.release()

    # Step 3: Save all keypoints to a CSV
    if rows:
        header = [f'kp_{i}' for i in range(len(rows[0]) - 3)] + ['label', 'video', 'frame']
        with open(output_csv_path, 'w', newline='') as f:
            writer = csv.writer(f)
            writer.writerow(header)
            writer.writerows(rows)
        print(f"✅ Saved {len(rows)} frames to: {output_csv_path}")
    else:
        print("⚠️ No keypoints extracted. CSV not created.")


In [None]:
# Provide input video folder with subfolders
process_all_frames_to_csv("../Processed Videos", "keypoints_per_frame.csv")

## Pose extraction using YOLO object detection

In [5]:
# This function serves as a check to see if YOLOv8 is taking the correct person.
def detect_batter_and_extract_pose(video_path, max_frames=150, show=True):
    cap = cv2.VideoCapture(video_path)
    pose = mp_pose.Pose(static_image_mode=False)
    frame_count = 0
    pose_results = []

    while cap.isOpened() and frame_count < max_frames:
        ret, frame = cap.read()
        if not ret:
            break
        frame_count += 1

        # Run YOLO detection
        detections = yolo(frame)[0].boxes.data.cpu().numpy()
        person_boxes = [d for d in detections if int(d[5]) == 0]  # class 0 = person

        if not person_boxes:
            continue

        # Sort by y1 (top of box) — highest person in frame comes first
        person_boxes.sort(key=lambda b: b[1])  # b[1] = y1
        x1, y1, x2, y2, conf, cls = person_boxes[0]  # topmost person
        x1, y1, x2, y2 = map(int, [x1, y1, x2, y2])
        cropped = frame[y1:y2, x1:x2]

        if cropped.size == 0:
            continue

        cropped_rgb = cv2.cvtColor(cropped, cv2.COLOR_BGR2RGB)
        result = pose.process(cropped_rgb)

        if result.pose_landmarks:
            pose_results.append(result.pose_landmarks)

            if show:
                annotated = cropped_rgb.copy()
                mp_drawing.draw_landmarks(
                    annotated,
                    result.pose_landmarks,
                    mp_pose.POSE_CONNECTIONS
                )
                plt.figure(figsize=(10, 6))
                plt.imshow(annotated)
                plt.title(f"Batter Pose @ Frame {frame_count}")
                plt.axis('off')
                plt.show()

    cap.release()
    pose.close()
    return pose_results

In [None]:
detect_batter_and_extract_pose("../CKT Dataset/Cover Drive/(100).mp4")