In [1]:
import os
from pathlib import Path
import glob

import torch 
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF

import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import random

from file_io import read, write
import h5py

In [2]:
class dataset(Dataset):
    def __init__(self, 
                 root,
                 input_tfm_function,
                 gt_tfm_function,
                 input_folder_1 = "image_2", 
                 input_folder_2 = "image_2", 
                 gt_folder = "flow_occ",
                 input_search_param_1="*_10.png", 
                 input_search_param_2 = "*_11.png", 
                 gt_search_param = "*_10.png"):
        
        self.root = Path(root)
        self.input1_dir = self.root/input_folder_1
        self.input2_dir = self.root/input_folder_2
        self.gt_dir = self.root/gt_folder
        
        self.input_1_names = sorted([os.path.basename(x) for x in self.input1_dir.glob(input_search_param_1)])
        self.input_2_names = sorted([os.path.basename(x) for x in self.input2_dir.glob(input_search_param_2)])
        self.gt_names = sorted([os.path.basename(x) for x in self.gt_dir.glob(gt_search_param)])
        self.input_tfm_function = input_tfm_function
        self.gt_tfm_function = gt_tfm_function
        
#         for a,b,c in zip(self.input_1_names,self.input_2_names,self.gt_names):
#             print(a,b,c)
        
        
    def __len__(self):
        return len(self.input_1_names)
    
    def __getitem__(self, idx):
#         print(idx)
        print(self.input_1_names[idx],self.input_2_names[idx],self.gt_names[idx])
        input1 = read(str(self.input1_dir/self.input_1_names[idx]))
        input2 = read(str(self.input2_dir/self.input_2_names[idx]))
        gt = read(str(self.gt_dir/self.gt_names[idx]))
        
        input1 = self.input_tfm_function(input1)
        input2 = self.input_tfm_function(input2)
        gt = self.gt_tfm_function(gt)
        
        return input1, input2, gt

In [3]:
# set the transform
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.ToPILImage(),
        transforms.Resize((512,512), interpolation=2),
        transforms.ToTensor()
        ])

In [4]:
# train_loader_kitti = dataset("../Data/KittiDataset/training",
#                        transform,transform,
#                        input_folder_1 = "image_2", 
#                        input_folder_2 = "image_2", 
#                        gt_folder = "flow_occ",
#                        input_search_param_1="*_10.png", 
#                        input_search_param_2 = "*_11.png", 
#                        gt_search_param = "*_10.png")

In [5]:
train_loader_FlyingChairs = dataset("../Data/ChairsSDHom_extended",
                       transform,transform,
                       input_folder_1 = "train", 
                       input_folder_2 = "train", 
                       gt_folder = "train",
                       input_search_param_1="*-img_0.png", 
                       input_search_param_2 = "*-img_1.png", 
                       gt_search_param = "*-flow_01.flo")

In [6]:
# dataset_kitti = torch.utils.data.DataLoader(train_loader_kitti,10,shuffle=True,num_workers=2)

In [7]:
dataset_flyingChairs = torch.utils.data.DataLoader(train_loader_FlyingChairs,10,shuffle=True,num_workers=1)

In [8]:
def write_H5(dataset, path):
    index = 0
    for inputs1,inputs2,target in dataset:
        print(index)
        with h5py.File(path + str(index), 'w') as f:
            in1 = f.create_dataset('img1', dtype=np.float32,data=inputs1,compression="gzip", compression_opts=9)
            in2 = f.create_dataset('img2', dtype=np.float32,data=inputs2,compression="gzip", compression_opts=9)    
            targ = f.create_dataset('target', dtype=np.float32,data=target,compression="gzip", compression_opts=9)
            index+=1

In [9]:
def read_H5(index, path):
    with h5py.File(path + str(index), 'r') as f:
        img1 = f["img1"][()]
        img2 = f["img2"][()]
        trg = f["target"][()]
        img1_file = torch.from_numpy(img1)
        img2_file = torch.from_numpy(img2)
        target_file = torch.from_numpy(trg)
        return img1_file,img2_file,target_file

