In [None]:
#default_exp datasets

In [None]:
#exporti
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import random_split
from tqdm import tqdm
import warnings
import random
from typing import Union

In [None]:
#hide
from nbdev.showdoc import show_doc

# Topo dataset

In [None]:
#export
class TopoDataset(Dataset):
    """
    A class for the generation of datasets. TopoDataset inherits from `torch.utils.data.Dataset`, so all functionalities from PyTorch are also available here.
    """
    def __init__(self, 
                 dataset:list=[], # A list containing either only problems or tuples `(problem, gt_solution)` of problems and corresponding ground truth solutions. By default, `dataset=[]`, so the dataset is empty. However, it can still be changed later via `TopoDataset.dataset=...`.
                 name:str=None, # The name of the dataset.
                 verbose:bool=True # Whether to give the user feedback on the progress.
                ):
        self.dataset = dataset
        self.name = name
        self.verbose = verbose


    @property
    def dataset(self):
        return self._dataset


    @dataset.setter
    def dataset(self, dataset_):
        self._dataset = dataset_
        self._size = len(dataset_)


    @property
    def size(self):
        return self._size


    def __len__(self):
        """
        Returns the size of `self.dataset`.
        """
        return len(self.dataset)


    def __getitem__(self,
                    idx:int # The index for which `(problem, gt_solution)` should be returned.
                   ):
        """
        Returns the tuple `(problem, gt_solution)` for index `idx`.
        """
        if idx >= len(self):
            raise IndexError(f"Could not find dataset entry with index {idx}.")
        return self.dataset[idx]


    def _build_empty_topo_dataset_with_same_attributes(self):
        topo_dataset = TopoDataset(verbose=False)
        topo_dataset.name = self.name
        return topo_dataset


    def get_samples(self,
                    n:int=-1, # The number of samples that should be returned. The default choice `n=-1` returns all samples from the dataset.
                    shuffle:bool=True, # Whether to take the samples from a shuffled dataset. If `False`, then the first samples from the dataset are taken.
                    seed:int=42 # The random seed for the shuffling
                   ):
        """
        Returns a tuple of lists `(problems, gt_solutions)` of length `n`.
        """
        if n == -1:
            n = len(self)
        else:
            n = min(len(self), n)

        if shuffle:
            random.seed(seed)
            samples = random.sample(self.dataset, n)
        else:
            samples = self.dataset[:n]

        return tuple(zip(*samples[:]))


    def get_problems(self,
                    n:int=-1, # The number of problems that should be returned. The default choice `n=-1` returns all problems from the dataset.
                    shuffle:bool=True, # Whether to take the problems from a shuffled dataset. If `False`, then the first problems from the dataset are taken.
                    seed:int=42 # The random seed for the shuffling
                   ):
        """
        Returns a list of length `n` which contains problems from the dataset.
        """
        samples = self.get_samples(n=n, shuffle=shuffle, seed=seed)
        if len(samples) > 0:
            return samples[0]


    def get_gt_solutions(self,
                    n:int=-1, # The number of ground truth solutions that should be returned. The default choice `n=-1` returns all solutions from the dataset.
                    shuffle:bool=True, # Whether to take the solutions from a shuffled dataset. If `False`, then the first solutions from the dataset are taken.
                    seed:int=42 # The random seed for the shuffling
                        ):
        """
        Returns a list of length `n` which contains ground truth solutions from the dataset.
        """
        samples = self.get_samples(n=n, shuffle=shuffle, seed=seed)
        if len(samples) > 1:
            return samples[1]


    def get_subset(self,
                   size:int, # The size of the returned topo dataset.
                   shuffle=True, # Whether to take the samples from a shuffled dataset. If `False`, then the first samples from the dataset are taken.
                   seed=42, # The random seed for the shuffling
                   invert_order=False # Whether the last samples should be taken (instead of the first samples). Only has an effect if `shuffle=False`.
                  ):
        """
        Returns a new `dl4to.dataset.TopoDataset` object with a subset of `size` samples from the original dataset.
        """
        if shuffle:
            random.seed(seed)
            dataset = random.sample(self.dataset, len(self.dataset))
        else:
            dataset = self.dataset

        topo_dataset = self._build_empty_topo_dataset_with_same_attributes()
        if invert_order:
            topo_dataset.dataset = dataset[-size:]
        else:
            topo_dataset.dataset = dataset[:size]
        topo_dataset._size = len(topo_dataset.dataset)
        return topo_dataset


    def info(self):
        """
        Prints basic information concerning the dataset.
        """
        print(f"This TopoDataset is called {self.name} and contains {len(self)} samples.")


    def __add__(self, 
                dataset:Union["dl4to.dataset.TopoDataset",list] # The dataset that is added to this one. If `dataset` is a list, then the samples in the list are added to the current dataset.
               ):
        """
        Adding up two datasets results in a new dataset object that contains the samples from both original datasets. 
        """
        if issubclass(type(dataset), TopoDataset):
            return CombinedTopoDataset(self, dataset)
        if type(dataset) == list:
            self.dataset += dataset
            return self
        raise AttributeError("dataset must be either a list or a TopoDataset object.")

