# Base dataloader

> Base class for dataloaders

In [None]:
#| default_exp dataloaders.base

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import numpy as np
from abc import ABC, abstractmethod
from typing import Union

In [None]:
#| export
class BaseDataLoader(ABC):
   
    """
    Base class for data loaders.
    The idea of the data loader is to provide all external information to the environment
    (including lagged data, demand etc.). Internal data influenced by past decisions (like
    inventory levels) is to be added from within the environment
    """

    def __init__(self):
        self.dataset_type = "train"

    @abstractmethod
    def __len__(self):
        '''
        Returns the length of the dataset. For dataloaders based on distributions, this 
        should return an error that the length is not defined, otherwise it should return
        the number of samples in the dataset.
        '''
        pass

    @abstractmethod   
    def __getitem__(self, idx):

        """
        Returns always a tuple of X and Y data. If no X data is available, return None.
        """
        pass

    @property
    @abstractmethod
    def X_shape(self):
        """
        Returns the shape of the X data.
        It should follow the format (n_samples, n_features). If the data has a time dimension with
        a fixed length, the shape should be (n_samples, n_time_steps, n_features). If the data is 
        generated from a distribtition, n_samples should be set to 1.
        """
        pass

    @property
    @abstractmethod
    def Y_shape(self):
        """
        Returns the shape of the Y data.
        It should follow the format (n_samples, n_SKUs). If the variable of interst is only a single
        SKU, the shape should be (n_samples, 1). If the data is 
        generated from a distribtition, n_samples should be set to 1.
        """
        pass

    @abstractmethod   
    def get_all_X(self,
                dataset_type: str = 'train' # can be 'train', 'val', 'test', 'all'
                ): 

        """
        Returns the entire features dataset. If no X data is available, return None.
        Return either the train, val, test, or all data.
        """
        pass    

    @abstractmethod   
    def get_all_Y(self,
                dataset_type: str = 'train' # can be 'train', 'val', 'test', 'all'
                ): 

        """
        Returns the entire target dataset. If no Y data is available, return None.
        Return either the train, val, test, or all data.
        """
        pass  

    @property
    @abstractmethod   
    def len_train(self):

        """
        Returns the length of the training set. For dataloaders based on distributions, this
        should return an error that the length is not defined, otherwise it should return
        the number of samples in the training set.
        """

        pass

    @property
    @abstractmethod 
    def len_val(self):

        """

        Returns the length of the validation set. For dataloaders based on distributions, this
        should return an error that the length is not defined, otherwise it should return
        the number of samples in the validation set.

        If no valiation set is defined, raise an error.
        """

        pass
    
    @property
    @abstractmethod 
    def len_test(self):
        
        """

        Returns the length of the test set. For dataloaders based on distributions, this
        should return an error that the length is not defined, otherwise it should return
        the number of samples in the test set.

        If no test set is defined, raise an error.
        """

        pass

        
    def train(self):

        """
        Set the internal state of the dataloader to train
        """

        self.dataset_type = "train"

    def val(self):

        """
        Set the internal state of the dataloader to validation
        """

        if self.val_index_start is None:
            raise ValueError('no validation set defined')
        else:
            self.dataset_type = "val"

    def test(self):

        """
        Set the internal state of the dataloader to test
        """

        if self.test_index_start is None:
            raise ValueError('no test set defined')
        else:
            self.dataset_type = "test"


In [None]:
show_doc(BaseDataLoader, title_level=2)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L14){target="_blank" style="float:right; font-size:smaller"}

## BaseDataLoader

>      BaseDataLoader ()

*Base class for data loaders.
The idea of the data loader is to provide all external information to the environment
(including lagged data, demand etc.). Internal data influenced by past decisions (like
inventory levels) is to be added from within the environment*

Train-Val-Test split:

* The dataloader contains all data, including the training, validation and test sets.

* Retrieval of the dataset types is achieved by setting the internal state to train, validation or test using appropriate functions. Then the index will automatically be adjusted to the correct dataset (see below on data retrieval).

* During training, both the agent and experiment function may have to know the length of the dataset. Therefore, the functions  ```len_train```, ```len_val``` and ```len_test``` with decorator ```@property``` must be defined

Data retrieval:

* Data retrieval is done with the ```___getitem___``` function. The function takes an index and returns the data at that index, typically as and X and Y pair.

* For non-distribution-based dataloaders, the ```__init__``` function must have arguments ```val_index_start``` and ```test_index_start``` from which the attributes ```val_index_start``` and ```test_index_start``` and ```train_index_end```are set. The ```__getitem__``` function must then check the index and return the correct data based on the internal state of the dataloader.

