In [14]:
from inference import get_model
import supervision as sv
from pathlib import Path
from tqdm import tqdm

api_key = "Ri8TOlA44FZ62FfabnTt"

def process_image(image_path, models):
    results = {}
    for model_name, model in models.items():
        model_results = model.infer(str(image_path))[0]
        if model_name == 'pose':
            results[model_name] = sv.KeyPoints.from_inference(model_results)
        elif model_name == 'ball':
            detections = sv.Detections.from_inference(model_results)
            results[model_name] = [(box.tolist(), conf) for box, conf in zip(detections.xyxy, detections.confidence)]
    return results

def process_frames(base_path='ex4', models=None):
    if models is None:
        models = {'pose': pose_model, 'ball': ball_model}
    
    data = {model_name: {} for model_name in models}
    
    base_path = Path(base_path)
    
    # Count total PNG files
    png_files = list(base_path.glob('*.png'))
    total_files = len(png_files)
    
    print(f"Total PNG files to process: {total_files}")
    
    with tqdm(total=total_files, desc="Processing frames") as pbar:
        for png_file in png_files:
            frame_number = int(png_file.stem)
            try:
                results = process_image(png_file, models)
                for model_name, model_results in results.items():
                    data[model_name][frame_number] = model_results
            except Exception as e:
                print(f"Error processing {png_file}: {str(e)}")
            pbar.update(1)  # Update once per image
    
    return data

# Initialize the models
pose_model = get_model(model_id="yolov8x-pose-1280", api_key=api_key)
ball_model = get_model(model_id="padel_ball/3", api_key=api_key)

# Run the processing
all_data = process_frames()

# Print some statistics
print(f"\nTotal pose frames: {len(all_data['pose'])}")
print(f"Total ball frames: {len(all_data['ball'])}")

# Print a few sample entries
print("\nSample pose data:")
for frame_id, pose_data in list(all_data['pose'].items())[:3]:
    print(f"Frame {frame_id}: {pose_data}")

print("\nSample ball data:")
for frame_id, ball_data in list(all_data['ball'].items())[:3]:
    print(f"Frame {frame_id}: {ball_data}")

Total PNG files to process: 1460


Processing frames: 100%|██████████| 1460/1460 [19:26<00:00,  1.25it/s]


Total pose frames: 1460
Total ball frames: 1460

Sample pose data:
Frame 348: KeyPoints(xy=array([[[1093.,  353.],
        [1096.,  350.],
        [1092.,  350.],
        [1100.,  354.],
        [1096.,  352.],
        [1112.,  378.],
        [1088.,  370.],
        [1124.,  410.],
        [1074.,  384.],
        [1099.,  431.],
        [1062.,  385.],
        [1108.,  430.],
        [1094.,  426.],
        [1070.,  467.],
        [1081.,  460.],
        [1077.,  510.],
        [1105.,  491.]]], dtype=float32), class_id=array([0]), confidence=array([[0.46269372, 0.43077105, 0.10891101, 0.787991  , 0.14176577,
        0.9976913 , 0.98278177, 0.9962567 , 0.89482415, 0.9906292 ,
        0.8767611 , 0.9968356 , 0.99045146, 0.9952928 , 0.9814501 ,
        0.9842116 , 0.95407826]], dtype=float32), data={'class_name': array(['person'], dtype='<U6')})
Frame 1186: KeyPoints(xy=array([[[ 499.,  444.],
        [ 491.,  439.],
        [ 495.,  439.],
        [ 452.,  443.],
        [ 488.,  445.]




In [15]:
import pickle
def save_data(data, filename):
    with open(filename, 'wb') as f:
        pickle.dump(data, f)
    print(f"Data saved to {filename}")

def load_data(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    print(f"Data loaded from {filename}")
    return data


In [16]:
print("all_data shape: ", len(all_data))
# Save the data
save_data(all_data, 'inference_all_data.pkl')

all_data shape:  2
Data saved to inference_all_data.pkl
