-
Notifications
You must be signed in to change notification settings - Fork 400
/
synthetic.py
217 lines (180 loc) · 9.97 KB
/
synthetic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Synthetic datasets used for testing, profiling, and debugging."""
from __future__ import annotations
from typing import Callable, Optional, Sequence, Union
import torch
import torch.utils.data
from PIL import Image
from torchvision.datasets import VisionDataset
from composer.core.types import MemoryFormat
from composer.utils.string_enum import StringEnum
__all__ = ['SyntheticDataType', 'SyntheticDataLabelType', 'SyntheticBatchPairDataset', 'SyntheticPILDataset']
class SyntheticDataType(StringEnum):
"""Defines the distribution of the synthetic data.
Attributes:
GAUSSIAN: Standard Guassian distribution.
SEPARABLE: Gaussian distributed, but classes will be mean-shifted for
separability.
"""
GAUSSIAN = 'gaussian'
SEPARABLE = 'separable'
class SyntheticDataLabelType(StringEnum):
"""Defines the class label type of the synthetic data.
Attributes:
CLASSIFICATION_INT: Class labels are ints.
CLASSIFICATION_ONE_HOT: Class labels are one-hot vectors.
"""
CLASSIFICATION_INT = 'classification_int'
CLASSIFICATION_ONE_HOT = 'classification_one_hot'
class SyntheticBatchPairDataset(torch.utils.data.Dataset):
"""Emulates a dataset of provided size and shape.
Args:
total_dataset_size (int): The total size of the dataset to emulate.
data_shape (List[int]): Shape of the tensor for input samples.
num_unique_samples_to_create (int): The number of unique samples to allocate memory for.
data_type (str or SyntheticDataType, optional), Type of synthetic data to create.
Default: ``SyntheticDataType.GAUSSIAN``.
label_type (str or SyntheticDataLabelType, optional), Type of synthetic data to
create. Default: ``SyntheticDataLabelType.CLASSIFICATION_INT``.
num_classes (int, optional): Number of classes to use. Required if
``SyntheticDataLabelType`` is ``CLASSIFICATION_INT``
or``CLASSIFICATION_ONE_HOT``. Default: ``None``.
label_shape (List[int], optional): Shape of the tensor for each sample label.
Default: ``None``.
device (str): Device to store the sample pool. Set to ``'cuda'`` to store samples
on the GPU and eliminate PCI-e bandwidth with the dataloader. Set to ``'cpu'``
to move data between host memory and the gpu on every batch. Default:
``'cpu'``.
memory_format (MemoryFormat, optional): Memory format for the sample pool.
Default: `MemoryFormat.CONTIGUOUS_FORMAT`.
transform (Callable, optional): Transform(s) to apply to data. Default: ``None``.
"""
def __init__(self,
*,
total_dataset_size: int,
data_shape: Sequence[int],
num_unique_samples_to_create: int = 100,
data_type: Union[str, SyntheticDataType] = SyntheticDataType.GAUSSIAN,
label_type: Union[str, SyntheticDataLabelType] = SyntheticDataLabelType.CLASSIFICATION_INT,
num_classes: Optional[int] = None,
label_shape: Optional[Sequence[int]] = None,
device: str = 'cpu',
memory_format: Union[str, MemoryFormat] = MemoryFormat.CONTIGUOUS_FORMAT,
transform: Optional[Callable] = None):
self.total_dataset_size = total_dataset_size
self.data_shape = data_shape
self.num_unique_samples_to_create = num_unique_samples_to_create
self.data_type = SyntheticDataType(data_type)
self.label_type = SyntheticDataLabelType(label_type)
self.num_classes = num_classes
self.label_shape = label_shape
self.device = device
self.memory_format = MemoryFormat(memory_format)
self.transform = transform
self._validate_label_inputs(label_type=self.label_type,
num_classes=self.num_classes,
label_shape=self.label_shape)
# The synthetic data
self.input_data = None
self.input_target = None
def _validate_label_inputs(self, label_type: SyntheticDataLabelType, num_classes: Optional[int],
label_shape: Optional[Sequence[int]]):
if label_type == SyntheticDataLabelType.CLASSIFICATION_INT or label_type == SyntheticDataLabelType.CLASSIFICATION_ONE_HOT:
if num_classes is None or num_classes <= 0:
raise ValueError('classification label_types require num_classes > 0')
def __len__(self) -> int:
return self.total_dataset_size
def __getitem__(self, idx: int):
idx = idx % self.num_unique_samples_to_create
if self.input_data is None:
# Generating data on the first call to __getitem__ so that data is stored on the correct gpu,
# after DeviceSingleGPU calls torch.cuda.set_device
# This does mean that the first batch will be slower
# generating samples so all values for the sample are the sample index
# e.g. all(input_data[1] == 1). Helps with debugging.
assert self.input_target is None
input_data = torch.randn(self.num_unique_samples_to_create, *self.data_shape, device=self.device)
input_data = torch.clone(input_data) # allocate actual memory
input_data = input_data.contiguous(memory_format=getattr(torch, self.memory_format.value))
if self.label_type == SyntheticDataLabelType.CLASSIFICATION_ONE_HOT:
assert self.num_classes is not None
input_target = torch.zeros((self.num_unique_samples_to_create, self.num_classes), device=self.device)
input_target[:, 0] = 1.0
elif self.label_type == SyntheticDataLabelType.CLASSIFICATION_INT:
assert self.num_classes is not None
if self.label_shape:
label_batch_shape = (self.num_unique_samples_to_create, *self.label_shape)
else:
label_batch_shape = (self.num_unique_samples_to_create,)
input_target = torch.randint(0, self.num_classes, label_batch_shape, device=self.device)
else:
raise ValueError(f'Unsupported label type {self.data_type}')
# If separable, force the positive examples to have a higher mean than the negative examples
if self.data_type == SyntheticDataType.SEPARABLE:
assert self.label_type == SyntheticDataLabelType.CLASSIFICATION_INT, \
'SyntheticDataType.SEPARABLE requires integer classes.'
assert torch.max(input_target) == 1 and torch.min(input_target) == 0, \
'SyntheticDataType.SEPARABLE only supports binary labels'
# Make positive examples have mean = 3 and negative examples have mean = -3
# so they are easier to separate with a classifier
input_data[input_target == 0] -= 3
input_data[input_target == 1] += 3
self.input_data = input_data
self.input_target = input_target
assert self.input_target is not None
if self.transform is not None:
return self.transform(self.input_data[idx]), self.input_target[idx]
else:
return self.input_data[idx], self.input_target[idx]
class SyntheticPILDataset(VisionDataset):
"""Similar to :class:`SyntheticBatchPairDataset`, but yields samples of type :class:`~PIL.Image.Image` and supports
dataset transformations.
Args:
total_dataset_size (int): The total size of the dataset to emulate.
data_shape (List[int]): Shape of the tensor for input samples.
num_unique_samples_to_create (int): The number of unique samples to allocate memory for.
data_type (str or SyntheticDataType, optional), Type of synthetic data to create.
Default: ``SyntheticDataType.GAUSSIAN``.
label_type (str or SyntheticDataLabelType, optional), Type of synthetic data to
create. Default: ``SyntheticDataLabelType.CLASSIFICATION_INT``.
num_classes (int, optional): Number of classes to use. Required if
``SyntheticDataLabelType`` is ``CLASSIFICATION_INT``
or ``CLASSIFICATION_ONE_HOT``. Default: ``None``.
label_shape (List[int], optional): Shape of the tensor for each sample label.
Default: ``None``.
transform (Callable, optional): Transform(s) to apply to data. Default: ``None``.
"""
def __init__(self,
*,
total_dataset_size: int,
data_shape: Sequence[int] = (64, 64, 3),
num_unique_samples_to_create: int = 100,
data_type: Union[str, SyntheticDataType] = SyntheticDataType.GAUSSIAN,
label_type: Union[str, SyntheticDataLabelType] = SyntheticDataLabelType.CLASSIFICATION_INT,
num_classes: Optional[int] = None,
label_shape: Optional[Sequence[int]] = None,
transform: Optional[Callable] = None):
super().__init__(root='', transform=transform)
self._dataset = SyntheticBatchPairDataset(
total_dataset_size=total_dataset_size,
data_shape=data_shape,
data_type=data_type,
num_unique_samples_to_create=num_unique_samples_to_create,
label_type=label_type,
num_classes=num_classes,
label_shape=label_shape,
)
def __len__(self) -> int:
return len(self._dataset)
def __getitem__(self, idx: int):
input_data, target = self._dataset[idx]
input_data = input_data.numpy()
# Shift and scale to be [0, 255]
input_data = (input_data - input_data.min())
input_data = (input_data * (255 / input_data.max())).astype('uint8')
sample = Image.fromarray(input_data)
if self.transform is not None:
return self.transform(sample), target
else:
return sample, target