In [None]:
import os
import sys
sys.path.insert(0, os.path.abspath("../"))

import glob

from models import VideoMotionPredictor

#### Load HF model

In [None]:
model = VideoMotionPredictor.from_pretrained("ai-forever/kandinsky-video-motion-predictor")
model.to("cuda");

#### Run some examples

In [None]:
asset_dir = "../assets/video_motion_predictor/examples/"
# backbone of the model is VideoMAEv2, it converts 16 frames to a single embedding
# fps=4 means we embed up to 4 seconds of a video in a single tensor
fps = 4.0
# max_frames is defined by the size of positional embeddings and what model saw
# during training, max size during training was 240, which is 1 min of video at 4 fps.
max_frames = 240

videos = glob.glob(os.path.join(asset_dir, "**", "*.mp4"), recursive=True)
predictions = model.inference(videos, fps=fps, max_frames=max_frames, return_dict=True)

chaotic = []
slideshow = []

for i, video in enumerate(videos):
    print(video)
    for k, v in predictions.items():
        print(f"{k}: {round(v[i], 3)}")
    print()

### Rank video examples

In [None]:
def get_top_low(data_tuple, key, top_k=3):
    sorted_data = sorted(data_tuple, key=lambda x: x[1][key], reverse=True)
    low = [el[0] for el in sorted_data][-top_k:]
    top = [el[0] for el in sorted_data][:top_k]
    return top, low

def format_list(lst):
    lst_str = "".join(["  " + el + "\n" for el in lst])
    return lst_str


prediction_list = []
keys = list(predictions.keys())
length = len(predictions[keys[0]])
for i in range(length):
    sample = {k: predictions[k][i] for k in keys}
    prediction_list.append(sample)

data_tuple = list(zip(videos, prediction_list))

top_camera, low_camera = get_top_low(data_tuple, key='camera_movement_score')
print("top camera:\n", format_list(top_camera), sep='')
print("low camera:\n", format_list(low_camera), sep='')

top_object, low_object = get_top_low(data_tuple, key='object_movement_score')
print("top object:\n", format_list(top_object), sep='')
print("low object:\n", format_list(low_object), sep='')

top_dynamics, low_dynamics = get_top_low(data_tuple, key='dynamics_score')
print("top dynamics:\n", format_list(top_dynamics), sep='')
print("low dynamics:\n", format_list(low_dynamics), sep='')