Skip to content

Commit

Permalink
Use training for batch_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremy Howard committed Mar 23, 2019
1 parent 1e6e4a1 commit 042f845
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Expand Up @@ -18,6 +18,7 @@ of that change.

### Fixed:

- Default to using training set for `batch_stats` instead of validation


## 1.0.50 (2019-03-19)
Expand Down
3 changes: 1 addition & 2 deletions fastai/vision/data.py
Expand Up @@ -164,10 +164,9 @@ def single_from_classes(path:Union[Path, str], classes:Collection[str], ds_tfms:
sd = ImageList([], path=path, ignore_empty=True).split_none()
return sd.label_const(0, label_cls=CategoryList, classes=classes).transform(ds_tfms, **kwargs).databunch()

def batch_stats(self, funcs:Collection[Callable]=None)->Tensor:
def batch_stats(self, funcs:Collection[Callable]=None, ds_type:DatasetType=DatasetType.Train)->Tensor:
"Grab a batch of data and call reduction function `func` per channel"
funcs = ifnone(funcs, [torch.mean,torch.std])
ds_type = DatasetType.Valid if self.valid_dl else DatasetType.Train
x = self.one_batch(ds_type=ds_type, denorm=False)[0].cpu()
return [func(channel_view(x), 1) for func in funcs]

Expand Down

0 comments on commit 042f845

Please sign in to comment.