In [36]:
import cv2
import supervision as sv
import pickle
import numpy as np
from pathlib import Path
from tqdm import tqdm

# Define the project name
project_name = "ex4"

# Shot detection parameters
GREEN_SHOT_WINDOW = 7   # Number of frames to consider for a confirmed shot
GREEN_SHOT_THRESHOLD = 7 # Minimum number of '1's in the window for a confirmed shot
YELLOW_SHOT_THRESHOLD = 2 # Minimum number of consecutive '1's for a potential shot

# Load data from pickle file
with open(f"inference_all_data_{project_name}.pkl", "rb") as f:
    all_data = pickle.load(f)

# Load frame predictions
with open(f"frame_predictions_{project_name}.txt", "r") as f:
    frame_predictions = [int(line.strip()) for line in f]

def process_shot_predictions(predictions):
    n = len(predictions)
    processed = np.zeros(n, dtype=int)  # 0: no shot, 1: potential shot (yellow), 2: confirmed shot (green)
    
    # Detect green (confirmed) shots
    for i in range(n - GREEN_SHOT_WINDOW + 1):
        window = predictions[i:i+GREEN_SHOT_WINDOW]
        if sum(window) >= GREEN_SHOT_THRESHOLD:
            processed[i:i+GREEN_SHOT_WINDOW] = 2
    
    # Detect yellow (potential) shots
    consecutive_ones = 0
    for i in range(n):
        if predictions[i] == 1:
            consecutive_ones += 1
        else:
            if YELLOW_SHOT_THRESHOLD <= consecutive_ones < GREEN_SHOT_THRESHOLD:
                processed[i-consecutive_ones:i] = np.maximum(processed[i-consecutive_ones:i], 1)
            consecutive_ones = 0
    
    # Handle case where consecutive ones extend to the end of the array
    if YELLOW_SHOT_THRESHOLD <= consecutive_ones < GREEN_SHOT_THRESHOLD:
        processed[-consecutive_ones:] = np.maximum(processed[-consecutive_ones:], 1)
    
    return processed

processed_predictions = process_shot_predictions(frame_predictions)

def process_frame(frame_number, image, pose_data, ball_data, shot_prediction):
    # Create annotators
    edge_annotator = sv.EdgeAnnotator(color=sv.Color.GREEN, thickness=2)
    box_annotator = sv.BoxAnnotator(color=sv.Color.RED, thickness=2)
    
    # Annotate keypoints
    if frame_number in pose_data:
        keypoints = sv.KeyPoints(xy=pose_data[frame_number].xy, confidence=pose_data[frame_number].confidence)
        image = edge_annotator.annotate(scene=image, key_points=keypoints)
    
    # Annotate ball bounding boxes
    if frame_number in ball_data and ball_data[frame_number]:
        detections = sv.Detections(
            xyxy=np.array([box for box, _ in ball_data[frame_number]]),
            confidence=np.array([conf for _, conf in ball_data[frame_number]]),
            class_id=np.zeros(len(ball_data[frame_number]), dtype=int)  # Assign class_id 0 to all ball detections
        )
        image = box_annotator.annotate(scene=image.copy(), detections=detections)
    
    # Add shot prediction overlay
    height, width = image.shape[:2]
    overlay = image.copy()
    
    if shot_prediction == 2:  # Confirmed shot (green)
        cv2.rectangle(overlay, (0, 0), (width, height), (0, 255, 0), -1)
        cv2.putText(overlay, "Shot", (width//2-50, height//2), cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3)
    elif shot_prediction == 1:  # Potential shot (yellow)
        cv2.rectangle(overlay, (0, 0), (width, height), (0, 255, 255), -1)
        cv2.putText(overlay, "Potential Shot", (width//2-150, height//2), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 0, 0), 3)
    
     
    cv2.addWeighted(overlay, 0.3, image, 0.7, 0, image)
    
    # Add frame number to the image (moved slightly down)
    cv2.putText(image, f"Frame: {frame_number}", (10, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
    
    return image
    
    return image

# Get and sort frame files
frame_files = sorted(Path(project_name).glob("*.png"), key=lambda x: int(x.stem))

# Set up video writer
first_frame = cv2.imread(str(frame_files[0]))
height, width = first_frame.shape[:2]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(f"{project_name}-annotated.mp4", fourcc, 30, (width, height))

# Process each frame with progress bar
for frame_file in tqdm(frame_files, desc="Processing frames"):
    frame_number = int(frame_file.stem)  # Get frame number from filename
    image = cv2.imread(str(frame_file))
    
    # Ensure frame_number is within bounds of processed_predictions
    prediction_index = min(frame_number - 1, len(processed_predictions) - 1)
    
    processed_frame = process_frame(
        frame_number,
        image,
        all_data['pose'],
        all_data['ball'],
        processed_predictions[prediction_index]
    )
    
    out.write(processed_frame)

# Release the video writer
out.release()

print(f"Video {project_name}-annotated.mp4 has been created.")

Processing frames: 100%|██████████| 1460/1460 [01:05<00:00, 22.45it/s]

Video ex4-annotated.mp4 has been created.



