In [None]:
from paddleocr import PaddleOCR

class PaddleModel:

    def __init__(self, device: int = 0) -> None:
        self.model = PaddleOCR(
            use_angle_cls=True,
            lang="en",
            show_log=True,
            det_db_score_mode="slow",
            ocr_version="PP-OCRv4",
            rec_algorithm="SVTR_LCNet",
            drop_score=0.0,
            use_gpu=True,
            gpu_id=device,
            gpu_mem=1000,
        )

In [None]:
!export CUDA_VISIBLE_DEVICES=2,3,4,5,6,7

import concurrent.futures 
import cv2
import torch
import json
import os
import re
import time
import numpy as np
import concurrent
import logging

from tqdm import tqdm
from PIL import Image
from typing import List
from utilities.models import Models

from viz import visualize_timestamps
from utilities.constants import *

MAX_GPUS = 6

def process_dir(dir_path: str, data_out_path: str, viz_out_path=None):
    """
    Extract all timestamps in a directory,
    return timestamps as dict.
    """

    assert os.path.isdir(
        dir_path), f"Error: bad path to video directory: {dir_path}"
    os.makedirs(data_out_path, exist_ok=True)
    if viz_out_path is not None:
        assert type(
            viz_out_path) is str, "Error: path links must be of type str."
        os.makedirs(viz_out_path, exist_ok=True)

    valid_formats = ['avi', 'mp4']
    vids = os.listdir(dir_path)
    for vid in vids: 
        extension = vid.split(".")[1]
        if extension not in valid_formats:
            vids.remove(vid)

    timestamps = {}
    with concurrent.futures.ProcessPoolExecutor(max_workers=MAX_GPUS) as executor:
        with tqdm(total=len(vids), desc="Processing Videos") as pbar:
            while len(vids) > 0:
                processes = []
                video_paths = []
                for device in range(MAX_GPUS):
                    if len(vids) == 0:
                        break
                    video_path = os.path.join(dir_path, vids[0])
                    data_path = os.path.join(data_out_path, vids[0].replace(
                        ".mp4", ".json").replace(".avi", ".json"))
                    process = executor.submit(extract_timestamps_from_video, video_path, data_path, device=device)
                    processes.append(process)
                    video_paths.append(video_path)  # Store the video_path corresponding to each process
                    vids.remove(vids[0])
                for process, video_path in zip(concurrent.futures.as_completed(processes), video_paths):
                    timestamps[video_path] = process.result()
                    pbar.update(1)
                    
    return timestamps


def extract_timestamps_from_video(video_path: str, save_path: str, device: int = 0):
    """
    Given a path to a basketball broadcast video,
    returns a timestamps dict.
    """

    assert os.path.exists(video_path)

    # print(f"Extracting timestamps for video at {video_path} \n")
    tr_x1, tr_y1, tr_x2, tr_y2 = None, None, None, None

    # TODO: All roi's extracted on device = 0
    time_remaining_roi = extract_roi_from_video(video_path)

    if time_remaining_roi is not None:
        tr_x1, tr_y1, tr_x2, tr_y2 = time_remaining_roi
    cap = cv2.VideoCapture(video_path)
    frames_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    timestamps = {}
    quarter = video_path[-5]  # period_x.mp4
    step = 5

    model = PaddleModel(device=device)

    for frame_index in range(frames_cnt):
        ret, frame = cap.read()
        if not ret:
            break
        time_remaining_img = None
        time_remaining = None
        if frame_index % step == 0:
            if time_remaining_roi is not None:
                assert tr_x1 and tr_y1 and tr_x2 and tr_y2
                time_remaining_img = frame[
                    tr_y1 - PAD : tr_y2 + 2 * PAD, tr_x1 - PAD : tr_x2 + 2 * PAD
                ]
            if time_remaining_img is not None:
                time_remaining = extract_time_remaining_from_image(
                    Image.fromarray(time_remaining_img),
                    model=model,
                )
                time_remaining = convert_time_to_float(time_remaining)
        timestamps[str(frame_index)] = {
            "quarter": quarter,
            "time_remaining": time_remaining,
        }
        if frame_index == BREAK:
            break

    post_process_timestamps(timestamps)
    return timestamps


def extract_roi_from_video(video_path: str):
    """
    Find time-remaining roi from video. Assumes static, naive approach.
    Returns a tensor with format: [x1, y1, x2, y2] or None if no
    ROI is found.
    """

    assert os.path.isfile(video_path), f"Error: bad path to video {video_path}."
    # assert video_path[-4:] == '.mp4'

    # print(f"Finding time-remaining ROI for video at {video_path}")
    cap = cv2.VideoCapture(video_path)
    frames_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    time_remaining_roi = None

    # TODO: skip through vid at one second intervals
    highest_conf = 0.0
    best_roi = None
    step = 30

    for i in range(frames_cnt):
        ret, frame = cap.read()
        if not ret:
            break
        if i % step == 0:
            results = Models.yolo(frame, verbose=False)
            classes, conf, boxes = (
                results[0].boxes.cls,
                results[0].boxes.conf,
                results[0].boxes.xyxy,
            )
            classes_conf = torch.stack((classes, conf), dim=1)
            predictions = torch.cat((classes_conf, boxes), dim=1)
            conf_mask = predictions[:, 1] > CONF_THRESH
            pred_thresh = predictions[conf_mask]
            for row in pred_thresh:
                if row[0] == QUARTER_KEY:
                    pass
                elif row[0] == TIME_REMAINING_KEY:
                    time_remaining_roi = row[2:].to(torch.int)
            for row in predictions:
                if row[0] == QUARTER_KEY:
                    pass
                elif row[0] == TIME_REMAINING_KEY:
                    if row[1] > highest_conf:
                        highest_conf = row[1]
                        best_roi = row[2:].to(torch.int)
            if time_remaining_roi is not None:
                break
    return best_roi