In [None]:
show_doc(BaseDataLoader.__len__)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L27){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.__len__

>      BaseDataLoader.__len__ ()

*Returns the length of the dataset. For dataloaders based on distributions, this 
should return an error that the length is not defined, otherwise it should return
the number of samples in the dataset.*

In [None]:
show_doc(BaseDataLoader.__getitem__)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L36){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.__getitem__

>      BaseDataLoader.__getitem__ (idx)

*Returns always a tuple of X and Y data. If no X data is available, return None.*

In [None]:
show_doc(BaseDataLoader.X_shape)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L45){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.X_shape

>      BaseDataLoader.X_shape ()

*Returns the shape of the X data.
It should follow the format (n_samples, n_features). If the data has a time dimension with
a fixed length, the shape should be (n_samples, n_time_steps, n_features). If the data is 
generated from a distribtition, n_samples should be set to 1.*

In [None]:
show_doc(BaseDataLoader.Y_shape)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L56){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.Y_shape

>      BaseDataLoader.Y_shape ()

*Returns the shape of the Y data.
It should follow the format (n_samples, n_SKUs). If the variable of interst is only a single
SKU, the shape should be (n_samples, 1). If the data is 
generated from a distribtition, n_samples should be set to 1.*

In [None]:
show_doc(BaseDataLoader.get_all_X)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L66){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.get_all_X

>      BaseDataLoader.get_all_X (dataset_type:str='train')

*Returns the entire features dataset. If no X data is available, return None.
Return either the train, val, test, or all data.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| dataset_type | str | train | can be 'train', 'val', 'test', 'all' |

In [None]:
show_doc(BaseDataLoader.get_all_Y)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L77){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.get_all_Y

>      BaseDataLoader.get_all_Y (dataset_type:str='train')

*Returns the entire target dataset. If no Y data is available, return None.
Return either the train, val, test, or all data.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| dataset_type | str | train | can be 'train', 'val', 'test', 'all' |

In [None]:
show_doc(BaseDataLoader.len_train)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L89){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.len_train

>      BaseDataLoader.len_train ()

*Returns the length of the training set. For dataloaders based on distributions, this
should return an error that the length is not defined, otherwise it should return
the number of samples in the training set.*

In [None]:
show_doc(BaseDataLoader.len_val)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L101){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.len_val

>      BaseDataLoader.len_val ()

*Returns the length of the validation set. For dataloaders based on distributions, this
should return an error that the length is not defined, otherwise it should return
the number of samples in the validation set.

If no valiation set is defined, raise an error.*

In [None]:
show_doc(BaseDataLoader.len_test)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L116){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.len_test

>      BaseDataLoader.len_test ()

*Returns the length of the test set. For dataloaders based on distributions, this
should return an error that the length is not defined, otherwise it should return
the number of samples in the test set.

If no test set is defined, raise an error.*

In [None]:
show_doc(BaseDataLoader.train)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L130){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.train

>      BaseDataLoader.train ()

*Set the internal state of the dataloader to train*

In [None]:
show_doc(BaseDataLoader.val)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L138){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.val

>      BaseDataLoader.val ()

*Set the internal state of the dataloader to validation*

In [None]:
show_doc(BaseDataLoader.test)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/dataloaders/base.py#L149){target="_blank" style="float:right; font-size:smaller"}

### BaseDataLoader.test

>      BaseDataLoader.test ()

*Set the internal state of the dataloader to test*

In [None]:
#| export
class DummyDataLoader(BaseDataLoader):
   
    """
    Dummy class for data loaders that can be usef for environment that do not require any data.
    """

    def __init__(self):
        self.dataset_type = "train"

    def __len__(self):
        '''

        '''
        pass

    def __getitem__(self, idx):

        """

        """
        pass

    @property
    def X_shape(self):
        """

        """
        pass

    @property
    def Y_shape(self):
        """

        """
        pass

    def get_all_X(self,
                dataset_type: str = 'train' # can be 'train', 'val', 'test', 'all'
                ): 

        """

        """
        pass    


    def get_all_Y(self,
                dataset_type: str = 'train' # can be 'train', 'val', 'test', 'all'
                ): 

        """

        """
        pass  

    @property
    def len_train(self):

        """

        """

        pass

    @property
    def len_val(self):

        """

        """

        pass
    
    @property
    def len_test(self):
        
        """

        """

        pass

        
    def train(self):

        """
        Set the internal state of the dataloader to train
        """

        self.dataset_type = "train"

    def val(self):

        """
        Set the internal state of the dataloader to validation
        """

        if self.val_index_start is None:
            raise ValueError('no validation set defined')
        else:
            self.dataset_type = "val"

    def test(self):

        """
        Set the internal state of the dataloader to test
        """

        if self.test_index_start is None:
            raise ValueError('no test set defined')
        else:
            self.dataset_type = "test"


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()