Skip to content

Commit

Permalink
Add classy OSS CIFAR (#35)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/ClassyVision#35

Pull Request resolved: #206

OSS Classy CIFAR. Simple wrapper around torchvision dataset. The classy dataset wrapper provides batching, transforms, shuffling, and restricting the size of the dataset.

Side-effects: Changed the name of the fb internal classy vision dataset to fb_cifar* and changed all naming schemes to match this.

I also added a unittest that relies on some of the torchvision testing utilities which are not packaged with torchvision...as such, I did this via a fbcode import.

Differential Revision: D18429440

fbshipit-source-id: f2a0bfbec7671c415871e40a18c8187aaf5251e8
  • Loading branch information
Aaron Adcock authored and facebook-github-bot committed Nov 21, 2019
1 parent 1859669 commit 4a0d784
Showing 1 changed file with 71 additions and 0 deletions.
71 changes: 71 additions & 0 deletions classy_vision/dataset/classy_cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional, Union

from classy_vision.dataset import ClassyDataset, register_dataset
from classy_vision.dataset.transforms import ClassyTransform, build_transforms
from torchvision.datasets.cifar import CIFAR10, CIFAR100


class CIFARDataset(ClassyDataset):
_CIFAR_TYPE = None

def __init__(
self,
split: Optional[str],
batchsize_per_replica: int,
shuffle: bool,
transform: Optional[Union[ClassyTransform, Callable]],
num_samples: Optional[int],
root: str,
download: bool = None,
):
assert self._CIFAR_TYPE in [
"cifar10",
"cifar100",
], "CIFARDataset must be subclassed and a valid _CIFAR_TYPE provided"
if self._CIFAR_TYPE == "cifar10":
dataset = CIFAR10(root=root, train=(split == "train"), download=download)
if self._CIFAR_TYPE == "cifar100":
dataset = CIFAR100(root=root, train=(split == "train"), download=download)

super().__init__(
dataset, split, batchsize_per_replica, shuffle, transform, num_samples
)

@classmethod
def from_config(cls, config):
(
transform_config,
batchsize_per_replica,
shuffle,
num_samples,
) = cls.parse_config(config)
split = config.get("split")
root = config.get("root")
download = config.get("download")

transform = build_transforms(transform_config)
return cls(
split=split,
batchsize_per_replica=batchsize_per_replica,
shuffle=shuffle,
transform=transform,
num_samples=num_samples,
root=root,
download=download,
)


@register_dataset("cifar10")
class CIFAR10Dataset(CIFARDataset):
_CIFAR_TYPE = "cifar10"


@register_dataset("cifar100")
class CIFAR100Dataset(CIFARDataset):
_CIFAR_TYPE = "cifar100"

0 comments on commit 4a0d784

Please sign in to comment.