TODO: Add some `pooch` work here to pull down the example files and stash them in the users local cache.
TODO: Massage the data so that it's in the correct locations. My hunch here is that we can place all the files together in
the same directory. And we'll need to concatenate the parameters.fits files together.
The data files are sorted alphabetically, so we need to be careful about the way that we concatenate the parameters.fits
so that the ordering is maintained.

In [1]:
from hyrax import Hyrax
h = Hyrax()

[2025-08-15 16:09:09,210 hyrax:INFO] Runtime Config read from: /home/drew/code/hyrax/src/hyrax/hyrax_default_config.toml


In [2]:
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
from torch import from_numpy
from hyrax.data_sets import HyraxDataset
from astropy.io import fits
from astropy.table import Table


"""
Note - `index` this class refers to the row index in the parameters.fits file.
"""
class SLDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict):
        super().__init__(config)
        self.metadata = self.read_metadata(Path(config['general']['data_dir']))
        self.filepaths = self.read_filepaths(Path(config['general']['data_dir']))

    def get_image(self, index: int):
        image_stack = np.zeros((5, 41, 41), dtype=np.float32)
        for b in range(5):
            file_index = index * 5 + b
            raw_data = fits.getdata(self.filepaths[file_index], memmap=False)
            image_stack[b] = raw_data[0][1]

        return from_numpy(image_stack)

    def get_label(self, index: int):
        file_index = index * 5
        ret_value = np.array([0., 1.], dtype=np.float32)
        if "_L_" in str(self.filepaths[file_index]):
            ret_value = np.array([1., 0.], dtype=np.float32)
        return from_numpy(ret_value)

    def get_object_id(self, index: int) -> str:
        return str(self.metadata[index]['Lens ID'])

    def get_filename(self, index: int) -> str:
        file_index = index * 5
        return str(self.filepaths[file_index].name)

    def read_filepaths(self, data_directory: Path):
        return sorted(list(data_directory.glob("*.fits")))

    def read_metadata(self, data_directory: Path):
        return Table.read(data_directory / "parameters.fits")

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, index: int):
        return {
            "image": self.get_image(index),
            "label": self.get_label(index),  # True == lens, False == non-lens
            "object_id": self.get_object_id(index),
        }

In [3]:
# h.config["general"]["data_dir"] = "/Users/drew/sl_data_challenge/hsc_lenses/hsc_lenses"
h.config["general"]["data_dir"] = "/home/drew/data/hsc_lenses"
h.config["data_set"]["name"] = "SLDataset"

ds = h.prepare()

[2025-08-15 16:09:18,059 hyrax.prepare:INFO] Finished Prepare


In [4]:
samp = ds[8]
print(f"Lens ID: {samp['object_id']}")
print(f"Is lens? {samp['label']}")
print(f"Data shape: {samp['image'].shape}")


Lens ID: 69612439091552763
Is lens? tensor([1., 0.])
Data shape: torch.Size([5, 41, 41])


In [5]:
ds.metadata

