# Export shot transitions
Run this notebook from an esper movies environment.

In [1]:
import pickle
import numpy as np
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import equal, overlaps
import matplotlib.pyplot as plt
from query.models import LabeledInterval, Labeler, Shot
import os
from tqdm import tqdm
from esper.prelude import *
from PIL import Image

# Load Shot Transitions

In [2]:
shots_qs = Shot.objects.filter(labeler__name__contains="manual")

In [3]:
shots = VideoIntervalCollection.from_django_qs(shots_qs)

In [4]:
window_size = 16
stride = 16

shot_boundaries = shots.map(
    lambda intrvl: (intrvl.start, intrvl.start, intrvl.payload)
).set_union(
    shots.map(lambda intrvl: (intrvl.end + 1, intrvl.end + 1, intrvl.payload))
).coalesce().filter(lambda intrvl: intrvl.payload != -1)

clips = shots.dilate(1).coalesce().dilate(-1).map(
    lambda intrvl: (
        intrvl.start - stride - ((intrvl.start - stride) % stride),
        intrvl.end + stride - ((intrvl.end + stride) % stride),
        intrvl.payload
    )
).dilate(1).coalesce().dilate(-1)

In [5]:
VAL_WINDOWS = '/app/data/shot_detection_weak_labels/validation_windows_same_val_test.pkl'
TEST_WINDOWS = '/app/data/shot_detection_weak_labels/test_windows_same_val_test.pkl'

In [6]:
with open(VAL_WINDOWS, 'rb') as f:
    val_windows_by_video_id = pickle.load(f)
with open(TEST_WINDOWS, 'rb') as f:
    test_windows_by_video_id = pickle.load(f)

In [7]:
shot_boundary_tuples = [
    (video_id, intrvl.start)
    for video_id in shot_boundaries.get_allintervals()
    for intrvl in shot_boundaries.get_intervallist(video_id).get_intervals()
]

In [8]:
val_frames = [
    (video_id, frame_number, 1 if (video_id, frame_number) in shot_boundary_tuples else 0)
    for video_id, window_start, window_end in val_windows_by_video_id
    for frame_number in range(window_start, window_end)
]

In [9]:
test_frames = [
    (video_id, frame_number, 1 if (video_id, frame_number) in shot_boundary_tuples else 0)
    for video_id, window_start, window_end in test_windows_by_video_id
    for frame_number in range(window_start, window_end)
]

# Export frames

In [10]:
all_frame_numbers = val_frames + test_frames

In [11]:
frame_numbers_by_video = {}
for video_id, frame_number, label in all_frame_numbers:
    if video_id not in frame_numbers_by_video:
        frame_numbers_by_video[video_id] = []
    frame_numbers_by_video[video_id].append(frame_number)

In [12]:
import hwang, storehouse

In [None]:
for video_id in tqdm(frame_numbers_by_video):
    video = Video.objects.get(id=video_id)
    backend = storehouse.StorageBackend.make_from_config(
        storehouse.StorageConfig.make_gcs_config(os.environ.get('BUCKET')))
    dec = hwang.Decoder(storehouse.RandomReadFile(backend, video.path))
    
    frame_nums = frame_numbers_by_video[video_id]
    frames = dec.retrieve(frame_nums)
    
    os.makedirs('/app/data/shot_transitions/images/{}'.format(video_id), exist_ok=True)
    
    for frame_num, frame in zip(frame_nums, frames):
        im = Image.fromarray(frame)
        im.save('/app/data/shot_transitions/images/{}/{:06d}.jpg'.format(video_id, frame_num))
    
    del frames

HBox(children=(IntProgress(value=0, max=28), HTML(value='')))

# Export train/val/test splits

In [None]:
os.makedirs('/app/shot_transitions/data', exist_ok=True)
with open('/app/shot_transitions/data/train.txt', 'w') as f:
    for video_id in train_set:
        for i, intrvl in enumerate(interview_labels[video_id].get_intervals()):
            f.write('{} {} {}\n'.format(video_id, i, intrvl['payload']))
with open('/app/shot_transitions/data/val.txt', 'w') as f:
    for video_id in val_set:
        for i, intrvl in enumerate(interview_labels[video_id].get_intervals()):
            f.write('{} {} {}\n'.format(video_id, i, intrvl['payload']))
with open('/app/shot_transitions/data/test.txt', 'w') as f:
    for video_id in test_set:
        for i, intrvl in enumerate(interview_labels[video_id].get_intervals()):
            f.write('{} {} {}\n'.format(video_id, i, intrvl['payload']))