In [10]:
def print_batch(im1,im2,target):
    for image_1, image_2, target_image in zip(im1,im2,target):
        im1_print = image_1.numpy().transpose((1,2,0))
        im2_print = image_2.numpy().transpose((1,2,0))
        
        print(target_image.shape)

        z = np.zeros((target_image.shape[1],target_image.shape[2]))
        target_print = np.stack((target_image[0],target_image[1], z), axis=0)
        target_print = target_print.transpose((1,2,0))
        print(target_print.shape)

        fig, ax = plt.subplots(1, 3, figsize=(15, 20))
        for a in ax:
          a.set_axis_off()

        ax[0].imshow(im1_print)
        ax[1].imshow(im2_print)
        ax[2].imshow(target_print)

        plt.show()

In [11]:
# write_H5(dataset_kitti, path = '../Data/H5/512by512Kitti/mini_batch_')

In [12]:
write_H5(dataset_flyingChairs, path = '../Data/H5/512by512FlyingChairs/mini_batch_')

0018000-img_0.png 0018000-img_1.png 0018000-flow_01.flo
0013460-img_0.png 0013460-img_1.png 0013460-flow_01.flo
0019124-img_0.png 0019124-img_1.png 0019124-flow_01.flo
0010409-img_0.png 0010409-img_1.png 0010409-flow_01.flo
0005616-img_0.png 0005616-img_1.png 0005616-flow_01.flo
0000306-img_0.png 0000306-img_1.png 0000306-flow_01.flo
0019663-img_0.png 0019663-img_1.png 0019663-flow_01.flo
0000807-img_0.png 0000807-img_1.png 0000807-flow_01.flo
0010518-img_0.png 0010518-img_1.png 0010518-flow_01.flo
0005704-img_0.png 0005704-img_1.png 0005704-flow_01.flo
0019697-img_0.png 0019697-img_1.png 0019697-flow_01.flo
0015559-img_0.png 0015559-img_1.png 0015559-flow_01.flo
0004398-img_0.png 0004398-img_1.png 0004398-flow_01.flo
0007197-img_0.png 0007197-img_1.png 0007197-flow_01.flo
0
0007395-img_0.png 0007395-img_1.png 0007395-flow_01.flo
0012789-img_0.png 0012789-img_1.png 0012789-flow_01.flo
0015130-img_0.png 0015130-img_1.png 0015130-flow_01.flo
0005116-img_0.png 0005116-img_1.png 0005116-fl

Exception: Caught Exception in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/Users/brianandika/opt/anaconda3/envs/pytorch-nn/lib/python3.7/site-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
    data = fetcher.fetch(index)
  File "/Users/brianandika/opt/anaconda3/envs/pytorch-nn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "/Users/brianandika/opt/anaconda3/envs/pytorch-nn/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
    data = [self.dataset[idx] for idx in possibly_batched_index]
  File "<ipython-input-2-6084ce35d484>", line 36, in __getitem__
    gt = read(str(self.gt_dir/self.gt_names[idx]))
  File "/Users/brianandika/LocalFolders/EECS504FinalProject/EECS-504-PWC-Net/BrianTest/file_io.py", line 13, in read
    elif file.endswith('.flo'): return readFlow(file)
  File "/Users/brianandika/LocalFolders/EECS504FinalProject/EECS-504-PWC-Net/BrianTest/file_io.py", line 106, in readFlow
    raise Exception('Flow file header does not contain PIEH')
Exception: Flow file header does not contain PIEH


0011954-img_0.png 0011954-img_1.png 0011954-flow_01.flo
0004142-img_0.png 0004142-img_1.png 0004142-flow_01.flo
../Data/ChairsSDHom_extended/train/0004142-flow_01.flo


In [None]:
im1,im2,target = read_H5(2,path = '../Data/H5/512by512FlyingChairs/mini_batch_')
print_batch(im1,im2,target)

In [16]:
read("../Data/ChairsSDHom_extended/train/0004142-flow_01.flo")

../Data/ChairsSDHom_extended/train/0004142-flow_01.flo


Exception: Flow file header does not contain PIEH