In [157]:
import subprocess
import numpy as np
from pathlib import Path
from fractions import Fraction
import json
from keyframes import get_keyframes_ffmpeg
from mvbench import find_video
import multiprocessing
from tqdm.auto import tqdm

In [254]:
class X264Sampler:
    def __init__(self, target_frames=16, max_feasible_iter=100, max_binary_iter=9):
        self.target = target_frames
        self.feasible_iter = max_feasible_iter
        self.binary_iter = max_binary_iter

    def get_video_meta(self, path):
        if path.is_dir():
            return len([f for f in path.iterdir() if f.is_file()]), Fraction(3)
        # (Same helper as before to get FPS/Duration)
        cmd = [
            "ffprobe", 
            "-v", "error",
            "-select_streams", "v:0",
            "-show_entries", "stream=avg_frame_rate,nb_frames", 
            "-of", "json", 
            str(path)
        ]
        out = subprocess.check_output(cmd)
        data = json.loads(out)["streams"][0]
        fps_str = data.get("avg_frame_rate", "30/1")
        if fps_str == "0/0": 
            fps_str = "30/1"
        num, den = map(int, fps_str.split('/'))
        fps = Fraction(num, den) if den != 0 else Fraction(30, 1)
        
        # Estimate duration/frames
        nb_frames = data.get("nb_frames")
        if nb_frames:
            total = int(nb_frames)
        else:
            # Fallback
            cmd_dur = ["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of", "csv=p=0", str(path)]
            dur = float(subprocess.check_output(cmd_dur).strip())
            total = int(dur * float(fps))
            
        return total, fps

    def _run_probe(self, path, fps, sc, k_min, k_max, scale="240", lc=None, start=None, end=None):
        frames, count, _ = get_keyframes_ffmpeg(
            path, 
            fps, 
            keyint_min=k_min, 
            keyint_max=k_max, 
            sc_threshold=sc,
            bframes=0,
            trim_start=start,
            trim_end=end,
            n_threads=1,
            scale=scale,
            lookahead_cap=lc
        )
        return frames, count

    
    def solve(self, path, start=None, end=None):
        total_frames, fps = self.get_video_meta(path)
        start_frame = 0
        
        if end is not None:
            end_frame = int(fps * end)
            if end_frame >= total_frames:
                end = None
            else:
                total_frames = end_frame
        if start is not None:
            start_frame = int(fps * start)
            # Account for incorrect start timings
            if start_frame >= total_frames:
               start = None
            else:
                total_frames -= start_frame
        
        # If we have <= target frames, return all the frames we can
        if total_frames <= self.target:
            print(f"{path} has {total_frames} frames <= {self.target}, uniformly sampling.")
            return {"keyframes": self._interp_frame_indices(list(range(start_frame, start_frame+total_frames)))}

        # Fix keyint_min (nyquist sampling freq), equal to min(4, total_frames/(target * 2)) so that we don't just end up uniformly sampling
        keyint_min = min(4, int(total_frames/(self.target * 2)))
        # print("Finding feasible")
        # Find feasible keyint_max
        feasible, keyint_max, keyframes, keyframe_count = self._find_feasible_max(path, total_frames, fps, keyint_min, start=start, end=end)

        # If we didn't find a feasible configuration, default to uniform sampling
        if not feasible:
            return {"keyframes": self._interp_frame_indices(list(range(start_frame, start_frame+total_frames)))}

        # Return early if we have the correct amount of frames from feasibility check
        if keyframe_count == self.target:
            return  {
                "keyframes": keyframes,
                "keyint_min": keyint_min,
                "keyint_max": keyint_max,
                "scenecut": sc
            }
        # print("BS", path, fps, keyint_min, keyint_max, start, end, start_frame, total_frames, keyframe_count)
        # Run binary search to find optimal scenecut
        sc, keyframes, keyframe_count = self._find_scenecut_bs(path, fps, keyint_min, keyint_max, start=start, end=end)

        # Interpolate down if needed and return
        return  {
            "keyframes": self._interp_frame_indices(keyframes),
            "keyint_min": keyint_min,
            "keyint_max": keyint_max,
            "scenecut": sc
        }

    def _find_scenecut_bs(self, path, fps, keyint_min, keyint_max, start=None, end=None, sc_max=300):
        # Binary search for optimal scenecut
        sc = sc_max
        sc_min = 0
        last_feasible_sc = sc
        last_feasible_keyframes = []
        for i in range(self.binary_iter):
            keyframes, count = self._run_probe(path, fps, sc, keyint_min, keyint_max, start=start, end=end)
            print(sc, count)
            if count == self.target:
                return sc, keyframes, count
            if count > self.target:
                last_feasible_sc = sc
                last_feasible_keyframes = keyframes
                sc_max = sc
            else:
                sc_min = sc
            sc = sc_min + (sc_max - sc_min) // 2
        print(f"WARN: binary scenecut search reahed max iterations.\npath={path} keyint_min={keyint_min} keyint_max={keyint_max} sc={sc} feasible_sc={last_feasible_sc} n_kfs={len(last_feasible_keyframes)}")
        return last_feasible_sc, last_feasible_keyframes, len(last_feasible_keyframes)
        
    def _find_feasible_max(self, path, total_frames, fps, keyint_min, sc_max=300, alpha=0.9, start=None, end=None):
        # First find feasible start, use keyint_max = total_frames - target*keyint_min and scenecut = 300. While we are getting less keyframes than target, decrease 
        # keyint_max by a factor. Do this until n_frames >= target.
        infeasible = (False, -1, [], 0)
        if total_frames < self.target:
            print(f"WARN: feasibility check failed for path={path}, less than {self.target} frames")
            return infeasible
        keyint_max = total_frames - self.target*keyint_min
        print(keyint_max)
        for i in range(self.feasible_iter):
            frames, count = self._run_probe(path, fps, sc_max, keyint_min, keyint_max, start=start, end=end)
            if count >= self.target:
                return True, keyint_max, frames, count
            keyint_max = int(alpha * keyint_max)
            if keyint_max <= keyint_min:
                print(f"WARN: feasiblility check hit keyint_min\npath={path} total_frames={total_frames} keyint_min={keyint_min} keyint_max={keyint_max}")
                return infeasible
        print(f"WARN: feasiblility check reached max iterations\npath={path} total_frames={total_frames} keyint_min={keyint_min} keyint_max={keyint_max}")
        return infeasible
            
        

    def _interp_frame_indices(self, indices):
        if len(indices) == self.target:
            return indices
        target_idx = np.linspace(0, len(indices)-1, self.target).astype(int)
        return [indices[i] for i in target_idx]

