In [1]:
from ultralytics import YOLO
import numpy as np
import cv2
import matplotlib.pyplot as plt
from io import BytesIO
from src.mapping_utils.keypoint_coordinates import *
import os
import supervision as sv

In [3]:
model_players = YOLO('yolo_weights/player_weights.pt')
model_ball = YOLO('yolo_weights/ball_weights.pt')
model_keypoints = YOLO('yolo_weights/keypoint_weights.pt')
tracker = sv.ByteTrack()

In [4]:
image_folder = 'data/3_test_1min_hamkam_from_start/img1'
image_folder = sorted([os.path.join(image_folder, f) for f in os.listdir(image_folder) if f.endswith('.jpg')])

output_directory = 'results'
os.makedirs(output_directory, exist_ok=True) 

In [5]:
keypoint_mirror_mapping = {
    0: 18, 2: 19, 3: 20, 4: 21, 5: 22, 6: 23, 7: 24,
    8: 25, 9: 26, 10: 31, 11: 30, 15: 27, 16: 17
}

# Load football field coordinates
football_field_coordinates = {}
with open('football_field_coordinates.txt', 'r') as file:
    for line in file:
        key, value = line.strip().split(':')
        football_field_coordinates[int(key)] = list(map(float, value.strip('[] ').split(', ')))

In [6]:
team1_color_range = [np.array([50, 50,  50]), np.array([150, 150, 150])]  
team2_color_range = [np.array([0, 0,  0]), np.array([160, 160, 160])]  
ref_color_range = [np.array([0, 0,  0]), np.array([70, 70, 70])]

# Hardcoded for the test set
def get_average_color(frame, box):
    """Classify the average color in the specified bounding box as white, black, or blue."""
    x1, y1, x2, y2 = map(int, box)
    center_x = (x1 + x2) // 2
    third_y = y1 + (y2 - y1) // 3

    mask = np.zeros(frame.shape[:2], dtype="uint8")
    cv2.circle(mask, (center_x, third_y), 15, 255, -1)
    masked_frame = cv2.bitwise_and(frame, frame, mask=mask)
    mean_color = cv2.mean(masked_frame, mask=mask)[:3]

    
    return (mean_color)


