In [17]:
# customized training for yolov9c
# credit to dataset here: https://universe.roboflow.com/computer-vision-d5fjh/basketball-detection-dn6fg/dataset/4
# !yolo task=detect mode=train epochs=3 data=data.yaml model=yolov8s

# to test the best.pt weight after training run:
#!yolo task=detect mode=predict model=best.pt show=True conf=0.5 source = frames/img0223.png

# to resume training from the last checkpoint:
# !yolo train resume model=runs/detect/train/weights/last.pt (replace path to the last .pt file)

# use model.track to save all the models into a video file:
# model.track(src .....)

In [18]:
# importing the necessary libraries
import shutil
from ultralytics import YOLO
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import cv2
from skimage.metrics import structural_similarity as ssim
from PIL import Image
import os

In [19]:
# Build a YOLOv8s model from pretrained weight
# yolo model trained on 9c, 100 epochs on GCP V100 deep learning instance

model = YOLO('weights/best9c_100.pt')
# Display model information (optional)
model.info()

YOLOv9c summary: 618 layers, 25531545 parameters, 0 gradients, 103.7 GFLOPs


(618, 25531545, 0, 103.69152)

In [20]:
# define a helper which will parse the output of yolo
def yolo_model_prediction(frame, yolo_model):
    # Perform model prediction
    results = yolo_model(frame)

    # Extract bounding boxes, classes, and confidence scores
    boxes = results[0].boxes.xyxy.tolist()
    classes = results[0].boxes.cls.tolist()
    confidences = results[0].boxes.conf.tolist()

    # Extract class names
    names = results[0].names

    # Format results
    yolo_results = []
    for bbox, cls, conf in zip(boxes, classes, confidences):
        x1, y1, x2, y2 = map(int, bbox)
        class_name = names[int(cls)]
        yolo_results.append({'class': class_name, 'confidence': conf, 'bbox': (x1, y1, x2, y2)})

    return yolo_results

In [21]:
# a helper which takes in a frame, output the danger zones and goals in the frame
def get_zones(yolo_results,frame):
    image_width, image_height = frame.size

    # Find the basket bounding box
    basket_bbox = None
    for result in yolo_results:
        if result['class'] == "basket":
            basket_bbox = result['bbox']
            break

    if basket_bbox is None:
        return None, None

    basket_center_x = (basket_bbox[0] + basket_bbox[2]) // 2
    basket_center_y = (basket_bbox[1] + basket_bbox[3]) // 2

    danger_region = {
        'x1': max(0, basket_center_x - 60),
        'y1': max(0, basket_center_y - 80),
        'x2': min(image_width, basket_center_x + 60),
        'y2': min(image_height, basket_center_y)
    }

    goal_zone = {
        'x1': max(0, basket_center_x - 35),
        'y1': basket_center_y,
        'x2': min(image_width, basket_center_x + 35),
        'y2': min(image_height, basket_center_y + 80)
    }

    return danger_region, goal_zone


In [22]:
# extract player ground locations from yolo results
def extract_player_locations(yolo_results):
    player_locations = []
    for result in yolo_results:
        if result['class'] == "person":
            x1, y1, x2, y2 = result['bbox']
            player_loc_x = (x1 + x2) // 2
            player_loc_y = y2
            player_locations.append((player_loc_x, player_loc_y))
    return player_locations

In [23]:
# a helper to keep track of basketball locations in a frame
def extract_ball_locations(yolo_results):
    ball_locations = []
    for result in yolo_results:
        if result['class'] == "ball" or result['class'] == "sport ball":
            x1, y1, x2, y2 = result['bbox']
            ball_center_x = (x1 + x2) // 2
            ball_center_y = (y1 + y2) // 2
            ball_locations.append((ball_center_x, ball_center_y))
    return ball_locations


In [24]:
# a helper determine if the ball in safe / danger zone
def check_ball_zone(ball_location, danger_zone, goal_zone):
    if ball_location is None:
        return ""
    ball_x, ball_y = ball_location

    if danger_zone['x1'] <= ball_x <= danger_zone['x2'] and \
       danger_zone['y1'] <= ball_y <= danger_zone['y2']:
        return "danger"
    elif goal_zone['x1'] <= ball_x <= goal_zone['x2'] and \
         goal_zone['y1'] <= ball_y <= goal_zone['y2']:
        return "safe"
    else:
        return "none"