sampler = X264Sampler(target_frames=16)

In [212]:
video_path = Path("data/MVBench/video/clevrer/video_validation/video_10009.mp4")
sampler._find_scenecut_bs(video_path, fps, 1, 128, start=1.5)


300 18
150 2
225 5
262 8
281 11
290 14
295 16


(295,
 [38, 44, 50, 57, 64, 70, 76, 81, 86, 91, 97, 103, 109, 115, 121, 127],
 16)

In [250]:
N_PARALLEL = 1

data_list = {
    # "Action Sequence": ("action_sequence.json", "data/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
    # "Action Prediction": ("action_prediction.json", "data/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
    # "Action Antonym": ("action_antonym.json", "data/MVBench/video/ssv2_video/", "video", False),
    "Fine-grained Action": ("fine_grained_action.json", "data/MVBench/video/Moments_in_Time_Raw/videos/", "video", False),
    # "Unexpected Action": ("unexpected_action.json", "data/MVBench/video/FunQA_test/test/", "video", False),
    # "Object Existence": ("object_existence.json", "data/MVBench/video/clevrer/video_validation/", "video", False),
    # "Object Interaction": ("object_interaction.json", "data/MVBench/video/star/Charades_v1_480/", "video", True), # has start & end
    # "Object Shuffle": ("object_shuffle.json", "data/MVBench/video/perception/videos/", "video", False),
    # "Moving Direction": ("moving_direction.json", "data/MVBench/video/clevrer/video_validation/", "video", False),
    # "Action Localization": ("action_localization.json", "data/MVBench/video/sta/sta_video/", "video", True),  # has start & end
    # "Scene Transition": ("scene_transition.json", "data/MVBench/video/scene_qa/video/", "video", False),
    # "Action Count": ("action_count.json", "data/MVBench/video/perception/videos/", "video", False),
    # "Moving Count": ("moving_count.json", "data/MVBench/video/clevrer/video_validation/", "video", False),
    # "Moving Attribute": ("moving_attribute.json", "data/MVBench/video/clevrer/video_validation/", "video", False),
    # "State Change": ("state_change.json", "data/MVBench/video/perception/videos/", "video", False),
    # "Fine-grained Pose": ("fine_grained_pose.json", "data/MVBench/video/nturgbd/", "video", False), # TODO: obtain later?
    # "Character Order": ("character_order.json", "data/MVBench/video/perception/videos/", "video", False),
    # "Egocentric Navigation": ("egocentric_navigation.json", "data/MVBench/video/vlnqa/", "video", False),
    # Some of the timestamps for episodic reasoning are invalid, have extra checks for this
    # "Episodic Reasoning": ("episodic_reasoning.json", "data/MVBench/video/tvqa/frames_fps3_hq/", "frame", True),  
    # "Counterfactual Inference": ("counterfactual_inference.json", "data/MVBench/video/clevrer/video_validation/", "video", False),
}

data_dir = "data/MVBench/json"

def run_mvbench_one(args):
    i, video_dir, row, trimmed = args
    video_path = find_video(Path(video_dir), Path(row.video))
    start = None
    end = None
    if trimmed:
        start = row.start
        end = row.end
    sampler = X264Sampler(target_frames=16)
    res = sampler.solve(video_path)
    if trimmed:
        res["start"] = start
        res["end"] = end
    return i, res

def run_mvbench_category(mv_df: pd.DataFrame, category: str, video_dir: Path, trimmed: bool):
    results = []
    total_correct = 0
    all_args = [[i, video_dir, row, trimmed] for i, row in mv_df.iterrows()]
    results = list(tqdm(map(run_mvbench_one, all_args), total=mv_df.shape[0]))
    results_map = {i: kfs for (i, kfs) in results}
    return results_map

def run_mvbench_h264():
    results_dict = {}
    for i, category in enumerate(data_list):
        json, video_dir, _, trimmed = data_list[category]
        print(f"({i + 1:02d}/19) Starting", category)
        
        df = pd.read_json(Path(data_dir) / Path(json), orient="records")
        results = run_mvbench_category(df, category, Path(video_dir), trimmed)
        results_dict[category] = results
        print("Finished", category)
    print("Done")
    return results_dict

In [None]:
results = run_mvbench_h264()

In [253]:
with open(f"computed_keyframes/test.json", "w") as f:
    json.dump(results, f)