# Utils

> To be written.

In [None]:
#| default_exp utils

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export

import numpy as np

TODO:
 - write descriptions
 - specify typings
 - ensure documentation on types, outputs, variable descriptions

In [None]:
#| export
class Parameter():

    """
    Simple class to handle parameters in the environment. The advantage of this class is that it can be
    used to set parameters that may change over time and accessed by multiple objects such as the 
    environment, agent or dataloaders.
    """
    
    def __init__(self, value, min_value=None, max_value=None, shape=(1,)):

        self._min_value = min_value
        self._max_value = max_value
        
        self.set_value(value, shape)

    def __call__(self):
        """
        Update and return the parameter in the provided index.

        Args:
             *idx (list): index of the parameter to return.

        Returns:
            The updated parameter in the provided index.

        """
        return self.get_value()

    def get_value(self):
        """
        Return the current value of the parameter in the provided index.

        Args:
            *idx (list): index of the parameter to return.

        Returns:
            The current value of the parameter in the provided index.

        """

        return self._value

    def set_value(self, value, shape=(1,)):
       
        """
        Set the value of the parameter.

        Args:
            value (float, int, numpy array): The value to set the parameter to.

        """

        if isinstance(value, (int, float)):
            self._value = np.array([value])
            self._value.reshape(shape)
        
        elif isinstance(value, list):
            value = np.array(value)
            assert value.shape == shape, "Shape of value must be the same as the shape of the parameter"
            self._value = value
        
        elif isinstance(value, np.ndarray):
            assert value.shape == shape, "Shape of value must be the same as the shape of the parameter"
            self._value = value
        
        else:
            raise ValueError("Value must be a scalar or numpy array")

        if self._min_value is not None:
            self._value = np.maximum(self._value, self._min_value)
        if self._max_value is not None:
            self._value = np.minimum(self._value, self._max_value)

    @property
    def shape(self):
        """
        Returns:
            The shape of the table of parameters.
        """
        return self._value.shape
    
    @property
    def size(self):
        """
        Returns:
            The size of the table of parameters.
        """
        return self._value.size 

In [None]:
overage_cost = Parameter(1) # integer
underage_cost = Parameter(2.13) # floag
ordering_cost = Parameter([2]*5, shape=(5,)) # list
holding_cost = Parameter(np.array([0,1]), shape=(2,)) # numpy array

print(overage_cost.get_value())
print(underage_cost.get_value())
print(ordering_cost.get_value())
print(holding_cost.get_value())

[1]
[2.13]
[2 2 2 2 2]
[0 1]


In [None]:
#| export

class MDPInfo():
    """
    This class is used to store the information of the environment.
    It is based on MushroomRL (https://github.com/MushroomRL)
    """
    
    def __init__(self, observation_space, action_space, gamma, horizon, dt=1e-1, backend='numpy'):
        """
        Constructor.

        Args:
             observation_space ([Box, Discrete]): the state space;
             action_space ([Box, Discrete]): the action space;
             gamma (float): the discount factor;
             horizon (int): the horizon;
             dt (float, 1e-1): the control timestep of the environment;
             backend (str, 'numpy'): the type of data library used to generate state and actions.

        """
        self.observation_space = observation_space
        self.action_space = action_space
        self.gamma = gamma
        self.horizon = horizon
        self.dt = dt
        self.backend = backend

    @property
    def size(self):
        """
        Returns:
            The sum of the number of discrete states and discrete actions. Only works for discrete spaces.

        """
        return self.observation_space.size + self.action_space.size

    @property
    def shape(self):
        """
        Returns:
            The concatenation of the shape tuple of the state and action spaces.

        """
        return self.observation_space.shape + self.action_space.shape

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()