Skip to content

Commit

Permalink
Drop "data/": "./data/imagenet" -> "./imagenet'
Browse files Browse the repository at this point in the history
  • Loading branch information
sublee committed Sep 20, 2019
1 parent 8bc2a22 commit 1c0a887
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
6 changes: 3 additions & 3 deletions examples/resnet101_accuracy_benchmark/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ environment only for this benchmark:
$ pip install -r requirements.txt
```

Prepare ImageNet dataset at `./data/imagenet`:
Prepare ImageNet dataset at `./imagenet`:
```sh
$ python -c "import torchvision; torchvision.datasets.ImageNet('./data/imagenet', split='train', download=True)"
$ python -c "import torchvision; torchvision.datasets.ImageNet('./data/imagenet', split='val', download=True)"
$ python -c "import torchvision; torchvision.datasets.ImageNet('./imagenet', split='train', download=True)"
$ python -c "import torchvision; torchvision.datasets.ImageNet('./imagenet', split='val', download=True)"
```

Then, run each benchmark:
Expand Down
12 changes: 4 additions & 8 deletions examples/resnet101_accuracy_benchmark/main.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""ResNet-101 Accuracy Benchmark"""
import os
import platform
import time
from typing import Any, Callable, Dict, List, Optional, Tuple, cast
Expand Down Expand Up @@ -78,10 +77,7 @@ def pipeline8_4k(model: nn.Module, devices: List[int]) -> Stuffs:
}


def dataloaders(base: str,
batch_size: int,
num_workers: int = 32,
) -> Tuple[DataLoader, DataLoader]:
def dataloaders(batch_size: int, num_workers: int = 32) -> Tuple[DataLoader, DataLoader]:
num_workers = num_workers if batch_size <= 4096 else num_workers // 2

post_transforms = torchvision.transforms.Compose([
Expand All @@ -90,7 +86,7 @@ def dataloaders(base: str,
])

train_dataset = torchvision.datasets.ImageNet(
root=os.path.join(base, 'imagenet'),
root='imagenet',
split='train',
transform=torchvision.transforms.Compose([
torchvision.transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
Expand All @@ -99,7 +95,7 @@ def dataloaders(base: str,
])
)
test_dataset = torchvision.datasets.ImageNet(
root=os.path.join(base, 'imagenet'),
root='imagenet',
split='val',
transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(256),
Expand Down Expand Up @@ -209,7 +205,7 @@ def cli(ctx: click.Context,
ctx.fail(str(exc))

# Prepare dataloaders.
train_dataloader, valid_dataloader = dataloaders('data', batch_size)
train_dataloader, valid_dataloader = dataloaders(batch_size)

# Optimizer with LR scheduler
steps = len(train_dataloader)
Expand Down

0 comments on commit 1c0a887

Please sign in to comment.