-
Notifications
You must be signed in to change notification settings - Fork 56
/
dataloaders.py
26 lines (19 loc) · 1 KB
/
dataloaders.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
def get_dataset_name(mode):
if mode == "ade20k":
return "Ade20kDataset"
if mode == "cityscapes":
return "CityscapesDataset"
if mode == "coco":
return "CocoStuffDataset"
else:
ValueError("There is no such dataset regime as %s" % mode)
def get_dataloaders(opt):
dataset_name = get_dataset_name(opt.dataset_mode)
file = __import__("dataloaders."+dataset_name)
dataset_train = file.__dict__[dataset_name].__dict__[dataset_name](opt, for_metrics=False)
dataset_val = file.__dict__[dataset_name].__dict__[dataset_name](opt, for_metrics=True)
print("Created %s, size train: %d, size val: %d" % (dataset_name, len(dataset_train), len(dataset_val)))
dataloader_train = torch.utils.data.DataLoader(dataset_train, batch_size = opt.batch_size, shuffle = True, drop_last=True)
dataloader_val = torch.utils.data.DataLoader(dataset_val, batch_size = opt.batch_size, shuffle = False, drop_last=False)
return dataloader_train, dataloader_val