In [None]:
#coding:utf8
import os
from PIL import Image
from torch.utils import data
import numpy as np
from torchvision import transforms as T


class Kitti(data.Dataset):    
    def __init__(self,root,transforms=None,
             white_list=[ 'Car', 'Pedestrian', 'Cyclist'],sets_type='train',process='train'):
        '''
        obj_type:all\car\cyc\ped
        sets_type:train\trainval\val\test for dataset
        :train\val\test\ for train val or inference
        root = '/hdd/you/rgbd-det/'
        '''
        self.obj_types = {'Car':'0', 'Pedestrian':'1', 'Cyclist':'2'}
        
        idx_path = os.path.join(root, 'data/image_sets/{}.txt'.format(sets_type)) # 3types
        if sets_type is not 'test': sets_type = 'trainval' # all train & val in here 2 types
        rgb_dir = os.path.join(root, 'dataset/{}/rgb'.format(sets_type))
        depth_dir = os.path.join(root, 'dataset/{}/depth'.format(sets_type))
        if sets_type is not 'test': # label or detection result
            sets_type = 'trainval/label'
        else:
            sets_type = 'test/det'
        text_path = os.path.join(root, 'dataset/{}'.format(sets_type))

        idx_fig = open(idx_path, 'r')
        idx_file = idx_fig.readlines()
        idx_fig.close()
        
        rgbs = [] # path 
        depths = []
        labels = [] # value; return img id when testing
        label = {}
        for idx in idx_file:
            idx = idx.strip('\n\r')
            txt = open(os.path.join(text_path,idx+'.txt'),'r')
            lines = txt.readlines()
            txt.close()
            for num,line in enumerate(lines):
                line = line.strip('\n\r').split(' ')
                if line[0] not in white_list:
                    continue
                obj = self.obj_types[line[0]]
                target = '{}.{}.{}.png'.format(idx,num,obj)
                rgbs.append(os.path.join(rgb_dir, target))
                depths.append(os.path.join(depth_dir, target))
                if sets_type is not 'test':
                    label['loc'] = [float(line[-3]), float(line[-2]), float(line[-1])]
                else:
                    label['loc'] = [float(idx), float(num), float(obj)] # return img ID when testing
                labels.append(label)
        
        self.idx_path = idx_path
        self.rgb_dir = rgb_dir
        self.depth_dir = depth_dir
        self.text_path = text_path
        self.rgbs = rgbs # necessary
        self.depths = depths # necessary
        self.labels = labels # necessary
        self.process = process
        
        if transforms is None:
            self.transforms = T.Compose([
                # T.Scale(224),
                # T.CenterCrop(224),
                T.ToTensor(),
                # is it necessary? the value needs to be determined
                # T.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225])
                ]) 
        else:
            self.transforms = transforms
             
    def __getitem__(self,index):
        rgb_path = self.rgbs[index]
        rgb = np.array(Image.open(rgb_path)) # WxHx3
        
        depth_path = self.depths[index]
        depth = np.array(Image.open(depth_path)) / 256. # WxH  may need refer to kitti depth_devkit
        depth = np.expand_dims(depth, axis=2) # WxHx1
        
        data = np.concatenate((rgb,depth), axis=2)        
        data = self.transforms(data)
        label = self.labels[index]
        return data, label
    
    def __len__(self):
        return len(self.rgbs)
    
    
if __name__ == '__main__':
    train_data = Kitti('/hdd/you/rgbd-det/',sets_type='train',process='train')
    data,label = train_data.__getitem__(0)
    print(data,data.shape,label,train_data.__len__(),sep='\n')