In [10]:
%pip install inference supervision inference_sdk --quiet

Note: you may need to restart the kernel to use updated packages.


In [11]:
import supervision as sv
from inference import get_model, InferencePipeline
from tqdm import tqdm

In [12]:
file_name = "demo_alt_10fps"
file_extension = ".mp4"
source_path = f"videos/{file_name}{file_extension}"
api_key = "" # API KEY
ball_detection_model_id = "tennis-ball-detection-e5vmz/16" #"tennis-ball-detection-e5vmz/15"
api_url = "http://127.0.0.1:9001" # use local inference server

In [13]:
ball_detector = get_model(model_id=ball_detection_model_id, api_key=api_key)

In [14]:
from inference_sdk import InferenceHTTPClient
CLIENT = InferenceHTTPClient(
    api_url="http://localhost:9001",
    api_key=api_key
)

image_url = "https://source.roboflow.com/pwYAXv9BTpqLyFfgQoPZ/u48G0UpWfk8giSw7wrU8/original.jpg"
result = CLIENT.infer(image_url, model_id="soccer-players-5fuqs/1")
print(result)

{'inference_id': '7eaec277-7484-4bb9-b804-395d0529998d', 'time': 0.10833283399551874, 'image': {'width': 398, 'height': 224}, 'predictions': [{'x': 5.0, 'y': 152.0, 'width': 10.0, 'height': 28.0, 'confidence': 0.8946052193641663, 'class': 'player', 'class_id': 1, 'detection_id': '2307aa31-8fc9-48af-b717-f10b78785eb2'}, {'x': 145.0, 'y': 95.5, 'width': 14.0, 'height': 25.0, 'confidence': 0.8900073766708374, 'class': 'player', 'class_id': 1, 'detection_id': '26ea721e-84e0-4868-99e2-d4907aae68fb'}, {'x': 25.0, 'y': 89.5, 'width': 16.0, 'height': 25.0, 'confidence': 0.8871816396713257, 'class': 'player', 'class_id': 1, 'detection_id': 'ea3a19ce-f30c-4335-a79d-5ea42fb6c0f4'}, {'x': 313.0, 'y': 178.0, 'width': 16.0, 'height': 32.0, 'confidence': 0.87803715467453, 'class': 'player', 'class_id': 1, 'detection_id': 'c6486bbc-335a-4e12-95d4-a704d46f996d'}, {'x': 324.0, 'y': 106.5, 'width': 12.0, 'height': 25.0, 'confidence': 0.8769770264625549, 'class': 'player', 'class_id': 1, 'detection_id': '

In [15]:
from inference_sdk import InferenceHTTPClient, InferenceConfiguration
import cv2
import numpy as np
import supervision as sv

custom_configuration = InferenceConfiguration(confidence_threshold=0.15)

client = InferenceHTTPClient(
    api_url="http://localhost:9001",
    api_key=api_key
)

client.configure(custom_configuration)

def infer_toss(image):
    toss_predictions = client.run_workflow(
        workspace_name="tennis-rally-detection",
        workflow_id="detect-toss",
        images={
            "image": image,
        }
    )
    return toss_predictions

def infer_ball(image):
    ball_predictions = client.infer(
        image,
        model_id=ball_detection_model_id,
        
    )
    return ball_predictions

def ball_detection_callback(image_slice: np.ndarray) -> sv.Detections:
    # Use the existing infer_ball function
    predictions = infer_ball(image_slice)
    
    # Convert predictions to sv.Detections format
    detections = sv.Detections.from_inference(predictions)
    return detections

ball_slicer = sv.InferenceSlicer(callback=ball_detection_callback)

In [16]:
import supervision as sv

video_info = sv.VideoInfo.from_video_path(video_path=source_path)

total_frames = video_info.total_frames
frame_number = 0


results = []

progress_bar = tqdm(total=total_frames, desc="Processing frames")

for frame in sv.get_video_frames_generator(source_path=source_path):
    toss_predictions = infer_toss(frame)
    ball_predictions = ball_slicer(frame)
    frame_number += 1
    progress_bar.update(1)
    results.append((toss_predictions, ball_predictions))
progress_bar.close()
# toss:
# [{'output': [{'time': 0.33360033300004943, 'image': {'width': 65, 'height': 109}, 'predictions': [{'class': 'other', 'class_id': 1, 'confidence': 0.9926}, {'class': 'toss', 'class_id': 0, 'confidence': 0.0074}], 'top': 'other', 'confidence': 0.9926, 'prediction_type': 'classification', 'parent_id': 'ced0e87d-2bd5-420d-be0c-66c145699be7', 'root_parent_id': 'image'}, {'time': 0.3336179580001044, 'image': {'width': 102, 'height': 152}, 'predictions': [{'class': 'other', 'class_id': 1, 'confidence': 0.9939}, {'class': 'toss', 'class_id': 0, 'confidence': 0.0061}], 'top': 'other', 'confidence': 0.9939, 'prediction_type': 'classification', 'parent_id': '627ce165-92d3-4070-996b-ca7adf4f2d1c', 'root_parent_id': 'image'}]}]

# ball:
#{'inference_id': '64347833-556d-4ba2-881e-48d508a8d3ca', 'time': 0.33276966599987645, 'image': {'width': 1280, 'height': 720}, 'predictions': [{'x': 710.0, 'y': 178.0, 'width': 6.0, 'height': 6.0, 'confidence': 0.6501944661140442, 'class': 'tennis-ball', 'class_id': 0, 'detection_id': '9d96a56e-6207-4bbb-abd8-02ab846e1f6e'}]}


Processing frames:   0%|          | 0/499 [00:00<?, ?it/s]

Processing frames: 100%|██████████| 499/499 [18:20<00:00,  2.21s/it]


In [17]:
def getTossResult(toss_predictions):
    if not toss_predictions:
        return False
    for output in toss_predictions:
        if not isinstance(output, dict) or 'output' not in output:
            continue
        for prediction in output['output']:
            if isinstance(prediction, dict) and prediction.get('top') == 'toss':
                return True
    return False

toss_results = []
ball_results = []
player_detections = []

for toss_predictions, ball_predictions in results:
    # Always append a boolean value, even if predictions are None
    toss_results.append(getTossResult(toss_predictions))
    #ball_results.append(len([pred for pred in ball_predictions.get('predictions', []) if pred['class'] == 'tennis-ball']))
    ball_results.append(len(ball_predictions.class_id))

    player_detections.append(sv.Detections.from_inference(toss_predictions["predictions_player"]))

print(f"Frames with toss detected: {sum(toss_results)}")
print(f"Frames with ball detected: {sum(ball_results)}")

print(f"Total frames processed: {len(toss_results)}")
print(f"Toss results length: {len(toss_results)}")
print(f"Ball results length: {len(ball_results)}")

AFTER_TOSS_THRESHOLD = 20

NO_BALL_THRESHOLD = 6
no_ball_end_signal = [False] * len(ball_results)
for i in range(4, len(ball_results)):
    if all(ball_results[j] == 0 for j in range(i-NO_BALL_THRESHOLD, i+1)) and not any(toss_results[j] for j in range(max(0, i-AFTER_TOSS_THRESHOLD), i+1)):
        no_ball_end_signal[i] = True

# Define the threshold for multiple balls
MULTIPLE_BALL_THRESHOLD = 2
MIN_HITS = 5
FRAMES_WINDOW = 6

multiple_ball_end_signal = [False] * len(ball_results)
for i in range(FRAMES_WINDOW - 1, len(ball_results)):
    # Count frames with 2 or more balls in the last 6 frames (including current)
    multiple_ball_count = sum(1 for j in range(i-FRAMES_WINDOW+1, i+1) if ball_results[j] >= MULTIPLE_BALL_THRESHOLD)
    print(f"Frame {i}: Multiple ball count: {multiple_ball_count}")
    
    # Check if there's no toss in the last 5 frames
    no_toss = not any(toss_results[j] for j in range(i-AFTER_TOSS_THRESHOLD, i+1))
    
    # Set the end signal if at least 3 out of 5 frames have multiple balls and there's no recent toss
    if multiple_ball_count >= MIN_HITS and no_toss:
        multiple_ball_end_signal[i] = True




# Print all results
print("\nAll results:")
for i in range(len(toss_results)):
    print(f"Frame {i}: | T: {toss_results[i]}, B: {ball_results[i]} | End signals: {no_ball_end_signal[i]} {multiple_ball_end_signal[i]}")

Frames with toss detected: 2
Frames with ball detected: 614
Total frames processed: 499
Toss results length: 499
Ball results length: 499
Frame 5: Multiple ball count: 0
Frame 6: Multiple ball count: 0
Frame 7: Multiple ball count: 0
Frame 8: Multiple ball count: 0
Frame 9: Multiple ball count: 0
Frame 10: Multiple ball count: 0
Frame 11: Multiple ball count: 0
Frame 12: Multiple ball count: 0
Frame 13: Multiple ball count: 0
Frame 14: Multiple ball count: 0
Frame 15: Multiple ball count: 0
Frame 16: Multiple ball count: 0
Frame 17: Multiple ball count: 0
Frame 18: Multiple ball count: 0
Frame 19: Multiple ball count: 0
Frame 20: Multiple ball count: 1
Frame 21: Multiple ball count: 1
Frame 22: Multiple ball count: 1
Frame 23: Multiple ball count: 1
Frame 24: Multiple ball count: 1
Frame 25: Multiple ball count: 1
Frame 26: Multiple ball count: 0
Frame 27: Multiple ball count: 0
Frame 28: Multiple ball count: 0
Frame 29: Multiple ball count: 0
Frame 30: Multiple ball count: 0
Frame 31:

In [18]:
import cv2
import numpy as np
import supervision as sv

# Define video information
video_info = sv.VideoInfo.from_video_path(source_path)

# Create VideoSink for output
import os
from datetime import datetime

original_name = source_path.split("/")[-1]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_filename = f"output_{file_name}_{timestamp}{file_extension}"

with sv.VideoSink(target_path=output_filename, video_info=video_info) as sink:
    cap = cv2.VideoCapture(source_path)
    in_play = False
    
    for frame_idx, ((toss_predictions, ball_predictions), (toss_result, ball_result), end_signal_1, end_signal_2) in enumerate(zip(results, zip(toss_results, ball_results), no_ball_end_signal, multiple_ball_end_signal)):
        ret, frame = cap.read()
        if not ret:
            break
        
        # Update in_play state
        if not in_play and toss_result:
            in_play = True
        elif in_play and (end_signal_1 or end_signal_2):
            in_play = False
        
        # Apply purple background with low opacity when in_play is true
        if in_play:
            overlay = frame.copy()
            cv2.rectangle(overlay, (0, 0), (frame.shape[1], frame.shape[0]), (128, 0, 128), -1)
            frame = cv2.addWeighted(overlay, 0.2, frame, 0.8, 0)
            
            # Add "In Play" text
            text = "In Play"
            font = cv2.FONT_HERSHEY_SIMPLEX
            text_size = cv2.getTextSize(text, font, 1, 2)[0]
            text_x = (frame.shape[1] - text_size[0]) // 2
            text_y = (frame.shape[0] + text_size[1]) // 2
            cv2.putText(frame, text, (text_x, text_y), font, 1, (255, 255, 255), 2, cv2.LINE_AA)

        # Add frame count to the bottom left of the window
        frame_count_text = f"Frame: {frame_idx + 1}"
        text_size = cv2.getTextSize(frame_count_text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 1)[0]
        cv2.putText(frame, frame_count_text, (10, frame.shape[0] - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1, cv2.LINE_AA)
        
        # Write the frame to the output video
        sink.write_frame(frame)
    
    cap.release()

print(f"Video processing complete. Output saved as '{output_filename}'")

Video processing complete. Output saved as 'output_demo_alt_10fps_20240815_085421.mp4'
