# Export conversations
Run this notebook from an esper movies environment.

In [56]:
import pickle
import numpy as np
from rekall.video_interval_collection import VideoIntervalCollection
from rekall.interval_list import IntervalList
from rekall.temporal_predicates import equal, overlaps
import matplotlib.pyplot as plt
from query.models import LabeledInterval, Labeler
import os

# Load Conversation Query Results

In [4]:
videos = Video.objects.filter(ignore_film=False, year__gte=1935).exclude(
    genres__name="animation"
).order_by('id')

In [5]:
conversations = VideoIntervalCollection({
    video.id: pickle.load(open('/app/data/conversations/{}.pkl'.format(video.id), 'rb')).get_intervallist(video.id)
    for video in videos
})

# Load Ground Truth Conversations

In [2]:
conversations_gt_qs = LabeledInterval.objects.filter(labeler__name__contains="conversations")
conversations_gt = VideoIntervalCollection.from_django_qs(conversations_gt_qs)

In [6]:
conversations_in_gt_bounds = conversations.map(
    lambda intrvl: (intrvl.start, intrvl.end, 0)
).filter_against(
    conversations_gt,
    lambda conv, gt: conv.end < gt.end
).filter_against(
    conversations.filter_against(
        conversations_gt,
        lambda conv, gt: conv.start > gt.start
    ),
    equal()
)

# Split into train, validation, test sets

In [20]:
video_ids = sorted(list(conversations_in_gt_bounds.get_allintervals().keys()))

In [21]:
for k in video_ids:
    print(Video.objects.get(id=k).name, k)

apollo 13 15
fight club 61
kill bill vol 2 98
stir crazy 192
the godfather part iii 216
erin brockovich 352
hang em high 372
harry potter and the chamber of secrets 374
man of la mancha 432
ordinary people 459
stage fright 517


In [25]:
val_set = [15, 61, 216, 374]

In [26]:
for v in val_set:
    print(conversations_in_gt_bounds.get_intervallist(v).size())

17
27
16
14


In [27]:
for v in video_ids:
    print(conversations_in_gt_bounds.get_intervallist(v).size())

17
27
70
73
16
37
89
14
79
60
55


In [28]:
test_set = [372]

In [74]:
train_set = [98, 192, 352, 432, 459, 517]

In [33]:
# Returns precision, recall, precision_per_item, recall_per_item
def compute_statistics(query_intrvllists, ground_truth_intrvllists):
    total_query_time = 0
    total_query_segments = 0
    total_ground_truth_time = 0
    total_ground_truth_segments = 0
    
    for video in query_intrvllists:
        total_query_time += query_intrvllists[video].coalesce().get_total_time()
        total_query_segments += query_intrvllists[video].size()
    for video in ground_truth_intrvllists:
        total_ground_truth_time += ground_truth_intrvllists[video].coalesce().get_total_time()
        total_ground_truth_segments += ground_truth_intrvllists[video].size()
        
    total_overlap_time = 0
    overlapping_query_segments = 0
    overlapping_ground_truth_segments = 0
    
    for video in query_intrvllists:
        if video in ground_truth_intrvllists:
            query_list = query_intrvllists[video]
            gt_list = ground_truth_intrvllists[video]
            
            total_overlap_time += query_list.overlaps(gt_list).coalesce().get_total_time()
            overlapping_query_segments += query_list.filter_against(gt_list, predicate=overlaps()).size()
            overlapping_ground_truth_segments += gt_list.filter_against(query_list, predicate=overlaps()).size()
    
    if total_query_time == 0:
        precision = 1.0
        precision_per_item = 1.0
    else:
        precision = total_overlap_time / total_query_time
        precision_per_item = overlapping_query_segments / total_query_segments
    
    if total_ground_truth_time == 0:
        recall = 1.0
        recall_per_item = 1.0
    else:
        recall = total_overlap_time / total_ground_truth_time
        recall_per_item = overlapping_ground_truth_segments / total_ground_truth_segments
    
    return precision, recall, precision_per_item, recall_per_item

