-
Notifications
You must be signed in to change notification settings - Fork 3
/
datasets.py
47 lines (36 loc) · 1.36 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/dataset.ipynb.
# %% ../nbs/dataset.ipynb 3
from __future__ import print_function, division, annotations
from .imports import *
from .utils import asnumpy
# %% auto 0
__all__ = ['JAXDataset', 'Dataset', 'ArrayDataset']
# %% ../nbs/dataset.ipynb 4
class Dataset:
"""A pytorch-like Dataset class."""
def __len__(self):
raise NotImplementedError
def __getitem__(self, index):
raise NotImplementedError
# %% ../nbs/dataset.ipynb 5
class ArrayDataset(Dataset):
"""Dataset wrapping numpy arrays."""
def __init__(
self,
*arrays: jax.Array, # Numpy array with same first dimension
asnumpy: bool = True, # Store arrays as numpy arrays if True; otherwise store as array type of *arrays
):
assert all(arrays[0].shape[0] == arr.shape[0] for arr in arrays), \
"All arrays must have the same dimension."
self.arrays = tuple(arrays)
if asnumpy:
self.asnumpy()
def asnumpy(self):
"""Convert all arrays to numpy arrays."""
self.arrays = tuple(asnumpy(arr) for arr in self.arrays)
def __len__(self):
return self.arrays[0].shape[0]
def __getitem__(self, index):
return jax.tree_util.tree_map(lambda x: x[index], self.arrays)
# %% ../nbs/dataset.ipynb 14
JAXDataset = Dataset