In [2]:
import sys
sys.path.append(f'E{os.getcwd()[1:]}\\RAFT')
import argparse
import os
import cv2 as cv
import glob
import numpy as np
import torch
from PIL import Image

from core.raft import RAFT
from core.utils import flow_viz
from core.utils.utils import InputPadder

DEVICE = 'cuda'

In [3]:
def load_image(imfile):
    img = np.array(Image.open(imfile)).astype(np.uint8)
    img = torch.from_numpy(img).permute(2, 0, 1).float()
    return img[None].to(DEVICE)


def viz(img, flo, outpath, i):
    img = img[0].permute(1,2,0).cpu().numpy()
    flo = flo[0].permute(1,2,0).cpu().numpy()
    
    # map flow to rgb image
    flo = flow_viz.flow_to_image(flo)
    img_flo = np.concatenate([img, flo], axis=0)

    # import matplotlib.pyplot as plt
    # plt.imshow(img_flo / 255.0)
    # plt.show()

    cv.imwrite(f"{args.outpath}/frame_{i:04d}.png", flo[:, :, [2,1,0]])
    #cv.imshow('image', flo[:, :, [2,1,0]]/255.0)
    #cv.waitKey()


def main(args):
    model = torch.nn.DataParallel(RAFT(args))
    model.load_state_dict(torch.load(args.model))

    model = model.module
    model.to(DEVICE)
    model.eval()
    i = 0
    with torch.no_grad():
        images = glob.glob(os.path.join(args.path, '*.png')) + \
                 glob.glob(os.path.join(args.path, '*.jpg'))
        
        images = sorted(images)
        for imfile1, imfile2 in zip(images[:-1], images[1:]):
            print(imfile1, " ", imfile2)
            image1 = load_image(imfile1)
            image2 = load_image(imfile2)

            padder = InputPadder(image1.shape)
            image1, image2 = padder.pad(image1, image2)

            flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
            #cv.imwrite(f"{args.outpath}/frame_{i:04d}.png", flow_up[0].permute(1,2,0).cpu().numpy()[:, :, [2,1,0]]/255.0)
            #cv.imwrite(f"{args.outpath}/frame_{i:04d}.png", flow_up[:, :, [2,1,0]]/255.0)
            viz(image1, flow_up, args.outpath, i)
            i+=1


In [4]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [53]:
if not os.path.exists(f"input_transformed/RAFT_flow/"):
    os.mkdir(f"input_transformed/RAFT_flow/")

model = "RAFT/models/raft-things.pth"
alternate_corr=False
mixed_precision=False
small = False

input_folder = glob.glob('input_transformed/vid/*')

for input_path in input_folder:
    input_path = input_path.replace("\\","/")
    output_path = f"input_transformed/RAFT_flow/{input_path[22:]}/"
    args = {'alternate_corr':alternate_corr,'mixed_precision':mixed_precision,'model':model,'path':input_path,'small':small, 'outpath':output_path}
    args = dotdict(args)
    
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    main(args)

ed/vid/pushup065\frame_0033.png
input_transformed/vid/pushup065\frame_0033.png   input_transformed/vid/pushup065\frame_0034.png
input_transformed/vid/pushup065\frame_0034.png   input_transformed/vid/pushup065\frame_0035.png
input_transformed/vid/pushup065\frame_0035.png   input_transformed/vid/pushup065\frame_0036.png
input_transformed/vid/pushup065\frame_0036.png   input_transformed/vid/pushup065\frame_0037.png
input_transformed/vid/pushup065\frame_0037.png   input_transformed/vid/pushup065\frame_0038.png
input_transformed/vid/pushup065\frame_0038.png   input_transformed/vid/pushup065\frame_0039.png
input_transformed/vid/pushup065\frame_0039.png   input_transformed/vid/pushup065\frame_0040.png
input_transformed/vid/pushup065\frame_0040.png   input_transformed/vid/pushup065\frame_0041.png
input_transformed/vid/pushup065\frame_0041.png   input_transformed/vid/pushup065\frame_0042.png
input_transformed/vid/pushup065\frame_0042.png   input_transformed/vid/pushup065\frame_0043.png
input_tr

In [5]:
if not os.path.exists(f"validation_transformed/RAFT_flow/"):
    os.mkdir(f"validation_transformed/RAFT_flow/")

model = "RAFT/models/raft-things.pth"
alternate_corr=False
mixed_precision=False
small = False

input_folder = glob.glob('validation_transformed/vid/*')

for input_path in input_folder:
    input_path = input_path.replace("\\","/")
    output_path = f"validation_transformed/RAFT_flow/{input_path[27:]}/"
    args = {'alternate_corr':alternate_corr,'mixed_precision':mixed_precision,'model':model,'path':input_path,'small':small, 'outpath':output_path}
    args = dotdict(args)
    
    if not os.path.exists(output_path):
        os.mkdir(output_path)
    main(args)

02\frame_0620.png   validation_transformed/vid/pushup002\frame_0621.png
validation_transformed/vid/pushup002\frame_0621.png   validation_transformed/vid/pushup002\frame_0622.png
validation_transformed/vid/pushup002\frame_0622.png   validation_transformed/vid/pushup002\frame_0623.png
validation_transformed/vid/pushup002\frame_0623.png   validation_transformed/vid/pushup002\frame_0624.png
validation_transformed/vid/pushup002\frame_0624.png   validation_transformed/vid/pushup002\frame_0625.png
validation_transformed/vid/pushup002\frame_0625.png   validation_transformed/vid/pushup002\frame_0626.png
validation_transformed/vid/pushup002\frame_0626.png   validation_transformed/vid/pushup002\frame_0627.png
validation_transformed/vid/pushup002\frame_0627.png   validation_transformed/vid/pushup002\frame_0628.png
validation_transformed/vid/pushup002\frame_0628.png   validation_transformed/vid/pushup002\frame_0629.png
validation_transformed/vid/pushup002\frame_0629.png   validation_transformed/vid