# Multi-modal data with a GraphQL alternative

This notebook will demonstrate working with multi-modal or multi-source data. By default, Hyrax assumes only a single dataset will be used for a given model.

As before, the two required configuration parameters are `config['general']['data_dir']` (to define the location of the data) and `config['data_set']['name']` (to define the dataset class that will load specific data from disk).

Below we the standard approach of loading the default dataset defined in the Hyrax config using `h.prepare()`. We then print the dataset object to see the basic information about it.

In [1]:
from hyrax import Hyrax

h = Hyrax()

# Set a few configs for later use
h.config["train"]["epochs"] = 1
h.config["data_set.random_dataset"]["shape"] = (1, 32, 32)
h.config["data_set.random_dataset"]["size"] = 50000

[2025-07-21 13:47:16,052 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


## Attaching a iterable datasets to a model
The following shows how to set the dataset to be the `HyraxRandomIterableDataset`.

When experimentation is complete, the following can also be specified directly in a configuration .toml file like so:
```toml
[model_data]

[model_data.rando]
dataset_class = "HyraxRandomIterableDataset"
data_directory = "./data"
fields = ["image"]



In [2]:
h.config["model_data"] = {
    "rando": {
        "dataset_class": "HyraxRandomIterableDataset",
        "data_directory": "./data",
        "fields": ["object_id", "image"],
        "primary_id_field": "object_id",
    },
}

## Examine the new iterable dataset
As before, calling ``h.prepare()`` will return an instance of the ``DataProvider`` dataset.
The ``DataProvider`` class can be thought of as a container of multiple datasets, as well as a gateway (in GraphQL terminology)
that will send requests for specific data to the datasets it contains. Printing the dataset object will show the configuration of the dataset.

In [3]:
ds = h.prepare()
print(ds)

[2025-07-21 13:47:22,319 hyrax.prepare:INFO] Finished Prepare


rando
  Dataset class: HyraxRandomIterableDataset
  Data directory: ./data
  Fields: object_id, image



The various datasets contained within the `DataProvider` instance.

In [4]:
ds.prepped_datasets

{'rando': <hyrax.data_sets.random.hyrax_random_dataset.HyraxRandomIterableDataset at 0x17abf5310>}

In [5]:
print(f"Is iterable: {ds.is_iterable()}")  # Should return True
print(f"Is mappable: {ds.is_map()}")  # Should return False

Is iterable: True
Is mappable: False


Checking the length of the dataset is the same as always.

In [6]:
print(f"Length of the multimodal dataset: {len(ds)}")

[2025-07-21 13:47:25,887 hyrax.data_sets.data_provider:ERROR] Primary dataset is iterable, cannot determine length.


TypeError: 'NoneType' object cannot be interpreted as an integer

In [7]:
samp = ds[2335]
print("Fields from cifar_0")
print(samp["cifar_0"]["image"].shape)
print(samp["cifar_0"]["label"])
print(samp["cifar_0"]["object_id"])
print("Fields from cifar_1")
print(samp["cifar_1"]["image"].shape)
print(samp["cifar_1"]["label"])
print("Fields from rando")
print(samp["rando"]["image"].shape)

Fields from cifar_0
(3, 32, 32)
5
2335
Fields from cifar_1
(3, 32, 32)
5
Fields from rando
(1, 32, 32)


## Pass the data through ``to_tensor``
Since we have access to the model class, we can call the ``to_tensor`` method with example data.
This allows easy checking that the output matches the expectations of the model architecture.

In this example, we expect ``to_tensor`` to return a tuple of (Tensor, int), or specifically a multi-channel image and a label.

In [8]:
m = h.model()
res = m.to_tensor(samp)
print(f"Type and shape of resulting image: {type(res[0])}, {res[0].shape}")
print(f"Type and shape of the label: {type(res[1])}, {res[1]}")

Type and shape of resulting image: <class 'torch.Tensor'>, torch.Size([4, 32, 32])
Type and shape of the label: <class 'int'>, 5


## Updating ``to_tensor``
The default implementation of ``to_tensor`` only makes use of "cifar_0" and "rando".
But if we are experimenting, we don't want to have to make code changes in the model class.
It would be much easier to experiment with in the notebook.
Here, we redefine the ``to_tensor`` method, and check the results by running sample data through the method.

In [9]:
import torch
import numpy as np


@staticmethod
def to_tensor(data_dict):
    """This function converts structured data to the input tensor we need to run

    Parameters
    ----------
    data_dict : dict
        The dictionary returned from our data source
    """
    cifar_data = data_dict.get("cifar_0", {})
    random_data = data_dict.get("rando", {})

    more_cifar_data = data_dict.get("cifar_1", {})

    if "label" in cifar_data:
        label = cifar_data["label"]

    if "image" in cifar_data and "image" in random_data:
        cifar_image = cifar_data["image"]
        random_image = random_data["image"]
        more_cifar_image = more_cifar_data["image"]
        stack_dim = 0 if cifar_image.ndim == 3 else 1
        image = torch.from_numpy(
            np.concatenate([cifar_image, random_image, more_cifar_image], axis=stack_dim)
        )
    elif "image" in cifar_data:
        image = cifar_data["image"]
    elif "image" in random_data:
        image = torch.from_numpy(random_data["image"])

    return (image, label)


m.to_tensor = to_tensor

After running the same sample through as before, we can see that the number of channels
in the image has changed (from 4 to 7), while all the other values have remained the same.

In [10]:
new_res = m.to_tensor(samp)
print(f"Type and shape of resulting image: {type(new_res[0])}, {new_res[0].shape}")
print(f"Type and shape of the label: {type(new_res[1])}, {new_res[1]}")

Type and shape of resulting image: <class 'torch.Tensor'>, torch.Size([7, 32, 32])
Type and shape of the label: <class 'int'>, 5


## Train with this model
Now that we've seen that the ``to_tensor`` method is returning a reasonable form of data, we can train our model.
As before, we call ``h.train()``.
While it is quiet verbose, the initialization logging shows that the model instance is created with data from
the ``DataProvider`` class, and that our new implementation of ``to_tensor`` is being used to manipulate
the data from ``DataProvider`` into the a form that our model architecture accepts.

In [None]:
h.train()

In [None]:
h.infer()

# Things get not-so-great after this
From here we really need to consider how we save state.
We need to make sure that the datasets collected in ``DataProvider`` are copied to the config.
We need to figure out how to recreate the current DataProvider from the config.
And we need to figure out when to use the contents of the persisted config file vs.
when to use the ``data`` attribute in the model class.

In [None]:
h.umap()

In [None]:
h.visualize()