
# Pickleball Action Recognition — YOLOv8 Pose + LSTM (End-to-End)

This notebook builds a **pickleball action recognizer** using **YOLOv8-Pose** to extract body keypoints and an **LSTM** to classify actions (e.g., *CorrectServe, DriveBackHand, DriveForehand*).  
It includes: data prep, pose extraction, sequence building, training, realtime inference, and export for deployment.

> **Run locally** with internet to `pip install` dependencies. CUDA is optional but recommended.


In [1]:

# ==== 1) Environment ====
# Run once (internet required). Comment out if you've already installed.
# If on Windows + CUDA, ensure correct torch version from https://pytorch.org/get-started/locally/

# %pip install ultralytics==8.2.0 opencv-python torch torchvision torchaudio numpy scikit-learn


In [8]:

# ==== 2) Imports & Global Config ====
import os, glob, math, json, time, collections
import numpy as np
import cv2
from pathlib import Path

try:
    import torch, torch.nn as nn, torch.optim as optim
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    from ultralytics import YOLO
except Exception as e:
    print("If imports fail, please run the pip install cell above.")
    print("Error:", e)

# --- Project configuration ---
# Adjust these for your machine:
DATA_ROOT = r"D:\B3-ICT\Group Project"   # your dataset root containing class folders
OUT_DIR   = "prepared"                     # where to save npy arrays
POSE_MODEL = "yolo11s-pose.pt"   # hoặc "yolo11s-pose.pt", "yolo11m-pose.pt", ...


# Classes (folder names). Update to match your data:
CLASSES = ["Serve", "DriveBackHand", "DriveForehand", "Dink"]

# Sequence parameters
SEQ_LEN = 24
STRIDE  = 6

# Keypoint filtering
MIN_KP_CONF = 0.25

# Training hyperparams
DEVICE = "cuda" if (hasattr(torch, 'cuda') and torch.cuda.is_available()) else "cpu"
EPOCHS = 40
BATCH  = 64
LR     = 1e-3
HIDDEN = 128
LAYERS = 2
DROPOUT = 0.2

print("DEVICE =", DEVICE)
print("DATA_ROOT =", DATA_ROOT)
print("CLASSES =", CLASSES)


DEVICE = cpu
DATA_ROOT = D:\B3-ICT\Group Project
CLASSES = ['Serve', 'DriveBackHand', 'DriveForehand', 'Dink']


In [9]:

# ==== 3) Utility: list videos per class ====

def list_videos(root, classes):
    vids, labels = [], []
    for ci, cname in enumerate(classes):
        class_dir = Path(root) / cname
        for p in class_dir.glob("*.mp4"):
            vids.append(str(p))
            labels.append(ci)
    return vids, labels

vids, labels = list_videos(DATA_ROOT, CLASSES)
print("Found videos:", len(vids))
for v, y in list(zip(vids, labels))[:5]:
    print(y, v)


Found videos: 533
0 D:\B3-ICT\Group Project\Serve\0805(2)-1.mp4
0 D:\B3-ICT\Group Project\Serve\0805(2)-10.mp4
0 D:\B3-ICT\Group Project\Serve\0805(2)-11.mp4
0 D:\B3-ICT\Group Project\Serve\0805(2)-12.mp4
0 D:\B3-ICT\Group Project\Serve\0805(2)-13.mp4


In [10]:

# ==== 4) Pose extraction helpers ====


L_SHOULDER, R_SHOULDER = 5, 6
L_HIP, R_HIP = 11, 12

def center_and_scale(kpts, min_scale=1e-3):

    xy  = kpts[:, :2].astype(np.float32)  # (17,2)
    conf= kpts[:, 2]
    hips = xy[[L_HIP, R_HIP]]
    center = hips.mean(axis=0)
    # scale from shoulders; fallback to hips
    if conf[L_SHOULDER] > 0 and conf[R_SHOULDER] > 0:
        scale = np.linalg.norm(xy[L_SHOULDER] - xy[R_SHOULDER])
    else:
        scale = np.linalg.norm(xy[L_HIP] - xy[R_HIP])
    scale = max(float(scale), min_scale)
    xy_norm = (xy - center) / scale
    return xy_norm

