# Multi-modal data with a GraphQL alternative

This notebook will demonstrate working with multi-modal or multi-source data. By default, Hyrax assumes only a single dataset will be used for a given model.

As before, the two required configuration parameters are `config['general']['data_dir']` (to define the location of the data) and `config['data_set']['name']` (to define the dataset class that will load specific data from disk).

Below we the standard approach of loading the default dataset defined in the Hyrax config using `h.prepare()`. We then print the dataset object to see the basic information about it.

In [1]:
from hyrax import Hyrax

h = Hyrax()

# Set a few configs for later use
h.config["train"]["epochs"] = 1
h.config["data_set.random_dataset"]["shape"] = (1, 32, 32)
h.config["data_set.random_dataset"]["size"] = 50000

[2025-07-22 16:03:22,595 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


## Attaching a iterable datasets to a model
The following shows how to set the dataset to be the `HyraxRandomIterableDataset`.

When experimentation is complete, the following can also be specified directly in a configuration .toml file like so:
```toml
[model_data]

[model_data.rando]
dataset_class = "HyraxRandomIterableDataset"
data_directory = "./data"
fields = ["image"]



In [2]:
h.config["model_data"] = {
    "rando_0": {
        "dataset_class": "HyraxRandomIterableDataset",
        "data_directory": "./data",
        "fields": ["object_id", "image"],
        "primary_id_field": "object_id",
    },
    # "rando_1": {
    #     "dataset_class": "HyraxRandomIterableDataset",
    #     "data_directory": "./data",
    #     "fields": ["object_id", "image"],
    # },
}

## Examine the new iterable dataset
As before, calling ``h.prepare()`` will return an instance of the ``DataProvider`` dataset.
The ``DataProvider`` class can be thought of as a container of multiple datasets, as well as a gateway (in GraphQL terminology)
that will send requests for specific data to the datasets it contains. Printing the dataset object will show the configuration of the dataset.

In [3]:
ds = h.prepare()
print(ds)

[2025-07-22 16:03:26,605 hyrax.data_sets.data_provider:ERROR] Dataset 'HyraxRandomIterableDataset' is an iterable-style dataset. This is not supported in the current implementation of DataProvider.Hyrax only supports 1-N map-style datasets at this time or single iterable-style datasets.
[2025-07-22 16:03:26,606 hyrax.data_sets.data_provider:ERROR] No `get_object_id` method for requested field, 'object_id' was found in dataset HyraxRandomIterableDataset.
[2025-07-22 16:03:26,606 hyrax.data_sets.data_provider:ERROR] Finished validating request. Problems found: 2


RuntimeError: Data request validation failed.

In [None]:
ds.get_sample()

In [None]:
print(f"Is iterable: {ds.is_iterable()}")  # Should return True
print(f"Is mappable: {ds.is_map()}")  # Should return False

In [None]:
samp = ds.get_sample()
print("Fields from rando_0")
print(samp["rando_0"]["image"].shape)
# print("Fields from rando_1")
# print(samp["rando_1"]["image"].shape)

## Pass the data through ``to_tensor``
Since we have access to the model class, we can call the ``to_tensor`` method with example data.
This allows easy checking that the output matches the expectations of the model architecture.

In this example, we expect ``to_tensor`` to return a tuple of (Tensor, int), or specifically a multi-channel image and a label.

In [None]:
import numpy as np
import torch


@staticmethod
def to_tensor(data_dict):
    """This function converts structured data to the input tensor we need to run

    Parameters
    ----------
    data_dict : dict
        The dictionary returned from our data source
    """
    rando_0 = data_dict.get("rando_0", {})
    rando_1 = data_dict.get("rando_1", {})

    if "image" in rando_0:
        image_0 = torch.from_numpy(rando_0["image"])

    if "image" in rando_1:
        image_1 = torch.from_numpy(rando_1["image"])

        stack_dim = 0 if image_0.ndim == 3 else 1
        image_0 = torch.from_numpy(np.concatenate([image_0, image_1], axis=stack_dim))

    return image_0


m = h.model()
m.to_tensor = to_tensor

In [None]:
res = m.to_tensor(samp)
print(f"Type and shape of resulting image: {type(res)}, {res.shape}")

## Train with this model
Now that we've seen that the ``to_tensor`` method is returning a reasonable form of data, we can train our model.
As before, we call ``h.train()``.
While it is quiet verbose, the initialization logging shows that the model instance is created with data from
the ``DataProvider`` class, and that our new implementation of ``to_tensor`` is being used to manipulate
the data from ``DataProvider`` into the a form that our model architecture accepts.

In [None]:
h.train()

In [None]:
h.infer()

# Things get not-so-great after this
From here we really need to consider how we save state.
We need to make sure that the datasets collected in ``DataProvider`` are copied to the config.
We need to figure out how to recreate the current DataProvider from the config.
And we need to figure out when to use the contents of the persisted config file vs.
when to use the ``data`` attribute in the model class.

In [None]:
h.umap()

In [None]:
h.visualize()

In [None]:
from torch.utils.data import Dataset, IterableDataset
from hyrax.data_sets.data_set_registry import HyraxDataset


class Thing(HyraxDataset):
    def __init__(self):
        super().__init__()

    def __len__(self):
        return 100

    def __getitem__(self, idx):
        return {"image": torch.randn(3, 32, 32), "label": idx}

In [None]:
Thing.is_iterable()