In [1]:
import os
import torch
from torch.utils import data
from pytorch_ssn.model.SSN import SSN, crop_like, superpixel_flow
from PIL import Image
from pytorch_ssn.model.util import get_spixel_image
import sys
import numpy as np
from tqdm import tqdm
import cv2
import matplotlib.pyplot as plt
import pytorch_ssn.IO as IO
from pytorch_ssn.connectivity import enforce_connectivity


In [3]:


from PIL import Image
import torch.nn.functional as F
import matplotlib.pyplot as plt

from glob import glob 
from skimage.color import rgb2lab
from skimage.util import img_as_float
from skimage import io

from pytorch_ssn.dataset import Resize, ssn_preprocess

import pytorch_ssn.RAFT.core.datasets as datasets
from pytorch_ssn.RAFT.core.raft import RAFT
from pytorch_ssn.RAFT.core.utils.utils import InputPadder, forward_interpolate


os.environ['CUDA_VISIBLE_DEVICES']='0'
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(device)



cuda


In [4]:
# Helper Functions
class ARGS:
    def __init__(self):
        self.n_spixels=100
        self.num_steps=10
        self.result_dir='./results/pix/'
        self.pre_dir='./pytorch_ssn/model/slic_model/45000_0.527_model.pt'        
        self.root = 'data'
args = ARGS()


def imgtensor2np(img):
    return img.permute(1,2,0).detach().cpu().numpy()
def to_device(args, device):
    args_out = []
    for arg in args:
        if isinstance(arg, list):
            arg = [ elem.to(device) for elem in arg ]
        else:
            arg = arg.to(device)
        args_out.append(arg)
    return args_out

def segmentfromLabels(given_img, new_spix_indices, num_h,num_w, enforce_connectivity=False):
    h, w = given_img.shape[0], given_img.shape[1]
    new_spix_indices = new_spix_indices[:, :h, :w].contiguous()
    spix_index = new_spix_indices.cpu().numpy()[0]
    spix_index = spix_index.astype(int)

    if enforce_connectivity:
        segment_size = (given_img.shape[0] * given_img.shape[1]) / (int(num_h*num_w) * 1.0)
        min_size = int(0.06 * segment_size)
        max_size = int(3 * segment_size)
        spix_index = enforce_connectivity(spix_index[np.newaxis, :, :], min_size, max_size)[0]

    return  get_spixel_image(given_img, spix_index)


In [5]:
class DataTemplate(data.Dataset):
    def __init__(self, root = "/home/gokul/courses/CMSC828I/sinteldata/MPI-Sintel-complete", mode = 'training',folder = 'alley_2', shape = (128,256)):
        
        # self.imfiles = sorted(glob(root + f'/images/{mode}/*.jpg')); self.gtfiles = sorted(glob(root + f'/groundTruth/{mode}/*.mat'))
        self.imfiles = sorted(glob(root + f'/{mode}/clean/{folder}/*.png' ))
        self.flofiles = sorted(glob(root +f'/{mode}/flow/{folder}/*.flo') )
        assert len(self.imfiles)-1 == len(self.flofiles), f' {len(self.imfiles), len(self.flofiles)-1}'
        self.resize = Resize(shape)

    def __getitem__(self, i):
        # load image and GT segment
        imfile1,imfile2, flofile = self.imfiles[i],self.imfiles[i+1], self.flofiles[i]        
        im1,im2, flo = rgb2lab(img_as_float(io.imread(imfile1)) ),rgb2lab(img_as_float(io.imread(imfile2)) ), IO.read(flofile)
        im1,im2, flo = self.resize(im1, im1.shape[:2]),self.resize(im2, im2.shape[:2]), self.resize(flo, flo.shape[:2])
        h,w = im1.shape[:2]

        k = int(0.5 * (h*w)//25 )
        # print(k)
        ssn_inputs, ssn_args = ssn_preprocess(im1, None, k )
        im1 = np.transpose(im1, [2, 0, 1]).astype(np.float32)
        im2 = np.transpose(im2, [2, 0, 1]).astype(np.float32)
        flo = np.transpose(flo, [2, 0, 1]).astype(np.float32)
        return [im1, im2], flo, ssn_inputs, ssn_args

    def __len__(self):
        return len(self.flofiles)

dataset = DataTemplate()
dataloader= data.DataLoader(dataset,batch_size=1, shuffle=False, num_workers=1)


In [6]:
# predict optical flow and segment
class MODELARGS:
    def __init__(self):
        self.model = "./pytorch_ssn/model/flow_model/raft-kitti.pth"    
        self.small= False; self.mixed_precision = True; 
        self.alternate_corr=False; self.dropout = 0
modelargs = MODELARGS()

# slic layer
SSNLayer = SSN(args.pre_dir, spixel_size=(5,5),dtype = 'layer', device = device)
# flow network
net = torch.nn.DataParallel(RAFT(modelargs))
net.load_state_dict(torch.load(modelargs.model))
net = net.module.to(device)

In [None]:
args.savefolder = 'sintelsegflow'
IO.foldercheck(f'{args.savefolder}/gt')
IO.foldercheck(f'{args.savefolder}/pred')

from loss import EPE

net.eval()
epe = []
for idx, data_sample in tqdm(enumerate(dataloader)):
    [im1, im2], flow, ssn_input, ssn_params = data_sample
    im1, im2, flow = to_device([im1, im2, flow], device)

    padder = InputPadder(im1.shape, mode='sintel')
    input1, input2 = padder.pad(im1, im2)

    flow_low, flow_pr = net(input1, input2, iters=24, test_mode=True)
    # epe.append(EPE(flow_pr, flow).item())

    ssn_input = ssn_input.to(device)  
    ssn_params = to_device(ssn_params, device)
    ssn_params.extend([None])
    _, spix_indices = SSNLayer(ssn_input, ssn_params)    

    spix_indices = crop_like(spix_indices.unsqueeze(1), im1)
    segflow_GT, _ = superpixel_flow( flow.clone(), spix_indices)
    segflow_pred, _ = superpixel_flow( flow_pr.clone(), spix_indices)    
    epe.append(EPE(segflow_pred, flow).item())


    # flowrgb_GT = IO.visualize_flow( imgtensor2np(segflow_GT[0]) ) 
    # flowrgb_pred = IO.visualize_flow(imgtensor2np(flow_pr[0]))
    # cv2.imwrite(f'{args.savefolder}/gt/{str(idx).zfill(2)}.png', flowrgb_GT)
    # cv2.imwrite(f'{args.savefolder}/pred/{str(idx).zfill(2)}.png', flowrgb_pred)


In [26]:
# EPE predicted flow
np.mean(epe)

0.34889925560172724

In [30]:
# EPE segmented flow
np.mean(epe)

0.35035872307358956

In [None]:
# [TODO] Depth Estimation with Tartan Dataloader
for idx , sample in tqdm(dataloader):
    im0, im1, flow, depth, tr_gt, rot_gt, intrinsics, ssn_input, ssn_params = sample
    
    padder = InputPadder(im1.shape, mode='sintel')
    input1, input2 = padder.pad(im1, im2)

    flow_low, flow_pr = net(input1, input2, iters=24, test_mode=True)
    # epe.append(EPE(flow_pr, flow).item())

    ssn_input = ssn_input.to(device)  
    ssn_params = to_device(ssn_params, device)
    ssn_params.extend([None])
    _, spix_indices = SSNLayer(ssn_input, ssn_params)    

    net(im0, im1)