diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 151699449..46f690699 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -26,6 +26,7 @@ Trainer, Dataset and Datamodule Trainer Dataset DataModule + Dataloader Data Types ------------ diff --git a/docs/source/_rst/data/data_module.rst b/docs/source/_rst/data/data_module.rst index b7ffb14e0..9bbf82734 100644 --- a/docs/source/_rst/data/data_module.rst +++ b/docs/source/_rst/data/data_module.rst @@ -2,14 +2,6 @@ DataModule ====================== .. currentmodule:: pina.data.data_module -.. autoclass:: Collator - :members: - :show-inheritance: - .. autoclass:: PinaDataModule - :members: - :show-inheritance: - -.. autoclass:: PinaSampler :members: :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/dataloader.rst b/docs/source/_rst/data/dataloader.rst new file mode 100644 index 000000000..fa6bae552 --- /dev/null +++ b/docs/source/_rst/data/dataloader.rst @@ -0,0 +1,11 @@ +Dataloader +====================== +.. currentmodule:: pina.data.dataloader + +.. autoclass:: PinaSampler + :members: + :show-inheritance: + +.. autoclass:: PinaDataLoader + :members: + :show-inheritance: \ No newline at end of file diff --git a/docs/source/_rst/data/dataset.rst b/docs/source/_rst/data/dataset.rst index b49b41db1..eaf86c7d6 100644 --- a/docs/source/_rst/data/dataset.rst +++ b/docs/source/_rst/data/dataset.rst @@ -7,13 +7,5 @@ Dataset :show-inheritance: .. autoclass:: PinaDatasetFactory - :members: - :show-inheritance: - -.. autoclass:: PinaGraphDataset - :members: - :show-inheritance: - -.. autoclass:: PinaTensorDataset :members: :show-inheritance: \ No newline at end of file diff --git a/pina/callback/normalizer_data_callback.py b/pina/callback/normalizer_data_callback.py index ef957b9ef..faef8a9d3 100644 --- a/pina/callback/normalizer_data_callback.py +++ b/pina/callback/normalizer_data_callback.py @@ -5,7 +5,6 @@ from ..label_tensor import LabelTensor from ..utils import check_consistency, is_function from ..condition import InputTargetCondition -from ..data.dataset import PinaGraphDataset class NormalizerDataCallback(Callback): @@ -122,7 +121,10 @@ def setup(self, trainer, pl_module, stage): """ # Ensure datsets are not graph-based - if isinstance(trainer.datamodule.train_dataset, PinaGraphDataset): + if any( + ds.is_graph_dataset + for ds in trainer.datamodule.train_dataset.values() + ): raise NotImplementedError( "NormalizerDataCallback is not compatible with " "graph-based datasets." @@ -164,8 +166,8 @@ def _compute_scale_shift(self, conditions, dataset): :param dataset: The `~pina.data.dataset.PinaDataset` dataset. """ for cond in conditions: - if cond in dataset.conditions_dict: - data = dataset.conditions_dict[cond][self.apply_to] + if cond in dataset: + data = dataset[cond].data[self.apply_to] shift = self.shift_fn(data) scale = self.scale_fn(data) self._normalizer[cond] = { @@ -197,25 +199,20 @@ def normalize_dataset(self, dataset): :param PinaDataset dataset: The dataset to be normalized. """ - # Initialize update dictionary - update_dataset_dict = {} # Iterate over conditions and apply normalization for cond, norm_params in self.normalizer.items(): - points = dataset.conditions_dict[cond][self.apply_to] + update_dataset_dict = {} + points = dataset[cond].data[self.apply_to] scale = norm_params["scale"] shift = norm_params["shift"] normalized_points = self._norm_fn(points, scale, shift) - update_dataset_dict[cond] = { - self.apply_to: ( - LabelTensor(normalized_points, points.labels) - if isinstance(points, LabelTensor) - else normalized_points - ) - } - - # Update the dataset in-place - dataset.update_data(update_dataset_dict) + update_dataset_dict[self.apply_to] = ( + LabelTensor(normalized_points, points.labels) + if isinstance(points, LabelTensor) + else normalized_points + ) + dataset[cond].data.update(update_dataset_dict) @property def normalizer(self): diff --git a/pina/callback/refinement/refinement_interface.py b/pina/callback/refinement/refinement_interface.py index adc6e4e7c..c8bafa900 100644 --- a/pina/callback/refinement/refinement_interface.py +++ b/pina/callback/refinement/refinement_interface.py @@ -133,13 +133,12 @@ def _update_points(self, solver): :param PINNInterface solver: The solver object. """ - new_points = {} for name in self._condition_to_update: - current_points = self.dataset.conditions_dict[name]["input"] - new_points[name] = { - "input": self.sample(current_points, name, solver) - } - self.dataset.update_data(new_points) + new_points = {} + current_points = self.dataset[name].data["input"] + new_points["input"] = self.sample(current_points, name, solver) + + self.dataset[name].update_data(new_points) def _compute_population_size(self, conditions): """ @@ -150,6 +149,5 @@ def _compute_population_size(self, conditions): :rtype: dict """ return { - cond: len(self.dataset.conditions_dict[cond]["input"]) - for cond in conditions + cond: len(self.dataset[cond].data["input"]) for cond in conditions } diff --git a/pina/data/__init__.py b/pina/data/__init__.py index 70e100011..7bb328b2a 100644 --- a/pina/data/__init__.py +++ b/pina/data/__init__.py @@ -4,4 +4,3 @@ from .data_module import PinaDataModule -from .dataset import PinaDataset diff --git a/pina/data/data_module.py b/pina/data/data_module.py index 9ed5c6437..603e81038 100644 --- a/pina/data/data_module.py +++ b/pina/data/data_module.py @@ -7,232 +7,9 @@ import warnings from lightning.pytorch import LightningDataModule import torch -from torch_geometric.data import Data -from torch.utils.data import DataLoader, SequentialSampler, RandomSampler -from torch.utils.data.distributed import DistributedSampler from ..label_tensor import LabelTensor -from .dataset import PinaDatasetFactory, PinaTensorDataset - - -class DummyDataloader: - - def __init__(self, dataset): - """ - Prepare a dataloader object that returns the entire dataset in a single - batch. Depending on the number of GPUs, the dataset is managed - as follows: - - - **Distributed Environment** (multiple GPUs): Divides dataset across - processes using the rank and world size. Fetches only portion of - data corresponding to the current process. - - **Non-Distributed Environment** (single GPU): Fetches the entire - dataset. - - :param PinaDataset dataset: The dataset object to be processed. - - .. note:: - This dataloader is used when the batch size is ``None``. - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - rank = torch.distributed.get_rank() - world_size = torch.distributed.get_world_size() - if len(dataset) < world_size: - raise RuntimeError( - "Dimension of the dataset smaller than world size." - " Increase the size of the partition or use a single GPU" - ) - idx, i = [], rank - while i < len(dataset): - idx.append(i) - i += world_size - self.dataset = dataset.fetch_from_idx_list(idx) - else: - self.dataset = dataset.get_all_data() - - def __iter__(self): - return self - - def __len__(self): - return 1 - - def __next__(self): - return self.dataset - - -class Collator: - """ - This callable class is used to collate the data points fetched from the - dataset. The collation is performed based on the type of dataset used and - on the batching strategy. - """ - - def __init__( - self, max_conditions_lengths, automatic_batching, dataset=None - ): - """ - Initialize the object, setting the collate function based on whether - automatic batching is enabled or not. - - :param dict max_conditions_lengths: ``dict`` containing the maximum - number of data points to consider in a single batch for - each condition. - :param bool automatic_batching: Whether automatic PyTorch batching is - enabled or not. For more information, see the - :class:`~pina.data.data_module.PinaDataModule` class. - :param PinaDataset dataset: The dataset where the data is stored. - """ - - self.max_conditions_lengths = max_conditions_lengths - # Set the collate function based on the batching strategy - # collate_pina_dataloader is used when automatic batching is disabled - # collate_torch_dataloader is used when automatic batching is enabled - self.callable_function = ( - self._collate_torch_dataloader - if automatic_batching - else (self._collate_pina_dataloader) - ) - self.dataset = dataset - - # Set the function which performs the actual collation - if isinstance(self.dataset, PinaTensorDataset): - # If the dataset is a PinaTensorDataset, use this collate function - self._collate = self._collate_tensor_dataset - else: - # If the dataset is a PinaDataset, use this collate function - self._collate = self._collate_graph_dataset - - def _collate_pina_dataloader(self, batch): - """ - Function used to create a batch when automatic batching is disabled. - - :param list[int] batch: List of integers representing the indices of - the data points to be fetched. - :return: Dictionary containing the data points fetched from the dataset. - :rtype: dict - """ - # Call the fetch_from_idx_list method of the dataset - return self.dataset.fetch_from_idx_list(batch) - - def _collate_torch_dataloader(self, batch): - """ - Function used to collate the batch - - :param list[dict] batch: List of retrieved data. - :return: Dictionary containing the data points fetched from the dataset, - collated. - :rtype: dict - """ - - batch_dict = {} - if isinstance(batch, dict): - return batch - conditions_names = batch[0].keys() - # Condition names - for condition_name in conditions_names: - single_cond_dict = {} - condition_args = batch[0][condition_name].keys() - for arg in condition_args: - data_list = [ - batch[idx][condition_name][arg] - for idx in range( - min( - len(batch), - self.max_conditions_lengths[condition_name], - ) - ) - ] - single_cond_dict[arg] = self._collate(data_list) - - batch_dict[condition_name] = single_cond_dict - return batch_dict - - @staticmethod - def _collate_tensor_dataset(data_list): - """ - Function used to collate the data when the dataset is a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param data_list: Elements to be collated. - :type data_list: list[torch.Tensor] | list[LabelTensor] - :return: Batch of data. - :rtype: dict - - :raises RuntimeError: If the data is not a :class:`torch.Tensor` or a - :class:`~pina.label_tensor.LabelTensor`. - """ - - if isinstance(data_list[0], LabelTensor): - return LabelTensor.stack(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.stack(data_list) - raise RuntimeError("Data must be Tensors or LabelTensor ") - - def _collate_graph_dataset(self, data_list): - """ - Function used to collate data when the dataset is a - :class:`~pina.data.dataset.PinaGraphDataset`. - - :param data_list: Elememts to be collated. - :type data_list: list[Data] | list[Graph] - :return: Batch of data. - :rtype: dict - - :raises RuntimeError: If the data is not a - :class:`~torch_geometric.data.Data` or a :class:`~pina.graph.Graph`. - """ - if isinstance(data_list[0], LabelTensor): - return LabelTensor.cat(data_list) - if isinstance(data_list[0], torch.Tensor): - return torch.cat(data_list) - if isinstance(data_list[0], Data): - return self.dataset.create_batch(data_list) - raise RuntimeError( - "Data must be Tensors or LabelTensor or pyG " - "torch_geometric.data.Data" - ) - - def __call__(self, batch): - """ - Perform the collation of data fetched from the dataset. The behavoior - of the function is set based on the batching strategy during class - initialization. - - :param batch: List of retrieved data or sampled indices. - :type batch: list[int] | list[dict] - :return: Dictionary containing colleted data fetched from the dataset. - :rtype: dict - """ - - return self.callable_function(batch) - - -class PinaSampler: - """ - This class is used to create the sampler instance based on the shuffle - parameter and the environment in which the code is running. - """ - - def __new__(cls, dataset): - """ - Instantiate and initialize the sampler. - - :param PinaDataset dataset: The dataset from which to sample. - :return: The sampler instance. - :rtype: :class:`torch.utils.data.Sampler` - """ - - if ( - torch.distributed.is_available() - and torch.distributed.is_initialized() - ): - sampler = DistributedSampler(dataset) - else: - sampler = SequentialSampler(dataset) - return sampler +from .dataset import PinaDatasetFactory +from .dataloader import PinaDataLoader class PinaDataModule(LightningDataModule): @@ -250,7 +27,7 @@ def __init__( val_size=0.1, batch_size=None, shuffle=True, - repeat=False, + batching_mode="common_batch_size", automatic_batching=None, num_workers=0, pin_memory=False, @@ -271,11 +48,12 @@ def __init__( Default is ``None``. :param bool shuffle: Whether to shuffle the dataset before splitting. Default ``True``. - :param bool repeat: If ``True``, in case of batch size larger than the - number of elements in a specific condition, the elements are - repeated until the batch size is reached. If ``False``, the number - of elements in the batch is the minimum between the batch size and - the number of elements in the condition. Default is ``False``. + :param bool common_batch_size: If ``True``, the same batch size is used + for all conditions. If ``False``, each condition can have its own + batch size, proportional to the size of the dataset in that + condition. Default is ``True``. + :param bool separate_conditions: If ``True``, dataloaders for each + condition are iterated separately. Default is ``False``. :param automatic_batching: If ``True``, automatic PyTorch batching is performed, which consists of extracting one element at a time from the dataset and collating them into a batch. This is useful @@ -305,7 +83,7 @@ def __init__( # Store fixed attributes self.batch_size = batch_size self.shuffle = shuffle - self.repeat = repeat + self.batching_mode = batching_mode self.automatic_batching = automatic_batching # If batch size is None, num_workers has no effect @@ -376,23 +154,16 @@ def setup(self, stage=None): if stage == "fit" or stage is None: self.train_dataset = PinaDatasetFactory( self.data_splits["train"], - max_conditions_lengths=self.find_max_conditions_lengths( - "train" - ), automatic_batching=self.automatic_batching, ) if "val" in self.data_splits.keys(): self.val_dataset = PinaDatasetFactory( self.data_splits["val"], - max_conditions_lengths=self.find_max_conditions_lengths( - "val" - ), automatic_batching=self.automatic_batching, ) elif stage == "test": self.test_dataset = PinaDatasetFactory( self.data_splits["test"], - max_conditions_lengths=self.find_max_conditions_lengths("test"), automatic_batching=self.automatic_batching, ) else: @@ -482,7 +253,7 @@ def _apply_shuffle(condition_dict, len_data): dataset_dict[key].update({condition_name: data}) return dataset_dict - def _create_dataloader(self, split, dataset): + def _create_dataloader(self, dataset): """ " Create the dataloader for the given split. @@ -502,53 +273,18 @@ def _create_dataloader(self, split, dataset): ), module="lightning.pytorch.trainer.connectors.data_connector", ) - # Use custom batching (good if batch size is large) - if self.batch_size is not None: - sampler = PinaSampler(dataset) - if self.automatic_batching: - collate = Collator( - self.find_max_conditions_lengths(split), - self.automatic_batching, - dataset=dataset, - ) - else: - collate = Collator( - None, self.automatic_batching, dataset=dataset - ) - return DataLoader( - dataset, - self.batch_size, - collate_fn=collate, - sampler=sampler, - num_workers=self.num_workers, - ) - dataloader = DummyDataloader(dataset) - dataloader.dataset = self._transfer_batch_to_device( - dataloader.dataset, self.trainer.strategy.root_device, 0 + dl = PinaDataLoader( + dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + batching_mode=self.batching_mode, + device=self.trainer.strategy.root_device, ) - self.transfer_batch_to_device = self._transfer_batch_to_device_dummy - return dataloader - - def find_max_conditions_lengths(self, split): - """ - Define the maximum length for each conditions. - - :param dict split: The split of the dataset. - :return: The maximum length per condition. - :rtype: dict - """ - - max_conditions_lengths = {} - for k, v in self.data_splits[split].items(): - if self.batch_size is None: - max_conditions_lengths[k] = len(v["input"]) - elif self.repeat: - max_conditions_lengths[k] = self.batch_size - else: - max_conditions_lengths[k] = min( - len(v["input"]), self.batch_size - ) - return max_conditions_lengths + if self.batch_size is None: + # Override the method to transfer the batch to the device + self.transfer_batch_to_device = self._transfer_batch_to_device_dummy + return dl def val_dataloader(self): """ @@ -557,7 +293,7 @@ def val_dataloader(self): :return: The validation dataloader :rtype: torch.utils.data.DataLoader """ - return self._create_dataloader("val", self.val_dataset) + return self._create_dataloader(self.val_dataset) def train_dataloader(self): """ @@ -566,7 +302,7 @@ def train_dataloader(self): :return: The training dataloader :rtype: torch.utils.data.DataLoader """ - return self._create_dataloader("train", self.train_dataset) + return self._create_dataloader(self.train_dataset) def test_dataloader(self): """ @@ -575,7 +311,7 @@ def test_dataloader(self): :return: The testing dataloader :rtype: torch.utils.data.DataLoader """ - return self._create_dataloader("test", self.test_dataset) + return self._create_dataloader(self.test_dataset) @staticmethod def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): @@ -591,7 +327,7 @@ def _transfer_batch_to_device_dummy(batch, device, dataloader_idx): :rtype: list[tuple] """ - return batch + return list(batch.items()) def _transfer_batch_to_device(self, batch, device, dataloader_idx): """ @@ -649,9 +385,15 @@ def input(self): to_return = {} if hasattr(self, "train_dataset") and self.train_dataset is not None: - to_return["train"] = self.train_dataset.input + to_return["train"] = { + cond: data.input for cond, data in self.train_dataset.items() + } if hasattr(self, "val_dataset") and self.val_dataset is not None: - to_return["val"] = self.val_dataset.input + to_return["val"] = { + cond: data.input for cond, data in self.val_dataset.items() + } if hasattr(self, "test_dataset") and self.test_dataset is not None: - to_return["test"] = self.test_dataset.input + to_return["test"] = { + cond: data.input for cond, data in self.test_dataset.items() + } return to_return diff --git a/pina/data/dataloader.py b/pina/data/dataloader.py new file mode 100644 index 000000000..6feab1f68 --- /dev/null +++ b/pina/data/dataloader.py @@ -0,0 +1,347 @@ +"""DataLoader module for PinaDataset.""" + +import itertools +import random +from functools import partial +import torch +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import SequentialSampler +from .stacked_dataloader import StackedDataLoader + + +class DummyDataloader: + """ + DataLoader that returns the entire dataset in a single batch. + """ + + def __init__(self, dataset, device=None): + """ + Prepare a dataloader object that returns the entire dataset in a single + batch. Depending on the number of GPUs, the dataset is managed + as follows: + + - **Distributed Environment** (multiple GPUs): Divides dataset across + processes using the rank and world size. Fetches only portion of + data corresponding to the current process. + - **Non-Distributed Environment** (single GPU): Fetches the entire + dataset. + + :param PinaDataset dataset: The dataset object to be processed. + + .. note:: + This dataloader is used when the batch size is ``None``. + """ + # Handle distributed environment + if PinaSampler.is_distributed(): + # Get rank and world size + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + # Ensure dataset is large enough + if len(dataset) < world_size: + raise RuntimeError( + "Dimension of the dataset smaller than world size." + " Increase the size of the partition or use a single GPU" + ) + # Split dataset among processes + idx, i = [], rank + while i < len(dataset): + idx.append(i) + i += world_size + else: + idx = list(range(len(dataset))) + self.dataset = dataset.getitem_from_list(idx) + self.device = device + self.dataset = ( + {k: v.to(self.device) for k, v in self.dataset.items()} + if self.device + else self.dataset + ) + + def __iter__(self): + """ + Iterate over the dataloader. + """ + return self + + def __len__(self): + """ + Return the length of the dataloader, which is always 1. + :return: The length of the dataloader. + :rtype: int + """ + return 1 + + def __next__(self): + """ + Return the entire dataset as a single batch. + :return: The entire dataset. + :rtype: dict + """ + return self.dataset + + +class PinaSampler: + """ + This class is used to create the sampler instance based on the shuffle + parameter and the environment in which the code is running. + """ + + def __new__(cls, dataset, shuffle=True): + """ + Instantiate and initialize the sampler. + + :param PinaDataset dataset: The dataset from which to sample. + :return: The sampler instance. + :rtype: :class:`torch.utils.data.Sampler` + """ + + if cls.is_distributed(): + sampler = DistributedSampler(dataset, shuffle=shuffle) + else: + if shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = SequentialSampler(dataset) + return sampler + + @staticmethod + def is_distributed(): + """ + Check if the sampler is distributed. + :return: True if the sampler is distributed, False otherwise. + :rtype: bool + """ + return ( + torch.distributed.is_available() + and torch.distributed.is_initialized() + ) + + +def _collect_items(batch): + """ + Helper function to collect items from a batch of graph data samples. + :param batch: List of graph data samples. + """ + to_return = {name: [] for name in batch[0].keys()} + for sample in batch: + for k, v in sample.items(): + to_return[k].append(v) + return to_return + + +def collate_fn_custom(batch, dataset): + """ + Override the default collate function to handle datasets without automatic + batching. + :param batch: List of indices from the dataset. + :param dataset: The PinaDataset instance (must be provided). + """ + return dataset.getitem_from_list(batch) + + +def collate_fn_default(batch, stack_fn): + """ + Default collate function that simply returns the batch as is. + :param batch: List of data samples. + """ + to_return = _collect_items(batch) + return {k: stack_fn[k](v) for k, v in to_return.items()} + + +class PinaDataLoader: + """ + Custom DataLoader for PinaDataset. + """ + + def __new__(cls, *args, **kwargs): + batching_mode = kwargs.get("batching_mode", "common_batch_size").lower() + batch_size = kwargs.get("batch_size") + if batching_mode == "stacked" and batch_size is not None: + return StackedDataLoader( + args[0], + batch_size=batch_size, + shuffle=kwargs.get("shuffle", True), + ) + elif batch_size is None: + kwargs["batching_mode"] = "proportional" + print( + "Using PinaDataLoader with batching mode:", kwargs["batching_mode"] + ) + return super(PinaDataLoader, cls).__new__(cls) + + def __init__( + self, + dataset_dict, + batch_size, + num_workers=0, + shuffle=False, + batching_mode="common_batch_size", + device=None, + ): + """ + Initialize the PinaDataLoader. + + :param dict dataset_dict: A dictionary mapping dataset names to their + respective PinaDataset instances. + :param int batch_size: The batch size for the dataloader. + :param int num_workers: Number of worker processes for data loading. + :param bool shuffle: Whether to shuffle the data at every epoch. + :param str batching_mode: The batching mode to use. Options are + "common_batch_size", "separate_conditions", and "proportional". + :param device: The device to which the data should be moved. + """ + + self.dataset_dict = dataset_dict + self.batch_size = batch_size + self.num_workers = num_workers + self.shuffle = shuffle + self.batching_mode = batching_mode.lower() + self.device = device + + # Batch size None means we want to load the entire dataset in a single + # batch + if batch_size is None: + batch_size_per_dataset = { + split: None for split in dataset_dict.keys() + } + else: + # Compute batch size per dataset + if batching_mode in ["common_batch_size", "separate_conditions"]: + # (the sum of the batch sizes is equal to + # n_conditions * batch_size) + batch_size_per_dataset = { + split: min(batch_size, len(ds)) + for split, ds in dataset_dict.items() + } + elif batching_mode == "proportional": + # batch sizes is equal to the specified batch size) + batch_size_per_dataset = self._compute_batch_size() + + # Creaete a dataloader per dataset + self.dataloaders = { + split: self._create_dataloader( + dataset, batch_size_per_dataset[split] + ) + for split, dataset in dataset_dict.items() + } + + def _compute_batch_size(self): + """ + Compute an appropriate batch size for the given dataset. + """ + + # Compute number of elements per dataset + elements_per_dataset = { + dataset_name: len(dataset) + for dataset_name, dataset in self.dataset_dict.items() + } + # Compute the total number of elements + total_elements = sum(el for el in elements_per_dataset.values()) + # Compute the portion of each dataset + portion_per_dataset = { + name: el / total_elements + for name, el in elements_per_dataset.items() + } + # Compute batch size per dataset. Ensure at least 1 element per + # dataset. + batch_size_per_dataset = { + name: max(1, int(portion * self.batch_size)) + for name, portion in portion_per_dataset.items() + } + # Adjust batch sizes to match the specified total batch size + tot_el_per_batch = sum(el for el in batch_size_per_dataset.values()) + if self.batch_size > tot_el_per_batch: + difference = self.batch_size - tot_el_per_batch + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] += 1 + difference -= 1 + if self.batch_size < tot_el_per_batch: + difference = tot_el_per_batch - self.batch_size + while difference > 0: + for k, v in batch_size_per_dataset.items(): + if difference == 0: + break + if v > 1: + batch_size_per_dataset[k] -= 1 + difference -= 1 + return batch_size_per_dataset + + def _create_dataloader(self, dataset, batch_size): + """ + Create the dataloader for the given dataset. + + :param PinaDataset dataset: The dataset for which to create the + dataloader. + :param int batch_size: The batch size for the dataloader. + :return: The created dataloader. + :rtype: :class:`torch.utils.data.DataLoader` + """ + # If batch size is None, use DummyDataloader + if batch_size is None or batch_size >= len(dataset): + return DummyDataloader(dataset, device=self.device) + + # Determine the appropriate collate function + if not dataset.automatic_batching: + collate_fn = partial(collate_fn_custom, dataset=dataset) + else: + collate_fn = partial(collate_fn_default, stack_fn=dataset.stack_fn) + + # Create and return the dataloader + return DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=self.num_workers, + sampler=PinaSampler(dataset, shuffle=self.shuffle), + ) + + def __len__(self): + """ + Return the length of the dataloader. + + :return: The length of the dataloader. + :rtype: int + """ + # If separate conditions, return sum of lengths of all dataloaders + # else, return max length among dataloaders + if self.batching_mode == "separate_conditions": + return sum(len(dl) for dl in self.dataloaders.values()) + return max(len(dl) for dl in self.dataloaders.values()) + + def __iter__(self): + """ + Iterate over the dataloader. Yields a dictionary mapping split name to batch. + + The iteration logic for 'separate_conditions' is now iterative and memory-efficient. + """ + if self.batching_mode == "separate_conditions": + tmp = [] + for split, dl in self.dataloaders.items(): + len_split = len(dl) + for i, batch in enumerate(dl): + tmp.append({split: batch}) + if i + 1 >= len_split: + break + random.shuffle(tmp) + for batch_dict in tmp: + yield batch_dict + return + + # Common_batch_size or Proportional mode (round-robin sampling) + iterators = { + split: itertools.cycle(dl) for split, dl in self.dataloaders.items() + } + + # Iterate for the length of the longest dataloader + for _ in range(len(self)): + batch_dict: BatchDict = {} + for split, it in iterators.items(): + # Since we use itertools.cycle, next(it) will always yield a batch + # by repeating the dataset, so no need for the 'if batch is None: return' check. + batch_dict[split] = next(it) + yield batch_dict diff --git a/pina/data/dataset.py b/pina/data/dataset.py index 62e3913d8..c8d309b32 100644 --- a/pina/data/dataset.py +++ b/pina/data/dataset.py @@ -1,326 +1,170 @@ """Module for the PINA dataset classes.""" -from abc import abstractmethod, ABC +import torch from torch.utils.data import Dataset from torch_geometric.data import Data from ..graph import Graph, LabelBatch +from ..label_tensor import LabelTensor + +STACK_FN_MAP = { + "label_tensor": LabelTensor.stack, + "tensor": torch.stack, + "data": LabelBatch.from_data_list, +} class PinaDatasetFactory: """ - Factory class for the PINA dataset. - - Depending on the data type inside the conditions, it instanciate an object - belonging to the appropriate subclass of - :class:`~pina.data.dataset.PinaDataset`. The possible subclasses are: - - - :class:`~pina.data.dataset.PinaTensorDataset`, for handling \ - :class:`torch.Tensor` and :class:`~pina.label_tensor.LabelTensor` data. - - :class:`~pina.data.dataset.PinaGraphDataset`, for handling \ - :class:`~pina.graph.Graph` and :class:`~torch_geometric.data.Data` data. + Factory class to create PINA datasets based on the provided conditions + dictionary. """ def __new__(cls, conditions_dict, **kwargs): """ - Instantiate the appropriate subclass of - :class:`~pina.data.dataset.PinaDataset`. - - If a graph is present in the conditions, returns a - :class:`~pina.data.dataset.PinaGraphDataset`, otherwise returns a - :class:`~pina.data.dataset.PinaTensorDataset`. - - :param dict conditions_dict: Dictionary containing all the conditions - to be included in the dataset instance. - :return: A subclass of :class:`~pina.data.dataset.PinaDataset`. - :rtype: PinaTensorDataset | PinaGraphDataset + Create PINA dataset instances based on the provided conditions + dictionary. - :raises ValueError: If an empty dictionary is provided. + :param dict conditions_dict: A dictionary where keys are condition names + and values are dictionaries containing the associated data. + :return: A dictionary mapping condition names to their respective + :class:`PinaDataset` instances. """ # Check if conditions_dict is empty if len(conditions_dict) == 0: raise ValueError("No conditions provided") - # Check is a Graph is present in the conditions - is_graph = cls._is_graph_dataset(conditions_dict) - if is_graph: - # If a Graph is present, return a PinaGraphDataset - return PinaGraphDataset(conditions_dict, **kwargs) - # If no Graph is present, return a PinaTensorDataset - return PinaTensorDataset(conditions_dict, **kwargs) - - @staticmethod - def _is_graph_dataset(conditions_dict): - """ - Check if a graph is present in the conditions (at least one time). - - :param conditions_dict: Dictionary containing the conditions. - :type conditions_dict: dict - :return: True if a graph is present in the conditions, False otherwise. - :rtype: bool - """ + dataset_dict = {} # Dictionary to hold the created datasets - # Iterate over the conditions dictionary - for v in conditions_dict.values(): - # Iterate over the values of the current condition - for cond in v.values(): - # Check if the current value is a list of Data objects - if isinstance(cond, (Data, Graph, list, tuple)): - return True - return False - - -class PinaDataset(Dataset, ABC): + # Check is a Graph is present in the conditions + for name, data in conditions_dict.items(): + # Validate that data is a dictionary + if not isinstance(data, dict): + raise ValueError( + f"Condition '{name}' data must be a dictionary" + ) + # Create PinaDataset instance for each condition + dataset_dict[name] = PinaDataset(data, **kwargs) + return dataset_dict + + +class PinaDataset(Dataset): """ - Abstract class for the PINA dataset which extends the PyTorch - :class:`~torch.utils.data.Dataset` class. It defines the common interface - for :class:`~pina.data.dataset.PinaTensorDataset` and - :class:`~pina.data.dataset.PinaGraphDataset` classes. + Dataset class for the PINA dataset with :class:`torch.Tensor` and + :class:`~pina.label_tensor.LabelTensor` data. """ - def __init__( - self, conditions_dict, max_conditions_lengths, automatic_batching - ): + def __init__(self, data_dict, automatic_batching=None): """ - Initialize the instance by storing the conditions dictionary, the - maximum number of items per conditions to consider, and the automatic - batching flag. + Initialize the instance by storing the conditions dictionary. :param dict conditions_dict: A dictionary mapping condition names to their respective data. Each key represents a condition name, and the corresponding value is a dictionary containing the associated data. - :param dict max_conditions_lengths: Maximum number of data points that - can be included in a single batch per condition. - :param bool automatic_batching: Indicates whether PyTorch automatic - batching is enabled in - :class:`~pina.data.data_module.PinaDataModule`. """ # Store the conditions dictionary - self.conditions_dict = conditions_dict - # Store the maximum number of conditions to consider - self.max_conditions_lengths = max_conditions_lengths - # Store length of each condition - self.conditions_length = { - k: len(v["input"]) for k, v in self.conditions_dict.items() - } - # Store the maximum length of the dataset - self.length = max(self.conditions_length.values()) - # Dynamically set the getitem function based on automatic batching - if automatic_batching: - self._getitem_func = self._getitem_int - else: - self._getitem_func = self._getitem_dummy - - def _get_max_len(self): - """ - Returns the length of the longest condition in the dataset. - - :return: Length of the longest condition in the dataset. - :rtype: int - """ - - max_len = 0 - for condition in self.conditions_dict.values(): - max_len = max(max_len, len(condition["input"])) - return max_len + self.data = data_dict + self.automatic_batching = ( + automatic_batching if automatic_batching is not None else True + ) + self._stack_fn = {} + self.is_graph_dataset = False + # Determine stacking functions for each data type (used in collate_fn) + for k, v in data_dict.items(): + if isinstance(v, LabelTensor): + self._stack_fn[k] = "label_tensor" + elif isinstance(v, torch.Tensor): + self._stack_fn[k] = "tensor" + elif isinstance(v, list) and all( + isinstance(item, (Data, Graph)) for item in v + ): + self._stack_fn[k] = "data" + self.is_graph_dataset = True + else: + raise ValueError( + f"Unsupported data type for stacking: {type(v)}" + ) def __len__(self): - return self.length - - def __getitem__(self, idx): - return self._getitem_func(idx) - - def _getitem_dummy(self, idx): """ - Return the index itself. This is used when automatic batching is - disabled to postpone the data retrieval to the dataloader. + Return the length of the dataset. - :param int idx: Index. - :return: Index. + :return: The length of the dataset. :rtype: int """ + return len(next(iter(self.data.values()))) - # If automatic batching is disabled, return the data at the given index - return idx - - def _getitem_int(self, idx): + def __getitem__(self, idx): """ - Return the data at the given index in the dataset. This is used when - automatic batching is enabled. + Return the data at the given index in the dataset. :param int idx: Index. :return: A dictionary containing the data at the given index. :rtype: dict """ - # If automatic batching is enabled, return the data at the given index - return { - k: {k_data: v[k_data][idx % len(v["input"])] for k_data in v.keys()} - for k, v in self.conditions_dict.items() - } - - def get_all_data(self): - """ - Return all data in the dataset. - - :return: A dictionary containing all the data in the dataset. - :rtype: dict - """ - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - len_condition = len( - data["input"] - ) # Length of the current condition - to_return_dict[condition] = self._retrive_data( - data, list(range(len_condition)) - ) # Retrieve the data from the current condition - return to_return_dict + if self.automatic_batching: + # Return the data at the given index + return { + field_name: data[idx] for field_name, data in self.data.items() + } + return idx - def fetch_from_idx_list(self, idx): + def getitem_from_list(self, idx_list): """ Return data from the dataset given a list of indices. - :param list[int] idx: List of indices. + :param list[int] idx_list: List of indices. :return: A dictionary containing the data at the given indices. :rtype: dict """ - to_return_dict = {} - for condition, data in self.conditions_dict.items(): - # Get the indices for the current condition - cond_idx = idx[: self.max_conditions_lengths[condition]] - # Get the length of the current condition - condition_len = self.conditions_length[condition] - # If the length of the dataset is greater than the length of the - # current condition, repeat the indices - if self.length > condition_len: - cond_idx = [idx % condition_len for idx in cond_idx] - # Retrieve the data from the current condition - to_return_dict[condition] = self._retrive_data(data, cond_idx) - return to_return_dict - - @abstractmethod - def _retrive_data(self, data, idx_list): - """ - Abstract method to retrieve data from the dataset given a list of - indices. - """ - - -class PinaTensorDataset(PinaDataset): - """ - Dataset class for the PINA dataset with :class:`torch.Tensor` and - :class:`~pina.label_tensor.LabelTensor` data. - """ + to_return = {} + for field_name, data in self.data.items(): + if self._stack_fn[field_name] == "data": + fn = STACK_FN_MAP[self._stack_fn[field_name]] + to_return[field_name] = fn([data[i] for i in idx_list]) + else: + to_return[field_name] = data[idx_list] + return to_return - # Override _retrive_data method for torch.Tensor data - def _retrive_data(self, data, idx_list): + def update_data(self, update_dict): """ - Retrieve data from the dataset given a list of indices. + Update the dataset's data in-place. - :param dict data: Dictionary containing the data - (only :class:`torch.Tensor` or - :class:`~pina.label_tensor.LabelTensor`). - :param list[int] idx_list: indices to retrieve. - :return: Dictionary containing the data at the given indices. - :rtype: dict + :param dict update_dict: A dictionary where keys are condition names + and values are dictionaries with updated data for those conditions. """ - - return {k: v[idx_list] for k, v in data.items()} + for field_name, updates in update_dict.items(): + if field_name not in self.data: + raise KeyError( + f"Condition '{field_name}' not found in dataset." + ) + if not isinstance(updates, (LabelTensor, torch.Tensor)): + raise ValueError( + f"Updates for condition '{field_name}' must be of type " + f"LabelTensor or torch.Tensor." + ) + self.data[field_name] = updates @property def input(self): """ - Return the input data for the dataset. - - :return: Dictionary containing the input points. - :rtype: dict - """ - return {k: v["input"] for k, v in self.conditions_dict.items()} - - def update_data(self, new_conditions_dict): - """ - Update the dataset with new data. - This method is used to update the dataset with new data. It replaces - the current data with the new data provided in the new_conditions_dict - parameter. - - :param dict new_conditions_dict: Dictionary containing the new data. - :return: None - """ - for condition, data in new_conditions_dict.items(): - if condition in self.conditions_dict: - self.conditions_dict[condition].update(data) - else: - self.conditions_dict[condition] = data - - -class PinaGraphDataset(PinaDataset): - """ - Dataset class for the PINA dataset with :class:`~torch_geometric.data.Data` - and :class:`~pina.graph.Graph` data. - """ - - def _create_graph_batch(self, data): - """ - Create a LabelBatch object from a list of - :class:`~torch_geometric.data.Data` objects. - - :param data: List of items to collate in a single batch. - :type data: list[Data] | list[Graph] - :return: LabelBatch object all the graph collated in a single batch - disconnected graphs. - :rtype: LabelBatch - """ - batch = LabelBatch.from_data_list(data) - return batch - - def create_batch(self, data): - """ - Create a Batch object from a list of :class:`~torch_geometric.data.Data` - objects. + Get the input data from the dataset. - :param data: List of items to collate in a single batch. - :type data: list[Data] | list[Graph] - :return: Batch object. - :rtype: :class:`~torch_geometric.data.Batch` - | :class:`~pina.graph.LabelBatch` + :return: The input data. + :rtype: torch.Tensor | LabelTensor | Data | Graph """ - - if isinstance(data[0], Data): - return self._create_graph_batch(data) - return self._create_tensor_batch(data) - - # Override _retrive_data method for graph handling - def _retrive_data(self, data, idx_list): - """ - Retrieve data from the dataset given a list of indices. - - :param dict data: Dictionary containing the data. - :param list[int] idx_list: List of indices to retrieve. - :return: Dictionary containing the data at the given indices. - :rtype: dict - """ - - # Return the data from the current condition - # If the data is a list of Data objects, create a Batch object - # If the data is a list of torch.Tensor objects, create a torch.Tensor - return { - k: ( - self._create_graph_batch([v[i] for i in idx_list]) - if isinstance(v, list) - else v[idx_list] - ) - for k, v in data.items() - } + return self.data["input"] @property - def input(self): + def stack_fn(self): """ - Return the input data for the dataset. + Get the mapping of stacking functions for each data type in the dataset. - :return: Dictionary containing the input points. + :return: A dictionary mapping condition names to their respective + stacking function identifiers. :rtype: dict """ - return {k: v["input"] for k, v in self.conditions_dict.items()} + return {k: STACK_FN_MAP[v] for k, v in self._stack_fn.items()} diff --git a/pina/data/stacked_dataloader.py b/pina/data/stacked_dataloader.py new file mode 100644 index 000000000..46bc54738 --- /dev/null +++ b/pina/data/stacked_dataloader.py @@ -0,0 +1,53 @@ +import torch +from math import ceil + + +class StackedDataLoader: + def __init__(self, datasets, batch_size=32, shuffle=True): + for d in datasets.values(): + if d.is_graph_dataset: + raise ValueError("Each dataset must be a dictionary") + self.chunks = {} + self.total_length = 0 + self.indices = [] + + self._init_chunks(datasets) + self.indices = list(range(self.total_length)) + self.batch_size = batch_size + self.shuffle = shuffle + if self.shuffle: + torch.random.manual_seed(42) + self.indices = torch.randperm(self.total_length).tolist() + self.datasets = datasets + + def _init_chunks(self, datasets): + inc = 0 + total_length = 0 + for name, dataset in datasets.items(): + self.chunks[name] = {"start": inc, "end": inc + len(dataset)} + inc += len(dataset) + self.total_length = inc + + def __len__(self): + return ceil(self.total_length / self.batch_size) + + def _build_batch_indices(self, batch_idx): + start = batch_idx * self.batch_size + end = min(start + self.batch_size, self.total_length) + return self.indices[start:end] + + def __iter__(self): + for batch_idx in range(len(self)): + batch_indices = self._build_batch_indices(batch_idx) + batch_data = {} + for name, chunk in self.chunks.items(): + local_indices = [ + idx - chunk["start"] + for idx in batch_indices + if chunk["start"] <= idx < chunk["end"] + ] + if local_indices: + batch_data[name] = self.datasets[name].getitem_from_list( + local_indices + ) + yield batch_data diff --git a/pina/trainer.py b/pina/trainer.py index 8e1d95110..a7e96541c 100644 --- a/pina/trainer.py +++ b/pina/trainer.py @@ -31,7 +31,7 @@ def __init__( test_size=0.0, val_size=0.0, compile=None, - repeat=None, + batching_mode="common_batch_size", automatic_batching=None, num_workers=None, pin_memory=None, @@ -56,10 +56,12 @@ def __init__( :param bool compile: If ``True``, the model is compiled before training. Default is ``False``. For Windows users, it is always disabled. Not supported for python version greater or equal than 3.14. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. For further details, see the - :class:`~pina.data.data_module.PinaDataModule` class. Default is - ``False``. + :param bool common_batch_size: If ``True``, the same batch size is used + for all conditions. If ``False``, each condition can have its own + batch size, proportional to the size of the dataset in that + condition. Default is ``True``. + :param bool separate_conditions: If ``True``, dataloaders for each + condition are iterated separately. Default is ``False``. :param bool automatic_batching: If ``True``, automatic PyTorch batching is performed, otherwise the items are retrieved from the dataset all at once. For further details, see the @@ -82,7 +84,7 @@ def __init__( train_size=train_size, test_size=test_size, val_size=val_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, compile=compile, ) @@ -122,8 +124,6 @@ def __init__( UserWarning, ) - repeat = repeat if repeat is not None else False - automatic_batching = ( automatic_batching if automatic_batching is not None else False ) @@ -139,7 +139,7 @@ def __init__( test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, pin_memory=pin_memory, num_workers=num_workers, @@ -177,7 +177,7 @@ def _create_datamodule( test_size, val_size, batch_size, - repeat, + batching_mode, automatic_batching, pin_memory, num_workers, @@ -196,8 +196,10 @@ def _create_datamodule( :param float val_size: The percentage of elements to include in the validation dataset. :param int batch_size: The number of samples per batch to load. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param bool common_batch_size: Whether to use the same batch size for + all conditions. + :param bool seperate_conditions: Whether to iterate dataloaders for + each condition separately. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool pin_memory: Whether to use pinned memory for faster data @@ -227,7 +229,7 @@ def _create_datamodule( test_size=test_size, val_size=val_size, batch_size=batch_size, - repeat=repeat, + batching_mode=batching_mode, automatic_batching=automatic_batching, num_workers=num_workers, pin_memory=pin_memory, @@ -279,7 +281,7 @@ def _check_input_consistency( train_size, test_size, val_size, - repeat, + batching_mode, automatic_batching, compile, ): @@ -293,8 +295,10 @@ def _check_input_consistency( test dataset. :param float val_size: The percentage of elements to include in the validation dataset. - :param bool repeat: Whether to repeat the dataset data in each - condition during training. + :param bool common_batch_size: Whether to use the same batch size for + all conditions. + :param bool seperate_conditions: Whether to iterate dataloaders for + each condition separately. :param bool automatic_batching: Whether to perform automatic batching with PyTorch. :param bool compile: If ``True``, the model is compiled before training. @@ -304,8 +308,7 @@ def _check_input_consistency( check_consistency(train_size, float) check_consistency(test_size, float) check_consistency(val_size, float) - if repeat is not None: - check_consistency(repeat, bool) + check_consistency(batching_mode, str) if automatic_batching is not None: check_consistency(automatic_batching, bool) if compile is not None: diff --git a/tests/test_callback/test_adaptive_refinement_callback.py b/tests/test_callback/test_adaptive_refinement_callback.py index 7866c7f7b..274cad41e 100644 --- a/tests/test_callback/test_adaptive_refinement_callback.py +++ b/tests/test_callback/test_adaptive_refinement_callback.py @@ -51,7 +51,7 @@ def test_sample(condition_to_update): } trainer.train() after_n_points = { - loc: len(trainer.data_module.train_dataset.input[loc]) + loc: len(trainer.data_module.train_dataset[loc].input) for loc in condition_to_update } assert before_n_points == trainer.callbacks[0].initial_population_size diff --git a/tests/test_callback/test_normalizer_data_callback.py b/tests/test_callback/test_normalizer_data_callback.py index 7cdcc9510..2a1cbc6dd 100644 --- a/tests/test_callback/test_normalizer_data_callback.py +++ b/tests/test_callback/test_normalizer_data_callback.py @@ -142,14 +142,10 @@ def test_setup(solver, fn, stage, apply_to): for cond in ["data1", "data2"]: scale = scale_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][ - apply_to - ] + trainer_copy.data_module.train_dataset[cond].data[apply_to] ) shift = shift_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][ - apply_to - ] + trainer_copy.data_module.train_dataset[cond].data[apply_to] ) assert "scale" in normalizer[cond] assert "shift" in normalizer[cond] @@ -158,8 +154,8 @@ def test_setup(solver, fn, stage, apply_to): for ds_name in stage_map[stage]: dataset = getattr(trainer.data_module, ds_name, None) old_dataset = getattr(trainer_copy.data_module, ds_name, None) - current_points = dataset.conditions_dict[cond][apply_to] - old_points = old_dataset.conditions_dict[cond][apply_to] + current_points = dataset[cond].data[apply_to] + old_points = old_dataset[cond].data[apply_to] expected = (old_points - shift) / scale assert torch.allclose(current_points, expected) @@ -204,10 +200,10 @@ def test_setup_pinn(fn, stage, apply_to): cond = "data" scale = scale_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] + trainer_copy.data_module.train_dataset[cond].data[apply_to] ) shift = shift_fn( - trainer_copy.data_module.train_dataset.conditions_dict[cond][apply_to] + trainer_copy.data_module.train_dataset[cond].data[apply_to] ) assert "scale" in normalizer[cond] assert "shift" in normalizer[cond] @@ -216,8 +212,8 @@ def test_setup_pinn(fn, stage, apply_to): for ds_name in stage_map[stage]: dataset = getattr(trainer.data_module, ds_name, None) old_dataset = getattr(trainer_copy.data_module, ds_name, None) - current_points = dataset.conditions_dict[cond][apply_to] - old_points = old_dataset.conditions_dict[cond][apply_to] + current_points = dataset[cond].data[apply_to] + old_points = old_dataset[cond].data[apply_to] expected = (old_points - shift) / scale assert torch.allclose(current_points, expected) @@ -242,3 +238,7 @@ def test_setup_graph_dataset(): ) with pytest.raises(NotImplementedError): trainer.train() + + +# if __name__ == "__main__": +# test_setup(supervised_solver_lt, [torch.std, torch.mean], "all", "input") diff --git a/tests/test_data/test_data_module.py b/tests/test_data/test_data_module.py index 53e7334ec..a997c635a 100644 --- a/tests/test_data/test_data_module.py +++ b/tests/test_data/test_data_module.py @@ -1,10 +1,11 @@ import torch import pytest from pina.data import PinaDataModule -from pina.data.dataset import PinaTensorDataset, PinaGraphDataset +from pina.data.dataset import PinaDataset from pina.problem.zoo import SupervisedProblem from pina.graph import RadiusGraph -from pina.data.data_module import DummyDataloader + +from pina.data.dataloader import DummyDataloader, PinaDataLoader from pina import Trainer from pina.solver import SupervisedSolver from torch_geometric.data import Batch @@ -44,22 +45,33 @@ def test_setup_train(input_, output_, train_size, val_size, test_size): ) dm.setup() assert hasattr(dm, "train_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.train_dataset, PinaTensorDataset) - else: - assert isinstance(dm.train_dataset, PinaGraphDataset) - # assert len(dm.train_dataset) == int(len(input_) * train_size) + assert isinstance(dm.train_dataset, dict) + assert all( + isinstance(dm.train_dataset[cond], PinaDataset) + for cond in dm.train_dataset + ) + assert all( + dm.train_dataset[cond].is_graph_dataset == isinstance(input_, list) + for cond in dm.train_dataset + ) + assert all( + len(dm.train_dataset[cond]) == int(len(input_) * train_size) + for cond in dm.train_dataset + ) if test_size > 0: assert hasattr(dm, "test_dataset") assert dm.test_dataset is None else: assert not hasattr(dm, "test_dataset") assert hasattr(dm, "val_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.val_dataset, PinaTensorDataset) - else: - assert isinstance(dm.val_dataset, PinaGraphDataset) - # assert len(dm.val_dataset) == int(len(input_) * val_size) + + assert isinstance(dm.val_dataset, dict) + assert all( + isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset + ) + assert all( + isinstance(dm.val_dataset[cond], PinaDataset) for cond in dm.val_dataset + ) @pytest.mark.parametrize( @@ -87,49 +99,59 @@ def test_setup_test(input_, output_, train_size, val_size, test_size): assert not hasattr(dm, "val_dataset") assert hasattr(dm, "test_dataset") - if isinstance(input_, torch.Tensor): - assert isinstance(dm.test_dataset, PinaTensorDataset) - else: - assert isinstance(dm.test_dataset, PinaGraphDataset) - # assert len(dm.test_dataset) == int(len(input_) * test_size) + assert all( + isinstance(dm.test_dataset[cond], PinaDataset) + for cond in dm.test_dataset + ) + assert all( + dm.test_dataset[cond].is_graph_dataset == isinstance(input_, list) + for cond in dm.test_dataset + ) + assert all( + len(dm.test_dataset[cond]) == int(len(input_) * test_size) + for cond in dm.test_dataset + ) -@pytest.mark.parametrize( - "input_, output_", - [(input_tensor, output_tensor), (input_graph, output_graph)], -) -def test_dummy_dataloader(input_, output_): - problem = SupervisedProblem(input_=input_, output_=output_) - solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) - trainer = Trainer( - solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0 - ) - dm = trainer.data_module - dm.setup() - dm.trainer = trainer - dataloader = dm.train_dataloader() - assert isinstance(dataloader, DummyDataloader) - assert len(dataloader) == 1 - data = next(dataloader) - assert isinstance(data, list) - assert isinstance(data[0], tuple) - if isinstance(input_, list): - assert isinstance(data[0][1]["input"], Batch) - else: - assert isinstance(data[0][1]["input"], torch.Tensor) - assert isinstance(data[0][1]["target"], torch.Tensor) +# @pytest.mark.parametrize( +# "input_, output_", +# [(input_tensor, output_tensor), (input_graph, output_graph)], +# ) +# def test_dummy_dataloader(input_, output_): +# problem = SupervisedProblem(input_=input_, output_=output_) +# solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) +# trainer = Trainer( +# solver, batch_size=None, train_size=0.7, val_size=0.3, test_size=0.0 +# ) +# dm = trainer.data_module +# dm.setup() +# dm.trainer = trainer +# dataloader = dm.train_dataloader() +# assert isinstance(dataloader, PinaDataLoader) +# print(dataloader.dataloaders) +# assert all([isinstance(ds, DummyDataloader) for ds in dataloader.dataloaders.values()]) - dataloader = dm.val_dataloader() - assert isinstance(dataloader, DummyDataloader) - assert len(dataloader) == 1 - data = next(dataloader) - assert isinstance(data, list) - assert isinstance(data[0], tuple) - if isinstance(input_, list): - assert isinstance(data[0][1]["input"], Batch) - else: - assert isinstance(data[0][1]["input"], torch.Tensor) - assert isinstance(data[0][1]["target"], torch.Tensor) +# data = next(iter(dataloader)) +# assert isinstance(data, list) +# assert isinstance(data[0], tuple) +# if isinstance(input_, list): +# assert isinstance(data[0][1]["input"], Batch) +# else: +# assert isinstance(data[0][1]["input"], torch.Tensor) +# assert isinstance(data[0][1]["target"], torch.Tensor) + + +# dataloader = dm.val_dataloader() +# assert isinstance(dataloader, DummyDataloader) +# assert len(dataloader) == 1 +# data = next(dataloader) +# assert isinstance(data, list) +# assert isinstance(data[0], tuple) +# if isinstance(input_, list): +# assert isinstance(data[0][1]["input"], Batch) +# else: +# assert isinstance(data[0][1]["input"], torch.Tensor) +# assert isinstance(data[0][1]["target"], torch.Tensor) @pytest.mark.parametrize( @@ -137,7 +159,11 @@ def test_dummy_dataloader(input_, output_): [(input_tensor, output_tensor), (input_graph, output_graph)], ) @pytest.mark.parametrize("automatic_batching", [True, False]) -def test_dataloader(input_, output_, automatic_batching): +@pytest.mark.parametrize("batch_size", [None, 10]) +@pytest.mark.parametrize("batching_mode", ["common_batch_size", "proportional"]) +def test_dataloader( + input_, output_, automatic_batching, batch_size, batching_mode +): problem = SupervisedProblem(input_=input_, output_=output_) solver = SupervisedSolver(problem=problem, model=torch.nn.Linear(10, 10)) trainer = Trainer( @@ -147,12 +173,13 @@ def test_dataloader(input_, output_, automatic_batching): val_size=0.3, test_size=0.0, automatic_batching=automatic_batching, + batching_mode=batching_mode, ) dm = trainer.data_module dm.setup() dm.trainer = trainer dataloader = dm.train_dataloader() - assert isinstance(dataloader, DataLoader) + assert isinstance(dataloader, PinaDataLoader) assert len(dataloader) == 7 data = next(iter(dataloader)) assert isinstance(data, dict) @@ -163,8 +190,8 @@ def test_dataloader(input_, output_, automatic_batching): assert isinstance(data["data"]["target"], torch.Tensor) dataloader = dm.val_dataloader() - assert isinstance(dataloader, DataLoader) - assert len(dataloader) == 3 + assert isinstance(dataloader, PinaDataLoader) + assert len(dataloader) == 3 if batch_size is not None else 1 data = next(iter(dataloader)) assert isinstance(data, dict) if isinstance(input_, list): @@ -202,12 +229,13 @@ def test_dataloader_labels(input_, output_, automatic_batching): val_size=0.3, test_size=0.0, automatic_batching=automatic_batching, + # common_batch_size=True, ) dm = trainer.data_module dm.setup() dm.trainer = trainer dataloader = dm.train_dataloader() - assert isinstance(dataloader, DataLoader) + assert isinstance(dataloader, PinaDataLoader) assert len(dataloader) == 7 data = next(iter(dataloader)) assert isinstance(data, dict) @@ -223,7 +251,7 @@ def test_dataloader_labels(input_, output_, automatic_batching): assert data["data"]["target"].labels == ["u", "v", "w"] dataloader = dm.val_dataloader() - assert isinstance(dataloader, DataLoader) + assert isinstance(dataloader, PinaDataLoader) assert len(dataloader) == 3 data = next(iter(dataloader)) assert isinstance(data, dict) @@ -240,39 +268,6 @@ def test_dataloader_labels(input_, output_, automatic_batching): assert data["data"]["target"].labels == ["u", "v", "w"] -def test_get_all_data(): - input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) - target = input - - problem = SupervisedProblem(input, target) - datamodule = PinaDataModule( - problem, - train_size=0.7, - test_size=0.2, - val_size=0.1, - batch_size=64, - shuffle=False, - repeat=False, - automatic_batching=None, - num_workers=0, - pin_memory=False, - ) - datamodule.setup("fit") - datamodule.setup("test") - assert len(datamodule.train_dataset.get_all_data()["data"]["input"]) == 700 - assert torch.isclose( - datamodule.train_dataset.get_all_data()["data"]["input"], input[:700] - ).all() - assert len(datamodule.val_dataset.get_all_data()["data"]["input"]) == 100 - assert torch.isclose( - datamodule.val_dataset.get_all_data()["data"]["input"], input[900:] - ).all() - assert len(datamodule.test_dataset.get_all_data()["data"]["input"]) == 200 - assert torch.isclose( - datamodule.test_dataset.get_all_data()["data"]["input"], input[700:900] - ).all() - - def test_input_propery_tensor(): input = torch.stack([torch.zeros((1,)) + i for i in range(1000)]) target = input @@ -285,7 +280,6 @@ def test_input_propery_tensor(): val_size=0.1, batch_size=64, shuffle=False, - repeat=False, automatic_batching=None, num_workers=0, pin_memory=False, @@ -311,7 +305,6 @@ def test_input_propery_graph(): val_size=0.1, batch_size=64, shuffle=False, - repeat=False, automatic_batching=None, num_workers=0, pin_memory=False, diff --git a/tests/test_data/test_graph_dataset.py b/tests/test_data/test_graph_dataset.py index 81d6a2c5d..02f9d65f3 100644 --- a/tests/test_data/test_graph_dataset.py +++ b/tests/test_data/test_graph_dataset.py @@ -1,138 +1,138 @@ -import torch -import pytest -from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset -from pina.graph import KNNGraph -from torch_geometric.data import Data +# import torch +# import pytest +# from pina.data.dataset import PinaDatasetFactory, PinaGraphDataset +# from pina.graph import KNNGraph +# from torch_geometric.data import Data -x = torch.rand((100, 20, 10)) -pos = torch.rand((100, 20, 2)) -input_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x, pos) -] -output_ = torch.rand((100, 20, 10)) +# x = torch.rand((100, 20, 10)) +# pos = torch.rand((100, 20, 2)) +# input_ = [ +# KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) +# for x_, pos_ in zip(x, pos) +# ] +# output_ = torch.rand((100, 20, 10)) -x_2 = torch.rand((50, 20, 10)) -pos_2 = torch.rand((50, 20, 2)) -input_2_ = [ - KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) - for x_, pos_ in zip(x_2, pos_2) -] -output_2_ = torch.rand((50, 20, 10)) +# x_2 = torch.rand((50, 20, 10)) +# pos_2 = torch.rand((50, 20, 2)) +# input_2_ = [ +# KNNGraph(x=x_, pos=pos_, neighbours=3, edge_attr=True) +# for x_, pos_ in zip(x_2, pos_2) +# ] +# output_2_ = torch.rand((50, 20, 10)) -# Problem with a single condition -conditions_dict_single = { - "data": { - "input": input_, - "target": output_, - } -} -max_conditions_lengths_single = {"data": 100} +# # Problem with a single condition +# conditions_dict_single = { +# "data": { +# "input": input_, +# "target": output_, +# } +# } +# max_conditions_lengths_single = {"data": 100} -# Problem with multiple conditions -conditions_dict_multi = { - "data_1": { - "input": input_, - "target": output_, - }, - "data_2": { - "input": input_2_, - "target": output_2_, - }, -} +# # Problem with multiple conditions +# conditions_dict_multi = { +# "data_1": { +# "input": input_, +# "target": output_, +# }, +# "data_2": { +# "input": input_2_, +# "target": output_2_, +# }, +# } -max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} +# max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_multi, max_conditions_lengths_multi), - ], -) -def test_constructor(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - assert isinstance(dataset, PinaGraphDataset) - assert len(dataset) == 100 +# @pytest.mark.parametrize( +# "conditions_dict, max_conditions_lengths", +# [ +# (conditions_dict_single, max_conditions_lengths_single), +# (conditions_dict_multi, max_conditions_lengths_multi), +# ], +# ) +# def test_constructor(conditions_dict, max_conditions_lengths): +# dataset = PinaDatasetFactory( +# conditions_dict, +# max_conditions_lengths=max_conditions_lengths, +# automatic_batching=True, +# ) +# assert isinstance(dataset, PinaGraphDataset) +# assert len(dataset) == 100 -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_multi, max_conditions_lengths_multi), - ], -) -def test_getitem(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - data = dataset[50] - assert isinstance(data, dict) - assert all([isinstance(d["input"], Data) for d in data.values()]) - assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) - assert all( - [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()] - ) - assert all( - [d["target"].shape == torch.Size((20, 10)) for d in data.values()] - ) - assert all( - [ - d["input"].edge_index.shape == torch.Size((2, 60)) - for d in data.values() - ] - ) - assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()]) +# @pytest.mark.parametrize( +# "conditions_dict, max_conditions_lengths", +# [ +# (conditions_dict_single, max_conditions_lengths_single), +# (conditions_dict_multi, max_conditions_lengths_multi), +# ], +# ) +# def test_getitem(conditions_dict, max_conditions_lengths): +# dataset = PinaDatasetFactory( +# conditions_dict, +# max_conditions_lengths=max_conditions_lengths, +# automatic_batching=True, +# ) +# data = dataset[50] +# assert isinstance(data, dict) +# assert all([isinstance(d["input"], Data) for d in data.values()]) +# assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) +# assert all( +# [d["input"].x.shape == torch.Size((20, 10)) for d in data.values()] +# ) +# assert all( +# [d["target"].shape == torch.Size((20, 10)) for d in data.values()] +# ) +# assert all( +# [ +# d["input"].edge_index.shape == torch.Size((2, 60)) +# for d in data.values() +# ] +# ) +# assert all([d["input"].edge_attr.shape[0] == 60 for d in data.values()]) - data = dataset.fetch_from_idx_list([i for i in range(20)]) - assert isinstance(data, dict) - assert all([isinstance(d["input"], Data) for d in data.values()]) - assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) - assert all( - [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()] - ) - assert all( - [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()] - ) - assert all( - [ - d["input"].edge_index.shape == torch.Size((2, 1200)) - for d in data.values() - ] - ) - assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()]) +# data = dataset.fetch_from_idx_list([i for i in range(20)]) +# assert isinstance(data, dict) +# assert all([isinstance(d["input"], Data) for d in data.values()]) +# assert all([isinstance(d["target"], torch.Tensor) for d in data.values()]) +# assert all( +# [d["input"].x.shape == torch.Size((400, 10)) for d in data.values()] +# ) +# assert all( +# [d["target"].shape == torch.Size((20, 20, 10)) for d in data.values()] +# ) +# assert all( +# [ +# d["input"].edge_index.shape == torch.Size((2, 1200)) +# for d in data.values() +# ] +# ) +# assert all([d["input"].edge_attr.shape[0] == 1200 for d in data.values()]) -def test_input_single_condition(): - dataset = PinaDatasetFactory( - conditions_dict_single, - max_conditions_lengths=max_conditions_lengths_single, - automatic_batching=True, - ) - input_ = dataset.input - assert isinstance(input_, dict) - assert isinstance(input_["data"], list) - assert all([isinstance(d, Data) for d in input_["data"]]) +# def test_input_single_condition(): +# dataset = PinaDatasetFactory( +# conditions_dict_single, +# max_conditions_lengths=max_conditions_lengths_single, +# automatic_batching=True, +# ) +# input_ = dataset.input +# assert isinstance(input_, dict) +# assert isinstance(input_["data"], list) +# assert all([isinstance(d, Data) for d in input_["data"]]) -def test_input_multi_condition(): - dataset = PinaDatasetFactory( - conditions_dict_multi, - max_conditions_lengths=max_conditions_lengths_multi, - automatic_batching=True, - ) - input_ = dataset.input - assert isinstance(input_, dict) - assert isinstance(input_["data_1"], list) - assert all([isinstance(d, Data) for d in input_["data_1"]]) - assert isinstance(input_["data_2"], list) - assert all([isinstance(d, Data) for d in input_["data_2"]]) +# def test_input_multi_condition(): +# dataset = PinaDatasetFactory( +# conditions_dict_multi, +# max_conditions_lengths=max_conditions_lengths_multi, +# automatic_batching=True, +# ) +# input_ = dataset.input +# assert isinstance(input_, dict) +# assert isinstance(input_["data_1"], list) +# assert all([isinstance(d, Data) for d in input_["data_1"]]) +# assert isinstance(input_["data_2"], list) +# assert all([isinstance(d, Data) for d in input_["data_2"]]) diff --git a/tests/test_data/test_tensor_dataset.py b/tests/test_data/test_tensor_dataset.py index 81a122f2f..183becb35 100644 --- a/tests/test_data/test_tensor_dataset.py +++ b/tests/test_data/test_tensor_dataset.py @@ -1,86 +1,86 @@ -import torch -import pytest -from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset +# import torch +# import pytest +# from pina.data.dataset import PinaDatasetFactory, PinaTensorDataset -input_tensor = torch.rand((100, 10)) -output_tensor = torch.rand((100, 2)) +# input_tensor = torch.rand((100, 10)) +# output_tensor = torch.rand((100, 2)) -input_tensor_2 = torch.rand((50, 10)) -output_tensor_2 = torch.rand((50, 2)) +# input_tensor_2 = torch.rand((50, 10)) +# output_tensor_2 = torch.rand((50, 2)) -conditions_dict_single = { - "data": { - "input": input_tensor, - "target": output_tensor, - } -} +# conditions_dict_single = { +# "data": { +# "input": input_tensor, +# "target": output_tensor, +# } +# } -conditions_dict_single_multi = { - "data_1": { - "input": input_tensor, - "target": output_tensor, - }, - "data_2": { - "input": input_tensor_2, - "target": output_tensor_2, - }, -} +# conditions_dict_single_multi = { +# "data_1": { +# "input": input_tensor, +# "target": output_tensor, +# }, +# "data_2": { +# "input": input_tensor_2, +# "target": output_tensor_2, +# }, +# } -max_conditions_lengths_single = {"data": 100} +# max_conditions_lengths_single = {"data": 100} -max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} +# max_conditions_lengths_multi = {"data_1": 100, "data_2": 50} -@pytest.mark.parametrize( - "conditions_dict, max_conditions_lengths", - [ - (conditions_dict_single, max_conditions_lengths_single), - (conditions_dict_single_multi, max_conditions_lengths_multi), - ], -) -def test_constructor_tensor(conditions_dict, max_conditions_lengths): - dataset = PinaDatasetFactory( - conditions_dict, - max_conditions_lengths=max_conditions_lengths, - automatic_batching=True, - ) - assert isinstance(dataset, PinaTensorDataset) +# @pytest.mark.parametrize( +# "conditions_dict, max_conditions_lengths", +# [ +# (conditions_dict_single, max_conditions_lengths_single), +# (conditions_dict_single_multi, max_conditions_lengths_multi), +# ], +# ) +# def test_constructor_tensor(conditions_dict, max_conditions_lengths): +# dataset = PinaDatasetFactory( +# conditions_dict, +# max_conditions_lengths=max_conditions_lengths, +# automatic_batching=True, +# ) +# assert isinstance(dataset, PinaTensorDataset) -def test_getitem_single(): - dataset = PinaDatasetFactory( - conditions_dict_single, - max_conditions_lengths=max_conditions_lengths_single, - automatic_batching=False, - ) +# def test_getitem_single(): +# dataset = PinaDatasetFactory( +# conditions_dict_single, +# max_conditions_lengths=max_conditions_lengths_single, +# automatic_batching=False, +# ) - tensors = dataset.fetch_from_idx_list([i for i in range(70)]) - assert isinstance(tensors, dict) - assert list(tensors.keys()) == ["data"] - assert sorted(list(tensors["data"].keys())) == ["input", "target"] - assert isinstance(tensors["data"]["input"], torch.Tensor) - assert tensors["data"]["input"].shape == torch.Size((70, 10)) - assert isinstance(tensors["data"]["target"], torch.Tensor) - assert tensors["data"]["target"].shape == torch.Size((70, 2)) +# tensors = dataset.fetch_from_idx_list([i for i in range(70)]) +# assert isinstance(tensors, dict) +# assert list(tensors.keys()) == ["data"] +# assert sorted(list(tensors["data"].keys())) == ["input", "target"] +# assert isinstance(tensors["data"]["input"], torch.Tensor) +# assert tensors["data"]["input"].shape == torch.Size((70, 10)) +# assert isinstance(tensors["data"]["target"], torch.Tensor) +# assert tensors["data"]["target"].shape == torch.Size((70, 2)) -def test_getitem_multi(): - dataset = PinaDatasetFactory( - conditions_dict_single_multi, - max_conditions_lengths=max_conditions_lengths_multi, - automatic_batching=False, - ) - tensors = dataset.fetch_from_idx_list([i for i in range(70)]) - assert isinstance(tensors, dict) - assert list(tensors.keys()) == ["data_1", "data_2"] - assert sorted(list(tensors["data_1"].keys())) == ["input", "target"] - assert isinstance(tensors["data_1"]["input"], torch.Tensor) - assert tensors["data_1"]["input"].shape == torch.Size((70, 10)) - assert isinstance(tensors["data_1"]["target"], torch.Tensor) - assert tensors["data_1"]["target"].shape == torch.Size((70, 2)) +# def test_getitem_multi(): +# dataset = PinaDatasetFactory( +# conditions_dict_single_multi, +# max_conditions_lengths=max_conditions_lengths_multi, +# automatic_batching=False, +# ) +# tensors = dataset.fetch_from_idx_list([i for i in range(70)]) +# assert isinstance(tensors, dict) +# assert list(tensors.keys()) == ["data_1", "data_2"] +# assert sorted(list(tensors["data_1"].keys())) == ["input", "target"] +# assert isinstance(tensors["data_1"]["input"], torch.Tensor) +# assert tensors["data_1"]["input"].shape == torch.Size((70, 10)) +# assert isinstance(tensors["data_1"]["target"], torch.Tensor) +# assert tensors["data_1"]["target"].shape == torch.Size((70, 2)) - assert sorted(list(tensors["data_2"].keys())) == ["input", "target"] - assert isinstance(tensors["data_2"]["input"], torch.Tensor) - assert tensors["data_2"]["input"].shape == torch.Size((50, 10)) - assert isinstance(tensors["data_2"]["target"], torch.Tensor) - assert tensors["data_2"]["target"].shape == torch.Size((50, 2)) +# assert sorted(list(tensors["data_2"].keys())) == ["input", "target"] +# assert isinstance(tensors["data_2"]["input"], torch.Tensor) +# assert tensors["data_2"]["input"].shape == torch.Size((50, 10)) +# assert isinstance(tensors["data_2"]["target"], torch.Tensor) +# assert tensors["data_2"]["target"].shape == torch.Size((50, 2)) diff --git a/tests/test_solver/test_supervised_solver.py b/tests/test_solver/test_supervised_solver.py index 6f7d1ab4d..5d9709b75 100644 --- a/tests/test_solver/test_supervised_solver.py +++ b/tests/test_solver/test_supervised_solver.py @@ -117,6 +117,10 @@ def test_solver_train(use_lt, batch_size, compile): assert isinstance(solver.model, OptimizedModule) +if __name__ == "__main__": + test_solver_train(use_lt=True, batch_size=20, compile=True) + + @pytest.mark.parametrize("batch_size", [None, 1, 5, 20]) @pytest.mark.parametrize("use_lt", [True, False]) def test_solver_train_graph(batch_size, use_lt):