# Multi-modal data with a GraphQL alternative

This notebook will demonstrate working with multi-modal or multi-source data. 

In [1]:
from hyrax import Hyrax

h = Hyrax()

# Set a few configs for later use
h.config["train"]["epochs"] = 1
h.config["data_set.random_dataset"]["shape"] = (1, 32, 32)
h.config["data_set.random_dataset"]["size"] = 50000


# Return a reference to the model class based on the configuration used to create `h`.
# This should feel similar to `ds = h.prepare()` - Perhaps we should rename to `h.data()`???`
m = h.model()

# Since `data` is a model class attribute, we can print the data dictionary like so
m.data

[2025-07-11 17:23:19,759 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


{'cifar_1': {'dataset_class': 'HyraxCifarDataSet',
  'data_directory': 'path/to/dataset',
  'fields': ['image', 'label', 'object_id'],
  'primary_id_field': 'object_id'},
 'rando': {'dataset_class': 'HyraxRandomDataset',
  'data_directory': '/fake/dir',
  'fields': ['image']}}

## Attaching datasets to a model
The following shows the process of attaching new datasets to the model class.

Note - Attempting to add a dataset with a friendly name that already exists will log and error. To _update_ a dataset already attached to a model, first run ``detach_dataset(...)``, then ``attach_dataset(...)``.

In [2]:
# This dataset is already defined on the model's default `data` attribute, so this log an error
m.attach_dataset(
    friendly_name="cifar_0",
    dataset_class="HyraxCifarDataSet",
    data_directory=h.config["general"]["data_dir"],
    fields=["image", "label", "object_id"],
    # if `primary_id_field` not specified, none of the fields will be used as the primary id field
    primary_id_field="object_id",
)

# Note that `primary_id_field` is not defined here
m.attach_dataset(
    friendly_name="cifar_1",
    dataset_class="HyraxCifarDataSet",
    data_directory=h.config["general"]["data_dir"],
    fields=["image", "label", "object_id"],
)

m.attach_dataset(
    friendly_name="cifar_2",
    dataset_class="HyraxCifarDataSet",
    data_directory="path/to/cifar/dataset",
    fields=["image", "label", "object_id"],
)

m.attach_dataset(
    friendly_name="random_dataset",
    dataset_class="HyraxRandomDataset",
    data_directory="path/to/random/dataset",
    fields=["image"],
)

m.data

[2025-07-11 17:23:28,452 hyrax.models.model_registry:ERROR] The friendly name 'cifar_1' already exists. If updating, first run `detach_dataset(cifar_1)`, then run `attach_dataset(('cifar_1', Ellipsis))` again.


{'cifar_1': {'dataset_class': 'HyraxCifarDataSet',
  'data_directory': '/Users/drew/code/hyrax/docs/pre_executed/data',
  'fields': ['image', 'label', 'object_id']},
 'rando': {'dataset_class': 'HyraxRandomDataset',
  'data_directory': '/fake/dir',
  'fields': ['image']},
 'cifar_0': {'dataset_class': 'HyraxCifarDataSet',
  'data_directory': '/Users/drew/code/hyrax/docs/pre_executed/data',
  'fields': ['image', 'label', 'object_id'],
  'primary_id_field': 'object_id'},
 'cifar_2': {'dataset_class': 'HyraxCifarDataSet',
  'data_directory': 'path/to/cifar/dataset',
  'fields': ['image', 'label', 'object_id']},
 'random_dataset': {'dataset_class': 'HyraxRandomDataset',
  'data_directory': 'path/to/random/dataset',
  'fields': ['image']}}

## Removing dataset from a model
The following shows the process of removing a dataset from a model.

Note - Attempting to remove a dataset that doesn't exist will log an error.

In [3]:
# Remove the specific dataset from `data`.
m.detach_dataset(friendly_name="cifar_2")
m.detach_dataset(friendly_name="random_dataset")

# This dataset doesn't exist and will log an error.
m.detach_dataset("poopfish")

m.data

[2025-07-11 17:23:30,298 hyrax.models.model_registry:ERROR] Cannot remove 'poopfish' from data. These can be removed: ['cifar_1', 'rando', 'cifar_0']


{'cifar_1': {'dataset_class': 'HyraxCifarDataSet',
  'data_directory': '/Users/drew/code/hyrax/docs/pre_executed/data',
  'fields': ['image', 'label', 'object_id']},
 'rando': {'dataset_class': 'HyraxRandomDataset',
  'data_directory': '/fake/dir',
  'fields': ['image']},
 'cifar_0': {'dataset_class': 'HyraxCifarDataSet',
  'data_directory': '/Users/drew/code/hyrax/docs/pre_executed/data',
  'fields': ['image', 'label', 'object_id'],
  'primary_id_field': 'object_id'}}

## Examine the multimodal dataset
As before, calling ``h.prepare()`` will return an instance of the ``DataProvider`` dataset.
The ``DataProvider`` class can be thought of as a container of multiple datasets, as well as a gateway (in GraphQL terminology)
that will send requests for specific data to the datasets it contains

In [4]:
ds = h.prepare()

[2025-07-11 17:23:42,202 hyrax.prepare:INFO] Finished Prepare


The various datasets contained within the `DataProvider` instance.

In [5]:
ds.prepped_datasets

{'cifar_1': <hyrax.data_sets.hyrax_cifar_data_set.HyraxCifarDataSet at 0x158343a40>,
 'rando': <hyrax.data_sets.random.hyrax_random_dataset.HyraxRandomDataset at 0x14fadcc80>,
 'cifar_0': <hyrax.data_sets.hyrax_cifar_data_set.HyraxCifarDataSet at 0x158341580>}

Checking the length of the dataset is the same as always.

In [9]:
print(f"Length of the multimodal dataset: {len(ds)}")
print(f"Length of a specific dataset contained inside: {len(ds.prepped_datasets['cifar_0'])}")

Length of the multimodal dataset: 50000
Length of a specific dataset contained inside: 50000


In [10]:
samp = ds[2335]
print("Fields from cifar_0")
print(samp["cifar_0"]["image"].shape)
print(samp["cifar_0"]["label"])
print(samp["cifar_0"]["object_id"])
print("Fields from cifar_1")
print(samp["cifar_1"]["image"].shape)
print(samp["cifar_1"]["label"])
print(samp["cifar_1"]["object_id"])
print("Fields from rando")
print(samp["rando"]["image"].shape)

Fields from cifar_0
(3, 32, 32)
5
2335
Fields from cifar_1
(3, 32, 32)
5
2335
Fields from rando
(1, 32, 32)


## Pass the data through ``to_tensor``
Since we have access to the model class, we can call the ``to_tensor`` method with example data.
This allows easy checking that the output matches the expectations of the model architecture.

In this example, we expect ``to_tensor`` to return a tuple of (Tensor, int), or specifically a multi-channel image and a label.

In [11]:
res = m.to_tensor(samp)
print(f"Type and shape of resulting image: {type(res[0])}, {res[0].shape}")
print(f"Type and shape of the label: {type(res[1])}, {res[1]}")

Type and shape of resulting image: <class 'torch.Tensor'>, torch.Size([4, 32, 32])
Type and shape of the label: <class 'int'>, 5


## Updating ``to_tensor``
The default implementation of ``to_tensor`` only makes use of "cifar_0" and "rando".
But if we are experimenting, we don't want to have to make code changes in the model class.
It would be much easier to experiment with in the notebook.
Here, we redefine the ``to_tensor`` method, and check the results by running sample data through the method.

In [12]:
import torch
import numpy as np


@staticmethod
def to_tensor(data_dict):
    """This function converts structured data to the input tensor we need to run

    Parameters
    ----------
    data_dict : dict
        The dictionary returned from our data source
    """
    cifar_data = data_dict.get("cifar_0", {})
    random_data = data_dict.get("rando", {})

    more_cifar_data = data_dict.get("cifar_1", {})

    if "label" in cifar_data:
        label = cifar_data["label"]

    if "image" in cifar_data and "image" in random_data:
        cifar_image = cifar_data["image"]
        random_image = random_data["image"]
        more_cifar_image = more_cifar_data["image"]
        stack_dim = 0 if cifar_image.ndim == 3 else 1
        image = torch.from_numpy(
            np.concatenate([cifar_image, random_image, more_cifar_image], axis=stack_dim)
        )
    elif "image" in cifar_data:
        image = cifar_data["image"]
    elif "image" in random_data:
        image = torch.from_numpy(random_data["image"])

    return (image, label)


m.to_tensor = to_tensor

After running the same sample through as before, we can see that the number of channels
in the image has changed (from 4 to 7), while all the other values have remained the same.

In [13]:
new_res = m.to_tensor(samp)
print(f"Type and shape of resulting image: {type(new_res[0])}, {new_res[0].shape}")
print(f"Type and shape of the label: {type(new_res[1])}, {new_res[1]}")

Type and shape of resulting image: <class 'torch.Tensor'>, torch.Size([7, 32, 32])
Type and shape of the label: <class 'int'>, 5


## Train with this model
Now that we've seen that the ``to_tensor`` method is returning a reasonable form of data, we can train our model.
As before, we call ``h.train()``.
While it is quiet verbose, the initialization logging shows that the model instance is created with data from
the ``DataProvider`` class, and that our new implementation of ``to_tensor`` is being used to manipulate
the data from ``DataProvider`` into the a form that our model architecture accepts.

In [14]:
h.train()

          0.19215691,  0.16078436],
        [-0.8745098 , -1.        , -0.85882354, ..., -0.03529412,
         -0.06666666, -0.04313725],
        [-0.8039216 , -0.8745098 , -0.6156863 , ..., -0.0745098 ,
         -0.05882353, -0.14509803],
        ...,
        [ 0.6313726 ,  0.5764706 ,  0.5529412 , ...,  0.254902  ,
         -0.56078434, -0.58431375],
        [ 0.41176474,  0.35686278,  0.45882356, ...,  0.4431373 ,
         -0.23921567, -0.3490196 ],
        [ 0.38823533,  0.3176471 ,  0.4039216 , ...,  0.69411767,
          0.18431377, -0.03529412]],

       [[-0.5137255 , -0.6392157 , -0.62352943, ...,  0.03529418,
         -0.01960784, -0.02745098],
        [-0.84313726, -1.        , -0.9372549 , ..., -0.3098039 ,
         -0.3490196 , -0.31764704],
        [-0.8117647 , -0.94509804, -0.7882353 , ..., -0.34117645,
         -0.34117645, -0.42745095],
        ...,
        [ 0.33333337,  0.20000005,  0.26274514, ...,  0.04313731,
         -0.75686276, -0.73333335],
        [ 0.090196

  2%|1         | 1/59 [00:00<?, ?it/s]

[2025-07-11 17:24:59,693 hyrax.pytorch_ignite:INFO] Total training time: 15.77[s]
[2025-07-11 17:24:59,693 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /Users/drew/code/hyrax/docs/pre_executed/results/20250711-172436-train-YSo5/checkpoint_epoch_1.pt
[2025-07-11 17:24:59,694 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /Users/drew/code/hyrax/docs/pre_executed/results/20250711-172436-train-YSo5/checkpoint_1_loss=-951.7826.pt
2025/07/11 17:24:59 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/07/11 17:24:59 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-07-11 17:24:59,709 hyrax.verbs.train:INFO] Finished Training
[2025-07-11 17:25:00,374 hyrax.model_exporters:INFO] Exported model to ONNX format: /Users/drew/code/hyrax/docs/pre_executed/results/20250711-172436-train-YSo5/example_model_opset_20.onnx


HyraxAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(7, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): GELU(approximate='none')
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): GELU(approximate='none')
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (5): GELU(approximate='none')
    (6): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): GELU(approximate='none')
    (8): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (9): GELU(approximate='none')
    (10): Flatten(start_dim=1, end_dim=-1)
    (11): Linear(in_features=1024, out_features=64, bias=True)
  )
  (dec_linear): Sequential(
    (0): Linear(in_features=64, out_features=1024, bias=True)
    (1): GELU(approximate='none')
  )
  (decoder): Sequential(
    (0): ConvTranspose2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    (1): GELU(approximate='none')
    (2): 

In [15]:
h.infer()

          0.19215691,  0.16078436],
        [-0.8745098 , -1.        , -0.85882354, ..., -0.03529412,
         -0.06666666, -0.04313725],
        [-0.8039216 , -0.8745098 , -0.6156863 , ..., -0.0745098 ,
         -0.05882353, -0.14509803],
        ...,
        [ 0.6313726 ,  0.5764706 ,  0.5529412 , ...,  0.254902  ,
         -0.56078434, -0.58431375],
        [ 0.41176474,  0.35686278,  0.45882356, ...,  0.4431373 ,
         -0.23921567, -0.3490196 ],
        [ 0.38823533,  0.3176471 ,  0.4039216 , ...,  0.69411767,
          0.18431377, -0.03529412]],

       [[-0.5137255 , -0.6392157 , -0.62352943, ...,  0.03529418,
         -0.01960784, -0.02745098],
        [-0.84313726, -1.        , -0.9372549 , ..., -0.3098039 ,
         -0.3490196 , -0.31764704],
        [-0.8117647 , -0.94509804, -0.7882353 , ..., -0.34117645,
         -0.34117645, -0.42745095],
        ...,
        [ 0.33333337,  0.20000005,  0.26274514, ...,  0.04313731,
         -0.75686276, -0.73333335],
        [ 0.090196

  1%|1         | 1/98 [00:00<?, ?it/s]

[2025-07-11 17:25:46,985 hyrax.pytorch_ignite:INFO] Total evaluation time: 15.79[s]
[2025-07-11 17:25:47,042 hyrax.verbs.infer:INFO] Inference Complete.


<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x1588a7950>

# Things get not-so-great after this
From here we really need to consider how we save state.
We need to make sure that the datasets collected in ``DataProvider`` are copied to the config.
We need to figure out how to recreate the current DataProvider from the config.
And we need to figure out when to use the contents of the persisted config file vs.
when to use the ``data`` attribute in the model class.

In [14]:
h.umap()

[2025-07-11 16:53:49,256 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /Users/drew/code/hyrax/docs/pre_executed/results/20250711-165314-infer-RP-f for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-07-11 16:53:56,875 hyrax.verbs.umap:INFO] Saving UMAP results to /Users/drew/code/hyrax/docs/pre_executed/results/20250711-165356-umap-l_GX
[2025-07-11 16:53:57,118 hyrax.verbs.umap:INFO] Fitting the UMAP
OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.
[2025-07-11 16:54:01,893 hyrax.verbs.umap:INFO] Saving fitted UMAP Reducer


Creating lower dimensional representation using UMAP::   0%|          | 0/98 [00:00<?, ?it/s]

[2025-07-11 16:54:44,461 hyrax.verbs.umap:INFO] Finished transforming all data through UMAP


<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x3f91f26f0>

In [3]:
h.visualize()

[2025-07-11 17:10:05,902 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /Users/drew/code/hyrax/docs/pre_executed/results/20250711-165356-umap-l_GX for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.


URLError: <urlopen error [Errno 60] Operation timed out>