def find_time_remaining_from_results(results: List[str]):
    """
    Matches any string showing a valid time remaining of 20 minutes or less
    assumes brodcasts use MM:SS for times > 1 minute, and SS.S for times < 1 minute
    """
    if results is None:
        return None
    time_remaining_regex = r"(20:00)|(0[0-9]?:[0-9][0-9](\.[0-9])?)|([1-9]:[0-5][0-9])|(1[0-9]:[0-5][0-9](\.[0-9])?)|([0-9]\.[0-9])|([1-5][0-9]\.[0-9])"
    for result in results:
        result = result.replace(" ", "")
        match = re.match(time_remaining_regex, result)
        if match is not None and match[0] == result:
            return result
    return None


def extract_time_remaining_from_image(image: Image.Image, model: PaddleModel):
    """
    Given a PIL Image object,
    returns either a valid formatted time-remaining str (e.g., '11:30')
    or None.
    """
    rgb_img = image.convert("RGB")
    results = extract_text_with_paddle(rgb_img, model=model)
    time_remaining = find_time_remaining_from_results(results)
    return time_remaining


def extract_text_with_paddle(image: Image.Image, model: PaddleModel) -> List[str]:
    """
    Returns a [str] containing all words found in a
    provided PIL image.
    """

    if image is None:
        return []
    ideal_height = 100
    scale_factor = ideal_height / image.height
    new_size = (int(image.width * scale_factor), int(image.height * scale_factor))
    image = image.resize(new_size)
    img_arr = np.array(image)

    # cv2.imwrite("preprocessed_img.png", img_arr)
    results = []

    # pred w/ paddleocr
    raw_result = model.model(img_arr)
    text_arr = raw_result[1]
    for pred in text_arr:
        word = pred[0]
        results.append(word)
    return results


def convert_time_to_float(time_remaining):
    """
    Coverts valid time-remaining str
    to float value representation.
    Return None if time-remaining is invalid.

    Ex: '1:30' -> 90.
    """

    if time_remaining is None:
        return None
    minutes, seconds = 0.0, 0.0
    if ":" in time_remaining:
        time_arr = time_remaining.split(":")
        minutes = float(time_arr[0])
        seconds = float(time_arr[1])
    elif "." in time_remaining:
        seconds = float(time_remaining)
    else:
        return None
    return (60.0 * minutes) + seconds


def post_process_timestamps(timestamps):
    """
    Interpolate timestamps in-place.
    """

    last_quarter, last_time = None, None

    for key in timestamps:
        quarter, time_remaining = (
            timestamps[key]["quarter"],
            timestamps[key]["time_remaining"],
        )
        if quarter:
            last_quarter = quarter
        else:
            timestamps[key]["quarter"] = last_quarter
        if time_remaining:
            last_time = time_remaining
        else:
            timestamps[key]["time_remaining"] = last_time

In [None]:
replace_str = "C:/Users/Levi/Desktop/quantitative-benchmark/test-set\\"
with_str = "/playpen-storage/levlevi/nba-positions-videos-dataset/testing/quantitative-benchmark/test-set/"

In [None]:
import os
import json

annotations_fp = "/playpen-storage/levlevi/nba-positions-videos-dataset/testing/quantitative-benchmark/annotations/annotations.json"
with open(annotations_fp, 'r') as f:
    annotations = json.load(f)

In [None]:
dir_path = "/playpen-storage/levlevi/nba-positions-videos-dataset/testing/quantitative-benchmark/test-set"
dummy = "/playpen-storage/levlevi/nba-positions-videos-dataset/testing/quantitative-benchmark/dummy"
timestamps = process_dir(dir_path, dummy)

In [None]:
ground_truth = {}
for k in annotations:
    ground_truth = annotations[k]
    break

timestamp_vals = [timestamps[k]['time_remaining'] for k in timestamps]
ground_truth_vals = [ground_truth[k]['time_on_clock'] for k in ground_truth]

In [None]:
timestamps

In [None]:
import matplotlib.pyplot as plt

# Plotting
plt.figure(figsize=(10, 6))  # Set the figure size

# Plotting the data
plt.plot(
    timestamp_vals, label="Timestamp Values", color="blue"
)
plt.plot(
    ground_truth_vals,
    label="Ground Truth Values",
    color="red",
)

# Adding titles and labels
plt.title("Timestamp vs Ground Truth Values")
plt.xlabel("Frame Idx")
plt.ylabel("Time Remaining")

# Adding a legend
plt.legend()

# Adding a grid
plt.grid(True)

# Tight layout to adjust for the rotated x-axis labels
plt.tight_layout()

# Show the plot
plt.show()