Lens ID,ra,dec,zlens,mag_lens_g,mag_lens_r,mag_lens_i,mag_lens_z,mag_lens_y,ell_l,ell_l_PA,Rein,vel disp,sh,sh_PA,srcx,srcy,mag_src_g,mag_src_r,mag_src_i,mag_src_z,mag_src_y,zsrc,ell_s,ell_s_PA,Reff_s,n_s_sers,ell_m,ell_m_PA,Reff_l,n_l_sers
int64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,float64,int64,float64,float64,int64,int64
75339120850791805,242.44554906892057,55.00850181745938,0.5490264,21.914482,20.72183,19.843065,19.447388,19.226671,0.3187112200729652,140.55671939947152,1.492563900239972,322.541908350292,0.0184277930213721,165.10540757089433,-0.539412,0.54139,23.78413796420183,23.579951964201825,23.324303964201828,22.806399964201827,22.60081396420183,1.3204864688209512,0.7913979949902453,176.15020530130457,0.13909055536025083,1,0.3187112200729652,140.55671939947152,-999,-999
74643937444253184,213.32224600390697,52.73905141952488,0.58255017,22.761929,21.156885,20.12777,19.667177,19.414434,0.2316346746128754,144.15667879680944,0.9343815570008128,260.02318247470157,0.0011520234264059,1.4402219343718017,0.052176,0.395926,24.99161815939484,25.10994015939483,24.791514159394836,24.28267315939484,23.97795115939483,1.3505307973365617,0.3285413157429651,22.332981226490716,0.06689373056323532,1,0.2316346746128754,144.15667879680944,-999,-999
44223324036811477,149.695422904841,3.0730654944116496,0.91960937,24.011171,22.424341,21.162437,20.38196,20.159344,0.2301652369252709,67.07979948362424,0.7758354000912594,243.9328930590481,0.0172741275739362,154.1759454372911,0.395498,-0.045234,26.032162820101647,25.953453820101643,25.973011820101647,26.45687782010165,26.22992982010165,2.2965129206836985,0.2985856052945675,170.38760097201072,0.044008539128840195,1,0.2301652369252709,67.07979948362424,-999,-999
69581790204943781,225.25899793258637,42.60838770751077,0.6943516,22.228773,21.110435,19.887783,19.586441,19.200443,0.1503831414427968,59.239639601732165,1.7368381527311745,344.2016628284946,0.010387129432469,88.93069988654877,0.518148,0.152541,24.796290563415667,24.583549563415666,24.279836563415667,24.122608563415668,23.662105563415665,1.866385137591267,0.6388043512112632,173.92612524227113,0.08628889694184529,1,0.1503831414427968,59.239639601732165,-999,-999
69599657268895544,232.7942998977284,42.61872055220366,0.6506699,21.36213,20.065937,18.882666,18.434261,18.243608,0.2546410645234006,133.95651731405908,2.784847879054728,531.7224047948305,0.006371002448201,50.88318108822063,1.437089,-0.164365,25.585802316961686,25.233050316961684,24.76860531696169,24.37352831696169,24.31412331696168,1.105997310010219,0.1535637800279878,108.34787941706271,0.04272343866272886,1,0.2546410645234006,133.95651731405908,-999,-999
70360665344203068,221.4788543885395,44.50887152198978,1.026881,23.476093,22.430262,21.323118,20.856005,20.48569,0.2945968527250383,172.19886093545847,0.6742422113459928,215.80151247477423,0.0167371210406664,149.0885151221035,0.209969,0.061814,25.84347287523804,24.939557875238044,24.639731875238045,24.50938187523804,25.34569887523805,3.3195153364231267,0.5975014895934048,16.5712786904554,0.05740391267338577,1,0.2945968527250383,172.19886093545847,-999,-999
37489781684342465,35.3481868280165,-5.492482231992536,0.3464315,21.589935,20.06471,19.474009,19.127518,18.951971,0.1400338176310576,168.82189414901197,0.8658283050566086,202.88765754985496,0.0018885317491715,8.417669202677438,-0.246845,-0.000831,24.144554769116095,23.9362737691161,23.642845769116096,23.520744769116092,23.47341976911609,1.82255049695534,0.3996195413434265,168.6193028775504,0.12748936904367097,1,0.1400338176310576,168.82189414901197,-999,-999
69599103218113191,233.6023906798697,42.439942151091486,0.6226637,21.899918,20.498674,19.438536,19.044287,18.759584,0.2560329690352266,102.82577618124859,2.2445275574922223,376.19825165990056,0.0082532992877232,68.71546693632591,-0.664482,-0.002237,24.166408729221004,24.089107729221,23.930172729221,23.919903729221005,23.836158729221,1.857341484091474,0.265016818289844,42.8739456343977,0.12677965425986232,1,0.2560329690352266,102.82577618124859,-999,-999
69612439091552763,239.50216948252773,42.640620416993094,0.65890974,22.429792,21.444733,20.390316,20.011965,19.65237,0.2961839933917823,42.82083576937197,1.710228272028671,331.9272970515143,0.00638734473078,51.03800271265267,1.183007,0.771661,24.73634594685197,24.82052894685196,24.546888946851965,24.374273946851968,24.55785594685197,1.928261831064603,0.1710483449790056,77.43194765540096,0.09074100270785052,1,0.2961839933917823,42.82083576937197,-999,-999
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...


In [6]:
h.config["model"]["name"] = "HyraxCNN"
h.config["model"]["hyrax_cnn"]["output_classes"] = 2