In [25]:
def calculate_distance(point1, point2):
    x1, y1 = map(int, point1)
    x2, y2 = map(int, point2)
    return ((x2 - x1) ** 2 + (y2 - y1) ** 2) ** 0.5


In [26]:
def update_numbers(ball_locations, prev_ball_locations, danger_zone,goal_zone):

    inc_goals = 0
    inc_attempts = 0
    for prev_ball_location in prev_ball_locations:
        closest_distance = float('inf')
        closest_ball_location = None
        for ball_location in ball_locations:
            # find the closet ball location to the previous ball location
            # print(ball_location)
            # print(prev_ball_location)
            distance = calculate_distance(prev_ball_location, ball_location)
            if distance < closest_distance:
                closest_distance = distance
                closest_ball_location = ball_location

        # Check if the closest ball location is in the danger zone
        prev_ball_zone = check_ball_zone(prev_ball_location, danger_zone, goal_zone)
        curr_ball_zone = check_ball_zone(closest_ball_location, danger_zone, goal_zone)
        # print(prev_ball_zone)
        # print(curr_ball_zone)
        if ((prev_ball_zone == 'danger') and (curr_ball_zone == 'none')):
            inc_attempts += 1
        elif ((prev_ball_zone == 'danger') and (curr_ball_zone == 'safe')):
            inc_attempts += 1
            inc_goals +=1
    return inc_goals, inc_attempts


In [27]:
def map_point_to_top_view(pt):
    pts_frame =np.array([[252,243],[346,250],[482,257],[682,269],[144,272],[224,282],[339,293],[0,271],[95,300],[650,338]],dtype=np.float32)
    pts_top_view = np.array([[66,0],[270,0],[435,0],[635,0],[265,225],[350,225],[435,225],[66,347],[353,347],[635,347]],dtype=np.float32)

    # Compute homography matrix
    H, _ = cv2.findHomography(pts_frame, pts_top_view)
    
    # Map the point to top view
    pt = np.array([pt], dtype=np.float32)
    mapped_pt = cv2.perspectiveTransform(pt.reshape(-1, 1, 2), H)
    return mapped_pt[0][0]

In [28]:


