TODO: Add some `pooch` work here to pull down the example files and stash them in the users local cache.
TODO: Massage the data so that it's in the correct locations. My hunch here is that we can place all the files together in
the same directory. And we'll need to concatenate the parameters.fits files together.
The data files are sorted alphabetically, so we need to be careful about the way that we concatenate the parameters.fits
so that the ordering is maintained.

In [1]:
from hyrax import Hyrax

h = Hyrax()

[2025-08-29 13:34:36,744 hyrax:INFO] Runtime Config read from: /Users/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


In [2]:
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
from torch import from_numpy
from hyrax.data_sets import HyraxDataset
from astropy.io import fits
from astropy.table import Table
from collections.abc import Generator


"""
Note - `index` this class refers to the row index in the parameters.fits file.
"""


class SLDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict):
        super().__init__(config)
        self.data_directory = Path(config["general"]["data_dir"])
        self.metadata = self.read_metadata(self.data_directory)
        self.filepaths = self.read_filepaths(self.data_directory)

    def get_image(self, index: int):
        image_stack = np.zeros((5, 41, 41), dtype=np.float32)
        for b in range(5):
            file_index = index * 5 + b
            raw_data = fits.getdata(self.filepaths[file_index], memmap=False)
            image_stack[b] = raw_data[0][1]

        return from_numpy(image_stack)

    def get_label(self, index: int) -> np.ndarray[int]:
        file_index = index * 5
        ret_value = np.array([0.0, 1.0], dtype=np.float32)
        if "_L_" in str(self.filepaths[file_index]):
            ret_value = np.array([1.0, 0.0], dtype=np.float32)
        return from_numpy(ret_value)

    def get_object_id(self, index: int) -> str:
        id = "no_id"
        if "Lens ID" in self.metadata.columns and str(self.metadata[index]["Lens ID"]) != "--":
            id = str(self.metadata[index]["Lens ID"])
        elif "Object ID" in self.metadata.columns and str(self.metadata[index]["Object ID"]) != "--":
            id = str(self.metadata[index]["Object ID"])
        return id

    def get_filename(self, index: int) -> str:
        file_index = index * 5
        return str(self.filepaths[file_index].name)

    def read_filepaths(self, data_directory: Path):
        return sorted(list(data_directory.glob("*.fits")))

    def read_metadata(self, data_directory: Path):
        table = Table.read(data_directory / "parameters.fits")
        table["object_id"] = np.where(table["Lens ID"].mask, table["Object ID"], table["Lens ID"])
        return table

    def metadata_fields(self):
        return self.metadata.columns

    def ids(self) -> Generator[str]:
        for x in range(len(self)):
            yield str(self.get_object_id(x))

    def __len__(self):
        pattern = f"*_*.fits"
        files = list(self.data_directory.glob(pattern))
        return len(files) // 5

    def __getitem__(self, index: int):
        return {
            "image": self.get_image(index),
            "label": self.get_label(index),  # [1, 0] == lens, [0, 1] == non-lens
            "object_id": self.get_object_id(index),
        }

In [3]:
h.config["general"]["data_dir"] = "/Users/drew/sl_data_challenge/sl_100/hsc_combined"
# h.config["general"]["data_dir"] = "/home/drew/data/sl_100/hsc_combined"
h.config["data_set"]["name"] = "SLDataset"

ds = h.prepare()

[2025-08-29 13:34:39,613 hyrax.prepare:INFO] Finished Prepare


In [4]:
samp = ds[5]
print(f"ID: {samp['object_id']}")
print(f"Is lens? {samp['label']}")
print(f"Data shape: {samp['image'].shape}")


samp = ds[105]
print(f"ID: {samp['object_id']}")
print(f"Is lens? {samp['label']}")
print(f"Data shape: {samp['image'].shape}")

ID: 70360665344203068
Is lens? tensor([1., 0.])
Data shape: torch.Size([5, 41, 41])
ID: 41623412828827036
Is lens? tensor([0., 1.])
Data shape: torch.Size([5, 41, 41])


In [5]:
h.config["model"]["name"] = "HyraxCNN"
h.config["model"]["hyrax_cnn"]["output_classes"] = 2
h.config["train"]["epochs"] = 30
h.config["data_loader"]["batch_size"] = 10

In [6]:
h.train()

[2025-08-29 13:34:40,376 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-08-29 13:34:40,377 hyrax.models.model_registry:INFO] Using optimizer: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
2025-08-29 13:34:40,390 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.SLDataset': 
	{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x15ce7dd60>, 'batch_size': 10, 'shuffle': False, 'pin_memory': False}
2025-08-29 13:34:40,391 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.SLDataset': 
	{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x30ef53590>, 'batch_size': 10, 'shuffle': False, 'pin_memory': False}
  from tqdm.autonotebook import tqdm
2025/08/29 13:34:40 INFO mlflow.system_metrics.system_metrics_monitor: Skip logging GPU metrics. Set logger level to DEBUG for more details.
2025/08/29 13:34:40

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

  8%|8         | 1/12 [00:00<?, ?it/s]

