# Base agents

> To be written.

In [None]:
#| default_exp agents.base

In [None]:
import logging
logging_level = logging.DEBUG

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export


from abc import ABC, abstractmethod
from typing import Union, Optional, List
import numpy as np

from ddopnew.envs.base import BaseEnvironment
from ddopnew.utils import MDPInfo

# # TEMPORARY
# from sklearn.utils.validation import check_array
# import numbers

In [None]:
#| export

class BaseAgent():

    """  
    Base class for all agents to enforce a common interface. See below for more detailed description of the requriements.

    """

    train_mode = "direct_fit" # or "epochs_fit" or "env_interaction"
    
    def __init__(self,
                    environment_info: MDPInfo,
                    preprocessors: list[object] | None = None,  # default is empty list
                    postprocessors: list[object] | None = None  # default is empty list
                 ):

        self.preprocessors = preprocessors or []
        self.postprocessors = postprocessors or []

        self.environment_info = environment_info
        self.mode = "train"
        self.print = False  # Can be used for debugging
        self.receive_batch_dim = False

    def draw_action(self, observation: np.ndarray) -> np.ndarray: #

        """
        Main interfrace to the environemnt. Applies preprocessors to the observation and postprocessors to the action.
        Internal logic of the agent to be implemented in draw_action_ method.
        """

        observation = self.add_batch_dim(observation)

        for preprocessor in self.preprocessors:
            observation = preprocessor(observation)

        action = self.draw_action_(observation)
        
        for postprocessor in self.postprocessors:
            action = postprocessor(action)
        return action

    @abstractmethod
    def draw_action_(self, observation: np.ndarray) -> np.ndarray: #
        """Generate an action based on the observation - this is the core method that needs to be implemented by all agents."""
        pass

    def add_preprocessor(self, preprocessor: object): # pre-processor object that can be called via the "__call__" method
        """add a preprocessor to the agent"""
        self.preprocessors.append(preprocessor)
    
    def add_postprocessor(self, postprocessor: object): # post-processor object that can be called via the "__call__" method
        """add a postprocessor to the agent"""
        self.postprocessors.append(postprocessor)

    def train(self):
        """set the internal state of the agent to train"""
        self.mode = "train"
        
    def eval(self):
        """
        Set the internal state of the agent to eval. Note that for agents we do not differentiate between val and test modes.

        """
        self.mode = "eval"
    
    def add_batch_dim(self, input: np.ndarray) -> np.ndarray: #
        
        """
        Add a batch dimension to the input array if it doesn't already have one.
        This is necessary because most environments do not have a batch dimension, but agents typically expect one.
        If the environment does have a batch dimension, the agent can set the receive_batch_dim attribute to True to skip this step.

        """

        if self.receive_batch_dim:
            # If the batch dimension is expected, return the input as is
            return input
        else:
            # Add a batch dimension by expanding the dimensions of the input
            return np.expand_dims(input, axis=0)
        
    def flatten_X(self, X: np.ndarray) -> np.ndarray: #

        """
        Function to flatten the time-dimension of the feature matrix
        for agents that require a 2D input. Note applied by default but can be 
        used by agents inheriting from this class.
        
        """

        if X.ndim == 3:
            return X.reshape(X.shape[0], -1)
        else:
            return X
        
    def save(self):
        """Save the agent's parameters to a file."""
        raise NotImplementedError("This agent does not have a save method implemented.")

    def load(self):
        """Load the agent's parameters from a file."""
        raise NotImplementedError("This agent does not have a load method implemented.")
        

In [None]:
show_doc(BaseAgent, title_level=2)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L19){target="_blank" style="float:right; font-size:smaller"}

## BaseAgent

>      BaseAgent (environment_info:ddopnew.utils.MDPInfo,
>                 preprocessors:list[object]|None=None,
>                 postprocessors:list[object]|None=None)

*Base class for all agents to enforce a common interface. See below for more detailed description of the requriements.*

|    | **Type** | **Default** | **Details** |
| -- | -------- | ----------- | ----------- |
| environment_info | MDPInfo |  |  |
| preprocessors | list[object] \| None | None | default is empty list |
| postprocessors | list[object] \| None | None | default is empty list |

### Important notes:

Agents are, next to the environments, the core element of this library. The agents are the algorithms that take actions in the environment. They can be any type of algorithms ranging from optimization, supervised learning to reinforcement learning and any combination. Key for all the different agents to work is a common interface that allows them to interact with the environment.

**Draw action**:

