# Train a model with Croissant 🥐, Hugging Face 🤗 and TFDS

[TensorFlow Datasets](https://www.tensorflow.org/datasets/overview) (in short, TFDS) is an established library to handle downloading and preparing data efficiently and deterministically.

TFDS is framework-agnostic: it can generate datasets by constructing a `tf.data.Dataset`, a `np.array` or a [`ArrayRecord`](https://github.com/google/array_record) data source, for use with TensorFlow, Jax, PyTorch, and other Machine Learning frameworks.

TFDS has recently introduced a `CroissantBuilder`, which defines a TFDS dataset based on a Croissant 🥐 metadata file.

## Setup



Let's install and import the needed dependencies:

In [None]:
%%capture --no-display
# Install mlcroissant from the source
!pip install "git+https://github.com/mlcommons/croissant.git@${GITHUB_HEAD_REF:-main}#subdirectory=python/mlcroissant&egg=mlcroissant[dev]"
!pip install array_record
!pip install tfds-nightly
!pip install tensorflow
!pip install torch
!apt-get install tree

In [1]:
%%capture --no-display
import json
import os

from etils import epath
import mlcroissant as mlc
import requests
import tensorflow_datasets as tfds
import torch
from tqdm import tqdm

local_croissant_file = epath.Path("/tmp/croissant.json")
data_dir = "/tmp/croissant"

## Download the Croissant JSON-LD file

To initialize a `CroissantBuilder` in TFDS, we need a Croissant 🥐 file describing a dataset.
In this notebook, we will create a TFDS `CroissantBuilder` for [fashion_mnist](https://huggingface.co/datasets/fashion_mnist), a popular dataset for computer vision.

In [2]:
api_url = "https://datasets-server.huggingface.co/croissant?dataset=fashion_mnist"

# Download the JSON and write it to `local_croissant_file`.
response = requests.get(api_url, headers=None).json()
with local_croissant_file.open("w") as f:
  jsonld = json.dumps(response, indent=2)
  f.write(jsonld)
  print(jsonld)

{
  "@context": {
    "@language": "en",
    "@vocab": "https://schema.org/",
    "column": "ml:column",
    "data": {
      "@id": "ml:data",
      "@type": "@json"
    },
    "dataType": {
      "@id": "ml:dataType",
      "@type": "@vocab"
    },
    "extract": "ml:extract",
    "field": "ml:field",
    "fileProperty": "ml:fileProperty",
    "format": "ml:format",
    "includes": "ml:includes",
    "isEnumeration": "ml:isEnumeration",
    "jsonPath": "ml:jsonPath",
    "ml": "http://mlcommons.org/schema/",
    "parentField": "ml:parentField",
    "path": "ml:path",
    "recordSet": "ml:recordSet",
    "references": "ml:references",
    "regex": "ml:regex",
    "repeated": "ml:repeated",
    "replace": "ml:replace",
    "sc": "https://schema.org/",
    "separator": "ml:separator",
    "source": "ml:source",
    "subField": "ml:subField",
    "transform": "ml:transform"
  },
  "@type": "sc:Dataset",
  "name": "fashion_mnist",
  "description": "fashion_mnist dataset hosted on Hugging F

## Build the TFDS dataset

A `CroissantBuilder` takes as input a Croissant 🥐 file, and a list of `RecordSet` names to generate. Each `RecordSet` will correspond to a separated [`BuilderConfig`](https://www.tensorflow.org/datasets/api_docs/python/tfds/core/BuilderConfig).

In [3]:
import tensorflow_datasets as tfds

builder = tfds.core.dataset_builders.CroissantBuilder(
    jsonld=local_croissant_file,
    record_set_names=["record_set_fashion_mnist"],
    file_format='array_record',
    data_dir=data_dir,
)

  -  [dataset(fashion_mnist)] Property "https://schema.org/citation" is recommended, but does not exist.
  -  [dataset(fashion_mnist)] Property "https://schema.org/license" is recommended, but does not exist.
  -  [dataset(fashion_mnist)] Property "https://schema.org/version" is recommended, but does not exist.


Our `CroissantBuilder` uses the information contained in the Croissant 🥐 file to initialize the TFDS dataset's [documentation](https://www.tensorflow.org/datasets/api_docs/python/tfds/core/DatasetInfo), which we can explore using the [`DatasetBuilder.info`](https://www.tensorflow.org/datasets/api_docs/python/tfds/core/DatasetBuilder#info) method:

In [4]:
print(f"Dataset's description:\n{builder.info.description}\n")
print(f"Dataset's citation:\n{builder.info.citation}\n")
print(f"Dataset's features:\n{builder.info.features}")

# ...

Dataset's description:
fashion_mnist dataset hosted on Hugging Face and contributed by the HF Datasets community

Dataset's citation:


Dataset's features:
FeaturesDict({
    'image': Image(shape=(None, None, 3), dtype=uint8),
    'label': int64,
})


We can now generate and materialize the TFDS dataset on disk:

In [5]:
%%capture --no-display
builder.download_and_prepare()

`download_and_prepare` downloads the data and prepares the dataset specifically for ML. For instance, it uses an ML-optimized data format. You can read more [in the documentation](https://www.tensorflow.org/datasets/tfless_tfds). Let's inspect it on disk:

In [6]:
!tree {data_dir}/fashion_mnist

[01;34m/tmp/croissant/fashion_mnist[0m
└── [01;34mrecord_set_fashion_mnist[0m
    └── [01;34m1.0.0[0m
        ├── dataset_info.json
        ├── fashion_mnist-default.array_record-00000-of-00001
        └── features.json

3 directories, 3 files


The command above outputs a dictionary of data sources with a train/test split:

In [7]:
train, test = builder.as_data_source(split=['default[:80%]', 'default[80%:]'])

## Train a model

TFDS can be used with TensorFlow, JAX and PyTorch, because it supports many data loaders like [tf.data](https://www.tensorflow.org/guide/data), [PyGrain](https://github.com/google/grain) and [PyTorch DataLoaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). For example, let's try with Torch:

In [8]:
batch_size = 128
train_sampler = torch.utils.data.RandomSampler(train, num_samples=len(train))
train_loader = torch.utils.data.DataLoader(
    train,
    sampler=train_sampler,
    batch_size=batch_size,
)
test_loader = torch.utils.data.DataLoader(
    test,
    sampler=None,
    batch_size=batch_size,
)

DataLoaders can be fed in input of any ML pipeline. Let's try the example of a very simple example:

In [9]:
class LinearClassifier(torch.nn.Module):
  def __init__(self, shape, num_classes):
    super(LinearClassifier, self).__init__()
    height, width, channels = shape
    self.classifier = torch.nn.Linear(height * width * channels, num_classes)

  def forward(self, image):
    image = image.view(image.size()[0], -1).to(torch.float32)
    return self.classifier(image)

shape = train[0]["image"].shape
num_classes = 10
model = LinearClassifier(shape, num_classes)
optimizer = torch.optim.Adam(model.parameters())
loss_function = torch.nn.CrossEntropyLoss()

print('Training...')
model.train()
for example in tqdm(train_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  loss = loss_function(prediction, label)
  optimizer.zero_grad()
  loss.backward()
  optimizer.step()

print('Testing...')
model.eval()
num_examples = 0
true_positives = 0
for example in tqdm(test_loader):
  image, label = example['image'], example['label']
  prediction = model(image)
  num_examples += image.shape[0]
  predicted_label = prediction.argmax(dim=1)
  true_positives += (predicted_label == label).sum().item()
print(f'\nAccuracy: {true_positives/num_examples * 100:.2f}%')

Training...


100%|███████████████████████████████████████████████████████████████████████████████████| 438/438 [01:00<00:00,  7.24it/s]


Testing...


100%|███████████████████████████████████████████████████████████████████████████████████| 110/110 [00:15<00:00,  6.94it/s]


Accuracy: 77.64%



