# Base agents

> To be written.

In [None]:
#| default_exp agents.base

In [None]:
import logging
logging_level = logging.DEBUG

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export


from abc import ABC, abstractmethod
from typing import Union, Optional, List
import numpy as np

from ddopnew.envs.base import BaseEnvironment
from ddopnew.utils import MDPInfo

# # TEMPORARY
# from sklearn.utils.validation import check_array
# import numbers

In [None]:
#| export

class BaseAgent():

    train_mode = "direct_fit" # or "epochs_fit" or "env_interaction"
    
    def __init__(self,
                 environment_info: MDPInfo,
                 preprocessors: Optional[List[object]] = None, # default is empty list
                 postprocessors: Optional[List[object]] = None # default is empty list
                 ):
        
        """
        Initialize a BaseAgent.

        Args:
            environment_info (MDPInfo): Information about the environment (MDP).
            preprocessors (Optional[List[object]]): A list of preprocessors to apply to input data.
            postprocessors (Optional[List[object]]): A list of postprocessors to apply to output data.
        """

        self.preprocessors = preprocessors or []
        self.postprocessors = postprocessors or []

        self.environment_info = environment_info
        self.mode = "train"
        self.print = False  # Can be used for debugging
        self.receive_batch_dim = False

    @abstractmethod
    def draw_action_(self, observation):
        pass

    def draw_action(self, observation):

        observation = self.add_batch_dim(observation)

        for preprocessor in self.preprocessors:
            observation = preprocessor(observation)

        action = self.draw_action_(observation)
        
        for postprocessor in self.postprocessors:
            action = postprocessor(action)
        return action

    def add_preprocessor(self, preprocessor):
        self.preprocessors.append(preprocessor)
    
    def add_postprocessor(self, postprocessor):
        self.postprocessors.append(postprocessor)

    def train(self):
        self.mode = "train"
        
    def eval(self):
        self.mode = "eval"
    
    def add_batch_dim(self, input: np.ndarray) -> np.ndarray:
        
        """
        Add a batch dimension to the input array if it doesn't already have one.

        Args:
            input (np.ndarray): The input array that may need a batch dimension added.

        Returns:
            np.ndarray: The input array with an added batch dimension, if required.
        """

        if self.receive_batch_dim:
            # If the batch dimension is expected, return the input as is
            return input
        else:
            # Add a batch dimension by expanding the dimensions of the input
            return np.expand_dims(input, axis=0)
        
    def flatten_X(self, X):

        """

        Function to flatten the time-dimension of the feature matrix
        for agents that require a 2D input.
        
        
        Args:
            X (np.ndarray): The input data to be flattened.

        Returns:
            _type_: _description_
        """

        if X.ndim == 3:
            return X.reshape(X.shape[0], -1)
        else:
            return X
        
    def save(self):
        raise NotImplementedError("This agent does not have a save method implemented.")

    def load(self):
        raise NotImplementedError("This agent does not have a load method implemented.")
        

In [None]:
# #| export

# def check_cu_co(cu, co, n_outputs):
#     """Validate under- and overage costs.

#     Parameters
#     ----------
#     cu : {ndarray, Number or None}, shape (n_outputs,)
#        The underage costs per unit. Passing cu=None will output an array of ones.
#     co : {ndarray, Number or None}, shape (n_outputs,)
#        The overage costs per unit. Passing co=None will output an array of ones.
#     n_outputs : int
#        The number of outputs.
#     Returns
#     -------
#     cu : ndarray, shape (n_outputs,)
#        Validated underage costs. It is guaranteed to be "C" contiguous.
#     co : ndarray, shape (n_outputs,)
#        Validated overage costs. It is guaranteed to be "C" contiguous.
#     """
#     costs = [[cu, "cu"], [co, "co"]]
#     costs_validated = []
#     for c in costs:
#         if c[0] is None:
#             cost = np.ones(n_outputs, dtype=np.float64)
#         elif isinstance(c[0], numbers.Number):
#             cost = np.full(n_outputs, c[0], dtype=np.float64)
#         else:
#             cost = check_array(
#                 c[0], accept_sparse=False, ensure_2d=False, dtype=np.float64,
#                 order="C"
#             )
#             if cost.ndim != 1:
#                 raise ValueError(c[1], "must be 1D array or scalar")

#             if cost.shape != (n_outputs,):
#                 raise ValueError("{}.shape == {}, expected {}!"
#                                  .format(c[1], cost.shape, (n_outputs,)))
#         costs_validated.append(cost)
#     cu = costs_validated[0]
#     co = costs_validated[1]
#     return cu, co

# class NewsvendorSAAagentOLD(BaseAgent):

#     def __init__(self, environment_info, cu, co):
#         self.cu = cu
#         self.co = co

#         self.sl = cu / (cu + co)

#         self.quantiles = np.array([0.0])

#         super().__init__(environment_info)

#         self.fitted = False

#     def _calc_weights(self):
#         weights = np.full(self.n_samples_, 1 / self.n_samples_)
#         return weights

#     def fit(self, Y, X, mask=None):

#         demand = Y
#         features = X

#         self.mask=mask
#         y=demand

#         if mask is not None:
#             if demand.shape != mask.shape:
#                 if demand.shape[1]==1 & len(mask.shape)==1:
#                     mask = mask.reshape((-1,1))
#                     self.mask = mask
#                 if demand.shape != mask.shape:
#                     raise ValueError("Shapes of mask and demand do not match")
#                 # check if 1 either in mask.shape or demand.shape, if yes squeeze
#             demand=demand*mask
       
#         y = check_array(y, ensure_2d=False, accept_sparse='csr')

#         if y.ndim == 1:
#             y = np.reshape(y, (-1, 1))

#         # Training data
#         self.y_ = y
#         self.n_samples_ = y.shape[0]

#         # Determine output settings
#         self.n_outputs_ = y.shape[1]

#         # Check and format under- and overage costs
#         self.cu_, self.co_ = check_cu_co(self.cu, self.co, self.n_outputs_)

#         self.q_star = np.array(self._findQ(self._calc_weights()))

#         self.fitted=True

#         return self

#     def _findQ(self, weights):
#         """Calculate the optimal order quantity q"""

#         y = self.y_
#         q = []

#         for k in range(self.n_outputs_):
#             alpha = self.cu_[k] / (self.cu_[k] + self.co_[k])
#             y_product = y[:,k]
#             if self.mask is not None:
#                 mask_product = self.mask[:,k]
#                 y_product = y_product[mask_product.astype(bool)]
#             # print(y_product.shape)
#             q.append(np.quantile(y_product, alpha, interpolation="higher"))
#         print("found optimal q:", q)
#         return q

#     def draw_action(self, *args, **kwargs):

#         if self.fitted:
#             pred = self.q_star
            
#         else:
#             pred = np.random.rand(1)  
#         return pred

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()