In [1]:
import cv2
import os
import numpy as np
import time
import torch
from matplotlib import colors
from matplotlib import pyplot as plt
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm

In [2]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
import sys
sys.path.append('/content/drive/MyDrive/siamese-registration')

In [4]:
from models import *
from utils import get_transformation_matrix

In [19]:
checkpoint = "01/model-9.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = siamese_resnet18(1, 7)
model.load_state_dict(torch.load(os.path.join("/content/drive/MyDrive/outputs/models", checkpoint), map_location=device))
model.cuda()
model.eval()

print(f"Running on {device}")

Running on cuda


In [17]:
cap = cv2.VideoCapture("/content/drive/MyDrive/data/Study_02_00008_02_R.avi")
# Find the number of frames
video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - 1
count = 0
# Start converting the video
image_list = []
image_list_warped = []
time_start = time.time()
with tqdm(total=video_length, unit=' frames') as pbar:
    while cap.isOpened():
        # Extract the frame

        ret, frame = cap.read()
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        if not ret:
            continue
        if len(image_list) == 0:
            image_list.append(frame)
            image_list_warped.append(frame)
        else:
            frame0 = image_list_warped[count-1]
            frame1 = frame

            img0 = torch.unsqueeze(transforms.functional.to_tensor(frame0), 1).to(device)
            img1 = torch.unsqueeze(transforms.functional.to_tensor(frame1), 1).to(device)

            start = time.perf_counter()
            outputs = model(img0, img1)
            end = time.perf_counter()
            #print(f"{count}, time: {end - start}")

            rows, cols = frame0.shape
            center = (cols//2, rows//2)
            tx, ty, sx, sy, shx, shy, q = outputs.detach().cpu().numpy().reshape(-1).tolist()
            matrix = get_transformation_matrix(center, tx, ty, sx, sy, shx, shy, q)
            matrix_opencv = np.float32(matrix.flatten()[:6].reshape(2, 3))
            inverse_matrix = cv2.invertAffineTransform(matrix_opencv)

            frame1_warped = cv2.warpAffine(frame1, inverse_matrix, (cols, rows))

            image_list.append(frame)
            image_list_warped.append(frame1_warped)

        
        count = count + 1
        pbar.update(1)
        # If there are no more frames left
        if (count > (video_length-1)):
            # Log the time again
            time_end = time.time()
            # Release the feed
            cap.release()
            # Print stats
            break

Number of frames:  253


100%|██████████| 253/253 [00:14<00:00, 17.43 frames/s]


In [14]:
img = np.concatenate((image_list[0], image_list_warped[0]), axis = 1)

# Define the codec and create VideoWriter object.The output is stored in 'outpy.avi' file.
out = cv2.VideoWriter('/content/drive/MyDrive/outputs/video/Study_02_00008_02_R.mp4', cv2.VideoWriter_fourcc(*'XVID'), 15, (img.shape[1], img.shape[0]), 0)

for image, image_warped in zip(image_list, image_list_warped):
    img = np.concatenate((image, image_warped), axis = 1)
    I = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
    out.write(I)

cv2.destroyAllWindows()
out.release()

In [9]:
img.shape

(770, 2000)