Skip to content

Commit

Permalink
[minor] .gitignore data/ cached by tests (#995)
Browse files Browse the repository at this point in the history
  • Loading branch information
crutcher committed May 31, 2022
1 parent e7602a4 commit b3a4c68
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 1 deletion.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,6 @@ env.bak/
venv.bak/
.vscode/*
*.DS_Store

# Data generated by tests
cached_datasets/
2 changes: 2 additions & 0 deletions fair_dev/common_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
"Common cache root for torchvision.datasets and others."
DATASET_CACHE_ROOT = "cached_datasets"
8 changes: 7 additions & 1 deletion tests/optim/test_layerwise_gradient_scaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import torchvision
import torchvision.transforms as transforms

from fair_dev.common_paths import DATASET_CACHE_ROOT
from fairscale.optim.layerwise_gradient_scaler import LayerwiseGradientScaler
from fairscale.utils.testing import skip_a_test_if_in_CI

Expand Down Expand Up @@ -71,7 +72,12 @@ def load_data(model_type: str) -> Union[DataLoader, Tuple[Any, Any]]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# TODO: we should NOT do this download over and over again during test.
train_ds = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
train_ds = torchvision.datasets.CIFAR10(
root=DATASET_CACHE_ROOT,
train=True,
download=True,
transform=transform,
)
train_ds_loader = torch.utils.data.DataLoader(train_ds, batch_size=128, shuffle=False, num_workers=2)

image, _ = train_ds[0]
Expand Down

0 comments on commit b3a4c68

Please sign in to comment.