def print_statistics(query_intrvllists, ground_truth_intrvllists):
    precision, recall, precision_per_item, recall_per_item = compute_statistics(
        query_intrvllists, ground_truth_intrvllists)

    print("Precision: ", precision)
    print("Recall: ", recall)
    print("F1: ", 2 * precision * recall / (precision + recall))
    print("Precision Per Item: ", precision_per_item)
    print("Recall Per Item: ", recall_per_item)
    print("F1 Per Item: ", 2 * precision_per_item * recall_per_item / (
        precision_per_item + recall_per_item
    ))

In [34]:
print_statistics(
    {
        v: conversations_in_gt_bounds.get_intervallist(v)
        for v in val_set
    },
    {
        v: conversations_gt.get_intervallist(v)
        for v in val_set
    })

Precision:  0.7357072678509785
Recall:  0.750382558060803
F1:  0.7429724529541509
Precision Per Item:  0.6486486486486487
Recall Per Item:  0.8888888888888888
F1 Per Item:  0.75


In [35]:
print_statistics(
    {
        v: conversations_in_gt_bounds.get_intervallist(v)
        for v in test_set
    },
    {
        v: conversations_gt.get_intervallist(v)
        for v in test_set
    })

Precision:  0.6773267922432545
Recall:  0.764000365425525
F1:  0.7180575402788085
Precision Per Item:  0.6966292134831461
Recall Per Item:  0.76
F1 Per Item:  0.7269361308238198


# Split into segments

In [36]:
def get_fps_map(vids):
    vs = Video.objects.filter(id__in=vids)
    return {v.id: v.fps for v in vs}

def frame_second_conversion(c, mode='f2s'):
    fps_map = get_fps_map(set(c.get_allintervals().keys()))
    
    def second_to_frame(fps):
        def map_fn(intrvl):
            i2 = intrvl.copy()
            i2.start = int(intrvl.start * fps)
            i2.end = int(intrvl.end * fps)
            return i2
        return map_fn
    
    def frame_to_second(fps):
        def map_fn(intrvl):
            i2 = intrvl.copy()
            i2.start = int(intrvl.start / fps)
            i2.end = int(intrvl.end / fps)
            return i2
        return map_fn
    
    if mode=='f2s':
        fn = frame_to_second
    if mode=='s2f':
        fn = second_to_frame
    output = {}
    for vid, intervals in c.get_allintervals().items():
        output[vid] = intervals.map(fn(fps_map[vid]))
    return VideoIntervalCollection(output)

def frame_to_second_collection(c):
    return frame_second_conversion(c, 'f2s')

def second_to_frame_collection(c):
    return frame_second_conversion(c, 's2f')

In [43]:
interval = 10
segs_dict = {}
for video_id in video_ids:
    video = Video.objects.get(id=video_id)
    iset = IntervalList([
        (i - interval / 2, i + interval / 2, 0)
        for i in range(0, int(video.num_frames / video.fps), interval)
    ])
    segs_dict[video_id] = iset
    
segments = VideoIntervalCollection(segs_dict).filter_against(
    frame_to_second_collection(conversations_gt),
    lambda seg, gt: seg.end < gt.end
).filter_against(
    segments.filter_against(
        frame_to_second_collection(conversations_gt),
        lambda seg, gt: seg.start > gt.start
    ),
    equal()
)

# Export Images

In [48]:
segments_frames = second_to_frame_collection(segments)

In [51]:
from PIL import Image
from tqdm import tqdm

In [54]:
import hwang, storehouse