def annotate_frame(frame, yolo_results, attempts, goals,closest_player_loc,is_goal, is_attempt,top_view_img='src/top_view_court_img.png'):
    annotated_frame = frame.copy()
    
    # frame for mapping, resized two image temporarily
    top_view_img = cv2.imread(top_view_img)
    top_view_img = cv2.resize(top_view_img, (700, 400))


    resized_width, resized_height = 700, 400
    
    width, height = annotated_frame.size

    # Calculate scaling factors
    scale_x = resized_width / width
    scale_y = resized_height / height

    # Print attempts and goals in bottom-left corner
    text_attempts = f"Attempts: {attempts}"
    text_goals = f"Goals: {goals}"
    
    # Convert the PIL Image to a NumPy array
    annotated_frame = np.array(annotated_frame)

    # Use the NumPy array in cv2.putText
    cv2.putText(annotated_frame, text_attempts, (10, height - 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
    cv2.putText(annotated_frame, text_goals, (10, height - 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)


    # Draw bounding boxes around players
    for result in yolo_results:
        if result['class'] == "person":
            x1, y1, x2, y2 = result['bbox']
            cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (255, 0, 0), 2)

    # Draw bounding box around the basket
    for result in yolo_results:
        if result['class'] == "basket":
            x1, y1, x2, y2 = result['bbox']
            cv2.rectangle(annotated_frame, (x1, y1), (x2, y2), (0, 255, 255), 2)

    # Draw circles around ball locations
    ball_locations = extract_ball_locations(yolo_results)
    for ball_location in ball_locations:
        cv2.circle(annotated_frame, ball_location, 5, (0, 0, 255), -1)

    # Draw danger region and goal zone
    danger_region, goal_zone = get_zones(yolo_results, frame)
    if danger_region is not None:
        cv2.rectangle(annotated_frame, (danger_region['x1'], danger_region['y1']), (danger_region['x2'], danger_region['y2']), (0, 0, 255), 2)
    if goal_zone is not None:
        cv2.rectangle(annotated_frame, (goal_zone['x1'], goal_zone['y1']), (goal_zone['x2'], goal_zone['y2']), (0, 255, 0), 2)
    
   # annotate the closet player location as a circle in the frame
    top_view_img_mapped= None
    if closest_player_loc is None or (is_goal == False and is_attempt == False):
        
        if top_view_img_mapped is None:
            # Resize the top view image to a smaller size
            resized_top_view_img = cv2.resize(top_view_img, (0, 0), fx=0.5, fy=0.5)
        
            # Get the dimensions of the frame and resized top view image
            frame_height, frame_width, _ = annotated_frame.shape
            resized_height, resized_width, _ = resized_top_view_img.shape

            annotated_frame[:resized_height, frame_width - resized_width:] = resized_top_view_img

        return annotated_frame
    else:
        if is_goal == True:
            # scale from original frame to resized frame 700,400
            scaled_point = (int(closest_player_loc[0] * scale_x), int(closest_player_loc[1] * scale_y))
            # map the scaled point to the top view image
            mapped_point = map_point_to_top_view(scaled_point)
            x,y = mapped_point
            x = int(x)
            y = int(y)

            top_view_img_mapped= cv2.circle(top_view_img, (x,y), 5, (0, 255, 0), -1)
            # Resize the top view image to a smaller size
            resized_top_view_img = cv2.resize(top_view_img_mapped, (0, 0), fx=0.5, fy=0.5)
            
            # Get the dimensions of the frame and resized top view image
            frame_height, frame_width, _ = annotated_frame.shape
            resized_height, resized_width, _ = resized_top_view_img.shape

            annotated_frame[:resized_height, frame_width - resized_width:] = resized_top_view_img


        if is_attempt == True:
            scaled_point = (int(closest_player_loc[0] * scale_x), int(closest_player_loc[1] * scale_y))
            mapped_point = map_point_to_top_view(scaled_point)
            
            x,y = mapped_point
            x = int(x)
            y = int(y)
            
            # annotate the closest player location in the top view non-resized image
            top_view_img_m= cv2.line(top_view_img, (x - 5, y - 5), (x + 5, y + 5), (255, 0, 0), 2)
            top_view_img_mapped= cv2.line(top_view_img_m, (x + 5, y - 5), (x - 5, y + 5), (255, 0, 0), 2)

            # Resize the top view image to a smaller size
            resized_top_view_img = cv2.resize(top_view_img_mapped, (0, 0), fx=0.5, fy=0.5)
            
            # Get the dimensions of the frame and resized top view image
            frame_height, frame_width, _ = annotated_frame.shape
            resized_height, resized_width, _ = resized_top_view_img.shape

            annotated_frame[:resized_height, frame_width - resized_width:] = resized_top_view_img
        
        # update the top view image in the directory
        cv2.imwrite('src/top_view_court_img.png', resized_top_view_img)
            
    return annotated_frame


In [29]:
def find_closest_player_to_ball(yolo_results, ball_locations):
    closest_player = None
    for result in yolo_results:
        for ball_location in ball_locations:
            if result['class'] == "person":
                x1, y1, x2, y2 = result['bbox']
                if x1 <= ball_location[0] <= x2 and y1 <= ball_location[1] <= y2:
                    player_loc_x = (x1 + x2) // 2
                    player_loc_y = y2
                    closest_player = (player_loc_x, player_loc_y)
                    break
    return closest_player


In [30]:
# read from frame directory and detect objects using yolo
def check_goal_from_frame_dir(frame_directory, model, res_frame_dir):
    # init the goal and attempts
    goals = 0
    attempts = 0

    # read all frames in the directory
    fst_img_path = os.path.join(frame_directory, "img0001.png")
    frame = Image.open(fst_img_path)

    # extract basket location, ball location in the first frame
    yolo_results = yolo_model_prediction(frame, model)
    prev_ball_locations = extract_ball_locations(yolo_results)
    danger_region, goal_zone = get_zones(yolo_results, frame)

    # # extract player location in the first frame
    # player_locations = extract_player_locations(yolo_results)
    
    # find the closest player to the ball in the first frame
    prev_closest_player_loc = find_closest_player_to_ball(yolo_results, prev_ball_locations)
    closest_player_loc = None

    annotated_frame = annotate_frame(frame, yolo_results, attempts, goals, closest_player_loc, is_attempt=False, is_goal=False)
    save_path = os.path.join(res_frame_dir, "img0001.png")
    cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR, annotated_frame)
    cv2.imwrite(save_path, annotated_frame)

    file_names = os.listdir(frame_directory)
    file_names.sort()
    # print(file_names)
    for file_name in file_names:
        is_goal = False
        is_attempt = False
        if file_name.endswith(".jpg") or file_name.endswith(".png"):
            if file_name == "img0001.png":
                continue
            # extract ball location in current frame
            frame =Image.open(os.path.join(frame_directory, file_name))
            yolo_results = yolo_model_prediction(frame, model)
            ball_locations = extract_ball_locations(yolo_results)         
            inc_goals, inc_attempts = update_numbers(ball_locations,prev_ball_locations,danger_region, goal_zone)
            
            if inc_goals>0:
                is_goal = True
            if inc_attempts>0:
                is_attempt = True

            goals += inc_goals
            attempts += inc_attempts
            # annotate the frame and save the annotated frame
            
            # extract player location
            player_locations = extract_player_locations(yolo_results)
            # update the closest player to the ball
            closest_player_loc = find_closest_player_to_ball(yolo_results,ball_locations)
            
            if closest_player_loc is None:
                closest_player_loc = prev_closest_player_loc
            else:
                prev_closest_player_loc = closest_player_loc

            print(closest_player_loc)

            annotated_frame = annotate_frame(frame, yolo_results, attempts, goals, closest_player_loc,is_goal, is_attempt)
            
            save_path = os.path.join(res_frame_dir, file_name)
            cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR, annotated_frame)
            cv2.imwrite(save_path, annotated_frame)

            # if lose track of ball in current frame don't update prev_ball_locations
            if len(ball_locations) != 0:
                prev_ball_locations = ball_locations

    # return goal and attempts
    return goals, attempts


