In [None]:
from dataloader import RSData
import lightning as L

In [None]:
from torch.utils.data import DataLoader
from os import PathLike
from pathlib import Path

from dataloader import AugmentorChain


class RSDataModule(L.LightningDataModule):
    def __init__(
            self,
            ds_path: str | PathLike,
            train_area_ids: list[int],
            valid_area_ids: list[int],
            test_area_ids: list[int],
            cutout_size: int,
            batch_size: int,
            augmentor_chain: AugmentorChain | None = None):

        self.ds_path = Path(ds_path)
        self.train_area_ids = train_area_ids
        self.valid_area_ids = valid_area_ids
        self.test_area_ids = test_area_ids
        self.cutout_size = cutout_size
        self.batch_size = batch_size
        self.augmentor_chain = augmentor_chain

        train_data = self.get_dataset(mode='init')
        self.feature_stat_means = train_data.feature_stat_means
        self.feature_stat_stds = train_data.feature_stat_stds

        self.dataloader_args = {
            'batch_size': self.batch_size
        }

    def get_dataset(self, mode: str) -> RSData:
        if mode in ('train', 'init'):
            mask_area_ids = self.train_area_ids
            augmentor_chain = self.augmentor_chain
        elif mode == 'train':
            mask_area_ids = self.valid_area_ids
            augmentor_chain = None
        elif mode == 'train':
            mask_area_ids = self.test_area_ids
            augmentor_chain = None
        elif mode == 'predict':
            mask_area_ids = [0, 1, 2, 3, 4]
            augmentor_chain = None
        else:
            raise ValueError(
                f'`mode` must be one of \'init\', \'train\', \'valid\', \'test\', is \'{mode}\'.'
            )

        dataset = RSData(
            ds_path=self.ds_path,
            mask_area_ids=mask_area_ids,
            cutout_size=self.cutout_size,
            feature_stat_means=None if mode == 'init' else self.feature_stat_means,
            feature_stat_stds=None if mode == 'init' else self.feature_stat_stds,
            augmentor_chain=augmentor_chain
        )

        return dataset

    def train_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='train')
        return DataLoader(
            dataset,
            shuffle=True,
            **self.dataloader_args
        )

    def val_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='valid')
        return DataLoader(
            dataset,
            shuffle=False,
            **self.dataloader_args
        )

    def test_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='test')
        return DataLoader(
            dataset,
            shuffle=False,
            **self.dataloader_args
        )

    def predict_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='predict')
        return DataLoader(
            dataset,
            shuffle=False,
            **self.dataloader_args
        )


In [None]:
datamodule = RSDataModule(
    ds_path='../data/Sample_CombinedData_32signed/combined.zarr',
    train_area_ids=[1, 2],
    valid_area_ids=[3],
    test_area_ids=[4],
    cutout_size=21,
    batch_size=10,
)

In [None]:
train_dl = datamodule.train_dataloader()

In [None]:
next(iter(train_dl))

In [1]:
from src.augmentors import AugmentorChain

In [3]:
AugmentorChain.from_args('FlipAugmentor', PixelNoiseAugmentor={'scale': 1.0}, random_seed=1)

AugmentorChain(random_seed=1, augmentors=[PixelNoiseAugmentor(scale=1.0), FlipAugmentor()])