[2025-08-29 13:35:08,129 hyrax.pytorch_ignite:INFO] Total training time: 27.62[s]
[2025-08-29 13:35:08,129 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /Users/drew/code/hyrax/docs/pre_executed/results/20250829-133440-train-LAPy/checkpoint_epoch_30.pt
[2025-08-29 13:35:08,130 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /Users/drew/code/hyrax/docs/pre_executed/results/20250829-133440-train-LAPy/checkpoint_30_loss=-0.1448.pt
2025/08/29 13:35:08 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/08/29 13:35:08 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-08-29 13:35:08,141 hyrax.verbs.train:INFO] Finished Training
[2025-08-29 13:35:08,247 hyrax.model_exporters:INFO] Exported model to ONNX format: /Users/drew/code/hyrax/docs/pre_executed/results/20250829-133440-train-LAPy/example_model_opset_20.onnx


HyraxCNN(
  (conv1): Conv2d(5, 16, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(16, 64, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=3136, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
  (criterion): CrossEntropyLoss()
)

In [7]:
h.infer()

[2025-08-29 13:35:10,269 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-08-29 13:35:10,270 hyrax.models.model_registry:INFO] Using optimizer: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2025-08-29 13:35:10,273 hyrax.verbs.infer:INFO] data set has length 200
2025-08-29 13:35:10,279 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.SLDataset': 
	{'sampler': None, 'batch_size': 10, 'shuffle': False, 'pin_memory': False}
[2025-08-29 13:35:10,311 hyrax.verbs.infer:INFO] Saving inference results at: /Users/drew/code/hyrax/docs/pre_executed/results/20250829-133510-infer-o919
[2025-08-29 13:35:10,438 hyrax.pytorch_ignite:INFO] Evaluating model on device: mps
[2025-08-29 13:35:10,439 hyrax.pytorch_ignite:INFO] Total epochs: 1


Object IDs provided: ['75339120850791805', '74643937444253184', '44223324036811477', '69581790204943781', '69599657268895544', '70360665344203068', '37489781684342465', '69599103218113191', '69612439091552763', '69573677011698420']


  5%|5         | 1/20 [00:00<?, ?it/s]

Object IDs provided: ['69577400748364220', '76552951623077653', '74643928854324513', '69622059818316703', '70387049328282573', '70386783040321743', '69572856672967091', '70347471204663526', '69609007412698822', '75954546714644652']
Object IDs provided: ['70391696482918645', '69573140140808789', '70351431164520485', '69626599598751766', '70413553571486247', '69581648471021172', '69595546985187065', '69612980257449824', '37489644245381543', '37484705032990087']
Object IDs provided: ['70364758448036762', '37489115964407754', '69590453153975278', '69586454539427150', '44218243090502299', '69572860967930922', '42697201897461701', '69585913373548053', '70364779922872457', '42089601464029668']
Object IDs provided: ['75954417865623042', '69608315922964236', '69582060787879216', '70356263002727004', '69608461951854113', '70413420427502463', '41619693387150398', '69590320009988197', '70373438576938421', '69568737799330651']
Object IDs provided: ['69586196841390155', '44223324036819347', '6960873

[2025-08-29 13:35:11,562 hyrax.pytorch_ignite:INFO] Total evaluation time: 1.12[s]


Object IDs provided: ['45841255397072618', '42032968025251216', '75334306192459513', '42692679296903886', '41623421418757619', '41623949699747469', '41206139576145641', '42714660939523783', '41619401329368928', '75339511692818313']
Object IDs provided: ['41623554562744681', '75954825887510208', '43158464205195928', '37485387932780723', '42692249800174635', '38548886259791467', '41637281278220554', '37485533961667633', '41619143631342981', '41025561971152507']
Object IDs provided: ['75334297602524100', '43153911539866729', '43159151399957508', '37484700738029940', '45846082940332471', '41628369221081083', '43158863637138464', '42635358663376466', '38553834062110268', '41619697682119738']


[2025-08-29 13:35:11,591 hyrax.verbs.infer:INFO] Inference Complete.


<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x310140290>

In [15]:
from hyrax.data_sets import InferenceDataSet
import torch

infer_dir = "/Users/drew/code/hyrax/docs/pre_executed/results/20250829-133510-infer-o919"
infer_ds = InferenceDataSet(h.config, infer_dir)

lens_count = 0
nonlens_count = 0
true_pos = 0
true_neg = 0
false_pos = 0
false_neg = 0
for indx, id in enumerate(infer_ds.ids()):
    _, predicted = torch.max(infer_ds[indx], 0)
    if predicted == 0:
        lens_count += 1
    else:
        nonlens_count += 1

    orig_data = ds[indx]
    label = orig_data['label']
    if label[0] == 1.0 and predicted == 0:
        true_pos += 1
    elif label[1] == 1.0 and predicted == 0:
        false_pos += 1
    elif label[0] == 0.0 and predicted == 1:
        true_neg += 1
    elif label[0] == 1.0 and predicted == 1:
        false_neg += 1

print(f"Total Lens: {lens_count}, Total Non-Lens: {nonlens_count}")
print(f"True Positives: {true_pos}, True Negatives: {true_neg}, False Positives: {false_pos}, False Negatives: {false_neg}")

Total Lens: 81, Total Non-Lens: 119
True Positives: 21, True Negatives: 40, False Positives: 60, False Negatives: 79


In [None]:
h.config["model"]["name"] = "HyraxAutoencoder"
h.train()

In [None]:
h.infer()

In [None]:
from hyrax.data_sets import InferenceDataSet

infer_dir = "/Users/drew/code/hyrax/docs/pre_executed/results/20250828-133458-infer-S_IB"
infer_ds = InferenceDataSet(h.config, infer_dir)
infer_ds[1]

In [None]:
import numpy as np

d = np.load(f"{infer_dir}/batch_2.npy")
d

In [None]:
batch_index = np.load(f"{infer_dir}/batch_index.npy")
batch_index[1]

In [None]:
h.umap()

In [None]:
viz = h.visualize()