In [31]:
frame_dir = 'frames'
res_frame_dir = 'res_frames'
try:
    shutil.rmtree(res_frame_dir)
except:
    pass
os.makedirs(res_frame_dir)
goals, attempts = check_goal_from_frame_dir(frame_dir,model,res_frame_dir)
print(goals, attempts)


0: 384x640 1 ball, 1 basket, 3 persons, 157.6ms
Speed: 1.4ms preprocess, 157.6ms inference, 0.3ms postprocess per image at shape (1, 3, 384, 640)

0: 384x640 1 basket, 3 persons, 148.3ms
Speed: 1.0ms preprocess, 148.3ms inference, 0.7ms postprocess per image at shape (1, 3, 384, 640)
None

0: 384x640 1 ball, 1 basket, 3 persons, 151.0ms
Speed: 0.9ms preprocess, 151.0ms inference, 0.4ms postprocess per image at shape (1, 3, 384, 640)
None

0: 384x640 1 ball, 1 basket, 3 persons, 147.0ms
Speed: 1.0ms preprocess, 147.0ms inference, 0.5ms postprocess per image at shape (1, 3, 384, 640)
None

0: 384x640 1 ball, 1 basket, 3 persons, 145.1ms
Speed: 1.0ms preprocess, 145.1ms inference, 0.4ms postprocess per image at shape (1, 3, 384, 640)
None

0: 384x640 1 ball, 1 basket, 3 persons, 150.8ms
Speed: 1.0ms preprocess, 150.8ms inference, 0.4ms postprocess per image at shape (1, 3, 384, 640)
None

0: 384x640 1 ball, 1 basket, 3 persons, 163.9ms
Speed: 1.1ms preprocess, 163.9ms inference, 0.3ms po

In [32]:
# resemble the frames into a video using ffmpeg
!ffmpeg -y -r 30 -i res_frames/img%04d.png -vf "scale=1280:-2" -c:v libx264 -pix_fmt yuv420p output.mp4

ffmpeg version 4.2.2 Copyright (c) 2000-2019 the FFmpeg developers
  built with clang version 12.0.0
  configuration: --prefix=/Users/ktietz/demo/mc3/conda-bld/ffmpeg_1628925491858/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_plac --cc=arm64-apple-darwin20.0.0-clang --disable-doc --enable-avresample --enable-gmp --enable-hardcoded-tables --enable-libfreetype --enable-libvpx --enable-pthreads --enable-libopus --enable-postproc --enable-pic --enable-pthreads --enable-shared --enable-static --enable-version3 --enable-zlib --enable-libmp3lame --disable-nonfree --enable-gpl --enable-gnutls --disable-openssl --enable-libopenh264 --enable-libx264
  libavutil      56. 31.100 / 56. 31.100
  libavcodec     58. 54.100 / 58. 54.100
  libavformat    58. 29.100 / 58. 29.100
  libavdevice    58.  8.100 / 58.  8.100
  libavfilter     7. 57.100 /  7. 57