Skip to content

Commit

Permalink
Merge pull request #17929 from hawkinsp:torchloader
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570839721
  • Loading branch information
jax authors committed Oct 4, 2023
2 parents 6065464 + d8a0227 commit 9e3d64a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 14 deletions.
9 changes: 2 additions & 7 deletions docs/notebooks/Neural_Network_and_Data_Loading.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -298,17 +298,12 @@
"outputs": [],
"source": [
"import numpy as np\n",
"from jax.tree_util import tree_map\n",
"from torch.utils import data\n",
"from torchvision.datasets import MNIST\n",
"\n",
"def numpy_collate(batch):\n",
" if isinstance(batch[0], np.ndarray):\n",
" return np.stack(batch)\n",
" elif isinstance(batch[0], (tuple,list)):\n",
" transposed = zip(*batch)\n",
" return [numpy_collate(samples) for samples in transposed]\n",
" else:\n",
" return np.array(batch)\n",
" return tree_map(np.asarray, data.default_collate(batch))\n",
"\n",
"class NumpyLoader(data.DataLoader):\n",
" def __init__(self, dataset, batch_size=1,\n",
Expand Down
9 changes: 2 additions & 7 deletions docs/notebooks/Neural_Network_and_Data_Loading.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,17 +189,12 @@ JAX is laser-focused on program transformations and accelerator-backed NumPy, so
:id: 94PjXZ8y3dVF
import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
from torchvision.datasets import MNIST
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple,list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
return tree_map(np.asarray, data.default_collate(batch))
class NumpyLoader(data.DataLoader):
def __init__(self, dataset, batch_size=1,
Expand Down

0 comments on commit 9e3d64a

Please sign in to comment.