# TensorFlow Datasets for Croissant 🥐



[TensorFlow Datasets](https://www.tensorflow.org/datasets/overview) (in short, TFDS) is an established library to handle downloading and preparing data efficiently and deterministically.

TFDS is framework-agnostic: it can generate datasets by constructing a `tf.data.Dataset`, a `np.array` or a [`ArrayRecord`](https://github.com/google/array_record) data source, for use with TensorFlow, Jax, PyTorch, and other Machine Learning frameworks.

TFDS has recently introduced a `CroissantBuilder`, which defines a TFDS dataset based on a Croissant 🥐 metadata file.

## Setup



Let's install and import the needed dependencies:

In [2]:
!pip install array_record
!pip install tfds-nightly

!pip install mlcroissant@git+https://github.com/mlcommons/croissant#subdirectory=python/mlcroissant

!pip install datasets
!pip install GitPython
!pip install Pillow

import os
os.environ.pop('TFDS_DATA_DIR', None)

import tensorflow_datasets as tfds
from tensorflow_datasets.core.dataset_builders.croissant_builder import CroissantBuilder

Collecting mlcroissant@ git+https://github.com/mlcommons/croissant#subdirectory=python/mlcroissant
  Cloning https://github.com/mlcommons/croissant to /tmp/pip-install-w2tzlwoe/mlcroissant_5304c61ff8a74fd39f7e6881224114fc
  Running command git clone --filter=blob:none --quiet https://github.com/mlcommons/croissant /tmp/pip-install-w2tzlwoe/mlcroissant_5304c61ff8a74fd39f7e6881224114fc
  Resolved https://github.com/mlcommons/croissant to commit 77887a37ec70b33cdd694bcf583a81f5e833f065
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting jsonpath-rw (from mlcroissant@ git+https://github.com/mlcommons/croissant#subdirectory=python/mlcroissant)
  Downloading jsonpath-rw-1.4.0.tar.gz (13 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rdflib (from mlcroissant@ git+https://github.com/mlcommon

To initialize a `CroissantBuilder` in TFDS, we need a Croissant 🥐 file describing a dataset.

We can use the `from_huggingface_to_croissant` command of the `mlcroissant` library to convert a Hugging Face 🤗 datasets to Croissant JSON-LD files.

Let's create a Croissant metadata file for [fashion_mnist](https://huggingface.co/datasets/fashion_mnist), a popular dataset for computer vision.

In [3]:
!mlcroissant from_huggingface_to_croissant --dataset fashion_mnist

Downloading builder script: 100% 4.83k/4.83k [00:00<00:00, 17.0MB/s]
Downloading metadata: 100% 3.13k/3.13k [00:00<00:00, 10.9MB/s]
Downloading readme: 100% 8.85k/8.85k [00:00<00:00, 23.6MB/s]
I1127 13:47:02.165981 140028632248320 from_huggingface_to_croissant.py:189] Done. Wrote Croissant JSON-LD to /tmp/croissant_1701092817.7705376.json


## `CroissantBuilder` in TFDS

Given the Croissant file, we create a TFDS `CroissantBuilder` for the `fashion_mnist` dataset.

A `CroissantBuilder` takes as input a Croissant 🥐 file, and a list of `RecordSet` names to generate. Each `RecordSet` will correspond to a separated [`BuilderConfig`](https://www.tensorflow.org/datasets/api_docs/python/tfds/core/BuilderConfig).

In [5]:
fashion_croissant_path = "/tmp/croissant_1701092817.7705376.json"

fashion_mnist_builder = CroissantBuilder(
        file=fashion_croissant_path,
        record_set_names=["default"],
        file_format='array_record',
    )

  -  [dataset(fashion_mnist)] Property "https://schema.org/license" is recommended, but does not exist.


Our `CroissantBuilder` uses the information contained in the Croissant 🥐 file to initialize the TFDS dataset's [documentation](https://www.tensorflow.org/datasets/api_docs/python/tfds/core/DatasetInfo), which we can explore using the [`DatasetBuilder.info`](https://www.tensorflow.org/datasets/api_docs/python/tfds/core/DatasetBuilder#info) method:

In [7]:
print(f"Dataset's description:\n{fashion_mnist_builder.info.description}\n")
print(f"Dataset's citation:\n{fashion_mnist_builder.info.citation}\n")
print(f"Dataset's features:\n{fashion_mnist_builder.info.features}")

# ...

Dataset's description:
Fashion-MNIST is a dataset of Zalando's article images—consisting of a training set of
60,000 examples and a test set of 10,000 examples. Each example is a 28x28 grayscale image,
associated with a label from 10 classes. We intend Fashion-MNIST to serve as a direct drop-in
replacement for the original MNIST dataset for benchmarking machine learning algorithms.
It shares the same image size and structure of training and testing splits.

Dataset's citation:
@article{DBLP:journals/corr/abs-1708-07747,
  author    = {Han Xiao and
               Kashif Rasul and
               Roland Vollgraf},
  title     = {Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning
               Algorithms},
  journal   = {CoRR},
  volume    = {abs/1708.07747},
  year      = {2017},
  url       = {http://arxiv.org/abs/1708.07747},
  archivePrefix = {arXiv},
  eprint    = {1708.07747},
  timestamp = {Mon, 13 Aug 2018 16:47:27 +0200},
  biburl    = {https://dblp.org/rec/bi

We can now generate the TFDS dataset:

In [8]:
fashion_mnist_builder.download_and_prepare()
ds = fashion_mnist_builder.as_data_source()

Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/fashion_mnist/default/1.0.0...


Generating splits...:   0%|          | 0/1 [00:00<?, ? splits/s]

Generating default examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/fashion_mnist/default/1.0.0.incomplete7DCRNY/fashion_mnist-default.array_r…

Dataset fashion_mnist downloaded and prepared to /root/tensorflow_datasets/fashion_mnist/default/1.0.0. Subsequent calls will reuse this data.


The command above outputs a dictionary of data sources:

In [9]:
ds

{'default': DataSource(name=fashion_mnist, split='default', decoders=None)}

## Use with PyTorch

TFDS data sources can be used as regular map-style datasets, for example to train and test a PyTorch model:

In [13]:
!pip install torch

Using PyTorch, we train and evaluate a simple logistic regression on the first examples:

In [18]:
from tqdm import tqdm
import torch

# Define the splits, the sampler and the loaders.
train_split = fashion_mnist_builder.as_data_source(split='default[:70%]')
test_split = fashion_mnist_builder.as_data_source(split='default[70%:]')

batch_size = 128
train_sampler = torch.utils.data.RandomSampler(train_split, num_samples=5_000)

train_loader = torch.utils.data.DataLoader(
    train_split,
    sampler=train_sampler,
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    test_split,
    sampler=None,
    batch_size=batch_size,
)

features = fashion_mnist_builder.info.features
shape = features['image'].shape
num_classes = 10


class LinearClassifier(torch.nn.Module):
  def __init__(self, shape, num_classes):
    super(LinearClassifier, self).__init__()
    height, width, channels = shape
    self.classifier = torch.nn.Linear(height * width * channels, num_classes)

  def forward(self, image):
    image = image.view(image.size()[0], -1).to(torch.float32)
    return self.classifier(image)

model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

print('Training...')
model.train()
for example in tqdm(train_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  loss = loss_function(prediction, label)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  num_examples += image.shape[0]
  predicted_label = prediction.argmax(dim=1)
  true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')

shape (None, None, 3)
Training...


  0%|          | 0/40 [00:00<?, ?it/s]


RuntimeError: ignored