* The ```draw_action``` function is the main interface with the environment. It receives an observation as Numpy array and returns an action as Numpy array. The function ```draw_action``` is defined in the ```BaseAgent``` and should not be overwritten as it properly applies pre- and post-processing (see below). 

* Agents always expect the observation to be of shape (batch_size, observation_dim) or (batch_size, time_dim, observation_dim) to allow batch-processing during training. Most environment do not have a batch dimension as they apply the step function to a single observation. Hence, the agent will by default add an extra dimension to the observation. If this is not desired, the agent has an attribute "receive_batch_dim" that can be set to True to tell the agent that the observation already has a batch dimension.

* To create an agent, the function ```draw_action_``` (note the underscore!) needs to be defined that gets the pre-processed observation and returns the action for post-processing. This function should be overwritten in the derived class.

**Pre- and post-processing**:

* Often, the environment and the agent have different requirements for the observations and actions (e.g., an agent outputing a continuous action while the environment expects a discrete action). To handle this, the ```BaseAgent``` class applies pre- and post processing to the observations and actions.

* During definition, one can already add the processors as lists (to the arguments ```preprocessors``` and ```postprocessors```). After instantiation, processors are to be added using the  ```add_preprocessor``` and ```add_postprocessor```. 

* Note that processors are applied in the order they are added.

**Training**:

* The ```run_experiment```function in this library currently supports three types of training processes:
    * ```train_directly```: The agent is trained by calling agent.fit(X, Y) directly. In this case, the agent must have a fit function that takes the input and target data.
    * ```train_epochs```: The agent is iteratively trained on the training data (e.g., via SGD). In this case, the function ```fit_epoch``` must be implemented. ```fit_epoch``` does not get any argument, rather the dataloader from the environment needs to be given to the agent during initialization. The agent will then call the dataloader interatively to get the training data.
    * ```env_interaction```: The agent is trained by interacting with the environment (e.g., like all reinforcement learning agents). This case build on the ```Core``` class from MushroomRL. 

**Loading and saving**:

* All agents must implement a save and load function that allows to save and load the agent's parameters. See the Newsvendor ERM and (w)SAA agents for examples of different ways to save and load agents.




