Skip to content

Commit

Permalink
Merge 8477d65 into 21da0c0
Browse files Browse the repository at this point in the history
  • Loading branch information
droidadroit committed Jan 15, 2019
2 parents 21da0c0 + 8477d65 commit ce0be0b
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions autokeras/preprocessor.py
Expand Up @@ -263,13 +263,13 @@ def _transform(self, compose_list, data, targets):
class DataTransformerMlp(DataTransformer):
def __init__(self, data):
super().__init__()
self.max_val = data.max()
data = data / self.max_val
self.mean = np.mean(data, axis=0, keepdims=True).flatten()
self.std = np.std(data, axis=0, keepdims=True).flatten()
self.mean = np.mean(data, axis=0)
self.std = np.std(data, axis=0)

def transform_train(self, data, targets=None, batch_size=None):
dataset = self._transform([Normalize(torch.Tensor(self.mean), torch.Tensor(self.std))], data, targets)
data = (data - self.mean) / self.std
data = np.nan_to_num(data)
dataset = self._transform([], data, targets)

if batch_size is None:
batch_size = Constant.MAX_BATCH_SIZE
Expand All @@ -281,7 +281,6 @@ def transform_test(self, data, target=None, batch_size=None):
return self.transform_train(data, targets=target, batch_size=batch_size)

def _transform(self, compose_list, data, targets):
data = data / self.max_val
args = [0, len(data.shape) - 1] + list(range(1, len(data.shape) - 1))
data = torch.Tensor(data.transpose(*args))
data_transforms = Compose(compose_list)
Expand Down

0 comments on commit ce0be0b

Please sign in to comment.