# Base Environment

> Base environment class based on Gymnasium

In [None]:
#| default_exp envs.base

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import gymnasium as gym
from abc import ABC, abstractmethod
from typing import Union, List
import numpy as np

from ddopai.utils import MDPInfo, Parameter, set_param
import time

In [None]:
#| export
class BaseEnvironment(gym.Env, ABC):

    """

    Base class for environments enforcing a common interface.
    """

    def __init__(self,
                    mdp_info: MDPInfo, # MDPInfo object to ensure compatibility with the agents
                    postprocessors: list[object] | None = None,  # default is empty list
                    mode: str = "train", # Initial mode (train, val, test) of the environment
                    return_truncation: str = True, # whether to return a truncated condition in step function
                    horizon_train: int | str = "use_all_data" # horizon of the training data
                    ) -> None: #

        super().__init__()

        self.horizon_train = horizon_train

        self.return_truncation = return_truncation

        self._mode = mode
        self._mdp_info = mdp_info
        
        if mode == "train": 
            self.train()
        elif mode == "val":
            self.val()
        elif mode == "test":
            self.test()
        else:
            raise ValueError("mode must be 'train', 'val', or 'test'")

        self.postprocessors = postprocessors or []

    def set_param(self,
                        name: str, # name of the parameter (will become the attribute name)
                        input: Parameter | int | float | np.ndarray | List | None, # input value of the parameter
                        shape: tuple = (1,), # shape of the parameter
                        new: bool = False # whether to create a new parameter or update an existing one
                        ) -> None: #
        
        """
        Set a parameter for the environment. It converts scalar values to numpy arrays and ensures that
        environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, 
        the function will create a new parameter or update an existing one otherwise. If new is set to
        False, the function will raise an error if the parameter does not exist.
        """

        set_param(self, name, input, shape, new)

    def return_truncation_handler(self, observation, reward, terminated, truncated, info):
        """ 
        Handle the return_truncation attribute of the environment. This function is called by the step function

        """

        if self.return_truncation:
            return observation, reward, terminated, truncated, info
        else:
            return observation, reward, terminated, info

    def step(self, action):
        
        """
        Step function of the environment. Do not overwrite this function. 
        Instead, write the step_ function. Note that the postprocessor is applied here.

        """
        
        ## apply postprocessor
        for postprocessor in self.postprocessors:
            action = postprocessor(action)

        observation, reward, terminated, truncated, info = self.step_(action)

        return self.return_truncation_handler(observation, reward, terminated, truncated, info)
    
    def add_postprocessor(self, postprocessor: object): # post-processor object that can be called via the "__call__" method
        """Add a postprocessor to the agent"""
        self.postprocessors.append(postprocessor)

    @staticmethod
    def step_(self, action):
        """
        Step function of the environment. It is a wrapper around the step function that handles the return_truncation
        attribute of the environment. It must return the following: observation, reward, terminated, truncated, info

        """
        pass


    @property
    def mdp_info(self):
        """
        Returns: The MDPInfo object of the environment.

        """
        return self._mdp_info

    @property
    def info(self):
        """
        Returns: Alternative call to the method for mushroom_rl.

        """
        return self.mdp_info

    @property
    def mode(self):
        """
        Returns: A string with the current mode (train, test val) of the environment.

        """
        return self._mode

    @abstractmethod
    def set_action_space(self):
        """
        Set the action space of the environment.

        """
        pass

    @abstractmethod
    def set_observation_space(self):
        """
        Set the observation space of the environment.
        In general, this can be also a dict space, but the agent must have the appropriate pre-processor.

        """
        pass

    @abstractmethod
    def get_observation(self):
        """
        Return the current observation. Typically constructed from the output of the dataloader and 
        internal dynamics (such as inventory levels, pipeline vectors, etc.) of the environment.

        """
        pass

    @abstractmethod
    def reset(self):
        """
        Reset the environment. This function must be provided, using the function self.reset_index() to 
        handle indexing. It needs to account for the current training mode train, val, or test and handle
        the horizon_train param. See the reset function for the NewsvendorEnv for an example.

        """
        pass
    
    def set_index(self, index=None):
        
        """
        Handle the index of the environment.

        """

        if index is not None:
            self.index = index
        else:
            self.index += 1
        
        truncated = True if self.index >= self.max_index_episode else False
        
        return truncated

    def get_start_index(self,
        start_index: int | str = None, # index to start from
        ) -> int:
        
        """ Determine if the start index is random or 0,
        depending on the state of the environment and training
        process (over entire train set or in shorter episodes) """

        if start_index is None:
            if self._mode == "train":
                if self.horizon_train == "use_all_data":
                    start_index = 0
                elif hasattr(self.dataloader, "is_distribution") and self.dataloader.is_distribution:
                    start_index = 0
                else:
                    start_index = "random"
            elif self._mode == "val":
                start_index = 0
            elif self._mode == "test":
                start_index = 0
            else:
                raise ValueError("Mode not recognized.")
            
        else:
            start_index = start_index
        
        return start_index

    def reset_index(self,
        start_index: Union[int,str]):

        """

        Reset the index of the environment. If start_index is an integer, the index is set to this value. If start_index is "random",
        the index is set to a random integer between 0 and the length of the training data.

        """

        start_index = self.get_start_index(start_index)
 
        if start_index=="random":
            if self.mode == "train":
                if self.dataloader.len_train is not None and self.dataloader.len_train > self.mdp_info.horizon:
                    random_index = np.random.choice(self.dataloader.len_train-self.mdp_info.horizon)
                else:
                    random_index = 0
                self.start_index = random_index 
            else:
                raise ValueError("start_index cannot be 'random' in val or test mode")
        elif isinstance(start_index, int):
            self.start_index = start_index
        else:
            raise ValueError("start_index must be an integer or 'random'")

        self.max_index = self.dataloader.len_train if self.mode == "train" else self.dataloader.len_val if self.mode == "val" else self.dataloader.len_test
        self.max_index -= 1
        self.max_index_episode = np.minimum(self.max_index, self.start_index+self.mdp_info.horizon)
        if self.mode == "test" or self.mode == "val":
            self.max_index_episode += 1
    
        truncated = self.set_index(self.start_index) # assuming we only start randomly during training.
        
        return truncated

    def update_mdp_info(self, gamma=None, horizon=None):
        
        """
        Update the MDP info of the environment.

        """
        if gamma is not None:
            self._mdp_info.gamma = gamma
        if horizon is not None:
            self._mdp_info.horizon = horizon

    def train(self, update_mdp_info=True):
        """
        Set the environment in training mode by both setting the internal state self._train and the dataloader. 
        If the horizon is set to "use_all_data", the horizon is set to the length of the training data, otherwise
        it is set to the horizon_train attribute of the environment. Finally, the function updates the MDP info
        and resets with the new state.

        """
        self._mode = "train"

        if hasattr(self, "dataloader"):
            self.dataloader.train()

            if hasattr(self, "horizon_train"):
                if self.horizon_train == "use_all_data":
                    horizon = self.dataloader.len_train
                else:
                    horizon = self.horizon_train
        else:
            horizon = self.mdp_info.horizon

        if update_mdp_info:
            self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=horizon)

        self.reset()
    
    def val(self, update_mdp_info=True):
        """
        Set the environment in validation mode by both setting the internal state self._val and the dataloader.
        The horizon of val is always set to the length of the validation data. Finally, the function updates the MDP info
        and resets with the new state.

        """
        self._mode = "val"

        if hasattr(self, "dataloader"):
            self.dataloader.val()
            horizon = self.dataloader.len_val
        else:
            horizon = self.mdp_info.horizon

        if update_mdp_info:
            self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=horizon)

        self.reset()

    def test(self, update_mdp_info=True):
        """
        Set the environment in testing mode by both setting the internal state self._test and the dataloader.
        The horizon of test is always set to the length of the test data. Finally, the function updates the MDP info
        and resets with the new state.

        """
        self._mode = "test"

        if hasattr(self, "dataloader"):
            self.dataloader.test()
            horizon = self.dataloader.len_test
        else:
            horizon = self.mdp_info.horizon

        if update_mdp_info:
            self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=horizon)

        self.reset()

    def set_return_truncation(self, return_truncation: bool): # whether or not to return the truncated condition in the step function
        
        """
        Set the return_truncation attribute of the environment.
        """

        self.return_truncation = return_truncation

    def stop(self):
        """
        Stop the environment. This function is used to ensure compatibility with the Core of mushroom_rl.

        """
        pass