In [None]:
show_doc(BaseAgent.draw_action)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L50){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.draw_action

>      BaseAgent.draw_action (observation:numpy.ndarray)

*Main interfrace to the environemnt. Applies preprocessors to the observation and postprocessors to the action.
Internal logic of the agent to be implemented in draw_action_ method.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| observation | ndarray |  |
| **Returns** | **ndarray** |  |

In [None]:
show_doc(BaseAgent.draw_action_)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L47){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.draw_action_

>      BaseAgent.draw_action_ (observation:numpy.ndarray)

*Generate an action based on the observation - this is the core method that needs to be implemented by all agents.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| observation | ndarray |  |
| **Returns** | **ndarray** |  |

In [None]:
show_doc(BaseAgent.add_preprocessor)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L63){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.add_preprocessor

>      BaseAgent.add_preprocessor (preprocessor:object)

*add a preprocessor to the agent*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| preprocessor | object | pre-processor object that can be called via the "__call__" method |

In [None]:
show_doc(BaseAgent.add_postprocessor)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L66){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.add_postprocessor

>      BaseAgent.add_postprocessor (postprocessor:object)

*add a postprocessor to the agent*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| postprocessor | object | post-processor object that can be called via the "__call__" method |

In [None]:
show_doc(BaseAgent.train)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L69){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.train

>      BaseAgent.train ()

*set the internal state of the agent to train*

In [None]:
show_doc(BaseAgent.eval)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L72){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.eval

>      BaseAgent.eval ()

*Set the internal state of the agent to eval. Note that for agents we do not differentiate between val and test modes.*

In [None]:
show_doc(BaseAgent.add_batch_dim)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L75){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.add_batch_dim

>      BaseAgent.add_batch_dim (input:numpy.ndarray)

*Add a batch dimension to the input array if it doesn't already have one.
This is necessary because most environments do not have a batch dimension, but agents typically expect one.
If the environment does have a batch dimension, the agent can set the receive_batch_dim attribute to True to skip this step.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| input | ndarray |  |
| **Returns** | **ndarray** |  |

In [None]:
show_doc(BaseAgent.flatten_X)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L94){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.flatten_X

>      BaseAgent.flatten_X (X:numpy.ndarray)

*Function to flatten the time-dimension of the feature matrix
for agents that require a 2D input.*

|    | **Type** | **Details** |
| -- | -------- | ----------- |
| X | ndarray |  |
| **Returns** | **ndarray** |  |

In [None]:
show_doc(BaseAgent.save)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L114){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.save

>      BaseAgent.save ()

*Save the agent's parameters to a file.*

In [None]:
show_doc(BaseAgent.load)

---

[source](https://github.com/opimwue/ddopnew/blob/main/ddopnew/agents/base.py#L117){target="_blank" style="float:right; font-size:smaller"}

### BaseAgent.load

>      BaseAgent.load ()

*Load the agent's parameters from a file.*

In [None]:
# #| export

# def check_cu_co(cu, co, n_outputs):
#     """Validate under- and overage costs.

#     Parameters
#     ----------
#     cu : {ndarray, Number or None}, shape (n_outputs,)
#        The underage costs per unit. Passing cu=None will output an array of ones.
#     co : {ndarray, Number or None}, shape (n_outputs,)
#        The overage costs per unit. Passing co=None will output an array of ones.
#     n_outputs : int
#        The number of outputs.
#     Returns
#     -------
#     cu : ndarray, shape (n_outputs,)
#        Validated underage costs. It is guaranteed to be "C" contiguous.
#     co : ndarray, shape (n_outputs,)
#        Validated overage costs. It is guaranteed to be "C" contiguous.
#     """
#     costs = [[cu, "cu"], [co, "co"]]
#     costs_validated = []
#     for c in costs:
#         if c[0] is None:
#             cost = np.ones(n_outputs, dtype=np.float64)
#         elif isinstance(c[0], numbers.Number):
#             cost = np.full(n_outputs, c[0], dtype=np.float64)
#         else:
#             cost = check_array(
#                 c[0], accept_sparse=False, ensure_2d=False, dtype=np.float64,
#                 order="C"
#             )
#             if cost.ndim != 1:
#                 raise ValueError(c[1], "must be 1D array or scalar")

#             if cost.shape != (n_outputs,):
#                 raise ValueError("{}.shape == {}, expected {}!"
#                                  .format(c[1], cost.shape, (n_outputs,)))
#         costs_validated.append(cost)
#     cu = costs_validated[0]
#     co = costs_validated[1]
#     return cu, co

# class NewsvendorSAAagentOLD(BaseAgent):

#     def __init__(self, environment_info, cu, co):
#         self.cu = cu
#         self.co = co

#         self.sl = cu / (cu + co)

#         self.quantiles = np.array([0.0])

#         super().__init__(environment_info)

#         self.fitted = False

#     def _calc_weights(self):
#         weights = np.full(self.n_samples_, 1 / self.n_samples_)
#         return weights

#     def fit(self, Y, X, mask=None):

#         demand = Y
#         features = X

#         self.mask=mask
#         y=demand

#         if mask is not None:
#             if demand.shape != mask.shape:
#                 if demand.shape[1]==1 & len(mask.shape)==1:
#                     mask = mask.reshape((-1,1))
#                     self.mask = mask
#                 if demand.shape != mask.shape:
#                     raise ValueError("Shapes of mask and demand do not match")
#                 # check if 1 either in mask.shape or demand.shape, if yes squeeze
#             demand=demand*mask
       
#         y = check_array(y, ensure_2d=False, accept_sparse='csr')

#         if y.ndim == 1:
#             y = np.reshape(y, (-1, 1))

#         # Training data
#         self.y_ = y
#         self.n_samples_ = y.shape[0]

#         # Determine output settings
#         self.n_outputs_ = y.shape[1]

#         # Check and format under- and overage costs
#         self.cu_, self.co_ = check_cu_co(self.cu, self.co, self.n_outputs_)

#         self.q_star = np.array(self._findQ(self._calc_weights()))

#         self.fitted=True

#         return self

#     def _findQ(self, weights):
#         """Calculate the optimal order quantity q"""

#         y = self.y_
#         q = []

#         for k in range(self.n_outputs_):
#             alpha = self.cu_[k] / (self.cu_[k] + self.co_[k])
#             y_product = y[:,k]
#             if self.mask is not None:
#                 mask_product = self.mask[:,k]
#                 y_product = y_product[mask_product.astype(bool)]
#             # print(y_product.shape)
#             q.append(np.quantile(y_product, alpha, interpolation="higher"))
#         print("found optimal q:", q)
#         return q

#     def draw_action(self, *args, **kwargs):

#         if self.fitted:
#             pred = self.q_star
            
#         else:
#             pred = np.random.rand(1)  
#         return pred

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()