In [30]:
import pandas as pd
import os
import skimage
from skimage import io
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
import sys
sys.path.append("../Dataset_utils")
from poses_parser import pose_2_sixd_array
from torchvision import transforms

In [35]:
class UAV_GPS_Dataset(Dataset):
    def __init__(self, data_dir, representations=["rgb"],transform=None):
        self.data_dir = data_dir
        self.representations = representations
        self.images_names=os.listdir(os.path.join(data_dir,"rgb"))
        self.transform=transform
        self.TT=transforms.ToTensor()


    def __len__(self):
        return len(self.images_names)

    def __getitem__(self, idx):
        image_name=self.images_names[idx]
        name=image_name.split(".")[0]
        sample={}
        for representation in self.representations:
            if representation=="rgb":
                repre=io.imread(os.path.join(self.data_dir,"rgb",(name+".png")))
                if repre.shape[2]>3:
                    repre=repre[:,:,:3]
                repre=self.TT(repre)
            elif representation=="poses":
                repre = np.loadtxt(
                    os.path.join(self.data_dir,"poses",(name+".txt")),
                    dtype=np.float16
                )
                repre=pose_2_sixd_array(repre)
            elif representation=="semantics":
                repre=torch.from_numpy(
                    np.fromfile(
                    os.path.join(self.data_dir,"addl_scene_info","semantics",(name+".npy")),
                    dtype=np.uint8,
                    offset=128
                 ).reshape(480,720))
            else:
                repre = torch.load(os.path.join(self.data_dir,"addl_scene_info",representation,(name+".dat")))
                
            sample[representation]=repre
        if self.transform:
            sample=self.transform(sample)
        return sample

In [36]:
DS=UAV_GPS_Dataset(
    data_dir="../../Datasets/GPS/train_sorted",
    representations=["rgb", 'poses']
)

In [38]:
DS[78]["rgb"]

tensor([[[0.6039, 0.5569, 0.6353,  ..., 0.9451, 0.8667, 0.5608],
         [0.5647, 0.5608, 0.7020,  ..., 0.9216, 0.6118, 0.5059],
         [0.2392, 0.4824, 0.6784,  ..., 0.7686, 0.4980, 0.3216],
         ...,
         [0.4667, 0.5647, 0.5725,  ..., 0.5216, 0.4784, 0.5373],
         [0.3843, 0.6510, 0.4941,  ..., 0.4980, 0.5137, 0.5529],
         [0.4275, 0.4549, 0.5294,  ..., 0.5765, 0.5804, 0.4706]],

        [[0.4941, 0.4549, 0.5490,  ..., 0.8039, 0.7020, 0.3804],
         [0.4588, 0.4902, 0.6196,  ..., 0.7765, 0.4706, 0.4392],
         [0.2196, 0.4039, 0.5922,  ..., 0.6157, 0.4549, 0.2784],
         ...,
         [0.3647, 0.4627, 0.4745,  ..., 0.4118, 0.3804, 0.4667],
         [0.2824, 0.5529, 0.3961,  ..., 0.4039, 0.4314, 0.4902],
         [0.3294, 0.3569, 0.4235,  ..., 0.4941, 0.5020, 0.3961]],

        [[0.4471, 0.4118, 0.4980,  ..., 0.7412, 0.6627, 0.3490],
         [0.4353, 0.4314, 0.5451,  ..., 0.7020, 0.4078, 0.4863],
         [0.2471, 0.3412, 0.5137,  ..., 0.5725, 0.5098, 0.

In [1]:
#Test
import sys
sys.path.append("../Dataset_utils")
7from DataLoader import UAV_GPS_Dataset

In [2]:
DS=UAV_GPS_Dataset(
    data_dir="../../Datasets/GPS/train_sorted",
    representations=["rgb", 'init', 'semantics', 'normal', 'depth']
)

In [4]:
len(DS)

16197