# Linear Policy

In [2]:
from typing import Callable, List
import matplotlib.pyplot as plt
import numpy as np
import scipy.stats

In [7]:
class LinearPolicy:
    def __init__(self,
                parameters: np.ndarray,
                features: Callable[[np.array],
                np.ndarray]):
        """
        Linear Policy Constructor.
        
        Args:
            parameters (np.ndarray): policy parameters as np.ndarray.
            features (Callable[[np.ndarray], np.ndarray]): function used to extract features from the state representation.
        """
        self._parameters = parameters
        self._features = features
        
    def __call__(self, state: np.ndarray) -> np.ndarray:
        """
        Call method of the Policy.

        Args:
            state (np.ndarray): environment state.

        Returns:
            The resulting action.
        """

        #  calculate state features
        state_features = self._features(state)

        """
        the parameters shape [0] should be the same as the state features as they must be multiplied
        """
        assert state_features.shape[0] == self._parameters.shape[0]

        # dot product between parameters and state features 
        return np.dot(self._parameters.T, state_features)

In [8]:
# sample a random set of parameters
parameters = np.random.rand(5, 1)

# define the state features as identity function
features = lambda x: x

# define the policy
pi: LinearPolicy = LinearPolicy(parameters, features)
    
# smaple a state
state = np.random.rand(5, 1)

# Call the policy obtaining the action
action = pi(state)

print(action)

[[2.23827488]]
