---
title: "Swing DataClass"
author: "Ali Zaidi"
date: "2025-11-17"
categories: [Data Engineering]
description: "Now that we have the resources to find the clipping indexes for a videos, lets expand the capacity by introducing a dataclass that can handle and organize alot of this inforamtion"
format:
  html:
    code-fold: true
jupyter: python3
---

In [1]:
#| include: false
from fastai.vision.all import *

In [2]:
#| include: false
base_path = '../../../data/full_videos/ymirza'
swing_days = ['jun8', 'aug9', 'sep14']
parent_dir = f'{base_path}/{swing_days[-1]}'
files = [file for file in get_files(parent_dir, extensions='.pkl') if file.name[:3] == 'IMG']
file_names = [file.name.split('.')[0] for file in files]
len(file_names)

47

### We need functionality that stores the keypoints of a swing and computes/stores various components of interest, having these functions stored into a dataclass can help with modeling and plotting things quickly in order to extract useful information of how components evolve as the swing progresses
    - The angle of the hips and shoulders
    - The angle of joints (wrist, elbow + hand)
    - Club head information
#### By using abstract functions, we can compute the distance and angles between components and try to extract meaningful information fromm their relationship as the swing progresses

In [3]:
from typing import TypeVar, Generic, List, Optional, Callable
from dataclasses import dataclass, field
import numpy as np
import pickle
import cv2

In [4]:
#| code-fold: true
class SwingMetaData:
    """Data Container for a swings video metadata
    """
    def __init__(self,
                 path: str,
                 get_swing_idx=False,
                 get_score=False,
                 ):
        """
        Initialize the swing data and associated metadata
        
        Args:
            path: file_path to data
        """
        self.path = path
        self.str_path = self.get_str_path()
        self.video_path = f'{self.str_path.split(".")[-2]}.mp4'
        self.file_name = self.str_path.split('/')[-1].split('.')[0]
        if get_swing_idx:
            self.swing_idx = int(self.file_name.split('_')[-3])
        if get_score:
            self.score = self.file_name.split('_')[-1].split('.')[0]

    def get_str_path(self):
        # Just making sure the path being used is a str + not path object
        return str(self.path)

    def get_video_name(self):
        video = '_'.join(self.str_path.split('/')[1].split('_')[:2])
        return video

In [8]:
#| echo: false
swing_meta = SwingMetaData(files[0], get_swing_idx=True, get_score=True)
print(f'Our meta class can store useful information like the filename \n --->\
 {swing_meta.file_name}')
print(f'Also the swing index \n --->\
 {swing_meta.swing_idx}')
print(f'And the score (when added) \n --->\
 {swing_meta.score}')
print(f'For plotting and output purposes, a path to our labeled video is also available')

Our meta class can store useful information like the filename 
 ---> IMG_1093_swing_6_score_None
Also the swing index 
 ---> 6
And the score (when added) 
 ---> None
For plotting and output purposes, a path to our labeled video is also available


In [9]:
#| code-fold: true
class SwingKeypointData(SwingMetaData):
    """
    Class to handle keypoint data
    Will take a videos metadata and pull down keypoint values and scores
    """
    
    def __init__(self, video_path):
        super().__init__(video_path)
        #self.metadata = metadata
        self.kp_dicts = self.get_kp_dicts()
        self.key_points = self.get_keypoints()
        self.scores = self.get_scores()
        self.kps = np.concatenate([self.key_points, np.expand_dims(self.scores, -1)], axis=2)


    def get_kp_dicts(self):
        # Get the frame by frame output dicts from pose estimation models
        with open(self.str_path, 'rb') as f:
            loaded_dicts = pickle.load(f)
        return loaded_dicts

        
    def get_keypoints(self):
        kp_dicts = self.kp_dicts
        return np.stack([self.kp_dicts[key]['keypoints'] for key in kp_dicts.keys()])
        
        
    def get_scores(self):
        kp_dicts = self.kp_dicts
        return np.stack([self.kp_dicts[key]['keypoint_scores'] for key in kp_dicts.keys()])


    def get_frame(self, idx):
        ''' Just grabs a single frames keypoints and scores
        '''
        kps = self.key_points[idx]
        scores = self.scores[idx]
        return np.column_stack((kps, scores))

    
    def get_frames(self, indexes):
        '''
        num_idxs = len(Indexes)
        Takes a list of indexes and returns an array wof [num_idxs, 17, 3]
        [1] 17 keypoint markers
        [2] 3 keypoint values/certaintainty (X, Y, Score)
        '''
        return np.stack([self.get_frame(idx) for idx in indexes])

    
    def __len__(self):
        return len(self.key_points)