def draw_bounding_box(frame, box, id, category):
    if category == "players":
        average_color = get_average_color(frame, box)
        center_x = int((box[0] + box[2]) / 2)
        center_y = int((box[1] + box[3]) / 2)
        radius_x = int(0.85 * (box[2] - box[0]) / 2)
        radius_y = int((box[3] - box[1]) / 2)
        y_max = int(box[3])

        # Labeling
        text = f'{id}'
        text_size, _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.5, 2)
        text_x = center_x - text_size[0] // 2
        text_y = center_y - 60
        cv2.putText(frame, text, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 2)
        cv2.ellipse(frame, (center_x, y_max), (radius_y, radius_x), 0, -45, 235, average_color, 2)
        
    elif category == "ball":
        center_x = int((box[0] + box[2]) / 2)
        center_y = int((box[1] + box[3]) / 2)
        radius = int((box[2] - box[0]) / 2)
        cv2.circle(frame, (center_x, center_y), radius, (0, 255, 255), -1)
    elif category == "keypoints":
        center_x = int((box[0] + box[2]) / 2)
        center_y = int((box[1] + box[3]) / 2)
        radius = int((box[2] - box[0]) / 5)
        cv2.circle(frame, (center_x, center_y), radius, (0, 255, 0))
        cv2.putText(frame, f'ID: {id}', (center_x - 10, center_y - 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)


def create_football_field(player_coords, player_colors, ball_coords=None):
    """Create a visual representation of a football field with players and balls, using matplotlib."""
    fig, ax = plt.subplots(figsize=(7, 3.5))
    ax.set_facecolor((0, 0, 0, 0))  # Transparent background

    # Draw field
    ax.plot([45, 870], [565, 565], 'white')  # Top boundary
    ax.plot([45, 870], [30, 30], 'white')    # Bottom boundary
    ax.plot([45, 45], [30, 565], 'white')    # Left boundary
    ax.plot([870, 870], [30, 565], 'white')  # Right boundary
    ax.plot([457.5, 457.5], [30, 565], 'white')  # Halfway line
    center_circle = plt.Circle((457.5, 297.5), 70, color='white', fill=False)
    ax.add_patch(center_circle)

    # Goal areas
    ax.plot([45, 175], [455, 455], 'white')
    ax.plot([45, 175], [140, 140], 'white')
    ax.plot([870, 740], [455, 455], 'white')
    ax.plot([870, 740], [140, 140], 'white')
    ax.plot([175, 175], [455, 140], 'white')
    ax.plot([740, 740], [455, 140], 'white')
    
    # Player positions
    if player_coords is not None:
        for coord, bgr_color in zip(player_coords, player_colors):
            # Ensure there's no index error
            if len(bgr_color) != 3:
                raise ValueError("Expected BGR color format with three components.")
            rgb_color = (bgr_color[2] / 255.0, bgr_color[1] / 255.0, bgr_color[0] / 255.0)
            ax.scatter(coord[0], coord[1], color=rgb_color)
            
    if ball_coords is not None:
        for ball_coord in ball_coords:
            ax.scatter(ball_coord[0], ball_coord[1], color='yellow')

    ax.set_xlim(0, 915)
    ax.set_ylim(0, 595)
    ax.axis('off')
    ax.set_aspect('equal', adjustable='box')

    # Save to buffer
    buf = BytesIO()
    plt.savefig(buf, format='png', transparent=True, dpi=300, bbox_inches='tight', pad_inches=0)
    plt.close(fig)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    field_image = cv2.imdecode(img_arr, cv2.IMREAD_UNCHANGED)
    return field_image

def add_football_field_to_frame(frame, field_image, scale_factor=1.0):
    if scale_factor != 1.0:
        field_image = cv2.resize(field_image, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
    x_offset = frame.shape[1] - field_image.shape[1]
    y_offset = 0
    alpha_s = field_image[:, :, 3] / 255.0
    alpha_l = 1.0 - alpha_s
    for c in range(0, 3):
        frame[y_offset:y_offset + field_image.shape[0], x_offset:x_offset + field_image.shape[1], c] = \
            (alpha_s * field_image[:, :, c] + alpha_l * frame[y_offset:y_offset + field_image.shape[0], x_offset:x_offset + field_image.shape[1], c])
            
def calculate_center_line(keypoint_centers, keypoint_ids, center_point_ids):
    # Check that the lengths of centers and IDs match
    if len(keypoint_centers) != len(keypoint_ids):
        print("Mismatch in the number of keypoints and keypoint IDs")
        return None

    if len(keypoint_centers) > 0:
        # Create a mask for selecting centers where the keypoint ID is in center_point_ids
        mask = np.isin(keypoint_ids, center_point_ids)
        if np.sum(mask) == 0:
            return None

        selected_centers = keypoint_centers[mask]
        # Calculate the mean of x-coordinates from the selected centers
        return np.mean(selected_centers[:, 0])
    else:
        return None

def adjust_keypointID_based_on_center_line(keypoint_center, keypoint_id, center_x):

    # Define sets of IDs for clearer logic
    left_ids = {0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 15, 16}  # IDs that need adjustment when right of center
    right_ids = {17, 18, 19, 20, 21, 22, 23, 24, 25, 26}  # IDs that need adjustment when left of center

    # Check if the keypoint is on the right side and needs mirroring
    if keypoint_id in left_ids and keypoint_center[0] > center_x:
        return keypoint_mirror_mapping.get(keypoint_id, keypoint_id)  # Return the mirrored ID or original if not mapped

    # Check if the keypoint is on the left side and needs mirroring
    if keypoint_id in right_ids and keypoint_center[0] < center_x:
        return keypoint_mirror_mapping.get(keypoint_id, keypoint_id)  # Return the mirrored ID or original if not mapped

    return keypoint_id

def update_keypoint_history(keypoint_centers, keypoint_ids, movement_threshold=50):
    """Update the history of keypoints based on their movement and manage duplicates."""
    global keypoint_history
    new_history = {}

    # First pass: check for movement and select the closest one in case of duplicates
    for center, keypoint_id in zip(keypoint_centers, keypoint_ids):
        if keypoint_id in new_history:
            # Check if the new center is closer to the last known center than the current entry
            last_center = keypoint_history.get(keypoint_id, center)
            current_best_center = new_history[keypoint_id]
            
            # Calculate distances
            last_distance = np.linalg.norm(np.array(last_center) - np.array(center))
            current_best_distance = np.linalg.norm(np.array(last_center) - np.array(current_best_center))
            
            # If current keypoint is closer to the last known position, update the entry
            if last_distance < current_best_distance:
                new_history[keypoint_id] = center
        else:
            # If no duplicate, just add the keypoint to the new history
            last_center = keypoint_history.get(keypoint_id, center)
            distance = np.linalg.norm(np.array(last_center) - np.array(center))
            if distance <= movement_threshold:
                new_history[keypoint_id] = center

    # Update the history with the new positions
    keypoint_history = new_history


In [8]:
output_directory = 'results'
image_counter = 0
center_point_ids = [1, 12, 13, 14]
minimap_bounds = {'x_min': 40, 'x_max': 875, 'y_min': 25, 'y_max': 570} 

keypoint_history = {}
movement_threshold = 50

# For each image in the folder
for image_path in image_folder:
    image_counter += 1
    frame = cv2.imread(image_path)
    
    if frame is not None:
        
        ##---------------   PART 1 getting results from the models ---------------##
        results_players = model_players(frame, conf=0.6)
        results_keypoints = model_keypoints(frame, conf=0.4)
        results_ball = model_ball(frame)
        
        ball_bb = results_ball[0].boxes.xyxy.cpu().numpy()
        ball_ids = results_ball[0].boxes.cls.cpu().numpy()
        ball_centers = np.array([[0.5 * (box[0] + box[2]), 0.5 * (box[1] + box[3])] for box in ball_bb])

        ##--------------- Proccess results for tracker ---------------##
        detections = sv.Detections.from_ultralytics(results_players[0])
        tracked_objects = tracker.update_with_detections(detections)            
        #player_bb = results_players[0].boxes.xyxy.cpu().numpy()
        #player_ids = results_players[0].boxes.cls.cpu().numpy()
        player_bb = tracked_objects.xyxy
        player_ids = tracked_objects.tracker_id
        player_centers = np.array([[0.5 * (box[0] + box[2]), box[3]] for box in player_bb])
        player_colors = [get_average_color(frame, box.astype(int)) for box in player_bb]

    
        keypoint_id = results_keypoints[0].boxes.cls.cpu().numpy()
        keypoint_bb = results_keypoints[0].boxes.xyxy.cpu().numpy()
        keypoint_centers = np.array([[0.5 * (box[0] + box[2]), 0.5 * (box[1] + box[3])] for box in keypoint_bb])
        
        ##---------------   PART 2 checking if the keypoints are within threshhold range  ---------------##
        if image_counter > 1:
            update_keypoint_history(keypoint_centers, keypoint_id)
            keypoint_centers = np.array([center for center in keypoint_history.values()])
            keypoint_id = np.array(list(keypoint_history.keys()))
        else: 
            keypoint_history = {k_id: center for k_id, center in zip(keypoint_id, keypoint_centers)}
        
        ##---------------   PART 3 adjusting based on left/right of center line  ---------------##
        center_x = calculate_center_line(keypoint_centers, keypoint_id, center_point_ids)

        minimap_keypoints = []
        adjusted_keypoint_ids = []
        
        for i, keypoint_center in enumerate(keypoint_centers):
            current_keypoint_id = int(keypoint_id[i])
            adjusted_id = adjust_keypointID_based_on_center_line(keypoint_center, current_keypoint_id, center_x)
            adjusted_keypoint_ids.append(adjusted_id)
            minimap_keypoints.append(football_field_coordinates.get(adjusted_id, (0, 0)))  
            
        ##---------------   PART 4 mapping from frame to minimap  ---------------##
        H, status = cv2.findHomography(keypoint_centers, np.array(minimap_keypoints))
        minimap_player_coordinates = []
        minimap_ball_coordinates = []
        
        if player_centers.size > 0:
            minimap_player_coordinates = cv2.perspectiveTransform(player_centers.reshape(-1, 1, 2), H).reshape(-1, 2)
            
        if ball_centers.size > 0:
            minimap_ball_coordinates = cv2.perspectiveTransform(ball_centers.reshape(-1, 1, 2), H).reshape(-1, 2)
        
        ##---------------   PART 5 drawing bounding boxes and football field minimap ---------------##
        filtered_minimap_coords = []
        
        for coord in minimap_player_coordinates:
            if (minimap_bounds['x_min'] <= coord[0] <= minimap_bounds['x_max'] and
                minimap_bounds['y_min'] <= coord[1] <= minimap_bounds['y_max']):
                filtered_minimap_coords.append(coord)
            
        field_image = create_football_field(filtered_minimap_coords, player_colors, minimap_ball_coordinates)

        add_football_field_to_frame(frame, field_image, 0.5)  

        for box, player_id, minimap_coord, color in zip(player_bb, player_ids, minimap_player_coordinates, player_colors):
            if (minimap_bounds['x_min'] <= minimap_coord[0] <= minimap_bounds['x_max'] and
                minimap_bounds['y_min'] <= minimap_coord[1] <= minimap_bounds['y_max']):
                draw_bounding_box(frame, box, player_id, "players")

        for box in ball_bb:
            draw_bounding_box(frame, box, None, "ball") 
            
        for center, keypoint_id in zip(keypoint_centers, adjusted_keypoint_ids):
            draw_bounding_box(frame,[center[0]-5, center[1]-5, center[0]+5, center[1]+5], keypoint_id, "keypoints")
            
        
        ##---------------   PART 5 saving/showing frame  ---------------##
        
        # cv2.imshow('Frame', frame)
        output_filename = os.path.join(output_directory, f'{image_counter}.jpg')
        cv2.imwrite(output_filename, frame)
        print(output_filename)
        #if cv2.waitKey(1) & 0xFF == 27:
        #    break

# cv2.destroyAllWindows()


0: 576x1024 23 players, 48.7ms
Speed: 5.3ms preprocess, 48.7ms inference, 1.7ms postprocess per image at shape (1, 3, 576, 1024)



0: 544x960 1 1, 1 3, 1 12, 1 13, 1 14, 2 16s, 45.6ms
Speed: 3.1ms preprocess, 45.6ms inference, 3.0ms postprocess per image at shape (1, 3, 544, 960)

0: 576x1024 1 ball, 9.3ms
Speed: 5.1ms preprocess, 9.3ms inference, 1.5ms postprocess per image at shape (1, 3, 576, 1024)
results/1.jpg

0: 576x1024 23 players, 48.8ms
Speed: 6.5ms preprocess, 48.8ms inference, 1.7ms postprocess per image at shape (1, 3, 576, 1024)

0: 544x960 1 3, 1 12, 1 13, 1 14, 2 16s, 45.5ms
Speed: 3.5ms preprocess, 45.5ms inference, 1.5ms postprocess per image at shape (1, 3, 544, 960)

0: 576x1024 1 ball, 9.2ms
Speed: 5.0ms preprocess, 9.2ms inference, 1.4ms postprocess per image at shape (1, 3, 576, 1024)
results/2.jpg

0: 576x1024 21 players, 48.7ms
Speed: 5.3ms preprocess, 48.7ms inference, 1.7ms postprocess per image at shape (1, 3, 576, 1024)

0: 544x960 1 3, 1 12, 1 13, 1 14, 2 16s, 45.5ms
Speed: 3.1ms preprocess, 45.5ms inference, 1.5ms postprocess per image at shape (1, 3, 544, 960)

0: 576x1024 1 ball, 9

KeyboardInterrupt: 

# Create a video from the results

In [None]:
from src.create_video import create_video_from_images

results_path = 'results'
result_video_path = 'result_video/output_video.mp4'



create_video_from_images(results_path, result_video_path, 30)

2024-04-28 22:42:56,236 - INFO - Checking if image folder exists.
2024-04-28 22:42:56,237 - INFO - Listing and sorting images.
2024-04-28 22:42:56,240 - INFO - Reading the first image at results/1.jpg.
2024-04-28 22:42:56,323 - INFO - Initializing video writer.
2024-04-28 22:42:56,333 - INFO - Starting to write frames to video.
Creating Video: 100%|██████████| 1802/1802 [01:28<00:00, 20.35it/s]
