# Hyrax Demonstration

For this demonstration we'll walk through a simplified version of a typical machine learning workflow supported by Hyrax.

In [1]:
import hyrax
import pooch
import subprocess

import chromadb
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from IPython.display import IFrame

from hyrax.config_utils import find_most_recent_results_dir
from mpr_demo_plotting import sort_objects_by_median_distance, plot_grid, plot_umap

## Download a sample HSC dataset

In [None]:
file_path = pooch.retrieve(
    # DOI for Example HSC dataset
    url="doi:10.5281/zenodo.14498536/hsc_demo_data.zip",
    known_hash="md5:1be05a6b49505054de441a7262a09671",
    fname="example_hsc_new.zip",
    path="../../data",
    processor=pooch.Unzip(extract_dir="."),
)

/Users/drew/code/hyrax/data/./example_hsc_new.zip


This dataset is comprised of approximately 993 cutouts from the Hyper Suprime Cam survey.
Each cutout includes i, r and g bands and is 8 arcseconds on a side.


## Create and configure a Hyrax object

In [3]:
h = hyrax.Hyrax()

[2025-05-06 15:26:07,810 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


An instance of the `Hyrax` class will be used through out this demo.
Under the hood when it is created, it will:
-  Load the configuration file specified (here it's using the built in default).
-  Parse the configuration file for external libraries and add those to the appropriate registries.
-  Prepare logging for the system.

In [4]:
# Specify the location of the data to use for training
h.config["general"]["data_dir"] = "../../data/hsc_8asec_1000"

# Specify the dataset class that represents the data
h.config["data_set"]["name"] = "HSCDataSet"
h.config["data_set"]["train_size"] = 0.8
h.config["data_set"]["validate_size"] = 0.2
h.config["data_set"]["test_size"] = 0.0

# Select the model to use for training
h.config["model"]["name"] = "HyraxAutoencoder"

# Set the number of epochs and batch size for training.
h.config["train"]["epochs"] = 20
h.config["data_loader"]["batch_size"] = 32

The default configuration needs a few tweaks to work for this demo.
We've updated the location of our sample data, and specified which model we want to train.

The configuration is represented as nested python dictionary. This allows for easy manipulation in a notebook via the `.config` attribute of the hyrax instance.

## Train a model

In [None]:
h.train()

When we call `.train()` to train the model there's a lot going on under the hood:
- The model is automatically loaded onto the fastest hardware available.
- A data loader is instantiated and configured to load batches of data to the same hardware.
- A new timestamped directory is created under the configured results directory where all output is saved.
- The configuration becomes immutable and a copy is saved for reproducibility.
- The model and system metrics start being logged for review in both TensorBoard and MLFlow.
- Checkpoints are saved automatically both at the last epoch and at the epoch with the lowest loss value.
- Finally the model weights file is saved.

Training time depends heavily on the hardware available, model, and training parameters.
For a point of reference training takes about 40s for this case:
- Model trained: Hyrax autoencoder
- Dataset and size: Example HSC data, 993 samples, 96x96 pixel cutouts
- Number of epochs: 20
- Batch size: 32
- Hardware: Desktop with GTX 1660 Super GPU

While we train on only about 1000 samples here, Hyrax training has scaled up to over 1M samples on an HPC system with access to multiple GPUs without requiring the user to make any code changes.
To do so, the command line interface of Hyrax was used to work within a Slurm environment like so:
```
>> hyrax train --runtime-config ./results/<timestamped_directory>/runtime_config.toml
```

## Running inference

In [None]:
# Update the data set splits to be 100% test data
h.config["data_set"]["test_size"] = 1.0
h.config["data_set"]["train_size"] = 0.0
h.config["data_set"]["validate_size"] = 0.0

# Increase batch size for faster inference
h.config["data_loader"]["batch_size"] = 512

# Run inference
h.infer()

For this demo, we'll pretend that of all the models we trained, the last one performed best.
We'll now use that model to run inference.
Note that by default, Hyrax will find the weights of the last successfully trained model for inference, but of course, a different set of weights can be specified in the configuration.

First we make a small update to the data set splits, setting `test_size` to 100% and the other splits to 0%.
We also increase the batch size in order to make better use of the available GPU memory.

Finally we run inference over the dataset using the trained model weights with `h.infer()`.
As with training, Hyrax is doing a lot behind the scenes on behalf of the user including:
- Identifying and using the most performant hardware available.
- Creating a new timestamped directory where all output is saved.
- Freezing the configuration and saving a copy for reproducibility.
- Saving the results of inference in batched .npy files.
- Optionally persisting the results to a vector database.

Again, while predicting the latent space for only 1000 samples here, Hyrax inference has scaled up to over 1M samples on an HPC system with access to multiple GPUs without requiring any code changes.

## Examine an embedding

In [17]:
h.umap()

[2025-05-06 15:34:17,116 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /Users/drew/code/hyrax/docs/pre_executed/results/20250506-134633-infer-uEOR for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-05-06 15:34:17,388 hyrax.data_sets.hsc_data_set:INFO] Checking file dimensions to determine standard cutout size...
[2025-05-06 15:34:17,391 hyrax.data_sets.fits_image_dataset:INFO] FitsImageDataSet has 993 objects
[2025-05-06 15:34:17,403 hyrax.data_sets.hsc_data_set:INFO] Processed 993 objects for pruning
[2025-05-06 15:34:17,404 hyrax.verbs.umap:INFO] Saving UMAP results to /Users/drew/code/hyrax/docs/pre_executed/results/20250506-153417-umap-U4p6
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
0.00s - make the debugger miss breakpoints. Plea

Creating lower dimensional representation using UMAP::   0%|          | 0/32 [00:00<?, ?it/s]

[2025-05-06 15:34:24,532 hyrax.verbs.umap:INFO] Finished transforming all data through UMAP
[2025-05-06 15:34:24,579 hyrax.data_sets.hsc_data_set:INFO] Checking file dimensions to determine standard cutout size...
[2025-05-06 15:34:24,581 hyrax.data_sets.fits_image_dataset:INFO] FitsImageDataSet has 993 objects
[2025-05-06 15:34:24,593 hyrax.data_sets.hsc_data_set:INFO] Processed 993 objects for pruning


<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x16c509130>

Here we are using the output of the latest inference run to inform a UMAP fitter and then plot the resulting lower dimensional space.
With only about 1000 samples, it's hard to visually identify obvious groupings in the 2D space.

While fitting and transforming is made fairly straightforward with the `umap-learn` API, additional data plumbing is provided by FIBAD.
We ensure that file output of the inference step allows for efficient fitting and transformation with UMAP and all of the data plumbing to read from inference and write results is taken care of.

To support further exploration of the embedding space, FIBAD includes an early implementation of an interactive visualization tool.

## Interactive visualization

In [19]:
h.config["visualize"]["fields"] = ["ra", "dec"]
viz = h.visualize(width=1000, height=1000)

[2025-05-06 15:34:54,864 hyrax.data_sets.inference_dataset:INFO] Using most recent results dir /Users/drew/code/hyrax/docs/pre_executed/results/20250506-153417-umap-U4p6 for lookup. Use the [results] inference_dir config to set a directory or pass it to this verb.
[2025-05-06 15:34:54,910 hyrax.data_sets.hsc_data_set:INFO] Checking file dimensions to determine standard cutout size...
[2025-05-06 15:34:54,912 hyrax.data_sets.fits_image_dataset:INFO] FitsImageDataSet has 993 objects
[2025-05-06 15:34:54,924 hyrax.data_sets.hsc_data_set:INFO] Processed 993 objects for pruning


The Hyrax visualization tooling utilizes Holoviews, Datashader as well as an efficient tree structure to be able to display millions of points.
It allows for panning, zooming as well as lasso and box selections.
When selecting points, the resulting object ids are displayed are displayed in the associated table.

While this is an early version of interactive visualization, it has been scaled up to millions of data points.
The next steps for this tooling will be to support deeper interactivity, namely:
-  Automatically displaying the object selected in the table
-  Leveraging the vector db to identify similar objects
-  Supporting three dimensional UMAP output


This visualization runs in a notebook but when rendered to HTML (for demonstration or documentation) the server backing the interactive visual isn't packaged with the rendering. If the cell above was run locally, the resulting UI would look similar to the following screen shot.

![umap_visualization.JPG](attachment:umap_visualization.JPG)


## Create a vector database
By calling the `index` verb, we can populate a vector database with the results of inference. This vector database can be used for efficient similarity or nearest neighbor searches.

In [9]:
h.index()

[2025-05-06 15:29:39,470 hyrax.data_sets.hsc_data_set:INFO] Checking file dimensions to determine standard cutout size...
[2025-05-06 15:29:39,472 hyrax.data_sets.fits_image_dataset:INFO] FitsImageDataSet has 993 objects
[2025-05-06 15:29:39,484 hyrax.data_sets.hsc_data_set:INFO] Processed 993 objects for pruning
[2025-05-06 15:29:39,544 hyrax.verbs.vdb_index:INFO] Number of inference result batches to index: 2.
100%|██████████| 2/2 [00:00<00:00, 28.52it/s]


## Get an instance of the dataset

In [20]:
hsc_dataset = h.prepare()

[2025-05-06 15:35:02,379 hyrax.data_sets.hsc_data_set:INFO] Checking file dimensions to determine standard cutout size...
[2025-05-06 15:35:02,381 hyrax.data_sets.fits_image_dataset:INFO] FitsImageDataSet has 993 objects
[2025-05-06 15:35:02,392 hyrax.data_sets.hsc_data_set:INFO] Processed 993 objects for pruning
[2025-05-06 15:35:02,393 hyrax.data_sets.fits_image_dataset:INFO] Preloading FitsImageDataSet cache...
[2025-05-06 15:35:02,393 hyrax.prepare:INFO] Finished Prepare


[2025-05-06 15:35:05,557 hyrax.data_sets.fits_image_dataset:INFO] Processed 992 objects
