In [None]:
import os
import math
import numpy as np
import ffmpeg
import cv2
import imutils
from imutils.video import count_frames
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

In [None]:
ROOT_DIR = os.path.dirname(os.getcwd())
DATA_FOLDER = os.path.join(ROOT_DIR, "data")

In [None]:
video_waltter_path = os.path.join(DATA_FOLDER, "example_waltter_synchronized.mov")
video_vikture_path = os.path.join(DATA_FOLDER, "example_vikture_late_15s_synchronized.mov")

In [None]:
video_left_capture = cv2.VideoCapture(video_vikture_path)
video_right_capture = cv2.VideoCapture(video_waltter_path)

In [None]:
left_n_frames = int(video_left_capture.get(cv2.CAP_PROP_FRAME_COUNT))
right_n_frames = int(video_right_capture.get(cv2.CAP_PROP_FRAME_COUNT))

print(left_n_frames)
print(right_n_frames)

total_frames = min(left_n_frames, right_n_frames)
print(total_frames)

left_width = int(video_left_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 
left_height = int(video_left_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
left_fps = video_left_capture.get(cv2.CAP_PROP_FPS)

right_width = int(video_right_capture.get(cv2.CAP_PROP_FRAME_WIDTH)) 
right_height = int(video_right_capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
right_fps = video_right_capture.get(cv2.CAP_PROP_FPS)

print(left_width)
print(left_height)
print(left_fps)

print(right_width)
print(right_height)
print(right_fps)

In [None]:
final_fps = 60.0
final_height = 1080
final_width = 1920
fourcc = cv2.VideoWriter_fourcc('M','J','P','G')

In [None]:
video_path = os.path.join(DATA_FOLDER, "example_optical_flow_video.avi")
video_output = cv2.VideoWriter(video_path, fourcc, final_fps, (final_width,final_height))

In [None]:
def equalize_histogram(rgb_image):
    r_image, g_image, b_image = cv2.split(rgb_image)

    r_image_eq = cv2.equalizeHist(r_image)
    g_image_eq = cv2.equalizeHist(g_image)
    b_image_eq = cv2.equalizeHist(b_image)

    image_eq = cv2.merge([r_image_eq, g_image_eq, b_image_eq])
    return image_eq

In [None]:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))

def apply_clahe(image):
    image_lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    image_lab[...,0] = clahe.apply(image_lab[...,0])

    bgr_clahe_image = cv2.cvtColor(image_lab, cv2.COLOR_LAB2BGR)
    rgb_clahe_image = cv2.cvtColor(bgr_clahe_image, cv2.COLOR_BGR2RGB)

    return rgb_clahe_image

In [None]:
def preprocess_image(image, equalize_hist=True, clahe=False):
    if equalize_hist:
        image = equalize_histogram(image)
    if clahe:
        image = apply_clahe(image)
    
    return image

def preprocess_images(images):
    preprocessed_images = []
    
    for image in images:
        preprocessed_image = preprocess_image(image)
        preprocessed_images.append(preprocessed_image)
        
    return preprocessed_images

In [None]:
def calculate_optical_flow(frame1, frame2):
    gray1 = cv2.cvtColor(frame1, cv2.COLOR_BGR2GRAY)
    gray1 = cv2.GaussianBlur(gray1, (21, 21), 0)
    
    gray2 = cv2.cvtColor(frame2, cv2.COLOR_BGR2GRAY)
    gray2 = cv2.GaussianBlur(gray2, (21, 21), 0)

    frame_delta = cv2.absdiff(gray1, gray2)

    thresh = cv2.threshold(frame_delta, 25, 255, cv2.THRESH_BINARY)[1]
    thresh = cv2.dilate(thresh, None, iterations=2)
    
    ones = thresh == 255
    ones_flat = ones.flatten()

    optical_flow = np.sum(ones_flat)
    
    return optical_flow
    
def calculate_optical_flow_metric(frames):
    
    n_frames = len(frames)
    n_frames_middle = int(math.floor(n_frames/2))
    
    total_optical_flow = 0
    
    for frame1, frame2 in zip(frames[0:n_frames_middle], frames[n_frames_middle:n_frames]):
        total_optical_flow += calculate_optical_flow(frame1, frame2)
        
    return total_optical_flow
    
    
def write_frames(output_handle, frames):
    for frame in frames:
        output_handle.write(frame)

In [None]:
captured_frames = []

optical_flow_window_length = int(math.floor(final_fps / 2))
n_windows = math.ceil(total_frames/optical_flow_window_length)

for i in tqdm(range(n_windows)):
    
    left_frames = []
    right_frames = []
    
    for j in range(optical_flow_window_length):
        frame_number = i*optical_flow_window_length + j
        video_left_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        res, frame = video_left_capture.read()
        if res:
            left_frames.append(frame)
        else:
            print("Error reading frame")
    
    for j in range(optical_flow_window_length):
        frame_number = i*optical_flow_window_length + j
        video_right_capture.set(cv2.CAP_PROP_POS_FRAMES, frame_number)
        res, frame = video_right_capture.read()
        if res:
            right_frames.append(frame)
        else:
            print("Error reading frame")

    left_optical_flow = calculate_optical_flow_metric(left_frames)
    right_optical_flow = calculate_optical_flow_metric(right_frames)
    
    if left_optical_flow > right_optical_flow:
        images_processed = preprocess_images(left_frames)
    else:
        images_processed = preprocess_images(right_frames)

    write_frames(video_output, images_processed)

video_left_capture.release()
video_right_capture.release()
video_output.release()