In [None]:
h.train()

[2025-08-15 15:02:21,095 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-08-15 15:02:21,096 hyrax.models.model_registry:INFO] Using optimizer: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
2025-08-15 15:02:21,100 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.SLDataset': 
	{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7fcdf6be3740>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025-08-15 15:02:21,101 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.SLDataset': 
	{'sampler': <hyrax.pytorch_ignite.SubsetSequentialSampler object at 0x7fcd995d7f50>, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
2025/08/15 15:02:21 INFO mlflow.system_metrics.system_metrics_monitor: Started monitoring system metrics.
[2025-08-15 15:02:21,134 hyrax.pytorch_ignite:INFO] Training model on device: cuda


  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

  2%|1         | 1/59 [00:00<?, ?it/s]

[2025-08-15 15:57:18,790 hyrax.pytorch_ignite:INFO] Total training time: 3297.66[s]
[2025-08-15 15:57:18,791 hyrax.pytorch_ignite:INFO] Latest checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250815-150216-train-NUIO/checkpoint_epoch_10.pt
[2025-08-15 15:57:18,792 hyrax.pytorch_ignite:INFO] Best metric checkpoint saved as: /home/drew/code/hyrax/docs/pre_executed/results/20250815-150216-train-NUIO/checkpoint_10_loss=-0.0000.pt
2025/08/15 15:57:18 INFO mlflow.system_metrics.system_metrics_monitor: Stopping system metrics monitoring...
2025/08/15 15:57:18 INFO mlflow.system_metrics.system_metrics_monitor: Successfully terminated system metrics monitoring!
[2025-08-15 15:57:18,809 hyrax.verbs.train:INFO] Finished Training
[2025-08-15 15:57:23,246 hyrax.model_exporters:INFO] Exported model to ONNX format: /home/drew/code/hyrax/docs/pre_executed/results/20250815-150216-train-NUIO/example_model_opset_20.onnx


HyraxCNN(
  (conv1): Conv2d(5, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=784, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=2, bias=True)
  (criterion): CrossEntropyLoss()
)

In [7]:
h.infer()

[2025-08-15 16:09:30,287 hyrax.models.model_registry:INFO] Using criterion: torch.nn.CrossEntropyLoss with default arguments.
[2025-08-15 16:09:30,287 hyrax.models.model_registry:INFO] Using optimizer: torch.optim.SGD with arguments: {'lr': 0.01, 'momentum': 0.9}.
[2025-08-15 16:09:30,288 hyrax.verbs.infer:INFO] data set has length 50000
2025-08-15 16:09:30,358 ignite.distributed.auto.auto_dataloader INFO: Use data loader kwargs for dataset '<__main__.SLDataset': 
	{'sampler': None, 'batch_size': 512, 'shuffle': False, 'pin_memory': True}
[2025-08-15 16:09:30,651 hyrax.verbs.infer:INFO] Saving inference results at: /home/drew/code/hyrax/docs/pre_executed/results/20250815-160925-infer-B1Op
  from tqdm.autonotebook import tqdm
[2025-08-15 16:09:31,026 hyrax.pytorch_ignite:INFO] Evaluating model on device: cuda
[2025-08-15 16:09:31,028 hyrax.pytorch_ignite:INFO] Total epochs: 1


  1%|1         | 1/98 [00:00<?, ?it/s]

[2025-08-15 16:16:53,657 hyrax.pytorch_ignite:INFO] Total evaluation time: 442.63[s]
[2025-08-15 16:16:53,879 hyrax.verbs.infer:INFO] Inference Complete.


<hyrax.data_sets.inference_dataset.InferenceDataSet at 0x789946f85970>

In [14]:
from hyrax.data_sets.inference_dataset import InferenceDataSet
infer_ds = InferenceDataSet(h.config, results_dir = "/home/drew/code/hyrax/docs/pre_executed/results/20250815-160925-infer-B1Op" )

In [27]:
from torch.nn import Softmax
s = Softmax(dim=0)
for i in range(25):
    print(infer_ds[i])

tensor([ 23.0835, -26.7007], dtype=torch.float64)
tensor([ 34.5384, -39.9656], dtype=torch.float64)


ValueError: shape mismatch: value array of shape (2,2) could not be broadcast to indexing result of shape (1,2)