# Multi-modal data with a GraphQL alternative

This notebook will demonstrate working with multi-modal or multi-source data. By default, Hyrax assumes only a single dataset will be used for a given model.

As before, the two required configuration parameters are `config['general']['data_dir']` (to define the location of the data) and `config['data_set']['name']` (to define the dataset class that will load specific data from disk).

Below we the standard approach of loading the default dataset defined in the Hyrax config using `h.prepare()`. We then print the dataset object to see the basic information about it.

In [1]:
from hyrax import Hyrax

h = Hyrax()

# Set a few configs for later use
h.config["train"]["epochs"] = 1
h.config["data_set.random_dataset"]["shape"] = (1, 32, 32)
h.config["data_set.random_dataset"]["size"] = 50000

[2025-07-22 13:33:57,687 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


## Attaching a iterable datasets to a model
The following shows how to set the dataset to be the `HyraxRandomIterableDataset`.

When experimentation is complete, the following can also be specified directly in a configuration .toml file like so:
```toml
[model_data]

[model_data.rando]
dataset_class = "HyraxRandomIterableDataset"
data_directory = "./data"
fields = ["image"]



In [2]:
h.config["model_data"] = {
    "rando_0": {
        "dataset_class": "HyraxRandomIterableDataset",
        "data_directory": "./data",
        "fields": ["object_id", "image"],
        "primary_id_field": "object_id",
    },
    "rando_1": {
        "dataset_class": "HyraxRandomIterableDataset",
        "data_directory": "./data",
        "fields": ["object_id", "image"],
    },
}

## Examine the new iterable dataset
As before, calling ``h.prepare()`` will return an instance of the ``DataProvider`` dataset.
The ``DataProvider`` class can be thought of as a container of multiple datasets, as well as a gateway (in GraphQL terminology)
that will send requests for specific data to the datasets it contains. Printing the dataset object will show the configuration of the dataset.

In [3]:
ds = h.prepare()
print(ds)

[2025-07-22 13:34:00,647 hyrax.prepare:INFO] Finished Prepare


rando_0
  Dataset class: HyraxRandomIterableDataset
  Data directory: ./data
  Fields: object_id, image
rando_1
  Dataset class: HyraxRandomIterableDataset
  Data directory: ./data
  Fields: object_id, image



In [11]:
ds.get_sample()

{'rando_0': {'index': 7,
  'object_id': 7,
  'image': array([[[0.39884484, 0.91513836, 0.89960617, ..., 0.95127994,
           0.23864251, 0.4383163 ],
          [0.6812738 , 0.28683555, 0.26445645, ..., 0.60255677,
           0.57460254, 0.69453275],
          [0.44086492, 0.36550325, 0.3526907 , ..., 0.10638863,
           0.7537327 , 0.9651421 ],
          ...,
          [0.33532035, 0.36672413, 0.08297735, ..., 0.1615507 ,
           0.44905108, 0.58613694],
          [0.20847565, 0.1187771 , 0.1357665 , ..., 0.15624124,
           0.93203765, 0.45069343],
          [0.16738671, 0.369471  , 0.3347882 , ..., 0.21927172,
           0.20030725, 0.32320726]]], dtype=float32),
  'label': np.int64(1)},
 'rando_1': {'index': 7,
  'object_id': 7,
  'image': array([[[0.39884484, 0.91513836, 0.89960617, ..., 0.95127994,
           0.23864251, 0.4383163 ],
          [0.6812738 , 0.28683555, 0.26445645, ..., 0.60255677,
           0.57460254, 0.69453275],
          [0.44086492, 0.36550325, 0.3

In [4]:
print(f"Is iterable: {ds.is_iterable()}")  # Should return True
print(f"Is mappable: {ds.is_map()}")  # Should return False

Is iterable: True
Is mappable: False


In [5]:
a = next(iter(ds))
print(a)

{'rando_0': {'index': 0, 'object_id': 0, 'image': array([[[0.08925092, 0.773956  , 0.6545715 , ..., 0.44341415,
         0.45045954, 0.22723871],
        [0.09213591, 0.55458474, 0.8878898 , ..., 0.7447621 ,
         0.36664265, 0.9675097 ],
        [0.41085035, 0.32582533, 0.90553576, ..., 0.38747835,
         0.8980876 , 0.28832805],
        ...,
        [0.41836828, 0.5780172 , 0.5375471 , ..., 0.6346291 ,
         0.9714626 , 0.41181087],
        [0.15344363, 0.40878308, 0.8401149 , ..., 0.7874322 ,
         0.3427019 , 0.5491443 ],
        [0.19697303, 0.43141818, 0.5296637 , ..., 0.07205909,
         0.8685205 , 0.84199315]]], dtype=float32), 'label': np.int64(1)}, 'rando_1': {'index': 0, 'object_id': 0, 'image': array([[[0.08925092, 0.773956  , 0.6545715 , ..., 0.44341415,
         0.45045954, 0.22723871],
        [0.09213591, 0.55458474, 0.8878898 , ..., 0.7447621 ,
         0.36664265, 0.9675097 ],
        [0.41085035, 0.32582533, 0.90553576, ..., 0.38747835,
         0.898087

In [6]:
samp = ds.get_sample()
print("Fields from rando_0")
print(samp["rando_0"]["image"].shape)
print("Fields from rando_1")
print(samp["rando_1"]["image"].shape)

Fields from rando_0
(1, 32, 32)
Fields from rando_1
(1, 32, 32)


## Pass the data through ``to_tensor``
Since we have access to the model class, we can call the ``to_tensor`` method with example data.
This allows easy checking that the output matches the expectations of the model architecture.

In this example, we expect ``to_tensor`` to return a tuple of (Tensor, int), or specifically a multi-channel image and a label.

In [7]:
import numpy as np
import torch


@staticmethod
def to_tensor(data_dict):
    """This function converts structured data to the input tensor we need to run

    Parameters
    ----------
    data_dict : dict
        The dictionary returned from our data source
    """
    rando_0 = data_dict.get("rando_0", {})
    rando_1 = data_dict.get("rando_1", {})

    if "image" in rando_0:
        image_0 = torch.from_numpy(rando_0["image"])

    if "image" in rando_1:
        image_1 = torch.from_numpy(rando_1["image"])

    stack_dim = 0 if image_0.ndim == 3 else 1
    image = torch.from_numpy(np.concatenate([image_0, image_1], axis=stack_dim))

    return image


m = h.model()
m.to_tensor = to_tensor

In [8]:
res = m.to_tensor(samp)
print(f"Type and shape of resulting image: {type(res)}, {res.shape}")

Type and shape of resulting image: <class 'torch.Tensor'>, torch.Size([2, 32, 32])


## Train with this model
Now that we've seen that the ``to_tensor`` method is returning a reasonable form of data, we can train our model.
As before, we call ``h.train()``.
While it is quiet verbose, the initialization logging shows that the model instance is created with data from
the ``DataProvider`` class, and that our new implementation of ``to_tensor`` is being used to manipulate
the data from ``DataProvider`` into the a form that our model architecture accepts.

In [9]:
h.train()

         0.45045954, 0.22723871],
        [0.09213591, 0.55458474, 0.8878898 , ..., 0.7447621 ,
         0.36664265, 0.9675097 ],
        [0.41085035, 0.32582533, 0.90553576, ..., 0.38747835,
         0.8980876 , 0.28832805],
        ...,
        [0.41836828, 0.5780172 , 0.5375471 , ..., 0.6346291 ,
         0.9714626 , 0.41181087],
        [0.15344363, 0.40878308, 0.8401149 , ..., 0.7874322 ,
         0.3427019 , 0.5491443 ],
        [0.19697303, 0.43141818, 0.5296637 , ..., 0.07205909,
         0.8685205 , 0.84199315]]], dtype=float32), 'label': np.int64(1)}, 'rando_1': {'index': 0, 'object_id': 0, 'image': array([[[0.08925092, 0.773956  , 0.6545715 , ..., 0.44341415,
         0.45045954, 0.22723871],
        [0.09213591, 0.55458474, 0.8878898 , ..., 0.7447621 ,
         0.36664265, 0.9675097 ],
        [0.41085035, 0.32582533, 0.90553576, ..., 0.38747835,
         0.8980876 , 0.28832805],
        ...,
        [0.41836828, 0.5780172 , 0.5375471 , ..., 0.6346291 ,
         0.9714626 ,

AttributeError: 'HyraxRandomIterableDataset' object has no attribute 'get_object_id'

In [None]:
h.infer()

# Things get not-so-great after this
From here we really need to consider how we save state.
We need to make sure that the datasets collected in ``DataProvider`` are copied to the config.
We need to figure out how to recreate the current DataProvider from the config.
And we need to figure out when to use the contents of the persisted config file vs.
when to use the ``data`` attribute in the model class.

In [None]:
h.umap()

In [None]:
h.visualize()