In [None]:
import os
import math
import numpy as np
import ffmpeg
import cv2
import imutils
from imutils.video import count_frames
import matplotlib.pyplot as plt
from IPython.display import Audio

In [None]:
ROOT_DIR = os.path.dirname(os.getcwd())
DATA_FOLDER = os.path.join(ROOT_DIR, "data")

In [None]:
video_keparoicam_L = os.path.join(DATA_FOLDER, "keparoicam_clipL_synchronized.mp4")
video_keparoicam_R = os.path.join(DATA_FOLDER, "keparoicam_clipR_synchronized.mp4")

In [None]:
video_left_capture = cv2.VideoCapture(video_keparoicam_L)
video_right_capture = cv2.VideoCapture(video_keparoicam_R)

In [None]:
print(cv2.cuda.getCudaEnabledDeviceCount())

## Calculate intermediate objects for panorama stitching

In [None]:
from stitching import Stitcher
from stitching.images import Images

class VideoStitcher(Stitcher):

    def initialize_stitcher(self, **kwargs):
        super().initialize_stitcher(**kwargs)
        self.cameras = None
        self.cameras_registered = False
        self.corners = None
        self.seam_masks = None
        
    def stitch(self, images, feature_masks=[]):
        
        t0 = time.perf_counter()
        
        self.images = Images.of(
            images, self.medium_megapix, self.low_megapix, self.final_megapix
        )
        
        start = time.perf_counter()

        print(f"Init images took: {start - t0} seconds")

        if not self.cameras_registered:
            imgs = self.resize_medium_resolution()
            features = self.find_features(imgs, feature_masks)
            matches = self.match_features(features)
            imgs, features, matches = self.subset(imgs, features, matches)
            cameras = self.estimate_camera_parameters(features, matches)
            cameras = self.refine_camera_parameters(features, matches, cameras)
            cameras = self.perform_wave_correction(cameras)
            self.estimate_scale(cameras)
            self.cameras: cv2.detail.CameraParams = cameras
            self.cameras_registered = True
            for camera in self.cameras:
                print("Camera")
                print(camera.K())
                print(camera.aspect)
                print(camera.focal)
                print(camera.R)
                print(camera.t)
                print(camera.ppx)
                print(camera.ppx)
                print(camera.)
                

        t1 = time.perf_counter()
        print(f"Init took: {t1 - start} seconds")
            
        imgs = self.resize_low_resolution()
        
        t2 = time.perf_counter()
        print(f"Resize took: {t2 - t1} seconds")
        imgs, masks, corners, sizes = self.warp_low_resolution(imgs, self.cameras)
        
        t3 = time.perf_counter()
        print(f"Warp took: {t3 - t2} seconds")
        
        self.prepare_cropper(imgs, masks, corners, sizes)
        
        t4 = time.perf_counter()
        print(f"Prepare cropper took: {t4 - t3} seconds")
        
        imgs, masks, corners, sizes = self.crop_low_resolution(
            imgs, masks, corners, sizes
        )
        
        t5 = time.perf_counter()
        print(f"Cropping took: {t5 - t4} seconds")
        
        self.estimate_exposure_errors(corners, imgs, masks)
        
        t6 = time.perf_counter()
        print(f"Cropping took: {t6 - t5} seconds")
        
        if self.seam_masks is None:
            seam_masks = self.find_seam_masks(imgs, corners, masks)

        t7 = time.perf_counter()
        print(f"Seam masks: {t7 - t6} seconds")
        
        imgs = self.resize_final_resolution()
        
        t8 = time.perf_counter()
        print(f"Resize final: {t8 - t7} seconds")
        
        imgs, masks, corners, sizes = self.warp_final_resolution(imgs, self.cameras)
        
        for img in imgs:
            print("Images after warping")
            plt.imshow(img)
            plt.show()
        
        t9 = time.perf_counter()
        print(f"Warp final: {t9 - t8} seconds")
        
        imgs, masks, corners, sizes = self.crop_final_resolution(
            imgs, masks, corners, sizes
        )
        
        t10 = time.perf_counter()
        print(f"Cropping final took: {t10 - t9} seconds")
        
        self.set_masks(masks)
        
        t11 = time.perf_counter()
        print(f"Setting masks took: {t11 - t10} seconds")
        
        imgs = self.compensate_exposure_errors(corners, imgs)
        
        t12 = time.perf_counter()
        print(f"Exposure compensation took: {t12 - t11} seconds")
        
        if self.seam_masks is None:
            seam_masks = self.resize_seam_masks(seam_masks)
            self.seam_masks = seam_masks
        
        t13 = time.perf_counter()
        print(f"Resizing seam masks took: {t13 - t12} seconds")

        self.initialize_composition(corners, sizes)
        
        t14 = time.perf_counter()
        print(f"Init composaition took: {t14 - t13} seconds")
        
        self.blend_images(imgs, self.seam_masks, corners)
        
        t15 = time.perf_counter()
        print(f"Bleding images took: {t15 - t14} seconds")
        
        panorama = self.create_final_panorama()
        
        t16 = time.perf_counter()
        print(f"Creating final pano took: {t16 - t15} seconds")
        
        print(f"Total time: {t16 - t0} seconds")
        
        return panorama

In [None]:
import time
from tqdm.notebook import tqdm

stitcher = VideoStitcher()
stitcher.initialize_stitcher(
    blend_strength=20, 
    try_use_gpu=True, 
    warper_type="mercator", 
    medium_megapix=0.3, 
    final_megapix=0.7
)
#stitcher.initialize_stitcher(
#    try_use_gpu=True, medium_megapix=0.4, final_megapix=0.8, 
#    detector="sift", nfeatures=400, warper_type="mercator", confidence_threshold=1,
#    blender_type="no"
#)

for i in tqdm(range(1000)):
    ret_left, left_frame = video_left_capture.read()
    ret_right, right_frame = video_right_capture.read()
    
    start = time.perf_counter()
    panorama = stitcher.stitch(images = [left_frame, right_frame])
    end = time.perf_counter()
    
    print(f"Panorama stitching took: {end - start} seconds, image size {panorama.shape}")
    
    plt.imshow(panorama)
    plt.show()