# Tennis Video Analysis

In [None]:
!pip install ultralytics

## Video Utils

In [649]:
import cv2
import matplotlib.pyplot as plt

def read_video(video_path):
  cap = cv2.VideoCapture(video_path)
  if not cap.isOpened():
    print("Error reading video file")
    return

  frames = []
  while True:
    ret, frame = cap.read()
    if not ret:
      break
    frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
  cap.release()
  return frames

def save_video(output_frames, output_path, frame_per_sec = 30):
  fourcc = cv2.VideoWriter_fourcc(*'MJPG')
  out = cv2.VideoWriter(output_path, fourcc, frame_per_sec, (output_frames[0].shape[1], output_frames[0].shape[0]))
  for frame in output_frames:
    out.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
  out.release()

def show_frames(frames_list, idx=(0, 5)):
  if not isinstance(frames_list, list):
      plt.imshow(frames_list)
      plt.show()
  elif isinstance(idx, tuple):
    for i in range(idx[0], idx[1]):
      plt.imshow(frames_list[i])
      plt.show()
  elif isinstance(idx, int):
    plt.imshow(frames_list[idx])
    plt.show()

def drow_frames_number(frame, color=(0, 255, 255)):
  for i, frame in enumerate(frame):
    text = "Frame: " + str(i+1)
    cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, color, 2)

## Bbox Utils

In [650]:
def bbox_center(box):
  X1, Y1, X2, Y2 = box
  center_x = int((X1 + X2) / 2)
  center_y = int((Y1 + Y2) / 2)
  return center_x, center_y

def points_distance(p1, p2):
  return ((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) ** 0.5

def convert_keypoints_to_xy_tuples(points):
    return [(int(points[i]), int(points[i + 1])) for i in range(0, len(points), 2)]

## Player Tracker

In [651]:
from ultralytics import YOLO
import cv2
import numpy as np

class PlayerTracker():
  def __init__(self, model_path='yolov8n.pt'):
    self.model = YOLO(model_path)

  def detect_video(self, frames):
    player_detections = []
    for frame in frames:
      player_detections.append(self.detect_frame(frame))
    return player_detections

  def detect_frame(self, frame):
    results = self.model.track(frame, persist=True)[0]
    player = {}
    for box in results.boxes:
      track_id = int(box.id.tolist()[0])
      bbox = box.xyxy.tolist()[0]
      # keep only 'person' abjects
      if results.names[box.cls.tolist()[0]] == 'person':
        player[track_id] = bbox
    return player

  def draw_video(self, video_frames, plyer_detections, color=(255, 0, 0)):
    output_frames = []
    for frame, plyer in zip(video_frames, plyer_detections):
      output_frames.append(self.draw_frame(frame, plyer, color))
    return output_frames

  def draw_frame(self, frame, plyer, color=(255, 0, 0)):
    for track_id, bbox in plyer.items():
      x1, y1, x2, y2 = bbox
      # add 'Plyer ID'
      text = "Player: " + str(track_id)
      cv2.putText(frame, text, (int(x1), int(y1) -10), cv2.FONT_HERSHEY_COMPLEX, 1.0, (255, 0, 0), 2)
      # drow bboxs around player
      cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
    return frame

  def filter_playrs(self, keypoints, player_detections):
    true_playrs_detections = []
    for detection in player_detections:
      # calculate mean distance of each "player" to the keypoints
      players_distance = self._players_mean_distance_to_keypoints(keypoints, detection)
      # find the players (the two closest to the keypoints)
      players_distance.sort(key=lambda x: x[1])
      true_players_id = [players_distance[0][0], players_distance[1][0]]
      # extract the firs two closest from the player detections
      true_player = {}
      for track_id, bbox in detection.items():
        if track_id in true_players_id:
          true_player[track_id] = bbox
      true_playrs_detections.append(true_player)
      # maintain player id by their position on court
      true_playrs_detections = self._maintain_players_id_by_position(true_playrs_detections)
    return true_playrs_detections

  def _players_mean_distance_to_keypoints(self, keypoints, player_detections):
    players_distance = []
    for track_id, bbox in player_detections.items():
      mean_distance = self._calculate_mean_distance(bbox, keypoints)
      players_distance.append((track_id, float(mean_distance)))
    return players_distance

  def _calculate_mean_distance(self, player_bbox, keypoints):
      player_mean_distance = []
      # calculate player bbox center
      player_center = bbox_center(player_bbox)
      keypoints = convert_keypoints_to_xy_tuples(keypoints)
      for keypoint in keypoints:
        distance = points_distance(player_center, keypoint)
        player_mean_distance.append(distance)
      return np.mean(player_mean_distance)

  def _maintain_players_id_by_position(self, true_playrs_detections):
    consistent_id_detection = []
    for detection in true_playrs_detections:
      new_id_detection = {}
      bboxs = list(detection.values())

      y_position_1 = bboxs[0][1]
      y_position_2 = bboxs[1][1]

      if y_position_2 < y_position_1:
        new_id_detection[1] = bboxs[0]
        new_id_detection[2] = bboxs[1]
      else:
        new_id_detection[1] = bboxs[1]
        new_id_detection[2] = bboxs[0]

      consistent_id_detection.append(new_id_detection)
    return consistent_id_detection


## Ball Tracker

In [652]:
from ultralytics import YOLO
import cv2
import pandas as pd

class BallTracker():
  def __init__(self, model_path):
    self.model = YOLO(model_path)

  def detect_video(self, frames):
    ball_detections = []
    for frame in frames:
      ball_detections.append(self.detect_frame(frame))
    return ball_detections

  def detect_frame(self, frame):
    results = self.model.predict(frame, conf=0.1)[0]
    ball = {}
    for box in results.boxes:
      bbox = box.xyxy.tolist()[0]
      ball[1] = bbox
    return ball

  def interpolate_detections(self, ball_detections):
    ball_detections = [x.get(1, []) for x in ball_detections]

    # convert to DataFrame and interpolate
    df_ball_detections = pd.DataFrame(ball_detections, columns=['x1', 'y1', 'x2', 'y2'])
    df_ball_detections = df_ball_detections.interpolate()
    df_ball_detections = df_ball_detections.bfill()  # copie the first detection to the first frames
    ball_detections = [{1:x} for x in df_ball_detections.to_numpy().tolist()]  # returns to the input format
    return ball_detections

  def draw_video(self, video_frames, ball_detections):
    output_frames = []
    for frame, plyer in zip(video_frames, ball_detections):
      output_frames.append(self.draw_frame(frame, plyer))
    return output_frames

  def draw_frame(self, frame, ball):
    for track_id, bbox in ball.items():
      x1, y1, x2, y2 = bbox
      # drow bboxs around the ball
      cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (255, 255, 0), 2)
    return frame

