# Base Environment

> Base environment class based on Gymnasium

In [None]:
#| default_exp envs.base

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import gymnasium as gym
from abc import ABC, abstractmethod
from typing import Union
import numpy as np

from ddopnew.utils import MDPInfo
from ddopnew.utils import Parameter

In [None]:
#| export
class BaseEnvironment(gym.Env, ABC):

    """

    Base class for environments enforcing a common interface.
    """

    def __init__(self,
                    mdp_info: MDPInfo, # MDPInfo object to ensure compatibility with the agents
                    mode: str = "train", # Initial mode (train, val, test) of the environment
                    ) -> None: #

        super().__init__()

        self._mode = mode
        self.mdp_info = mdp_info
        
        if mode == "train": 
            self.train()
        elif mode == "val":
            self.val()
        elif mode == "test":
            self.test()
        else:
            raise ValueError("mode must be 'train', 'val', or 'test'")


    def set_param(self,
                        name: str, # name of the parameter (will become the attribute name)
                        input: Parameter | float | np.ndarray, # input value of the parameter
                        shape: tuple = (1,), # shape of the parameter
                        new: bool = False # whether to create a new parameter or update an existing one
                        ) -> None: #
        
        """
        Set a parameter for the environment. It converts scalar values to numpy arrays and ensures that
        environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, 
        the function will create a new parameter or update an existing one otherwise. If new is set to
        False, the function will raise an error if the parameter does not exist.
        """

        # check if input is a valid type
        if isinstance(input, Parameter):
            if input.shape != shape:
                raise ValueError("Parameter shape must be equal to the shape specified for this environment parameter")
            param = input
        
        elif isinstance(input, (int, float)):
            param = np.full(shape, input)

        elif isinstance(input, list):
            input = np.array(input)
            if input.shape == shape:
                param = input
            elif input.size == 1:  # Handle single-element arrays correctly
                param = np.full(shape, input.item())
            else:
                raise ValueError("Input array must match the specified shape or be a single-element array")

        elif isinstance(input, np.ndarray):
            if input.shape == shape:
                param = input
            elif input.size == 1:  # Handle single-element arrays correctly
                param = np.full(shape, input.item())
            else:
                raise ValueError("Input array must match the specified shape or be a single-element array")
        else:
            raise TypeError("Input must be a Parameter, scalar, or numpy array")

        # set the parameter
        if new:
            setattr(self, name, param)
        else:
            # check if parameter already exists
            if not hasattr(self, name):
                raise AttributeError(f"Parameter {name} does not exist in this environment")
            else:
                getattr(self, name).set(param)

    @property
    def info(self):
        """
        Returns: The MDPInfo object of the environment.

        """
        return self._mdp_info

    @property
    def mode(self):
        """
        Returns: A string with the current mode (train, test val) of the environment.

        """
        return self._mode

    @abstractmethod
    def set_action_space(self):
        """
        Set the action space of the environment.

        """
        pass

    @abstractmethod
    def set_observation_space(self):
        """
        Set the observation space of the environment.
        In general, this can be also a dict space, but the agent must have the appropriate pre-processor.

        """
        pass

    @abstractmethod
    def get_observation(self):
        """
        Return the current observation. Typically constructed from the output of the dataloader and 
        internal dynamics (such as inventory levels, pipeline vectors, etc.) of the environment.

        """
        pass

    @abstractmethod
    def reset(self):
        """
        Reset the environment. This function must be provided, using the function self.reset_index() to 
        handle indexing. It needs to account for the current training mode train, val, or test and handle
        the horizon_train param. See the reset function for the NewsvendorEnv for an example.

        """
        pass
    
    def set_index(self, index=None):
        """
        Handle the index of the environment.

        """

        if index is not None:
            self.index = index
        else:
            self.index += 1
        truncated = True if self.index >= self.mdp_info.horizon else False
        
        return truncated


    def reset_index(self,
        start_index: Union[int,str]):

        """

        Reset the index of the environment. If start_index is an integer, the index is set to this value. If start_index is "random",
        the index is set to a random integer between 0 and the length of the training data.

        """
 
        if start_index=="random":
            if self.mode == "train":
                truncated = self.set_index(np.random.randint(0, self.dataloader.len_train)) # assuming we only start randomly during training.
            else:
                raise ValueError("start_index cannot be 'random' in val or test mode")
        elif isinstance(start_index, int):
            truncated = self.set_index(start_index)
        else:
            raise ValueError("start_index must be an integer or 'random'")
        
        return truncated

    def update_mdp_info(self, gamma=None, horizon=None):
        
        """
        Update the MDP info of the environment.

        """
        if gamma is not None:
            self.mdp_info.gamma = gamma
        if horizon is not None:
            self.mdp_info.horizon = horizon

    def train(self, update_mdp_info=True):
        """
        Set the environment in training mode by both setting the internal state self._train and the dataloader. 
        If the horizon is set to "use_all_data", the horizon is set to the length of the training data, otherwise
        it is set to the horizon_train attribute of the environment. Finally, the function updates the MDP info
        and resets with the new state.

        """
        self._mode = "train"

        if hasattr(self, "dataloader"):
            self.dataloader.train()

            if hasattr(self, "horizon_train"):
                if self.horizon_train == "use_all_data":
                    horizon = self.dataloader.len_train
                else:
                    horizon = self.horizon_train
        else:
            horizon = self.mdp_info.horizon

        if update_mdp_info:
            self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=horizon)

        self.reset()
    
    def val(self, update_mdp_info=True):
        """
        Set the environment in validation mode by both setting the internal state self._val and the dataloader.
        The horizon of val is always set to the length of the validation data. Finally, the function updates the MDP info
        and resets with the new state.

        """
        self._mode = "val"

        if hasattr(self, "dataloader"):
            self.dataloader.val()
            horizon = self.dataloader.len_val
        else:
            horizon = self.mdp_info.horizon

        if update_mdp_info:
            self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=horizon)

        self.reset()

    def test(self, update_mdp_info=True):
        """
        Set the environment in testing mode by both setting the internal state self._test and the dataloader.
        The horizon of test is always set to the length of the test data. Finally, the function updates the MDP info
        and resets with the new state.

        """
        self._mode = "test"

        if hasattr(self, "dataloader"):
            self.dataloader.test()
            horizon = self.dataloader.len_test
        else:
            horizon = self.mdp_info.horizon

        if update_mdp_info:
            self.update_mdp_info(gamma=self.mdp_info.gamma, horizon=horizon)

        self.reset()

