In [1]:
from torch.utils.data import DataLoader
import numpy as np

from cata2data import CataData
from cata2data.preprocessing import (
    image_preprocessing,
    wcs_preprocessing,
    catalogue_preprocessing,
)

# Data available at: https://archive-gw-1.kat.ac.za/public/repository/10.48479/emmd-kf31/index.html
catalogue_paths = [
    "data/MIGHTEE_Continuum_Early_Science_COSMOS_Level1.fits",
    "data/MIGHTEE_Continuum_Early_Science_XMMLSS_Level1.fits",
]
image_paths = [
    "data/MIGHTEE_Continuum_Early_Science_COSMOS_r-1p2.app.restored.circ.fits",
    "data/MIGHTEE_Continuum_Early_Science_XMMLSS_r-1p2_circ.hires.fits",
]
field_names = ["COSMOS", "XMMLSS"]

### Create Data Set ###
mightee_data = CataData(
    catalogue_paths=catalogue_paths,
    image_paths=image_paths,
    field_names=field_names,
    cutout_width=70,
    catalogue_preprocessing=catalogue_preprocessing,
    spectral_axis=False,
)

print(len(mightee_data))
print(mightee_data[0].shape)



1712
(1, 70, 70)


In [2]:
class ClippedCataData:
    def __init__(self, catadata, sigma):
        self.sigma = sigma
        self.catadata = catadata

    def __getitem__(self, index):
        rms = self.catadata.df.loc[index, "ISL_RMS"]
        return np.where(
            self.catadata[index] > self.sigma * rms, self.catadata[index], 0
        )

    def __len__(self):
        return self.catadata.__len__()


train_dataloader = DataLoader(
    ClippedCataData(mightee_data, 3), batch_size=64, shuffle=True
)

for data in train_dataloader:
    print(data.shape)
    break

torch.Size([64, 1, 70, 70])
