Skip to content

Commit

Permalink
Carvana dataset loader
Browse files Browse the repository at this point in the history
  • Loading branch information
milesial committed Jul 30, 2020
1 parent 84f8392 commit 4ad8323
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@


class BasicDataset(Dataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
def __init__(self, imgs_dir, masks_dir, scale=1, mask_suffix=''):
self.imgs_dir = imgs_dir
self.masks_dir = masks_dir
self.scale = scale
self.mask_suffix = mask_suffix
assert 0 < scale <= 1, 'Scale must be between 0 and 1'

self.ids = [splitext(file)[0] for file in listdir(imgs_dir)
Expand Down Expand Up @@ -43,7 +44,7 @@ def preprocess(cls, pil_img, scale):

def __getitem__(self, i):
idx = self.ids[i]
mask_file = glob(self.masks_dir + idx + '.*')
mask_file = glob(self.masks_dir + idx + self.mask_suffix + '.*')
img_file = glob(self.imgs_dir + idx + '.*')

assert len(mask_file) == 1, \
Expand All @@ -63,3 +64,8 @@ def __getitem__(self, i):
'image': torch.from_numpy(img).type(torch.FloatTensor),
'mask': torch.from_numpy(mask).type(torch.FloatTensor)
}


class CarvanaDataset(BasicDataset):
def __init__(self, imgs_dir, masks_dir, scale=1):
super().__init__(imgs_dir, masks_dir, scale, mask_suffix='_mask')

0 comments on commit 4ad8323

Please sign in to comment.