In [None]:
show_doc(BaseEnvironment, title_level=2)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envsdfd.py#L13){target="_blank" style="float:right; font-size:smaller"}

## BaseEnvironment

>      BaseEnvironment (mdp_info:ddopnew.utils.MDPInfo, mode:str='train')

*Base class for environments enforcing a common interface.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| mdp_info | MDPInfo |  | MDPInfo object to ensure compatibility with the agents |
| mode | str | train | Initial mode (train, val, test) of the environment |
| **Returns** | **None** |  |  |

### Important notes:


**init method**:

* When adding parameters to the environment, make sure to always add them via self.[parameter name] = set_param(...). This ensures all parameters are of the correct types and shapes.

* During the init method, any Gymnasium environment expects the action and observation space to be defined. For clarity, avoid doing it directly in the init, but rather use the functions ```set_action_space()``` and ```set_observation_space()``` and call them in the ```___init___`` method.


**train, val, test, and horizon (episode length)**:

* When the ```__init__``` method is called, the environment executes the ```train()```, ```val()``` or ```test()``` methods. Therefore, they must be implemented in a way that they work right during set-up. 

* ```train()```, ```val()``` and ```test()``` methods are provided in the base class, but can also be overwritten if necessary. In any case, they must set the dataloader to the correct dataset to ensure no
data leakage. They also need to update mdp_info to update the horizon (episode length) of the environment

* The horizon for validation and testing will be equal to the length of those datasets. For training, there is a parameter ```horizon_train``` that either contains a string "use_all_data" or an integer. If it is the former, the horizon will be the length of the training dataset. If it is the latter, the environment will play an episode of length ```horizon_train``` starting at a random point of the training dataset. 

**step method**:

* The step method is the core of the environment, calculating the next state (observation) and reward given an action. For clarity, the construction of the next state (we call it more general observation to include POMDPs) is done in a separate method called ```get_observation()``` that must be called inside the step function. See documentation below and the Newsvendor environment ```envs.inventory.NewsvendorEnv``` for an example.

* The dataloader will typically return an X,Y pair (where X are some features and Y typically is demand) The X is necessary at the end of the step to construct the next observation to be returned to the agent. The Y is only relevant one step later to calculate the reward. Hence, Y is typically transferred to the next step method via an object variable like self.demand (see ```envs.inventory.NewsvendorEnv``` as an example).


**reset method**:

* The reset method may depend strongly on the environment dynamics, so it must be implemented for the specific environment. It needs to fulfill two requirements: 1) it needs to differenticate between train, val, and test mode and 2) when setting the training mode, it needs to
take the ```horizon_train``` parameter into account. 

* At the end of the function, first the ```reset_index()``` method should be called (either with a specific index as string or the flag ``` "random" ```as input) and then the ```get_observation()``` method to construct the first observation.

In [None]:
show_doc(BaseEnvironment.set_param)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L199){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_param

>      BaseEnvironment.set_param (name:str,
>                                 input:ddopnew.utils.Parameter|float|numpy.ndar
>                                 ray, shape:tuple=(1,), new:bool=False)

*Set a parameter for the environment. It converts scalar values to numpy arrays and ensures that
environment parameters are either of the Parameter class of Numpy arrays. If new is set to True, 
the function will create a new parameter or update an existing one otherwise. If new is set to
False, the function will raise an error if the parameter does not exist.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| name | str |  | name of the parameter (will become the attribute name) |
| input | ddopnew.utils.Parameter \| float \| numpy.ndarray |  | input value of the parameter |
| shape | tuple | (1,) | shape of the parameter |
| new | bool | False | whether to create a new parameter or update an existing one |
| **Returns** | **None** |  |  |

In [None]:
show_doc(BaseEnvironment.info)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envsdfd.py#L29){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.info

>      BaseEnvironment.info ()

*Returns: The MDPInfo object of the environment.*

In [None]:
show_doc(BaseEnvironment.mode)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L59){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.mode

>      BaseEnvironment.mode ()

*Returns: A string with the current mode (train, test val) of the environment.*

In [None]:
show_doc(BaseEnvironment.set_action_space)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envsdfd.py#L38){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_action_space

>      BaseEnvironment.set_action_space ()

*Set the action space of the environment.*

In [None]:
show_doc(BaseEnvironment.set_observation_space)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envsdfd.py#L46){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_observation_space

>      BaseEnvironment.set_observation_space ()

*Set the observation space of the environment.
In general, this can be also a dict space, but the agent must have the appropriate pre-processor.*

In [None]:
show_doc(BaseEnvironment.get_observation)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envsdfd.py#L54){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.get_observation

>      BaseEnvironment.get_observation ()

*Return the current observation. Typically constructed from the output of the dataloader and 
internal dynamics (such as inventory levels, pipeline vectors, etc.) of the environment.*

In [None]:
show_doc(BaseEnvironment.reset)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L92){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.reset

>      BaseEnvironment.reset ()

*Reset the environment. This function must be provided, using the function self.reset_index() to 
handle indexing. It needs to account for the current training mode train, val, or test and handle
the horizon_train param. See the reset function for the NewsvendorEnv for an example.*

In [None]:
show_doc(BaseEnvironment.set_index)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L100){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.set_index

>      BaseEnvironment.set_index (index=None)

*Handle the index of the environment.*

In [None]:
show_doc(BaseEnvironment.reset_index)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L173){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.reset_index

>      BaseEnvironment.reset_index (start_index:Union[int,str])

*Reset the index of the environment. If start_index is an integer, the index is set to this value. If start_index is "random",
the index is set to a random integer between 0 and the length of the training data.*

In [None]:
show_doc(BaseEnvironment.update_mdp_info)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L185){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.update_mdp_info

>      BaseEnvironment.update_mdp_info (gamma=None, horizon=None)

*Update the MDP info of the environment.*

In [None]:
show_doc(BaseEnvironment.train)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L114){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.train

>      BaseEnvironment.train (update_mdp_info=True)

*Set the environment in training mode by both setting the internal state self._train and the dataloader. 
If the horizon is set to "use_all_data", the horizon is set to the length of the training data, otherwise
it is set to the horizon_train attribute of the environment. Finally, the function updates the MDP info
and resets with the new state.*

In [None]:
show_doc(BaseEnvironment.val)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L137){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.val

>      BaseEnvironment.val (update_mdp_info=True)

*Set the environment in validation mode by both setting the internal state self._val and the dataloader.
The horizon of val is always set to the length of the validation data. Finally, the function updates the MDP info
and resets with the new state.*

In [None]:
show_doc(BaseEnvironment.test)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/envs/base.py#L155){target="_blank" style="float:right; font-size:smaller"}

### BaseEnvironment.test

>      BaseEnvironment.test (update_mdp_info=True)

*Set the environment in testing mode by both setting the internal state self._test and the dataloader.
The horizon of test is always set to the length of the test data. Finally, the function updates the MDP info
and resets with the new state.*

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()