## Court Line Detector

In [653]:
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import cv2
import numpy as np
from scipy.spatial import cKDTree

class CourtLineDetector():
  def __init__(self, model_path, map_to='cpu'):
    self.model = models.resnet50(pretrained=True)
    self.model.fc = torch.nn.Linear(self.model.fc.in_features, 14*2)
    self.model.load_state_dict(torch.load(model_path, map_location=map_to))

    self.transforms = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

  def detect_video(self, frames):
    court_line_detections = []
    for frame in frames:
      court_line_detections.append(self.detect_frame(frame))
    return court_line_detections

  def detect_frame(self, frame):
    if isinstance(frame, list):
      frame = frame[0]
    img = self.transforms(frame).unsqueeze(0)
    with torch.no_grad():
      output = self.model(img)
    keypoints = output.squeeze().cpu().numpy()
    # returns to original location
    orig_h, orig_w = frame.shape[:2]
    keypoints [0::2] *= orig_w / 224.0
    keypoints [1::2] *= orig_h / 224.0
    return keypoints

  def draw_video(self, video_frames, court_line_detections, color=(0, 0, 0)):
    output_frames = []
    if not isinstance(court_line_detections, list):
      for frame in video_frames:
        output_frames.append(self.draw_frame(frame, court_line_detections))
      return output_frames
    else:
      for frame, keypoints in zip(video_frames, court_line_detections):
        output_frames.append(self.draw_frame(frame, keypoints, color))
      return output_frames

  def draw_frame(self, frame, keypoints, color=(0, 0, 0)):
    for i in range(0, len(keypoints), 2):
      x, y = int(keypoints[i]), int(keypoints[i+1])
      cv2.putText(frame, str(i//2), (x, y -10), cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
      cv2.circle(frame, (x, y), 5, color, -1)
    return frame


class CourtLineIntersections():
	def __init__(self, binary_threshold=200, line_threshold=250, min_line_length=100, max_line_gap=5, intersections_min_dist=10):
		self.binary_threshold = binary_threshold
		self.line_threshold = line_threshold
		self.min_line_length = min_line_length
		self.max_line_gap = max_line_gap
		self.intersections_min_dist = intersections_min_dist

	def get_intersections(self, image):
		if isinstance(image, list):
			image = image[0]
		binary = self._binary_image(image)
		lines = self._find_lines(binary)
		intersections = self._find_intersections(lines)
		return self._remove_neighbors_intersections(intersections)

	def update_keypoint_to_intersection(self, keypoints, intersections, min_distances=100):
		updated_keypoints = keypoints.copy()
		tree = cKDTree(intersections)
		distances, indices = tree.query(keypoints)
		for i, (dist, idx) in enumerate(zip(distances, indices)):
			if dist < min_distances:
				updated_keypoints[i] = intersections[idx]
		return updated_keypoints

	def draw_intersections(self, image, intersections):
		for x, y in intersections:
			cv2.circle(image, (x, y), 5, (0, 0, 255), -1)

	def _binary_image(self, image):
		gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
		return cv2.threshold(gray, self.binary_threshold, 255, cv2.THRESH_BINARY)[1]

	def _find_lines(self, binary_image):
		lines = cv2.HoughLinesP(binary_image, 1, np.pi / 180, self.line_threshold,
								minLineLength=self.min_line_length, maxLineGap=self.max_line_gap)
		return np.squeeze(lines) if lines is not None else np.array([])

	def _find_intersections(self, lines):
		intersections = []
		for i in range(len(lines)):
			for j in range(i + 1, len(lines)):
				x1, y1, x2, y2 = lines[i]
				x3, y3, x4, y4 = lines[j]
				A1, B1, C1 = self._line_equation(x1, y1, x2, y2)
				A2, B2, C2 = self._line_equation(x3, y3, x4, y4)

				det = A1 * B2 - A2 * B1
				if det != 0:  # lines are not parallel
					x = (C1 * B2 - C2 * B1) / det
					y = (A1 * C2 - A2 * C1) / det

					if self._is_on_line(x, y, x1, y1, x2, y2) or self._is_on_line(x, y, x3, y3, x4, y4):
						intersections.append((int(x), int(y)))
		return np.array(intersections)

	def _line_equation(self, x1, y1, x2, y2):
		A = y2 - y1
		B = x1 - x2
		C = A * x1 + B * y1
		return A, B, C

	def _is_on_line(self, px, py, x1, y1, x2, y2):
		return min(x1, x2) <= px <= max(x1, x2) and min(y1, y2) <= py <= max(y1, y2)

	def _remove_neighbors_intersections(self, intersections):
		filtered = []
		for i, (x1, y1) in enumerate(intersections):
			keep = True
			for (x2, y2) in filtered:
				if np.sqrt((x1 - x2) ** 2 + (y1 - y2) ** 2) < self.intersections_min_dist:
					keep = False  # stop checking if a close one is found
					break
			if keep:
				filtered.append((x1, y1))
		return np.array(filtered)


## Video Analysis

In [654]:
ball_tracker_flag = 1

In [655]:
input_video_path = 'input_video.mp4'
frames = read_video(input_video_path)

### Detections

In [None]:
player_tracker = PlayerTracker()
player_detection = player_tracker.detect_video(frames)

In [None]:
if ball_tracker_flag:
  ball_tracker_model = 'ball_best.pt'
  ball_tracker = BallTracker(ball_tracker_model)
  ball_detections = ball_tracker.detect_video(frames)
  ball_detections = ball_tracker.interpolate_detections(ball_detections)

In [None]:
court_detector_model = 'keypoints_model_epoch_10.pth'
court_detector  = CourtLineDetector(court_detector_model)
court_detections = court_detector.detect_frame(frames)

court_intersections = CourtLineIntersections()
intersections = court_intersections.get_intersections(frames)
court_detections = court_intersections.update_keypoint_to_intersection(court_detections.reshape(-1, 2), intersections).reshape(-1)

In [659]:
player_detection = player_tracker.filter_playrs(court_detections, player_detection)

### Drawing

In [660]:
output = player_tracker.draw_video(frames, player_detection)

In [661]:
output = ball_tracker.draw_video(frames, ball_detections)

In [662]:
output = court_detector.draw_video(frames, court_detections)

In [663]:
output = drow_frames_number(frames)

### print

In [None]:
show_frames(frames)

## Save Video

In [665]:
from google.colab import files

out_path = 'output_video' + ".avi"
save_video(frames, out_path)
files.download(out_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>