Skip to content

Commit

Permalink
Merge 7491638 into 68e0442
Browse files Browse the repository at this point in the history
  • Loading branch information
haifeng-jin committed Aug 31, 2018
2 parents 68e0442 + 7491638 commit a9f93a1
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions autokeras/preprocessor.py
Expand Up @@ -71,23 +71,31 @@ def transform_train(self, data, targets=None, batch_size=None):
common_list = [Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))]
compose_list = augment_list + common_list

return self._transform(batch_size, compose_list, data, targets)
dataset = self._transform(compose_list, data, targets)

if batch_size is None:
batch_size = Constant.MAX_BATCH_SIZE
batch_size = min(len(data), batch_size)

return DataLoader(dataset, batch_size=batch_size, shuffle=True)

def transform_test(self, data, targets=None, batch_size=None):
common_list = [Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))]
compose_list = common_list

return self._transform(batch_size, compose_list, data, targets)
dataset = self._transform(compose_list, data, targets)

def _transform(self, batch_size, compose_list, data, targets):
if batch_size is None:
batch_size = Constant.MAX_BATCH_SIZE
batch_size = min(len(data), batch_size)

return DataLoader(dataset, batch_size=batch_size, shuffle=False)

def _transform(self, compose_list, data, targets):
data = data / self.max_val
data = torch.Tensor(data.transpose(0, 3, 1, 2))
data_transforms = Compose(compose_list)
dataset = MultiTransformDataset(data, targets, data_transforms)
return DataLoader(dataset, batch_size=batch_size, shuffle=True)
return MultiTransformDataset(data, targets, data_transforms)


class MultiTransformDataset(Dataset):
Expand Down

0 comments on commit a9f93a1

Please sign in to comment.