TODO: Add some `pooch` work here to pull down the example files and stash them in the users local cache.
TODO: Massage the data so that it's in the correct locations. My hunch here is that we can place all the files together in
the same directory. And we'll need to concatenate the parameters.fits files together.
The data files are sorted alphabetically, so we need to be careful about the way that we concatenate the parameters.fits
so that the ordering is maintained.

In [None]:
from hyrax import Hyrax

h = Hyrax()

In [None]:
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
from torch import from_numpy
from hyrax.data_sets import HyraxDataset
from astropy.io import fits
from astropy.table import Table


"""
Note - `index` this class refers to the row index in the parameters.fits file.
"""


class SLDataset(HyraxDataset, Dataset):
    def __init__(self, config: dict):
        super().__init__(config)
        self.data_directory = Path(config["general"]["data_dir"])
        self.metadata = self.read_metadata(self.data_directory)
        self.filepaths = self.read_filepaths(self.data_directory)

    def get_image(self, index: int):
        image_stack = np.zeros((5, 41, 41), dtype=np.float32)
        for b in range(5):
            file_index = index * 5 + b
            raw_data = fits.getdata(self.filepaths[file_index], memmap=False)
            image_stack[b] = raw_data[0][1]

        return from_numpy(image_stack)

    def get_label(self, index: int):
        file_index = index * 5
        ret_value = np.array([0.0, 1.0], dtype=np.float32)
        if "_L_" in str(self.filepaths[file_index]):
            ret_value = np.array([1.0, 0.0], dtype=np.float32)
        return from_numpy(ret_value)

    def get_object_id(self, index: int) -> str:
        id = "no_id"
        if "Lens ID" in self.metadata.columns and str(self.metadata[index]["Lens ID"]) != "--":
            id = str(self.metadata[index]["Lens ID"])
        elif "Object ID" in self.metadata.columns and str(self.metadata[index]["Object ID"]) != "--":
            id = str(self.metadata[index]["Object ID"])
        return id

    def get_filename(self, index: int) -> str:
        file_index = index * 5
        return str(self.filepaths[file_index].name)

    def read_filepaths(self, data_directory: Path):
        return sorted(list(data_directory.glob("*.fits")))

    def read_metadata(self, data_directory: Path):
        table = Table.read(data_directory / "parameters.fits")
        table["object_id"] = np.where(table["Lens ID"].mask, table["Object ID"], table["Lens ID"])
        return table

    def metadata_fields(self):
        return self.metadata.columns

    def __len__(self):
        pattern = f"*_*.fits"
        files = list(self.data_directory.glob(pattern))
        return len(files) // 5

    def __getitem__(self, index: int):
        return {
            "image": self.get_image(index),
            "label": self.get_label(index),  # [1, 0] == lens, [0, 1] == non-lens
            "object_id": self.get_object_id(index),
        }

In [None]:
# h.config["general"]["data_dir"] = "/Users/drew/sl_data_challenge/hsc_lenses/hsc_lenses"
h.config["general"]["data_dir"] = "/home/drew/data/sl_100/hsc_combined"
h.config["data_set"]["name"] = "SLDataset"

ds = h.prepare()

In [None]:
samp = ds[5]
print(f"ID: {samp['object_id']}")
print(f"Is lens? {samp['label']}")
print(f"Data shape: {samp['image'].shape}")


samp = ds[105]
print(f"ID: {samp['object_id']}")
print(f"Is lens? {samp['label']}")
print(f"Data shape: {samp['image'].shape}")

In [None]:
h.config["model"]["name"] = "HyraxCNN"
h.config["model"]["hyrax_cnn"]["output_classes"] = 2
h.config["train"]["epochs"] = 10
h.config["data_loader"]["batch_size"] = 10

In [None]:
h.train()

In [None]:
h.infer()

In [None]:
h.config["model"]["name"] = "HyraxAutoencoder"
h.train()

In [None]:
h.infer()

In [None]:
from hyrax.data_sets import InferenceDataSet

infer_dir = "/home/drew/code/hyrax/docs/pre_executed/results/20250827-171048-infer-bxSr"
infer_ds = InferenceDataSet(h.config, infer_dir)
infer_ds[1]

In [None]:
import numpy as np

d = np.load(f"{infer_dir}/batch_2.npy")
d

In [None]:
batch_index = np.load(f"{infer_dir}/batch_index.npy")
batch_index[1]

In [None]:
h.umap()

In [None]:
viz = h.visualize()