In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import warnings
import numpy as np
import os
import cv2
import tqdm
warnings.filterwarnings("ignore")
device = 'cuda'
shape = (H,W,C) = (256,256,3)

In [7]:
from torchvision import models
raft = models.optical_flow.raft_small(weights = 'Raft_Small_Weights.C_T_V2').eval().to(device)

def get_flow(img1, img2):
    img1_t = torch.from_numpy(img1/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)
    img2_t = torch.from_numpy(img2/255.0).permute(-1,0,1).unsqueeze(0).float().to(device)
    flow = raft(img1_t,img2_t)[-1].detach().cpu().permute(0,2,3,1).squeeze(0).numpy()
    return flow

def show_flow(flow):
    hsv_mask = np.zeros(shape= flow.shape[:-1] +(3,),dtype = np.uint8)
    hsv_mask[...,1] = 255
    mag, ang = cv2.cartToPolar(flow[...,0], flow[...,1],angleInDegrees=True)
    hsv_mask[:,:,0] = ang /2 
    hsv_mask[:,:,2] = cv2.normalize(mag,None,0,255,cv2.NORM_MINMAX)
    rgb = cv2.cvtColor(hsv_mask,cv2.COLOR_HSV2RGB)
    return(rgb)

In [8]:
videos_path = 'E:/Datasets/DeepStab_Dataset/stable/'
flows_path = 'E:/Datasets/Flows/'
cv2.namedWindow('window',cv2.WINDOW_NORMAL)
videos = os.listdir(videos_path)
for video in tqdm.tqdm(videos):
    cap = cv2.VideoCapture(os.path.join(videos_path, video))
    ret,prev = cap.read()
    prev = cv2.resize(prev,(W,H))
    flows = []
    while True:
        ret,curr = cap.read()
        if not ret: break
        curr = cv2.resize(curr,(W,H))
        flow = get_flow(prev,curr)
        flows.append(flow)
        prev = curr
        cv2.imshow('window',show_flow(flow))
        if cv2.waitKey(1) & 0xFF == ord('9'):
            break
    flows = np.array(flows).astype(np.float32)
    output_path = os.path.join(flows_path,video.split('.')[0] + '.npy')
    np.save(output_path,flows)
cv2.destroyAllWindows()

  9%|▊         | 2/23 [00:25<04:22, 12.52s/it]


KeyboardInterrupt: 