In [2]:
from skimage.color import rgb2lab
from skimage import io
from glob import glob

import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch, os, cv2
from torch.utils.data import Dataset, DataLoader
import torch.backends.cudnn as cudnn


from loss import EPE
from pytorch_ssn.RAFT.core.utils.utils import InputPadder, forward_interpolate
from pytorch_ssn.RAFT.core.raft import RAFT
from pytorch_ssn.model.SSN import SSN, crop_like, superpixel_flow, superpixel_seg
from pytorch_ssn.dataset import Resize, ssn_preprocess
from pytorch_ssn.connectivity import enforce_connectivity
from pytorch_ssn.model.util import get_spixel_image
import pytorch_ssn.IO as IO
import flownet.flow_transforms as flow_transforms
import torch.nn.functional as F

import torchvision.transforms as transforms
from flownet.models.FlowNetS import flownets

# get_spixel_image
os.environ['CUDA_VISIBLE_DEVICES']='0'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)


# [SINTEL TEST]
class Sintel(Dataset):
    def __init__(self, root = "./data/Sintel", folder = 'ambush_2', shape = (128,256)):
        
        # self.imfiles = sorted(glob(root + f'/images/{mode}/*.jpg')); self.gtfiles = sorted(glob(root + f'/groundTruth/{mode}/*.mat'))
        self.imfiles = sorted(glob(root + f'/clean/{folder}/*.png' ))
        self.flofiles = sorted(glob(root +f'/flow/{folder}/*.flo') )
        assert len(self.imfiles)-1 == len(self.flofiles), f' {len(self.imfiles), len(self.flofiles)-1}'
        self.resize = Resize(shape)

    def __getitem__(self, i):
        # load image and GT segment
        imfile1,imfile2, flofile = self.imfiles[i],self.imfiles[i+1], self.flofiles[i]        
        im1,im2, flo = io.imread(imfile1) , io.imread(imfile2) , IO.read(flofile)
        # im1,im2, flo = img_as_float(io.imread(imfile1)), img_as_float(io.imread(imfile2)), IO.read(flofile)
        im1,im2, flo = self.resize(im1, im1.shape[:2]),self.resize(im2, im2.shape[:2]), self.resize(flo, flo.shape[:2])

        h,w = im1.shape[:2]
        k = int(0.5 * (h*w)//25 )
        ssn_inputs, ssn_args = ssn_preprocess(rgb2lab(im1), None, k )
        im1 = np.transpose(im1, [2, 0, 1]).astype(np.float32)
        im2 = np.transpose(im2, [2, 0, 1]).astype(np.float32)
        flo = np.transpose(flo, [2, 0, 1]).astype(np.float32)
        return [im1, im2], flo, ssn_inputs, ssn_args

    def __len__(self):
        return len(self.flofiles)


def imgtensor2np(img):
    return img.permute(1,2,0).detach().cpu().numpy()
def to_device(args, device):
    args_out = []
    for arg in args:
        if isinstance(arg, list):
            arg = [ elem.to(device) for elem in arg ]
        else:
            arg = arg.to(device)
        args_out.append(arg)
    return args_out


def connect_segments(new_spix_indices, num_h, num_w, h, w):
    new_spix_indices = new_spix_indices[0]
    new_spix_indices = new_spix_indices[:, :h, :w].contiguous()
    spix_index = new_spix_indices.cpu().numpy()[0]
    spix_index = spix_index.astype(int)

    segment_size = (h * w) / (int(num_h*num_w) * 1.0)
    min_size = int(0.06 * segment_size)
    max_size = int(3 * segment_size)
    spix_index = enforce_connectivity(spix_index[np.newaxis, :, :], min_size, max_size)[0]
    spix_index = torch.tensor(spix_index).unsqueeze(0).unsqueeze(0).to(device)
    return spix_index

def segmentfromLabels(given_img, new_spix_indices, num_h,num_w, connect=False):
    h, w = given_img.shape[0], given_img.shape[1]
    new_spix_indices = new_spix_indices[:, :h, :w].contiguous()
    spix_index = new_spix_indices.cpu().numpy()[0]
    spix_index = spix_index.astype(int)

    if connect:
        segment_size = (given_img.shape[0] * given_img.shape[1]) / (int(num_h*num_w) * 1.0)
        min_size = int(0.06 * segment_size)
        max_size = int(3 * segment_size)
        spix_index = enforce_connectivity(spix_index[np.newaxis, :, :], min_size, max_size)[0]

    return  get_spixel_image(given_img, spix_index)

input_transform_FS = transforms.Compose([
    transforms.Normalize(mean=[0,0,0], std=[255,255,255]),
    transforms.Normalize(mean=[0.411,0.432,0.45], std=[1,1,1])
])
def flowNetS(flownet, im1, im2, transform, div_flow = 20):
    global device    
    image1, image2 = transform(im1[0]), transform(im2[0])
    input_var = torch.cat([image1, image2]).unsqueeze(0).to(device)
    flow_out = flownet(input_var)
    flow_out = F.interpolate(flow_out, size=image1.size()[-2:], mode='bilinear', align_corners=False) * div_flow
    return flow_out

def raftFlow(im1, im2, net):
    padder = InputPadder(im1.shape, mode='sintel')
    input1, input2 = padder.pad(im1, im2)
    _, flow_pr = net(input1, input2, iters=24, test_mode=True)
    return flow_pr
def superpixel_flow( flow, spix_indices):
    B, _, H, W  = spix_indices.size()
    spix_indices = spix_indices.reshape(B,1, -1)
    flow = flow.reshape(B,2, -1)
    
    for b in range(flow.size(0)):
        for Ci in range(len(torch.unique(spix_indices))):
            Ci_ROI = spix_indices == Ci
            flowCi_patchx, flowCi_patchy = flow[b, :1][Ci_ROI[b]], flow[b, 1:][Ci_ROI[b]] 
            meanx, meany = torch.mean(flowCi_patchx), torch.mean(flowCi_patchy)
            flow[b, :1][Ci_ROI[b]] = meanx
            flow[b, 1:][Ci_ROI[b]] = meany

    segmentedflow = flow.reshape(B, 2, H, W)
    return segmentedflow, spix_indices

alt cuda corr not found




cuda


In [3]:
class ARGS:
    def __init__(self):
        self.n_spixels=100
        self.num_steps=10
        self.pre_dir='./pytorch_ssn/model/slic_model/45000_0.527_model.pt'        
        self.root = 'data'
        self.flownet_dir = "./checkpoints/flownets_EPE1.951.pth.tar"
        self.raft_dir = './checkpoints/7_tartan.pth'
        self.mixed_precision = True; self.alternate_corr=False; self.dropout = 0.0
        self.small_ = False
        self.savefolder =  './checkpoints/sintelsegflow'

args = ARGS()


IO.foldercheck(f'{args.savefolder}/')

### raft flow network -  NA 
# net = torch.nn.DataParallel(RAFT(args))
# net.load_state_dict(torch.load(args.raft_dir))
# net = net.module.to(device)
# print("=> using RAFT pre-trained model '{}'".format(args.raft_dir))
# print("Raft Parameter Count: %d" % net.count_parameters())

# flow_pr = raftFlow(im1, im2, net)
# segflow_pred, _ = superpixel_flow( flow_pr.clone(), spix_indices)    
# epe.append(EPE(flow_pr, flow).item())
# segepe.append(EPE(segflow_pred, flow).item())
###

# ssn layer
SSNLayer = SSN(args.pre_dir, spixel_size=(5,5),dtype = 'layer', device = device)
print("=> using ssn pre-trained model '{}'".format(args.pre_dir))
print("SSN Parameter Count: %d" % SSNLayer.module.count_parameters())

# load flownet model
network_data = torch.load(args.flownet_dir)
print("=> using flownet pre-trained model '{}'".format(args.flownet_dir))
flownet = flownets(network_data).to(device)
print("Flownet Parameter Count: %d" % flownet.count_parameters())

cudnn.benchmark = True


=> using ssn pre-trained model './pytorch_ssn/model/slic_model/45000_0.527_model.pt'
SSN Parameter Count: 214962
=> using flownet pre-trained model './checkpoints/flownets_EPE1.951.pth.tar'
Flownet Parameter Count: 38675536


In [None]:
for folder in ['ambush_2', 'alley_1']:
    dataset = Sintel(folder=folder)
    dataloader= DataLoader(dataset,batch_size=1, shuffle=False, num_workers=1)
    IO.foldercheck(f'{args.savefolder}/{folder}/')
    with torch.no_grad():
        flownet.eval(); SSNLayer.eval()
        epe,epeS, segepe, segepeS = [],[], [], []
        for idx, data_sample in tqdm(enumerate(dataloader)):
            [im1, im2], flow, ssn_input, ssn_params = data_sample
            im1, im2, flow = to_device([im1, im2, flow], device)
            
            flow_S = flowNetS(flownet, im1.clone(), im2.clone(), input_transform_FS, div_flow = 20)
            # epe.append(EPE(flow_pr, flow).item())

            ssn_input = ssn_input.to(device)  
            ssn_params = to_device(ssn_params, device)
            ssn_params.extend([None])
            _, spix_indices = SSNLayer(ssn_input, ssn_params)    

            spix_indices = crop_like(spix_indices.unsqueeze(1), im1)
            segflow_GT, _ = superpixel_flow( flow.clone(), spix_indices)
            segflowS_pred,_ = superpixel_flow( flow_S.clone(), spix_indices)

            epeS.append(EPE(flow_S, flow).item())
            segepeS.append(EPE(segflowS_pred, flow).item())




            flowrgb_segGT = IO.visualize_flow( imgtensor2np(segflow_GT[0]) ) 
            flowrgb_segpredS = IO.visualize_flow(imgtensor2np(segflowS_pred[0]))
            flowrgb_GT = IO.visualize_flow( imgtensor2np(flow[0]) ) 
            flowrgb_predS = IO.visualize_flow(imgtensor2np(flow_S[0]))

            meanepeS = np.mean(epeS)
            meansegepeS = np.mean(segepeS)    

            # WE WANTED TO IMPROVE FLOWNETS's EPE SLIGHTLY USING  
            f, plts =plt.subplots(1,4,figsize=(27, 5))
            plts[0].imshow(flowrgb_GT)
            plts[0].set_title('Groundtruth')
            plts[1].axis('off')
            plts[1].imshow(flowrgb_segpredS)
            plts[1].set_title(f'Segmented  Flow (EPE:{meansegepeS :.2f})')
            plts[1].axis('off')
            plts[2].imshow(flowrgb_predS)
            plts[2].set_title(f'Predicted Flow (EPE:{meanepeS :.2f})')
            plts[2].axis('off')
            plts[3].imshow(imgtensor2np(im1[0]).astype(np.uint8) )
            plts[3].set_title('Image')
            plts[3].axis('off')
            f.tight_layout()
            f.suptitle(f" \n FlowNetS Optical Flow")

            plt.savefig(f'{args.savefolder}/{folder}/{str(idx).zfill(5)}.png')
#refer the savefolder for outputs

In [7]:
for folder in ['ambush_2', 'alley_1']:
    files=sorted(glob(f'{args.savefolder}/{folder}/*png'))
    print(len(files))
    for i, file in tqdm(enumerate(files)):
        img = cv2.imread(file)
        if(i==0):
            IO.foldercheck(f'{args.savefolder}/flow_video')
            h,w = img.shape[:2]    
            writer = cv2.VideoWriter(f'{args.savefolder}/flow_video/{folder}.mp4', cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), 10, (w, h))
        # out=cv2.cvtColor(out, cv2.COLOR_RGB2BGR)
        writer.write(img)
    writer.release()
print("Done stitching")

7it [00:00, 64.13it/s]

20
./checkpoints/sintelsegflow/flow_video   was not present, creating the folder...


20it [00:00, 67.05it/s]
7it [00:00, 61.95it/s]

49


49it [00:00, 63.48it/s]

Done stitching





# We infer that Enforcing Smoothness Constraint mildly helps improve the average End Poiint Error in the optical flow estimates