In [78]:
for video_id in tqdm(video_ids):
    video = Video.objects.get(id=video_id)
    backend = storehouse.StorageBackend.make_from_config(
        storehouse.StorageConfig.make_gcs_config(os.environ.get('BUCKET')))
    dec = hwang.Decoder(storehouse.RandomReadFile(backend, video.path))
    
    frame_nums = [
        int((intrvl.start + intrvl.end) / 2)
        for intrvl in segments_frames.get_intervallist(video_id).get_intervals()
    ]
    
    frames = dec.retrieve(frame_nums)
    
    os.makedirs('/app/data/conversation_export/images/{}'.format(video_id), exist_ok=True)
    
    for i, frame in enumerate(frames):
        im = Image.fromarray(frame)
        im.save('/app/data/conversation_export/images/{}/{:04d}.jpg'.format(video_id, i))


  0%|                                                                                                                                                                                                                 | 0/11 [00:00<?, ?it/s][A
  9%|██████████████████▎                                                                                                                                                                                      | 1/11 [01:23<13:50, 83.08s/it][A
 18%|████████████████████████████████████▌                                                                                                                                                                    | 2/11 [03:31<14:31, 96.79s/it][A
 27%|██████████████████████████████████████████████████████▌                                                                                                                                                 | 3/11 [11:00<26:59, 202.42s/it][A
 36%|██████████████████████████████

# Export Labels

In [64]:
def size(vic):
    return {
        vid: vic.get_intervallist(vid).size()
        for vid in vic.get_allintervals()
    }

In [58]:
segments_all_negative = segments.map(
    lambda intrvl: (intrvl.start, intrvl.end, 0)
)

In [66]:
conversations_gt.get_allintervals().keys()

dict_keys([192, 432, 98, 372, 517, 352, 216, 374, 459, 61, 15])

In [67]:
segments.get_allintervals().keys()

dict_keys([192, 352, 98, 432, 372, 517, 374, 216, 459, 61, 15])

In [71]:
conversations_gt.get_intervallist(15)

[<Interval start:2578 end:4100 payload:642>, <Interval start:4244 end:4826 payload:643>, <Interval start:5098 end:5828 payload:644>, <Interval start:7757 end:9546 payload:645>, <Interval start:9602 end:10300 payload:646>, <Interval start:12393 end:12943 payload:647>, <Interval start:13088 end:13884 payload:648>, <Interval start:14146 end:15212 payload:649>, <Interval start:15427 end:16116 payload:650>, <Interval start:18040 end:19198 payload:651>, <Interval start:20801 end:23368 payload:652>, <Interval start:24572 end:26185 payload:653>, <Interval start:26735 end:28753 payload:654>, <Interval start:29462 end:30873 payload:655>, <Interval start:31768 end:34618 payload:656>]

In [72]:
conversation_segments = segments.filter_against(
    frame_to_second_collection(conversations_gt), predicate = overlaps()
).map(
    lambda intrvl: (intrvl.start, intrvl.end, 1)
)

conversation_labels = segments_all_negative.minus(
    conversation_segments
).set_union(conversation_segments)

print(size(conversation_segments))
print(size(conversation_labels))

{192: 386, 352: 195, 98: 483, 432: 521, 372: 358, 517: 482, 374: 48, 216: 99, 459: 578, 61: 93, 15: 94}
{192: 632, 352: 279, 98: 706, 432: 713, 372: 657, 517: 640, 374: 75, 216: 121, 459: 703, 61: 222, 15: 132}


In [77]:
os.makedirs('/app/data/conversation_export/data', exist_ok=True)
with open('/app/data/conversation_export/data/train.txt', 'w') as f:
    for video_id in train_set:
        for i, intrvl in enumerate(conversation_labels.get_intervallist(
            video_id).get_intervals()):
            f.write('{} {} {}\n'.format(video_id, i, intrvl.payload))
with open('/app/data/conversation_export/data/val.txt', 'w') as f:
    for video_id in val_set:
        for i, intrvl in enumerate(conversation_labels.get_intervallist(
            video_id).get_intervals()):
            f.write('{} {} {}\n'.format(video_id, i, intrvl.payload))
with open('/app/data/conversation_export/data/test.txt', 'w') as f:
    for video_id in test_set:
        for i, intrvl in enumerate(conversation_labels.get_intervallist(
            video_id).get_intervals()):
            f.write('{} {} {}\n'.format(video_id, i, intrvl.payload))