In [None]:
show_doc(TopoDataset.__len__)

<h4 id="TopoDataset.__len__" class="doc_header"><code>TopoDataset.__len__</code><a href="__main__.py#L32" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.__len__</code>()

Returns the size of `self.dataset`.

In [None]:
show_doc(TopoDataset.__getitem__)

<h4 id="TopoDataset.__getitem__" class="doc_header"><code>TopoDataset.__getitem__</code><a href="__main__.py#L39" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.__getitem__</code>(**`idx`**:`int`)

Returns the tuple `(problem, gt_solution)` for index `idx`.

||Type|Default|Details|
|---|---|---|---|
|**`idx`**|`int`||The index for which `(problem, gt_solution)` should be returned.|


In [None]:
show_doc(TopoDataset.get_samples)

<h4 id="TopoDataset.get_samples" class="doc_header"><code>TopoDataset.get_samples</code><a href="__main__.py#L56" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.get_samples</code>(**`n`**:`int`=*`-1`*, **`shuffle`**:`bool`=*`True`*, **`seed`**:`int`=*`42`*)

Returns a tuple of lists `(problems, gt_solutions)` of length `n`.

||Type|Default|Details|
|---|---|---|---|
|**`n`**|`int`|`-1`|The number of samples that should be returned. The default choice `n=-1` returns all samples from the dataset.|
|**`shuffle`**|`bool`|`True`|Whether to take the samples from a shuffled dataset. If `False`, then the first samples from the dataset are taken.|
|**`seed`**|`int`|`42`|The random seed for the shuffling|


In [None]:
show_doc(TopoDataset.get_problems)

<h4 id="TopoDataset.get_problems" class="doc_header"><code>TopoDataset.get_problems</code><a href="__main__.py#L78" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.get_problems</code>(**`n`**:`int`=*`-1`*, **`shuffle`**:`bool`=*`True`*, **`seed`**:`int`=*`42`*)

Returns a list of length `n` which contains problems from the dataset.

||Type|Default|Details|
|---|---|---|---|
|**`n`**|`int`|`-1`|The number of problems that should be returned. The default choice `n=-1` returns all problems from the dataset.|
|**`shuffle`**|`bool`|`True`|Whether to take the problems from a shuffled dataset. If `False`, then the first problems from the dataset are taken.|
|**`seed`**|`int`|`42`|The random seed for the shuffling|


In [None]:
show_doc(TopoDataset.get_gt_solutions)

<h4 id="TopoDataset.get_gt_solutions" class="doc_header"><code>TopoDataset.get_gt_solutions</code><a href="__main__.py#L91" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.get_gt_solutions</code>(**`n`**:`int`=*`-1`*, **`shuffle`**:`bool`=*`True`*, **`seed`**:`int`=*`42`*)

Returns a list of length `n` which contains ground truth solutions from the dataset.

||Type|Default|Details|
|---|---|---|---|
|**`n`**|`int`|`-1`|The number of ground truth solutions that should be returned. The default choice `n=-1` returns all solutions from the dataset.|
|**`shuffle`**|`bool`|`True`|Whether to take the solutions from a shuffled dataset. If `False`, then the first solutions from the dataset are taken.|
|**`seed`**|`int`|`42`|The random seed for the shuffling|


In [None]:
show_doc(TopoDataset.get_subset)

<h4 id="TopoDataset.get_subset" class="doc_header"><code>TopoDataset.get_subset</code><a href="__main__.py#L104" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.get_subset</code>(**`size`**:`int`, **`shuffle`**=*`True`*, **`seed`**=*`42`*, **`invert_order`**=*`False`*)

Returns a new `dl4to.dataset.TopoDataset` object with a subset of `size` samples from the original dataset.

||Type|Default|Details|
|---|---|---|---|
|**`size`**|`int`||The size of the returned topo dataset.|
|**`shuffle`**|`bool`|`True`|Whether to take the samples from a shuffled dataset. If `False`, then the first samples from the dataset are taken.|
|**`seed`**|`int`|`42`|The random seed for the shuffling|
|**`invert_order`**|`bool`|`False`|Whether the last samples should be taken (instead of the first samples). Only has an effect if `shuffle=False`.|


