In [12]:
from dataloader import RSData
import lightning as L

In [19]:
from torch.utils.data import DataLoader
from os import PathLike
from pathlib import Path

from dataloader import AugmentorChain


class RSDataModule(L.LightningDataModule):
    def __init__(
            self,
            ds_path: str | PathLike,
            train_area_ids: list[int],
            valid_area_ids: list[int],
            test_area_ids: list[int],
            cutout_size: int,
            batch_size: int,
            augmentor_chain: AugmentorChain | None = None):

        self.ds_path = Path(ds_path)
        self.train_area_ids = train_area_ids
        self.valid_area_ids = valid_area_ids
        self.test_area_ids = test_area_ids
        self.cutout_size = cutout_size
        self.batch_size = batch_size
        self.augmentor_chain = augmentor_chain

        train_data = self.get_dataset(mode='init')
        self.feature_stat_means = train_data.feature_stat_means
        self.feature_stat_stds = train_data.feature_stat_stds

        self.dataloader_args = {
            'batch_size': self.batch_size
        }

    def get_dataset(self, mode: str) -> RSData:
        if mode in ('train', 'init'):
            mask_area_ids = self.train_area_ids
            augmentor_chain = self.augmentor_chain
        elif mode == 'train':
            mask_area_ids = self.valid_area_ids
            augmentor_chain = None
        elif mode == 'train':
            mask_area_ids = self.test_area_ids
            augmentor_chain = None
        elif mode == 'predict':
            mask_area_ids = [0, 1, 2, 3, 4]
            augmentor_chain = None
        else:
            raise ValueError(
                f'`mode` must be one of \'init\', \'train\', \'valid\', \'test\', is \'{mode}\'.'
            )

        dataset = RSData(
            ds_path=self.ds_path,
            mask_area_ids=mask_area_ids,
            cutout_size=self.cutout_size,
            feature_stat_means=None if mode == 'init' else self.feature_stat_means,
            feature_stat_stds=None if mode == 'init' else self.feature_stat_stds,
            augmentor_chain=augmentor_chain
        )

        return dataset

    def train_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='train')
        return DataLoader(
            dataset,
            shuffle=True,
            **self.dataloader_args
        )

    def val_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='valid')
        return DataLoader(
            dataset,
            shuffle=False,
            **self.dataloader_args
        )

    def test_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='test')
        return DataLoader(
            dataset,
            shuffle=False,
            **self.dataloader_args
        )

    def predict_dataloader(self) -> DataLoader:
        dataset = self.get_dataset(mode='predict')
        return DataLoader(
            dataset,
            shuffle=False,
            **self.dataloader_args
        )


In [20]:
datamodule = RSDataModule(
    ds_path='../data/Sample_CombinedData_32signed/combined.zarr',
    train_area_ids=[1, 2],
    valid_area_ids=[3],
    test_area_ids=[4],
    cutout_size=21,
    batch_size=10,
)

In [21]:
train_dl = datamodule.train_dataloader()

In [24]:
next(iter(train_dl))

[tensor([[[[-1.1074e-02, -8.2911e-02, -6.4006e-02,  ..., -1.4341e-01,
            -1.4719e-01,  2.4845e-02],
           [-6.4952e-02, -9.5199e-02, -9.3308e-02,  ..., -1.6217e-03,
            -1.3395e-01,  3.1461e-02],
           [-1.5475e-01, -4.0376e-02, -2.4307e-02,  ..., -1.5800e-02,
            -2.9033e-02, -1.2072e-01],
           ...,
           [-3.8485e-02, -1.2356e-01,  2.8625e-02,  ...,  1.4447e-02,
            -1.3910e-02, -1.4855e-02],
           [-1.3679e-01, -1.1074e-02, -3.0924e-02,  ...,  2.3899e-02,
            -2.1471e-02,  3.3352e-02],
           [-6.4952e-02, -9.5199e-02, -2.0674e-01,  ..., -7.6294e-02,
             2.1592e-03,  1.4583e-01]],
 
          [[ 2.2734e-01,  1.9455e-01,  2.2894e-01,  ...,  1.2658e-01,
             2.2654e-01,  3.8248e-01],
           [ 2.9931e-01,  2.2014e-01,  2.7052e-01,  ...,  1.6736e-01,
             1.4577e-01,  2.6013e-01],
           [ 3.0651e-01,  2.8651e-01,  3.4329e-01,  ...,  3.9047e-01,
             1.7776e-01,  2.1214e-01],


In [None]:
from src.augmentors import AugmentorModule

ModuleNotFoundError: No module named 'src'