In [26]:
import pickle
import numpy as np
import pandas as pd
from typing import List, Dict
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

FRAME_WINDOW = 7
FRAMES_BEFORE = 3
MIDDLE_FRAME_INDEX = 3

def get_bottom_left_ball(ball_data: List) -> List[float]:
    if not ball_data:
        return [0, 0, 0, 0]  # No ball detected, return zeros

    # Ensure each bounding box data entry has exactly 4 elements (x1, y1, x2, y2)
    processed_data = [bbox[:4] for bbox in ball_data if len(bbox) >= 4]

    if not processed_data:
        return [0, 0, 0, 0]  # Return zeros if no valid data is found

    ball_array = np.array(processed_data)  # Safely convert to NumPy array now
    ball_positions = ball_array[:, [0, 1]]  # x, y positions from [x1, y1, x2, y2]
    bottom_left_idx = np.lexsort((-ball_positions[:, 1], ball_positions[:, 0]))[0]

    return ball_data[bottom_left_idx][:4]  # Return bounding box, discard any other data

def get_bottom_left_player(keypoints) -> np.ndarray:
    if keypoints.xy.shape[0] == 0:
        return np.zeros((17, 2))
    
    centers = np.mean(keypoints.xy, axis=1)
    bottom_left_idx = np.lexsort(([-y for _, y in centers], [x for x, _ in centers]))[0]
    
    return keypoints.xy[bottom_left_idx]  # Return keypoints of the selected player

def extract_features(all_data: Dict, frame_numbers: List[int]) -> List[float]:
    features = []
    for i in frame_numbers:
        # Process pose data
        if i in all_data['pose']:
            keypoints = all_data['pose'][i]
            bottom_left_keypoints = get_bottom_left_player(keypoints)
            pose_features = bottom_left_keypoints.flatten().tolist()
        else:
            pose_features = [0] * (17 * 2)  # No pose data available, fill with zeros
        
        # Process ball data
        if i in all_data['ball']:
            ball_data = all_data['ball'][i]
            ball_features = get_bottom_left_ball(ball_data)
        else:
            ball_features = [0] * 4  # No ball data available, fill with zeros
        
        features.extend(pose_features + ball_features)
    
    return features

def generate_inference_data(all_data: Dict) -> List[List[float]]:
    data = []
    frame_numbers = sorted(set(all_data['pose'].keys()) | set(all_data['ball'].keys()))
    
    for i in range(len(frame_numbers) - FRAME_WINDOW + 1):
        current_frames = frame_numbers[i:i+FRAME_WINDOW]
        features = extract_features(all_data, current_frames)
        data.append(features)
    
    return data

# Load the inference data
with open('inference_all_data_ex4.pkl', 'rb') as f:
    inference_data = pickle.load(f)

# Generate features for inference
inference_features = generate_inference_data(inference_data)

# Create DataFrame
df_inference = pd.DataFrame(inference_features)

# Save the DataFrame
df_inference.to_csv('padel_shots_inference_dataset.csv', index=False)

logger.info(f"Inference dataset shape: {df_inference.shape}")

INFO:__main__:Inference dataset shape: (1454, 266)