In [None]:
show_doc(TopoDataset.info)

<h4 id="TopoDataset.info" class="doc_header"><code>TopoDataset.info</code><a href="__main__.py#L128" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.info</code>()

Prints basic information concerning the dataset.

In [None]:
show_doc(TopoDataset.__add__)

<h4 id="TopoDataset.__add__" class="doc_header"><code>TopoDataset.__add__</code><a href="__main__.py#L135" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoDataset.__add__</code>(**`dataset`**:`Union`\[`ForwardRef('dl4to.dataset.TopoDataset')`, `list`\])

Adding up two datasets results in a new dataset object that contains the samples from both original datasets. 

||Type|Default|Details|
|---|---|---|---|
|**`dataset`**|`typing.Union[ForwardRef('dl4to.dataset.TopoDataset'), list]`||The dataset that is added to this one. If `dataset` is a list, then the samples in the list are added to the current dataset.|


In [None]:
#export
class CombinedTopoDataset(TopoDataset):
    """
    A class that results from the summation of two topo datasets.
    """
    def __init__(self, 
                 dataset1:"dl4to.dataset.TopoDataset", # The first dataset of the summation.
                 dataset2:"dl4to.dataset.TopoDataset", # The second dataset of the summation.
                ):
        self._size = dataset1.size + dataset2.size
        self.name = f'{dataset1.name}_plus_{dataset2.name}'
        self.verbose = dataset1.verbose

        self.topo_dataset1 = dataset1
        self.topo_dataset2 = dataset2

        self.dataset1_ratio = len(self.topo_dataset1) / len(self)

        if self.verbose != dataset2.verbose:
            self.verbose = True
            warnings.warn(f"`verbose` attribute of the two datasets does not coincide. Automatically setting `verbose=True`.")
        if self.verbose:
            from_dataset1 = len(self.topo_dataset1)*[1] + len(self.topo_dataset2)*[0]
            print(f"Created combined dataset. {sum(from_dataset1)} of the samples (={100*self.dataset1_ratio:.2f}%) are from the first passed dataset (total: {len(self)} samples).")


    @property
    def dataset(self):
        return self.topo_dataset1.dataset + self.topo_dataset2.dataset


    def get_subset(self,
                   size:int, # The size of the returned topo dataset.
                   shuffle:bool=True, # Whether the dataset should be shuffled. If `False`, then the first samples from both datasets are taken.
                   seed:int=42, # The random seed for the shuffling.
                   balanced:bool=True # Whether the ratio between `dataset1` and `dataset2` should be maintained in the subset.
                  ):
        """
        Returns an instance of `dl4to.dataset.TopoDataset` with a subset of `size` samples from the original dataset. 
        """
        if not balanced:
            return super().get_subset(size=size, shuffle=shuffle, seed=seed)

        size_dataset1 = round(size * self.dataset1_ratio)
        size_dataset2 = size - size_dataset1
        if shuffle:
            random.seed(seed)
            dataset1 = random.sample(self.topo_dataset1.dataset, len(size_dataset1))
            random.seed(seed)
            dataset2 = random.sample(self.topo_dataset2.dataset, len(size_dataset2))
        else:
            dataset1 = self.topo_dataset1.dataset[:len(size_dataset1)]
            dataset2 = self.topo_dataset2.dataset[:len(size_dataset2)]
        topo_dataset = self._build_empty_topo_dataset_with_same_attributes()
        topo_dataset.dataset = dataset1 + dataset2
        return topo_dataset

In [None]:
show_doc(CombinedTopoDataset.get_subset)

<h4 id="CombinedTopoDataset.get_subset" class="doc_header"><code>CombinedTopoDataset.get_subset</code><a href="__main__.py#L32" class="source_link" style="float:right">[source]</a></h4>

> <code>CombinedTopoDataset.get_subset</code>(**`size`**, **`seed`**=*`42`*, **`balanced`**=*`True`*)

Returns an instance of `dl4to.dataset.TopoDataset` with a subset of `size` samples from the original dataset. 

||Type|Default|Details|
|---|---|---|---|
|**`size`**|||The size of the returned topo dataset.|
|**`seed`**|`int`|`42`|The random seed for the shuffling.|
|**`balanced`**|`bool`|`True`|Whether the ratio between `dataset1` and `dataset2` should be maintained in the subset.|