In [10]:
#| echo: false
skpdata = SwingKeypointData(files[0])
print(f'with our SwingKeyPointData class, we can store the raw keypoint values')
print(f'A swings full keypoints/confidence array would have shape: {skpdata.kps.shape}')
print(f'If we want to access just the keypoints we can we can access \n\
    the "key_points" attribute: {skpdata.key_points.shape}')
print(f'If we are interested in seeing just confidence values we can access \n\
    the "scores" attribute: {skpdata.scores.shape}')

with our SwingKeyPointData class, we can store the raw keypoint values
A swings full keypoints/confidence array would have shape: (180, 17, 3)
If we want to access just the keypoints we can we can access 
    the "key_points" attribute: (180, 17, 2)
If we are interested in seeing just confidence values we can access 
    the "scores" attribute: (180, 17)


### Now we are in a position to create a superclass that inherits these components and stores a more rich representation of our data

In [11]:
#| code-fold: true
class KpExtractor(SwingMetaData):
    def __init__(self, 
                 file_name,
                 get_swing_idx=False,
                 get_score=False,
                 score_threshold=None):
        super().__init__(file_name, 
                         get_swing_idx, 
                         get_score)
        self.keypoint_data = SwingKeypointData(file_name)
        self.score_threshold = score_threshold
        self.kps = self.threshold_score(self.keypoint_data, score_threshold)

        self.coco_idxs = {"L_SH":5, "R_SH":6, "L_EL":7, "R_EL":8, 
             "L_WR":9, "R_WR":10, "L_HI":11, "R_HI":12, #HIP
             "L_KN":13, "R_KN":14, "L_ANK":15, "R_ANK":16}
        
        # Dynamically create attributes using setattr
        for attr_name, coco_key in self.coco_idxs.items():
            kp_val = self.kps[:, coco_key, :].astype(float).copy()
            setattr(self, attr_name.lower(), kp_val)  # Lowercase only
    
    def threshold_score(self, kps, threshold_value=0.5):
        if threshold_value is None: 
            return self.keypoint_data.kps.astype(float).copy()
        # punch up score values to a threshold
        kps = self.keypoint_data.kps.astype(float).copy()
        mask = kps[..., 2] < threshold_value
        kps[mask, 2] = threshold_value
        return kps

    # helper function to index specific joints
    def __getattr__(self, name):
        # Convert requested attribute to lowercase and try to find it
        lower_name = name.lower()
        try:
            return object.__getattribute__(self, lower_name)
        except AttributeError:
            raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

In [86]:
#| echo: false
kp_0 = KpExtractor(files[0])
print(f'With our KpExtractor class, we can interrogate our data more effectively')
print(f'A swings full keypoints/confidence can still be found with ".kp" paradigm \n\
    --> here is the shape of the keypoints:{kp_0.kps.shape}')
print(f'If we want to access just keypoints of a single point -- can access this\n\
    w/ ".X_YY" paradigm, here is the array shape of the left shoulder: {kp_0.L_SH.shape}\n\
    where "X" represents left or right and \n\
    "YY" corresponds to a keypoint "L_WR--> left wrist')

With our KpExtractor class, we can interrogate our data more effectively
A swings full keypoints/confidence can still be found with ".kp" paradigm 
    --> here is the shape of the keypoints:(180, 17, 3)
If we want to access just keypoints of a single point -- can access this
    w/ ".X_YY" paradigm, here is the array shape of the left shoulder: (180, 3)
    where "X" represents left or right and 
    "YY" corresponds to a keypoint "L_WR--> left wrist


In [87]:
l_sh = kp_0.l_sh
r_sh = kp_0.r_sh
l_wr = kp_0.l_wr
r_wr = kp_0.r_wr

In [88]:
l_sh.shape, l_sh[:, 0].shape

((180, 3), (180,))