def extract_keypoints_from_frame(results, min_conf=MIN_KP_CONF):
    """Select the largest person and return (17,3) [x,y,conf] or None."""
    if len(results) == 0:
        return None
    r = results[0]
    if r.keypoints is None or r.keypoints.xy is None:
        return None
    xy = r.keypoints.xy
    conf = r.keypoints.conf
    if xy is None or conf is None or len(xy) == 0:
        return None

    # choose the largest bbox by area
    boxes = r.boxes.xyxy.cpu().numpy() if r.boxes is not None else None
    idx = 0
    if boxes is not None and len(boxes) > 0:
        areas = (boxes[:,2]-boxes[:,0])*(boxes[:,3]-boxes[:,1])
        idx = int(np.argmax(areas))

    kxy = xy[idx].cpu().numpy()                 # (17,2)
    kcf = conf[idx].cpu().numpy().reshape(-1,1) # (17,1)
    kpts = np.concatenate([kxy, kcf], axis=1)   # (17,3)

    # too many low-confidence points -> skip
    if (kpts[:,2] < min_conf).mean() > 0.5:
        return None
    return kpts


In [11]:

# ==== 5) Video → per-frame features → sequences ====
def video_to_features(path, yolo, seq_len=SEQ_LEN, stride=STRIDE):
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        print("Failed to open:", path)
        return [], []

    frames_feats = []
    prev_xy = None

    while True:
        ok, frame = cap.read()
        if not ok: break
        results = yolo.predict(frame, verbose=False)
        kpts = extract_keypoints_from_frame(results)
        if kpts is None:
            frames_feats.append(None)
            prev_xy = None
            continue

        xy_norm = center_and_scale(kpts)       
        feat_xy = xy_norm.flatten()            

        if prev_xy is not None:
            vel = (xy_norm - prev_xy).flatten() 
        else:
            vel = np.zeros_like(feat_xy)
        prev_xy = xy_norm.copy()

        feat = np.concatenate([feat_xy, vel], axis=0)  
        frames_feats.append(feat)

    cap.release()

    feats = np.array([f for f in frames_feats if f is not None])
    if len(feats) < seq_len:
        return [], []
    X, idxs = [], []
    for start in range(0, len(feats)-seq_len+1, stride):
        X.append(feats[start:start+seq_len])  
        idxs.append(start)
    return X, idxs

def build_dataset(data_root=DATA_ROOT, out_dir=OUT_DIR, pose_model=POSE_MODEL, classes=CLASSES):
    os.makedirs(out_dir, exist_ok=True)
    yolo = YOLO(pose_model)

    vids, labels = list_videos(data_root, classes)
    X_all, y_all = [], []
    for p, y in zip(vids, labels):
        X, _ = video_to_features(p, yolo)
        for seq in X:
            X_all.append(seq)
            y_all.append(y)
    X_all = np.array(X_all, dtype=np.float32)  # (N,T,68)
    y_all = np.array(y_all, dtype=np.int64)
    np.save(Path(out_dir) / "X.npy", X_all)
    np.save(Path(out_dir) / "y.npy", y_all)
    print("Saved:", X_all.shape, y_all.shape, " to", out_dir)


build_dataset()


Saved: (2735, 24, 68) (2735,)  to prepared


In [12]:

