In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, models

import os 
import cv2
import warnings
import numpy as np
from IPython import display
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore")
device = 'cuda'

In [2]:
def dense_warp(image, flow):
    """
    Densely warps an image using optical flow.

    Args:
        image (torch.Tensor): Input image tensor of shape (batch_size, channels, height, width).
        flow (torch.Tensor): Optical flow tensor of shape (batch_size, 2, height, width).

    Returns:
        torch.Tensor: Warped image tensor of shape (batch_size, channels, height, width).
    """
    batch_size, channels, height, width = image.size()

    # Generate a grid of pixel coordinates based on the optical flow
    grid_y, grid_x = torch.meshgrid(torch.arange(height), torch.arange(width))
    grid = torch.stack((grid_x, grid_y), dim=-1).to(image.device)
    grid = grid.unsqueeze(0).expand(batch_size, -1, -1, -1)
    new_grid = grid + flow.permute(0, 2, 3, 1)

    # Normalize the grid coordinates between -1 and 1
    new_grid /= torch.tensor([width - 1, height - 1], dtype=torch.float32, device=image.device)
    new_grid = new_grid * 2 - 1
    # Perform the dense warp using grid_sample
    warped_image = F.grid_sample(image, new_grid, align_corners=False)

    return warped_image

In [3]:
from models import ResNet, UNet
class DIFRINT(nn.Module):
    def __init__(self):
        super(DIFRINT,self).__init__()
        self.resnet = ResNet(hidden_size=64).to(device).eval()
        self.unet = UNet(hidden_size=64).to(device).eval()
        self.raft = models.optical_flow.raft_small(weights = 'Raft_Small_Weights.C_T_V2').eval().to(device)
    
    def forward(self, ft_minus, ft, ft_plus):
        flo1 = 0.5 * self.raft(ft_minus, ft_plus)[-1]
        flo2 = 0.5 * self.raft(ft_plus, ft_minus)[-1]
        warped1 = dense_warp(ft_minus,0.5 * flo1)
        warped2 = dense_warp(ft_plus,0.5 * flo2)
        fint = self.unet(warped1, warped2, flo1, flo2, ft_minus, ft_plus)
        flo3 = self.raft(ft, fint)[-1]
        warped3 = dense_warp(ft,flo3)
        fout = self.resnet(fint, warped3)
        return fint, fout

difrint = DIFRINT().eval().to(device)
total_params = sum(p.numel() for p in difrint.parameters())
print("Total number of parameters in DIFRINT model: {}".format(total_params))


Total number of parameters in DIFRINT model: 3216808


In [4]:
import os
import torch
# Load UNet checkpoints
unet_path = './ckpts/unet/'
ckpts = os.listdir(unet_path)
if ckpts:
    ckpts = sorted(ckpts, key=lambda x: int(x.split('.')[0].split('_')[1]))
    latest = ckpts[-1]
    #latest = 'unet_171.pth'
    state_dict = torch.load(os.path.join(unet_path, latest))
    difrint.unet.load_state_dict(state_dict['model'])
    print(f'Loaded UNet:{latest} ')
# Load ResNet checkpoints
resnet_path = './ckpts/resnet/'
ckpts = os.listdir(resnet_path)
if ckpts:
    ckpts = sorted(ckpts, key=lambda x: int(x.split('.')[0].split('_')[1]))
    latest = ckpts[-1]
    state_dict = torch.load(os.path.join(resnet_path, latest))
    difrint.resnet.load_state_dict(state_dict['model'])
    print(f'Loaded ResNet:{latest}')

Loaded UNet:unet_2.pth 
Loaded ResNet:resnet_188.pth


In [5]:
video_path = 'E:/Datasets/DeepStab_Dataset/unstable/2.avi'
cap = cv2.VideoCapture(video_path)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
hh,ww = 360,640
frames = np.zeros((frame_count,hh,ww,3),np.float32)
for i in range(frame_count):
    ret,img = cap.read()
    if not ret:
        break
    img = cv2.resize(img,(ww,hh))
    img = ((img / 255.0) * 2) - 1 
    frames[i,...] = img

In [6]:
SKIP = 1
ITER = 5
interpolated = frames.copy()
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
for iter in range(ITER):
    print(iter)
    temp = interpolated.copy()
    for frame_idx in range(SKIP,frame_count - SKIP):
        torch.cuda.empty_cache()
        ft_minus = torch.from_numpy(interpolated[frame_idx - SKIP,...]).permute(2,0,1).unsqueeze(0).to(device)
        ft = torch.from_numpy(frames[frame_idx]).permute(2,0,1).unsqueeze(0).to(device)
        ft_plus = torch.from_numpy(interpolated[frame_idx + SKIP,...]).permute(2,0,1).unsqueeze(0).to(device)
        with torch.no_grad(): 
            fint,fout = difrint(ft_minus,ft,ft_plus)
        temp[frame_idx,...] = fout.cpu().squeeze(0).permute(1,2,0).numpy()
        img  = (((fout.cpu().squeeze(0).permute(1,2,0).numpy() + 1) / 2)*255.0).astype(np.uint8)
        cv2.imshow('window',img)
        if cv2.waitKey(1) & 0xFF == ord('9'):
            break
    interpolated = temp.copy()
cv2.destroyAllWindows()

0


KeyboardInterrupt: 

In [None]:
from time import sleep
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
frame_count, h, w, c = interpolated.shape
out_path = f'./2_vimeo.avi'
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(out_path, fourcc, 30.0, (w, h))
for idx in range(frame_count):
    img  = interpolated[idx,...].copy()
    img = (((img + 1) /2 ) * 255).astype(np.uint8)
    out.write(img)
    cv2.imshow('window',img)
    #sleep(1/30)
    if cv2.waitKey(1) & 0xFF == ord('9'):
        break
cv2.destroyAllWindows()
out.release()

In [None]:
name = './results/comparison2.avi'
frame_count, h, w, c = interpolated.shape
out_path = name
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(out_path, fourcc, 30.0, (2*w,h))
for idx in range(frame_count):
    img1  = interpolated[idx,...].copy()
    img2 = frames[idx,...].copy()
    img1 = (((img1 + 1) /2 ) * 255).astype(np.uint8)
    img2 = (((img2 + 1) /2 ) * 255).astype(np.uint8)
    conc = cv2.hconcat([img1,img2])
    out.write(conc)
    cv2.imshow('window',conc)
    sleep(1/30)
    if cv2.waitKey(1) & 0xFF == ord('9'):
        break
cv2.destroyAllWindows()
out.release()

In [None]:
from time import sleep
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
for idx in range(frame_count):
    img  = interpolated[idx,...].copy()
    img = (((img + 1) /2 ) * 255).astype(np.uint8)
    img1 = frames[idx,...].copy()
    img1 = (((img1 + 1) /2 ) * 255).astype(np.uint8)
    diff = cv2.absdiff(img,img1)
    cv2.imshow('window',diff)
    sleep(1/60)
    if cv2.waitKey(1) & 0xFF == ord('9'):
        break
cv2.destroyAllWindows()