In [1]:
import os
from glob import glob

VIDEOS_DIR = '/mnt/sun/levlevi/nba-plus-statvu-dataset/game-replays'
TRACKS_DIR = '/playpen-storage/levlevi/player-re-id/src/testing/ocr_analysis/full_video_tracks'

"""
ffmpeg_command = [
                    "ffmpeg", "-ss", str(start_time), "-i", video_file,
                    "-t", "10", "-c:v", "copy", "-c:a", "copy", output_file
                ]
"""


out_dir = '/playpen-storage/levlevi/player-re-id/src/testing/ocr_analysis/sample_tracks_nba_100'
tracks_fps = glob(os.path.join(TRACKS_DIR, "*.txt"))

In [2]:
import pandas as pd
from tqdm import tqdm

# concatinate all player tracks
tracks = []

for fp in tqdm(tracks_fps):
    with open(fp, 'r') as f:
        lines = f.readlines()
        index = ['frame', 'entity_id', 'x1', 'y1', 'width', 'height', 'conf', 'pad2', 'pad3', 'pad4']
        df = pd.read_csv(fp, names=index)[['frame', 'entity_id', 'x1', 'y1', 'width', 'height', 'conf']]
        entities = df.entity_id.unique()
        for entity in entities:
            df_entity = df[df.entity_id == entity]
            tracks.append((fp, df_entity))

100%|██████████| 28/28 [00:11<00:00,  2.47it/s]


In [3]:
import random

random.shuffle(tracks)
sample_tracks = tracks[0: 100]

In [4]:
videos_fps = glob(os.path.join(VIDEOS_DIR, "*.mp4"))

In [5]:
import cv2
import Levenshtein

def measure_similarity(str1, str2):
    distance = Levenshtein.distance(str1, str2)
    max_len = max(len(str1), len(str2))
    similarity_ratio = 1 - (distance / max_len)
    return similarity_ratio


def find_closest_fp(fp: str):
    max_similarity = 0
    for video_fp in videos_fps:
        similarity = measure_similarity(fp, video_fp)
        if similarity > max_similarity:
            max_similarity = similarity
            closest_fp = video_fp
    return closest_fp


dst_dir = '/playpen-storage/levlevi/player-re-id/src/testing/ocr_analysis/sample_tracks_nba_100'
for track in tqdm(sample_tracks):
    
    fp = track[0]
    df = track[1]
    video_name = fp.split('/')[-1].replace('.txt', '.mp4')
    video_path = os.path.join(VIDEOS_DIR, video_name)
    video_path = find_closest_fp(video_path)
    
    cap = cv2.VideoCapture(video_path)
    out_subdir = os.path.join(dst_dir, video_name)
    os.makedirs(out_subdir, exist_ok=True)
    
    # set and save each frame from track
    for row in df.itertuples():
        row_out_dir = os.path.join(out_subdir, 'frame_{}'.format(row.frame))
        frame = row.frame
        x1 = int(row.x1)
        y1 = int(row.y1)
        width = int(row.width)
        height = int(row.height)
        frame_path = os.path.join(out_subdir, f"{frame:06d}.jpg")
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame)
        ret, frame = cap.read()
        if ret:
            frame = frame[y1:y1+height, x1:x1+width]
            try:
                cv2.imwrite(frame_path, frame)
            except:
                pass

 30%|███       | 30/100 [09:10<21:25, 18.36s/it]


KeyboardInterrupt: 