# ==== 6) LSTM model & training ====
class ActionLSTM(nn.Module):
    def __init__(self, in_dim, hidden=128, layers=2, num_classes=3, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(in_dim, hidden, num_layers=layers, batch_first=True, dropout=dropout)
        self.head = nn.Sequential(nn.LayerNorm(hidden), nn.Linear(hidden, num_classes))
    def forward(self, x):               # x: (B,T,F)
        out, _ = self.lstm(x)           # (B,T,H)
        out = out[:, -1, :]
        return self.head(out)           # (B,C)

def batch_iter(X, y, batch=64, shuffle=True):
    idx = np.arange(len(X))
    if shuffle:
        np.random.shuffle(idx)
    for i in range(0, len(X), batch):
        sel = idx[i:i+batch]
        yield torch.from_numpy(X[sel]).float(), torch.from_numpy(y[sel]).long()

def train_lstm(x_path=Path(OUT_DIR)/"X.npy", y_path=Path(OUT_DIR)/"y.npy",
               epochs=EPOCHS, batch=BATCH, lr=LR, hidden=HIDDEN, layers=LAYERS, dropout=DROPOUT):
    X = np.load(x_path)   # (N,T,68)
    y = np.load(y_path)   # (N,)

    Xtr, Xva, ytr, yva = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)

    model = ActionLSTM(in_dim=X.shape[-1], hidden=hidden, layers=layers,
                       num_classes=len(CLASSES), dropout=dropout).to(DEVICE)
    crit = nn.CrossEntropyLoss()
    opt  = optim.AdamW(model.parameters(), lr=lr)

    best_acc, best_path = 0.0, "lstm_best.pt"
    for ep in range(1, epochs+1):
        model.train()
        tot, correct, loss_sum = 0, 0, 0.0
        for xb, yb in batch_iter(Xtr, ytr, batch=batch, shuffle=True):
            xb, yb = xb.to(DEVICE), yb.to(DEVICE)
            opt.zero_grad()
            logits = model(xb)
            loss = crit(logits, yb)
            loss.backward()
            opt.step()
            loss_sum += float(loss)
            pred = logits.argmax(1)
            tot += len(yb); correct += int((pred==yb).sum())
        train_acc = correct/tot if tot>0 else 0.0

        # validation
        model.eval()
        with torch.no_grad():
            xb = torch.from_numpy(Xva).float().to(DEVICE)
            yb = torch.from_numpy(yva).long().to(DEVICE)
            logits = model(xb)
            va_loss = crit(logits, yb).item()
            va_acc  = (logits.argmax(1)==yb).float().mean().item()

        print(f"Epoch {ep:02d} | train_acc {train_acc:.3f} | val_acc {va_acc:.3f} | val_loss {va_loss:.3f}")
        if va_acc > best_acc:
            best_acc = va_acc
            torch.save({
                "model": model.state_dict(),
                "in_dim": X.shape[-1],
                "hidden": hidden,
                "layers": layers,
                "num_classes": len(CLASSES)
            }, best_path)

    # final report
    ckpt = torch.load(best_path, map_location=DEVICE)
    model.load_state_dict(ckpt["model"]); model.eval()
    with torch.no_grad():
        y_pred = model(torch.from_numpy(Xva).float().to(DEVICE)).argmax(1).cpu().numpy()
    print("\nValidation report:\n", classification_report(yva, y_pred, target_names=CLASSES))

# Uncomment to train after dataset is built:
train_lstm()


Epoch 01 | train_acc 0.859 | val_acc 0.932 | val_loss 0.224
Epoch 02 | train_acc 0.961 | val_acc 0.954 | val_loss 0.119
Epoch 03 | train_acc 0.983 | val_acc 0.973 | val_loss 0.089
Epoch 04 | train_acc 0.986 | val_acc 0.980 | val_loss 0.064
Epoch 05 | train_acc 0.992 | val_acc 0.980 | val_loss 0.060
Epoch 06 | train_acc 0.996 | val_acc 0.985 | val_loss 0.039
Epoch 07 | train_acc 0.994 | val_acc 0.976 | val_loss 0.090
Epoch 08 | train_acc 0.991 | val_acc 0.969 | val_loss 0.088
Epoch 09 | train_acc 0.995 | val_acc 0.989 | val_loss 0.038
Epoch 10 | train_acc 0.999 | val_acc 0.989 | val_loss 0.027
Epoch 11 | train_acc 0.994 | val_acc 0.973 | val_loss 0.088
Epoch 12 | train_acc 0.995 | val_acc 0.991 | val_loss 0.020
Epoch 13 | train_acc 1.000 | val_acc 0.993 | val_loss 0.021
Epoch 14 | train_acc 0.998 | val_acc 0.989 | val_loss 0.033
Epoch 15 | train_acc 0.997 | val_acc 0.987 | val_loss 0.036
Epoch 16 | train_acc 0.997 | val_acc 0.984 | val_loss 0.053
Epoch 17 | train_acc 0.997 | val_acc 0.9