In [None]:
show_doc(BaseEnvironment, title_level=2)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L16){target="_blank" style="float:right; font-size:smaller"}

## BaseEnvironment

>      BaseEnvironment (mdp_info:ddopai.utils.MDPInfo,
>                       postprocessors:list[object]|None=None, mode:str='train',
>                       return_truncation:str=True,
>                       horizon_train:int|str='use_all_data')

*Base class for environments enforcing a common interface.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| mdp_info | MDPInfo |  | MDPInfo object to ensure compatibility with the agents |
| postprocessors | list[object] \| None | None | default is empty list |
| mode | str | train | Initial mode (train, val, test) of the environment |
| return_truncation | str | True | whether to return a truncated condition in step function |
| horizon_train | int \| str | use_all_data | horizon of the training data |
| **Returns** | **None** |  |  |

### Important notes:


**init method**:

* When adding parameters to the environment, make sure to always add them via self.[parameter name] = set_param(...). This ensures all parameters are of the correct types and shapes.

* During the init method, any Gymnasium environment expects the action and observation space to be defined. For clarity, avoid doing it directly in the init, but rather use the functions ```set_action_space()``` and ```set_observation_space()``` and call them in the ```___init___`` method.


**train, val, test, and horizon (episode length)**:

* When the ```__init__``` method is called, the environment executes the ```train()```, ```val()``` or ```test()``` methods. Therefore, they must be implemented in a way that they work right during set-up. 

* ```train()```, ```val()``` and ```test()``` methods are provided in the base class, but can also be overwritten if necessary. In any case, they must set the dataloader to the correct dataset to ensure no
data leakage. They also need to update mdp_info to update the horizon (episode length) of the environment

* The horizon for validation and testing will be equal to the length of those datasets. For training, there is a parameter ```horizon_train``` that either contains a string "use_all_data" or an integer. If it is the former, the horizon will be the length of the training dataset. If it is the latter, the environment will play an episode of length ```horizon_train``` starting at a random point of the training dataset. 

**step method**:

* The step method is the core of the environment, calculating the next state (observation) and reward given an action. Since some frameworks expect a truncation condition (standard implementation in Gymnasium now) while others (e.g., mushroom_rl), do not, the step function is implemented in the base class and handles this (via a flag in in the environment called ```return_truncation```). **DO NOT OVERWRITE** the step function, but rather implement the ```step_(self, action)``` (underscore) method in the specific environment. This function shall always return a tuple of the form (observation, reward, terminated, truncated, info).

* For clarity, the construction of the next state (we call it more general observation to include POMDPs) is done in a separate method called ```get_observation()``` that must be called inside the step function. See documentation below and the Newsvendor environment ```envs.inventory.NewsvendorEnv``` for an example.

* The dataloader will typically return an X,Y pair (where X are some features and Y typically is demand) The X is necessary at the end of the step to construct the next observation to be returned to the agent. The Y is only relevant one step later to calculate the reward. Hence, Y is typically transferred to the next step method via an object variable like self.demand (see ```envs.inventory.NewsvendorEnv``` as an example).

**observation pre-processors and action post-processors**:

* Sometimes, it is necessary to process the observartion before giving it to the agent (e.g., changing shape) or to process the action before giving it to the environment (e.g., rounding). To ensure compatibility with mushroom_rl, the pre-processors sit with the agent (they must be added to the agent and are applied in the agent's ```draw_action()``` method). The post-processors sit with the environment and are applied in the environment's ```step()``` method.

**reset method**:

* The reset method may depend strongly on the environment dynamics, so it must be implemented for the specific environment. It needs to fulfill two requirements: 1) it needs to differenticate between train, val, and test mode and 2) when setting the training mode, it needs to
take the ```horizon_train``` parameter into account. 

* At the end of the function, first the ```reset_index()``` method should be called (either with a specific index as string or the flag ``` "random" ```as input) and then the ```get_observation()``` method to construct the first observation.

In [None]:
show_doc(BaseEnvironment.set_param)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L51){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_param

>      BaseEnvironment.set_param (name:str, input:Union[ddopai.utils.Parameter,
>                                 int,float,numpy.ndarray,List,NoneType],
>                                 shape:tuple=(1,), new:bool=False)

*Set a parameter for the environment. It converts scalar values to numpy arrays and ensures that
environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, 
the function will create a new parameter or update an existing one otherwise. If new is set to
False, the function will raise an error if the parameter does not exist.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| name | str |  | name of the parameter (will become the attribute name) |
| input | Union |  | input value of the parameter |
| shape | tuple | (1,) | shape of the parameter |
| new | bool | False | whether to create a new parameter or update an existing one |
| **Returns** | **None** |  |  |

In [None]:
show_doc(BaseEnvironment.return_truncation_handler)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L67){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.return_truncation_handler

>      BaseEnvironment.return_truncation_handler (observation, reward,
>                                                 terminated, truncated, info)

*Handle the return_truncation attribute of the environment. This function is called by the step function*

In [None]:
show_doc(BaseEnvironment.step)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L78){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.step

>      BaseEnvironment.step (action)

*Step function of the environment. Do not overwrite this function. 
Instead, write the step_ function. Note that the postprocessor is applied here.*

In [None]:
show_doc(BaseEnvironment.add_postprocessor)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L94){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.add_postprocessor

>      BaseEnvironment.add_postprocessor (postprocessor:object)

*Add a postprocessor to the agent*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| postprocessor | object | post-processor object that can be called via the "__call__" method |

In [None]:
show_doc(BaseEnvironment.step_)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L99){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.step_

>      BaseEnvironment.step_ (action)

*Step function of the environment. It is a wrapper around the step function that handles the return_truncation
attribute of the environment. It must return the following: observation, reward, terminated, truncated, info*

In [None]:
show_doc(BaseEnvironment.mdp_info)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L109){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.mdp_info

>      BaseEnvironment.mdp_info ()

*Returns: The MDPInfo object of the environment.*

In [None]:
show_doc(BaseEnvironment.mode)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L125){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.mode

>      BaseEnvironment.mode ()

*Returns: A string with the current mode (train, test val) of the environment.*

In [None]:
show_doc(BaseEnvironment.set_action_space)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L133){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_action_space

>      BaseEnvironment.set_action_space ()

*Set the action space of the environment.*

In [None]:
show_doc(BaseEnvironment.set_observation_space)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L141){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_observation_space

>      BaseEnvironment.set_observation_space ()

*Set the observation space of the environment.
In general, this can be also a dict space, but the agent must have the appropriate pre-processor.*

In [None]:
show_doc(BaseEnvironment.get_observation)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L150){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.get_observation

>      BaseEnvironment.get_observation ()

*Return the current observation. Typically constructed from the output of the dataloader and 
internal dynamics (such as inventory levels, pipeline vectors, etc.) of the environment.*

In [None]:
show_doc(BaseEnvironment.reset)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L159){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.reset

>      BaseEnvironment.reset ()

*Reset the environment. This function must be provided, using the function self.reset_index() to 
handle indexing. It needs to account for the current training mode train, val, or test and handle
the horizon_train param. See the reset function for the NewsvendorEnv for an example.*

In [None]:
show_doc(BaseEnvironment.set_index)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L168){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_index

>      BaseEnvironment.set_index (index=None)

*Handle the index of the environment.*

In [None]:
show_doc(BaseEnvironment.reset_index)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L212){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.reset_index

>      BaseEnvironment.reset_index (start_index:Union[int,str])

*Reset the index of the environment. If start_index is an integer, the index is set to this value. If start_index is "random",
the index is set to a random integer between 0 and the length of the training data.*

In [None]:
show_doc(BaseEnvironment.update_mdp_info)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L248){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.update_mdp_info

>      BaseEnvironment.update_mdp_info (gamma=None, horizon=None)

*Update the MDP info of the environment.*

In [None]:
show_doc(BaseEnvironment.train)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L259){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.train

>      BaseEnvironment.train (update_mdp_info=True)

*Set the environment in training mode by both setting the internal state self._train and the dataloader. 
If the horizon is set to "use_all_data", the horizon is set to the length of the training data, otherwise
it is set to the horizon_train attribute of the environment. Finally, the function updates the MDP info
and resets with the new state.*

In [None]:
show_doc(BaseEnvironment.val)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L285){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.val

>      BaseEnvironment.val (update_mdp_info=True)

*Set the environment in validation mode by both setting the internal state self._val and the dataloader.
The horizon of val is always set to the length of the validation data. Finally, the function updates the MDP info
and resets with the new state.*

In [None]:
show_doc(BaseEnvironment.test)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L305){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.test

>      BaseEnvironment.test (update_mdp_info=True)

*Set the environment in testing mode by both setting the internal state self._test and the dataloader.
The horizon of test is always set to the length of the test data. Finally, the function updates the MDP info
and resets with the new state.*

In [None]:
show_doc(BaseEnvironment.set_return_truncation)

---

[source](https://github.com/opimwue/ddopai/blob/main/ddopai/envs/base.py#L325){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_return_truncation

>      BaseEnvironment.set_return_truncation (return_truncation:bool)

*Set the return_truncation attribute of the environment.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| return_truncation | bool | whether or not to return the truncated condition in the step function |

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()