In [1]:
import galaxy_datasets

In [None]:
#Modified from https://github.com/mwalmsley/galaxy-datasets/blob/main/galaxy_datasets/pytorch/galaxy_dataset.py
from torch.utils.data import Dataset
from typing import List
class GalaxyDataset(Dataset):
    def __init__(self, catalog: pd.DataFrame, label_cols=None, transform=None, target_transform=None):
        self.catalog = catalog
        self.label_cols = label_cols
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self) -> int:
        return len(self.catalog)

    def __getitem__(self, idx: int):
        #the index is id_str so can use that for quick search on 1M+ catalo
        galaxy = self.catalog.dataset.iloc[idx]

        # load the image into memory
        image_loc = galaxy['file_loc']
        image = self.load_jpg_file(image_loc)

    
        if self.transform:
            image = self.transform(image)

        if self.label_cols is None:
            return image
        else:
            label = self.get_galaxy_label(galaxy, self.label_cols)
            if self.target_transform:
                label = self.target_transform(label)            
            return image, label

    def load_jpg_file(self,loc):
        im = Image.open(loc, mode='r') # HWC
        im.load()  # avoid lazy open
        return im

    def get_galaxy_label(self, galaxy: pd.Series, label_cols: List) -> np.ndarray:
        return galaxy[label_cols].astype(np.float32).values.squeeze()  # squeeze for if there's one label_col

In [None]:
#Modified from https://github.com/mwalmsley/galaxy-datasets/blob/main/galaxy_datasets/pytorch/galaxy_datamodule.py
from typing import Optional
from functools import partial
class GalaxyDataModule(pl.LightningDataModule):
    # takes generic catalogs (which are already downloaded and happy),
    # splits if needed, and creates generic datasets->dataloaders etc
    # easy to make dataset-specific default transforms if desired
    def __init__(
        self,
        label_cols,
        catalog=None,
        train_fraction=0.7,
        val_fraction=0.1,
        test_fraction=0.2,
        crop_scale_bounds=(0.7, 0.8),
        crop_ratio_bounds=(0.9, 1.1),
        resize_after_crop=224,
        batch_size=256,  # careful - will affect final performance
        num_workers=4,
        prefetch_factor=4,
        seed=42
    ):
        super().__init__()

        self.label_cols = label_cols
        self.catalog = catalog
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.seed = seed

        assert np.isclose(train_fraction + val_fraction + test_fraction, 1.)
        self.train_fraction = train_fraction
        self.val_fraction = val_fraction
        self.test_fraction = test_fraction

        self.prefetch_factor = prefetch_factor
        self.dataloader_timeout = 600  # seconds aka 10 mins)

        self.resize_after_crop = resize_after_crop
        self.crop_scale_bounds = crop_scale_bounds
        self.crop_ratio_bounds = crop_ratio_bounds

    def default_torchvision_transforms(self):
        # assume input is 0-255 uint8 tensor

        # automatically normalises from 0-255 int to 0-1 float
        transforms_to_apply = [transforms.ToTensor()]  # dataset gives PIL image currently
        transforms_to_apply += [
            transforms.RandomResizedCrop(
                size=self.resize_after_crop,  # assumed square
                scale=self.crop_scale_bounds,  # crop factor
                ratio=self.crop_ratio_bounds,  # crop aspect ratio
                interpolation=transforms.InterpolationMode.BILINEAR),  # new aspect ratio
            #transforms.RandomHorizontalFlip(),
            #transforms.RandomRotation(
            #    degrees=180., interpolation=transforms.InterpolationMode.BILINEAR)
        ]

        return transforms_to_apply

    # only called on main process
    def prepare_data(self):
        transforms_to_apply = self.default_torchvision_transforms()
        self.transform = partial(self.do_transform, transforms_to_apply=transforms_to_apply)
       # pass   # could include some basic checks

    # called on every gpu

    def setup(self, stage: Optional[str] = None):

        self.specify_catalogs(stage)

        # Assign train/val datasets for use in dataloaders
        # assumes dataset_class has these standard args
        if stage == "fit" or stage is None:
            self.train_dataset = GalaxyDataset(
                catalog=self.train_catalog, label_cols=self.label_cols, transform=self.transform
            )
            self.val_dataset = GalaxyDataset(
                catalog=self.val_catalog, label_cols=self.label_cols, transform=self.transform
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.test_dataset = GalaxyDataset(
                catalog=self.test_catalog, label_cols=self.label_cols, transform=self.transform
            )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, prefetch_factor=self.prefetch_factor, timeout=self.dataloader_timeout)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, prefetch_factor=self.prefetch_factor, timeout=self.dataloader_timeout)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, prefetch_factor=self.prefetch_factor, timeout=self.dataloader_timeout)

    def predict_dataloader(self):
        return DataLoader(self.predict_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers, pin_memory=True, persistent_workers=self.num_workers > 0, prefetch_factor=self.prefetch_factor, timeout=self.dataloader_timeout)

    def specify_catalogs(self, stage):
        if self.catalog is not None:
            # will split the catalog into train, val, test here
            self.train_catalog, self.val_catalog, self.test_catalog = data_split(catalog,self.train_fraction,self.val_fraction,self.test_fraction)
            
    def do_transform(img, transforms_to_apply):
        return np.transpose(transforms_to_apply(image=np.array(img))["image"], axes=[2, 0, 1]).astype(np.float32)