In [17]:

# ==== 7) Realtime (webcam) or video-file inference ====
def load_lstm_checkpoint(path="lstm_best.pt"):
    ckpt = torch.load(path, map_location=DEVICE)
    model = ActionLSTM(ckpt["in_dim"], ckpt["hidden"], ckpt["layers"], ckpt["num_classes"]).to(DEVICE)
    model.load_state_dict(ckpt["model"]); model.eval()
    return model

def infer_webcam(pose_model=POSE_MODEL, lstm_path="lstm_best.pt", seq_len=SEQ_LEN, stride=STRIDE, thresh=0.60):
    pose = YOLO(pose_model)
    model = load_lstm_checkpoint(lstm_path)

    cap = cv2.VideoCapture(0)
    buf = collections.deque(maxlen=seq_len)
    frame_idx, last_pred, smooth = 0, None, None

    while True:
        ok, frame = cap.read()
        if not ok: break
        results = pose.predict(frame, verbose=False)
        kpts = extract_keypoints_from_frame(results)
        if kpts is not None:
            xy_norm = center_and_scale(kpts).flatten()
            if len(buf) > 0:
                prev = (np.array(buf[-1])[:34]).reshape(17,2)
                vel = (xy_norm.reshape(17,2) - prev).flatten()
            else:
                vel = np.zeros_like(xy_norm)
            feat = np.concatenate([xy_norm, vel], axis=0)
            buf.append(feat)

        frame_idx += 1
        if len(buf) == seq_len and frame_idx % stride == 0:
            x = torch.from_numpy(np.expand_dims(np.stack(buf, axis=0), 0)).float().to(DEVICE)
            with torch.no_grad():
                logits = model(x)
                probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
            smooth = probs if smooth is None else 0.6*smooth + 0.4*probs
            cls_id = int(np.argmax(smooth))
            conf = float(smooth[cls_id])
            if conf >= thresh:
                last_pred = (CLASSES[cls_id], conf)

        # draw overlay
        if last_pred:
            text = f"{last_pred[0]}: {last_pred[1]:.2f}"
            cv2.putText(frame, text, (24,40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)

        cv2.imshow("Pickleball LSTM+YOLOv8", frame)
        if cv2.waitKey(1) & 0xFF == 27: break

    cap.release(); cv2.destroyAllWindows()

def infer_video_file(video_path, pose_model=POSE_MODEL, lstm_path="lstm_best.pt",
                     seq_len=SEQ_LEN, stride=STRIDE, thresh=0.60, display=True, save_path=None):
    pose = YOLO(pose_model)
    model = load_lstm_checkpoint(lstm_path)

    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        print("Cannot open", video_path); return

    buf = collections.deque(maxlen=seq_len)
    frame_idx, last_pred, smooth = 0, None, None

    fourcc = cv2.VideoWriter_fourcc(*"mp4v") if save_path else None
    out = None
    if save_path:
        w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
        out = cv2.VideoWriter(save_path, fourcc, fps, (w, h))

    while True:
        ok, frame = cap.read()
        if not ok: break
        results = pose.predict(frame, verbose=False)
        kpts = extract_keypoints_from_frame(results)
        if kpts is not None:
            xy_norm = center_and_scale(kpts).flatten()
            if len(buf) > 0:
                prev = (np.array(buf[-1])[:34]).reshape(17,2)
                vel = (xy_norm.reshape(17,2) - prev).flatten()
            else:
                vel = np.zeros_like(xy_norm)
            feat = np.concatenate([xy_norm, vel], axis=0)
            buf.append(feat)

        frame_idx += 1
        if len(buf) == seq_len and frame_idx % stride == 0:
            x = torch.from_numpy(np.expand_dims(np.stack(buf, axis=0), 0)).float().to(DEVICE)
            with torch.no_grad():
                logits = model(x)
                probs = torch.softmax(logits, dim=-1)[0].cpu().numpy()
            smooth = probs if smooth is None else 0.6*smooth + 0.4*probs
            cls_id = int(np.argmax(smooth))
            conf = float(smooth[cls_id])
            if conf >= thresh:
                last_pred = (CLASSES[cls_id], conf)

        if last_pred:
            text = f"{last_pred[0]}: {last_pred[1]:.2f}"
            cv2.putText(frame, text, (24,40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0,255,0), 2)

        if display:
            cv2.imshow("Pickleball LSTM+YOLOv8", frame)
            if cv2.waitKey(1) & 0xFF == 27: break
        if out:
            out.write(frame)

    cap.release()
    if out: out.release()
    if display:
        cv2.destroyAllWindows()

# Example usage (uncomment to run):infer_webcam()
infer_video_file(
    r"D:\B3-ICT\Group Project\Untitled video - Made with Clipchamp.mp4",
    save_path="hihihi.mp4",
    display=False        # <— turn off imshow
)


In [14]:

# ==== 8) Export (TorchScript / ONNX) ====
def export_torchscript(ckpt_path="lstm_best.pt", out_path="lstm_ts.pt", seq_len=SEQ_LEN, feat_dim=68):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model = ActionLSTM(ckpt["in_dim"], ckpt["hidden"], ckpt["layers"], ckpt["num_classes"]).cpu()
    model.load_state_dict(ckpt["model"]); model.eval()
    example = torch.randn(1, seq_len, feat_dim)
    traced = torch.jit.trace(model, example)
    traced.save(out_path)
    print("Saved TorchScript to", out_path)

def export_onnx(ckpt_path="lstm_best.pt", out_path="lstm.onnx", seq_len=SEQ_LEN, feat_dim=68):
    ckpt = torch.load(ckpt_path, map_location="cpu")
    model = ActionLSTM(ckpt["in_dim"], ckpt["hidden"], ckpt["layers"], ckpt["num_classes"]).cpu()
    model.load_state_dict(ckpt["model"]); model.eval()
    x = torch.randn(1, seq_len, feat_dim, requires_grad=False)
    torch.onnx.export(
        model, x, out_path,
        export_params=True,
        opset_version=13,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['logits'],
        dynamic_axes={'input': {0: 'batch', 1: 'time'}, 'logits': {0: 'batch'}}
    )
    print("Saved ONNX to", out_path)

# Example:
export_torchscript()
export_onnx()


Saved TorchScript to lstm_ts.pt


  torch.onnx.export(


OnnxExporterError: Module onnx is not installed!


## 9) Tips for Accuracy & Robustness

- **More classes**: Dink, Volley, Smash, Drop, Lob.
- **Ball detector**: train a tiny YOLO just for the ball and append ball (x, y, speed) to features.
- **Better normalization**: rotate skeleton to torso-aligned coordinates.
- **Augmentations**: horizontal flip (if class semantics allow), time-warping, subsequence jitter.
- **Class imbalance**: oversample, use class weights.
- **Post-processing**: hysteresis—require K consecutive steps to start/stop an event.
- **Data splits**: split by player/session to prevent leakage.


In [None]:

# ==== 10) Quickstart (edit paths, then run in order) ====
# 1) Set DATA_ROOT and CLASSES above to match your folders.
# 2) Run: build_dataset()
# 3) Run: train_lstm()
# 4) Run: infer_webcam()  or  infer_video_file("your_clip.mp4", save_path="annotated.mp4")
# 5) Export: export_torchscript() or export_onnx()
