In [4]:
from __future__ import annotations

'''An interface for different kinds of function approximations
(tabular, linear, DNN... etc), with several implementations.'''
import sys
sys.path.append("../../RL-book")


from abc import ABC, abstractmethod
from dataclasses import dataclass, replace, field
import itertools
import numpy as np
from operator import itemgetter
from scipy.interpolate import splrep, BSpline
from typing import (Callable, Dict, Generic, Iterator, Iterable, List,
                    Mapping, Optional, Sequence, Tuple, TypeVar)

import rl.iterate as iterate

X = TypeVar('X')
SMALL_NUM = 1e-6


class FunctionApprox(ABC, Generic[X]):
    '''Interface for function approximations.
    An object of this class approximates some function X ↦ ℝ in a way
    that can be evaluated at specific points in X and updated with
    additional (X, ℝ) points.
    '''

    @abstractmethod
    def representational_gradient(self, x_value: X) -> FunctionApprox[X]:
        '''Computes the gradient of the self FunctionApprox with respect
        to the parameters in the internal representation of the
        FunctionApprox, i.e., computes Gradient with respect to internal
        parameters of expected value of y for the input x, where the
        expectation is with respect tp the FunctionApprox's model of
        the probability distribution of y|x. The gradient is output
        in the form of a FunctionApprox whose internal parameters are
        equal to the gradient values.
        '''

    @abstractmethod
    def evaluate(self, x_values_seq: Iterable[X]) -> np.ndarray:
        '''Computes expected value of y for each x in
        x_values_seq (with the probability distribution
        function of y|x estimated as FunctionApprox)
        '''

    def __call__(self, x_value: X) -> float:
        return self.evaluate([x_value]).item()

    @abstractmethod
    def update(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> FunctionApprox[X]:

        '''Update the internal parameters of the FunctionApprox
        based on incremental data provided in the form of (x,y)
        pairs as a xy_vals_seq data structure
        '''

    @abstractmethod
    def solve(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]],
        error_tolerance: Optional[float] = None
    ) -> FunctionApprox[X]:
        '''Assuming the entire data set of (x,y) pairs is available
        in the form of the given input xy_vals_seq data structure,
        solve for the internal parameters of the FunctionApprox
        such that the internal parameters are fitted to xy_vals_seq.
        Since this is a best-fit, the internal parameters are fitted
        to within the input error_tolerance (where applicable, since
        some methods involve a direct solve for the fit that don't
        require an error_tolerance)
        '''

    @abstractmethod
    def within(self, other: FunctionApprox[X], tolerance: float) -> bool:
        '''Is this function approximation within a given tolerance of
        another function approximation of the same type?
        '''

    def argmax(self, xs: Iterable[X]) -> X:
        '''Return the input X that maximizes the function being approximated.
        Arguments:
          xs -- list of inputs to evaluate and maximize, cannot be empty
        Returns the X that maximizes the function this approximates.
        '''
        return list(xs)[np.argmax(self.evaluate(xs))]

    def rmse(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> float:
        '''The Root-Mean-Squared-Error between FunctionApprox's
        predictions (from evaluate) and the associated (supervisory)
        y values
        '''
        x_seq, y_seq = zip(*xy_vals_seq)
        errors: np.ndarray = self.evaluate(x_seq) - np.array(y_seq)
        return np.sqrt(np.mean(errors * errors))

    def iterate_updates(
        self,
        xy_seq_stream: Iterator[Iterable[Tuple[X, float]]]
    ) -> Iterator[FunctionApprox[X]]:
        '''Given a stream (Iterator) of data sets of (x,y) pairs,
        perform a series of incremental updates to the internal
        parameters (using update method), with each internal
        parameter update done for each data set of (x,y) pairs in the
        input stream of xy_seq_stream
        '''
        return iterate.accumulate(
            xy_seq_stream,
            lambda fa, xy: fa.update(xy),
            initial=self
        )

    def representational_gradient_stream(
        self,
        x_values_seq: Iterable[X]
    ) -> Iterator[FunctionApprox[X]]:
        for x_val in x_values_seq:
            yield self.representational_gradient(x_val)


@dataclass(frozen=True)
class Dynamic(FunctionApprox[X]):
    '''A FunctionApprox that works exactly the same as exact dynamic
    programming. Each update for a value in X replaces the previous
    value at X altogether.

    Fields:
    values_map -- mapping from X to its approximated value
    '''

    values_map: Mapping[X, float]

    def representational_gradient(self, x_value: X) -> Dynamic[X]:
        return Dynamic({x_value: 1.0})

    def evaluate(self, x_values_seq: Iterable[X]) -> np.ndarray:
        '''Evaluate the function approximation by looking up the value in the
        mapping for each state.

        Will raise an error if an X value has not been seen before and
        was not initialized.

        '''
        return np.array([self.values_map[x] for x in x_values_seq])

    def update(self, xy_vals_seq: Iterable[Tuple[X, float]]) -> Dynamic[X]:
        '''Update each X value by replacing its saved Y with a new one. Pairs
        later in the list take precedence over pairs earlier in the
        list.

        '''
        new_map = dict(self.values_map)
        for x, y in xy_vals_seq:
            new_map[x] = y

        return replace(self, values_map=new_map)

    def solve(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]],
        error_tolerance: Optional[float] = None
    ) -> Dynamic[X]:
        return replace(self, value_map=dict(xy_vals_seq))

    def within(self, other: FunctionApprox[X], tolerance: float) -> bool:
        '''This approximation is within a tolerance of another if the value
        for each X in both approximations is within the given
        tolerance.

        Raises an error if the other approximation is missing states
        that this approximation has.

        '''
        if not isinstance(other, Dynamic):
            return False

        return all(abs(self.values_map[s] - other.values_map[s]) <= tolerance
                   for s in self.values_map)


@dataclass(frozen=True)
class Tabular(FunctionApprox[X]):
    '''Approximates a function with a discrete domain (`X'), without any
    interpolation. The value for each `X' is maintained as a weighted
    mean of observations by recency (managed by
    `count_to_weight_func').

    In practice, this means you can use this to approximate a function
    with a learning rate α(n) specified by count_to_weight_func.

    If `count_to_weight_func' always returns 1, this behaves the same
    way as `Dynamic'.

    Fields:
    values_map -- mapping from X to its approximated value
    counts_map -- how many times a given X has been updated
    count_to_weight_func -- function for how much to weigh an update
      to X based on the number of times that X has been updated

    '''

    values_map: Mapping[X, float] = field(default_factory=lambda: {})
    counts_map: Mapping[X, int] = field(default_factory=lambda: {})
    count_to_weight_func: Callable[[int], float] = \
        field(default_factory=lambda: lambda n: 1.0 / n)

    def representational_gradient(self, x_value: X) -> Tabular[X]:
        return Tabular({x_value: 1.0})

    def evaluate(self, x_values_seq: Iterable[X]) -> np.ndarray:
        '''Evaluate the approximation at each given X.

        If an X has not been seen before, will return 0.0.
        '''
        return np.array([self.values_map.get(x, 0.) for x in x_values_seq])

    def update(self, xy_vals_seq: Iterable[Tuple[X, float]]) -> Tabular[X]:
        '''Update the approximation with the given points.

        Each X keeps a count n of how many times it was updated, and
        each subsequent update is discounted by
        count_to_weight_func(n), which defines our learning rate.

        '''
        values_map: Dict[X, float] = dict(self.values_map)
        counts_map: Dict[X, int] = dict(self.counts_map)

        for x, y in xy_vals_seq:
            counts_map[x] = counts_map.get(x, 0) + 1
            weight: float = self.count_to_weight_func(counts_map.get(x, 0))
            values_map[x] = weight * y + (1 - weight) * values_map.get(x, 0.)

        return replace(
            self,
            values_map=values_map,
            counts_map=counts_map
        )

    def solve(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]],
        error_tolerance: Optional[float] = None
    ) -> Tabular[X]:
        values_map: Dict[X, float] = {}
        counts_map: Dict[X, int] = {}
        for x, y in xy_vals_seq:
            counts_map[x] = counts_map.get(x, 0) + 1
            weight: float = self.count_to_weight_func(counts_map.get(x, 0))
            values_map[x] = weight * y + (1 - weight) * values_map.get(x, 0.)
        return replace(
            self,
            values_map=values_map,
            counts_map=counts_map
        )

    def within(self, other: FunctionApprox[X], tolerance: float) -> bool:
        if isinstance(other, Tabular):
            return\
                all(abs(self.values_map[s] - other.values_map[s]) <= tolerance
                    for s in self.values_map)

        return False


@dataclass(frozen=True)
class BSplineApprox(FunctionApprox[X]):
    feature_function: Callable[[X], float]
    degree: int
    knots: np.ndarray = field(default_factory=lambda: np.array([]))
    coeffs: np.ndarray = field(default_factory=lambda: np.array([]))

    def get_feature_values(self, x_values_seq: Iterable[X]) -> Sequence[float]:
        return [self.feature_function(x) for x in x_values_seq]

    def representational_gradient(self, x_value: X) -> BSplineApprox[X]:
        feature_val: float = self.feature_function(x_value)
        eps: float = 1e-6
        one_hots: np.array = np.eye(len(self.coeffs))
        return replace(
            self,
            coeffs=np.array([(
                BSpline(
                    self.knots,
                    c + one_hots[i] * eps,
                    self.degree
                )(feature_val) -
                BSpline(
                    self.knots,
                    c - one_hots[i] * eps,
                    self.degree
                )(feature_val)
            ) / (2 * eps) for i, c in enumerate(self.coeffs)]))

    def evaluate(self, x_values_seq: Iterable[X]) -> np.ndarray:
        spline_func: Callable[[Sequence[float]], np.ndarray] = \
            BSpline(self.knots, self.coeffs, self.degree)
        return spline_func(self.get_feature_values(x_values_seq))

    def update(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> BSplineApprox[X]:
        x_vals, y_vals = zip(*xy_vals_seq)
        feature_vals: Sequence[float] = self.get_feature_values(x_vals)
        sorted_pairs: Sequence[Tuple[float, float]] = \
            sorted(zip(feature_vals, y_vals), key=itemgetter(0))
        new_knots, new_coeffs, _ = splrep(
            [f for f, _ in sorted_pairs],
            [y for _, y in sorted_pairs],
            k=self.degree
        )
        return replace(
            self,
            knots=new_knots,
            coeffs=new_coeffs
        )

    def solve(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]],
        error_tolerance: Optional[float] = None
    ) -> BSplineApprox[X]:
        return self.update(xy_vals_seq)

    def within(self, other: FunctionApprox[X], tolerance: float) -> bool:
        if isinstance(other, BSplineApprox):
            return \
                np.all(np.abs(self.knots - other.knots) <= tolerance).item() \
                and \
                np.all(np.abs(self.coeffs - other.coeffs) <= tolerance).item()

        return False


@dataclass(frozen=True)
class AdamGradient:
    learning_rate: float
    decay1: float
    decay2: float

    @staticmethod
    def default_settings() -> AdamGradient:
        return AdamGradient(
            learning_rate=0.001,
            decay1=0.9,
            decay2=0.999
        )


@dataclass(frozen=True)
class Weights:
    adam_gradient: AdamGradient
    time: int
    weights: np.ndarray
    adam_cache1: np.ndarray
    adam_cache2: np.ndarray

    @staticmethod
    def create(
        weights: np.ndarray,
        adam_gradient: AdamGradient = AdamGradient.default_settings(),
        adam_cache1: Optional[np.ndarray] = None,
        adam_cache2: Optional[np.ndarray] = None
    ) -> Weights:
        return Weights(
            adam_gradient=adam_gradient,
            time=0,
            weights=weights,
            adam_cache1=np.zeros_like(
                weights
            ) if adam_cache1 is None else adam_cache1,
            adam_cache2=np.zeros_like(
                weights
            ) if adam_cache2 is None else adam_cache2
        )

    def update(self, gradient: np.ndarray) -> Weights:
        time: int = self.time + 1
        new_adam_cache1: np.ndarray = self.adam_gradient.decay1 * \
            self.adam_cache1 + (1 - self.adam_gradient.decay1) * gradient
        new_adam_cache2: np.ndarray = self.adam_gradient.decay2 * \
            self.adam_cache2 + (1 - self.adam_gradient.decay2) * gradient ** 2
        corrected_m: np.ndarray = new_adam_cache1 / \
            (1 - self.adam_gradient.decay1 ** time)
        corrected_v: np.ndarray = new_adam_cache2 / \
            (1 - self.adam_gradient.decay2 ** time)

        new_weights: np.ndarray = self.weights - \
            self.adam_gradient.learning_rate * corrected_m / \
            (np.sqrt(corrected_v) + SMALL_NUM)

        return replace(
            self,
            time=time,
            weights=new_weights,
            adam_cache1=new_adam_cache1,
            adam_cache2=new_adam_cache2,
        )

    def within(self, other: Weights, tolerance: float) -> bool:
        return np.all(np.abs(self.weights - other.weights) <= tolerance).item()


@dataclass(frozen=True)
class LinearFunctionApprox(FunctionApprox[X]):

    feature_functions: Sequence[Callable[[X], float]]
    regularization_coeff: float
    weights: Weights
    direct_solve: bool

    @staticmethod
    def create(
        feature_functions: Sequence[Callable[[X], float]],
        adam_gradient: AdamGradient = AdamGradient.default_settings(),
        regularization_coeff: float = 0.,
        weights: Optional[Weights] = None,
        direct_solve: bool = True
    ) -> LinearFunctionApprox[X]:
        return LinearFunctionApprox(
            feature_functions=feature_functions,
            regularization_coeff=regularization_coeff,
            weights=Weights.create(
                adam_gradient=adam_gradient,
                weights=np.zeros(len(feature_functions))
            ) if weights is None else weights,
            direct_solve=direct_solve
        )

    def get_feature_values(self, x_values_seq: Iterable[X]) -> np.ndarray:
        return np.array(
            [[f(x) for f in self.feature_functions] for x in x_values_seq]
        )

    def representational_gradient(self, x_value: X) -> LinearFunctionApprox[X]:
        return replace(
            self,
            weights=replace(
                self.weights,
                weights=np.array([f(x_value) for f in self.feature_functions])
            )
        )

    def evaluate(self, x_values_seq: Iterable[X]) -> np.ndarray:
        return np.dot(
            self.get_feature_values(x_values_seq),
            self.weights.weights
        )

    def within(self, other: FunctionApprox[X], tolerance: float) -> bool:
        if isinstance(other, LinearFunctionApprox):
            return self.weights.within(other.weights, tolerance)

        return False

    def regularized_loss_gradient(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> np.ndarray:
        x_vals, y_vals = zip(*xy_vals_seq)
        feature_vals: np.ndarray = self.get_feature_values(x_vals)
        diff: np.ndarray = np.dot(feature_vals, self.weights.weights) \
            - np.array(y_vals)
        return np.dot(feature_vals.T, diff) / len(diff) \
            + self.regularization_coeff * self.weights.weights

    def update(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> LinearFunctionApprox[X]:
        gradient: np.ndarray = self.regularized_loss_gradient(xy_vals_seq)
        new_weights: np.ndarray = self.weights.update(gradient)
        return replace(self, weights=new_weights)

    def solve(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]],
        error_tolerance: Optional[float] = None
    ) -> LinearFunctionApprox[X]:
        if self.direct_solve:
            x_vals, y_vals = zip(*xy_vals_seq)
            feature_vals: np.ndarray = self.get_feature_values(x_vals)
            feature_vals_T: np.ndarray = feature_vals.T
            left: np.ndarray = np.dot(feature_vals_T, feature_vals) \
                + feature_vals.shape[0] * self.regularization_coeff * \
                np.eye(len(self.weights.weights))
            right: np.ndarray = np.dot(feature_vals_T, y_vals)
            ret = replace(
                self,
                weights=Weights.create(
                    adam_gradient=self.weights.adam_gradient,
                    weights=np.dot(np.linalg.inv(left), right)
                )
            )
        else:
            tol: float = 1e-6 if error_tolerance is None else error_tolerance

            def done(
                a: LinearFunctionApprox[X],
                b: LinearFunctionApprox[X],
                tol: float = tol
            ) -> bool:
                return a.within(b, tol)

            ret = iterate.converged(
                self.iterate_updates(itertools.repeat(xy_vals_seq)),
                done=done
            )

        return ret


@dataclass(frozen=True)
class DNNSpec:
    neurons: Sequence[int]
    bias: bool
    hidden_activation: Callable[[np.ndarray], np.ndarray]
    hidden_activation_deriv: Callable[[np.ndarray], np.ndarray]
    output_activation: Callable[[np.ndarray], np.ndarray]
    output_activation_deriv: Callable[[np.ndarray], np.ndarray]


@dataclass(frozen=True)
class DNNApprox(FunctionApprox[X]):

    feature_functions: Sequence[Callable[[X], float]]
    dnn_spec: DNNSpec
    regularization_coeff: float
    weights: Sequence[Weights]

    @staticmethod
    def create(
        feature_functions: Sequence[Callable[[X], float]],
        dnn_spec: DNNSpec,
        adam_gradient: AdamGradient = AdamGradient.default_settings(),
        regularization_coeff: float = 0.,
        weights: Optional[Sequence[Weights]] = None
    ) -> DNNApprox[X]:
        if weights is None:
            inputs: Sequence[int] = [len(feature_functions)] + \
                [n + (1 if dnn_spec.bias else 0)
                 for i, n in enumerate(dnn_spec.neurons)]
            outputs: Sequence[int] = list(dnn_spec.neurons) + [1]
            wts = [Weights.create(
                weights=np.random.randn(output, inp) / np.sqrt(inp),
                adam_gradient=adam_gradient
            ) for inp, output in zip(inputs, outputs)]
        else:
            wts = weights

        return DNNApprox(
            feature_functions=feature_functions,
            dnn_spec=dnn_spec,
            regularization_coeff=regularization_coeff,
            weights=wts
        )

    def get_feature_values(self, x_values_seq: Iterable[X]) -> np.ndarray:
        return np.array(
            [[f(x) for f in self.feature_functions] for x in x_values_seq]
        )

    def forward_propagation(
        self,
        x_values_seq: Iterable[X]
    ) -> Sequence[np.ndarray]:
        """
        :param x_values_seq: a n-length-sequence of input points
        :return: list of length (L+2) where the first (L+1) values
                 each represent the 2-D input arrays (of size n x |I_l|),
                 for each of the (L+1) layers (L of which are hidden layers),
                 and the last value represents the output of the DNN (as a
                 1-D array of length n)
        """
        inp: np.ndarray = self.get_feature_values(x_values_seq)
        ret: List[np.ndarray] = [inp]
        for w in self.weights[:-1]:
            out: np.ndarray = self.dnn_spec.hidden_activation(
                np.dot(inp, w.weights.T)
            )
            if self.dnn_spec.bias:
                inp = np.insert(out, 0, 1., axis=1)
            else:
                inp = out
            ret.append(inp)
        ret.append(
            self.dnn_spec.output_activation(
                np.dot(inp, self.weights[-1].weights.T)
            )[:, 0]
        )
        return ret

    def evaluate(self, x_values_seq: Iterable[X]) -> np.ndarray:
        return self.forward_propagation(x_values_seq)[-1]

    def within(self, other: FunctionApprox[X], tolerance: float) -> bool:
        if isinstance(other, DNNApprox):
            return all(w1.within(w2, tolerance)
                       for w1, w2 in zip(self.weights, other.weights))
        else:
            return False

    def backward_propagation(
        self,
        fwd_prop: Sequence[np.ndarray],
        objective_derivative_output: Callable[[np.ndarray], np.ndarray]
    ) -> Sequence[np.ndarray]:
        """
        :param
        fwd_prop represents the result of forward propagation, a sequence
        of (L+1) 2-D np.ndarrays, followed by a 1-D np.ndarray for the output
        of the DNN.
        objective_derivative_output represents the derivative of the objective
        function with respect to the output of the DNN

        :return: list (of length L+1) of |O_l| x |I_l| 2-D arrays,
                 i.e., same as the type of self.weights.weights
        This function computes the gradient (with respect to weights) of
        cross-entropy loss where the output layer activation function
        is the canonical link function of the conditional distribution of y|x
        """
        layer_inputs: Sequence[np.ndarray] = fwd_prop[:-1]
        deriv: np.ndarray = objective_derivative_output(fwd_prop[-1]) * \
            self.dnn_spec.output_activation_deriv(fwd_prop[-1])
        deriv = deriv.reshape(1, -1)
        back_prop: List[np.ndarray] = [np.dot(deriv, layer_inputs[-1]) /
                                       deriv.shape[1]]
        # L is the number of hidden layers, n is the number of points
        # layer l deriv represents dObj/dS_l where S_l = I_l . weights_l
        # (S_l is the result of applying layer l without the activation func)
        for i in reversed(range(len(self.weights) - 1)):
            # deriv_l is a 2-D array of dimension |O_l| x n
            # The recursive formulation of deriv is as follows:
            # deriv_{l-1} = (weights_l^T inner deriv_l) haddamard g'(S_{l-1}),
            # which is ((|I_l| x |O_l|) inner (|O_l| x n)) haddamard
            # (|I_l| x n), which is (|I_l| x n) = (|O_{l-1}| x n)
            # Note: g'(S_{l-1}) is expressed as hidden layer activation
            # derivative as a function of O_{l-1} (=I_l).
            deriv = np.dot(self.weights[i + 1].weights.T, deriv) * \
                self.dnn_spec.hidden_activation_deriv(layer_inputs[i + 1].T)
            # If self.dnn_spec.bias is True, then I_l = O_{l-1} + 1, in which
            # case # the first row of the calculated deriv is removed to yield
            # a 2-D array of dimension |O_{l-1}| x n.
            if self.dnn_spec.bias:
                deriv = deriv[1:]
            # layer l gradient is deriv_l inner layer_inputs_l, which is
            # of dimension (|O_l| x n) inner (n x (|I_l|) = |O_l| x |I_l|
            back_prop.append(np.dot(deriv, layer_inputs[i]) / deriv.shape[1])
        return back_prop[::-1]

    def representational_gradient(self, x_value: X) -> DNNApprox[X]:
        return replace(
            self,
            weights=replace(
                self.weights,
                weights=self.backward_propagation(
                    fwd_prop=self.forward_propagation([x_value]),
                    objective_derivative_output=lambda arr: np.ones_like(arr)
                )
            )
        )

    def regularized_loss_gradient(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> Sequence[np.ndarray]:
        """
        :param pairs: list of pairs of n (x, y) points
        :return: list (of length L+1) of |O_l| x |I_l| 2-D array,
                 i.e., same as the type of self.weights.weights
        This function computes the regularized gradient (with respect to
        weights) of cross-entropy loss where the output layer activation
        function is the canonical link function of the conditional
        distribution of y|x
        """
        x_vals, y_vals = zip(*xy_vals_seq)
        fwd_prop: Sequence[np.ndarray] = self.forward_propagation(x_vals)

        def obj_deriv_output(out: np.ndarray) -> np.ndarray:
            return (out - np.array(y_vals)) / \
                self.dnn_spec.output_activation_deriv(out)

        return [x + self.regularization_coeff * self.weights[i].weights
                for i, x in enumerate(self.backward_propagation(
                    fwd_prop=fwd_prop,
                    objective_derivative_output=obj_deriv_output
                ))]

    def update(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]]
    ) -> DNNApprox[X]:
        return replace(
            self,
            weights=[w.update(g) for w, g in zip(
                self.weights,
                self.regularized_loss_gradient(xy_vals_seq)
            )]
        )

    def solve(
        self,
        xy_vals_seq: Iterable[Tuple[X, float]],
        error_tolerance: Optional[float] = None
    ) -> DNNApprox[X]:
        tol: float = 1e-6 if error_tolerance is None else error_tolerance

        def done(
            a: DNNApprox[X],
            b: DNNApprox[X],
            tol: float = tol
        ) -> bool:
            return a.within(b, tol)

        return iterate.converged(
            self.iterate_updates(itertools.repeat(xy_vals_seq)),
            done=done
        )


def learning_rate_schedule(
    initial_learning_rate: float,
    half_life: float,
    exponent: float
) -> Callable[[int], float]:
    def lr_func(n: int) -> float:
        return initial_learning_rate * (1 + (n - 1) / half_life) ** -exponent
    return lr_func


if __name__ == '__main__':

    from scipy.stats import norm
    from pprint import pprint

    alpha = 2.0
    beta_1 = 10.0
    beta_2 = 4.0
    beta_3 = -6.0
    beta = (beta_1, beta_2, beta_3) # Beta is the weights?

    x_pts = np.arange(-10.0, 10.5, 0.5)
    y_pts = np.arange(-10.0, 10.5, 0.5)
    z_pts = np.arange(-10.0, 10.5, 0.5)
    pts: Sequence[Tuple[float, float, float]] = \
        [(x, y, z) for x in x_pts for y in y_pts for z in z_pts] #Create random points as data

    def superv_func(pt): #This is what fitted function is
        return alpha + np.dot(beta, pt)

    n = norm(loc=0., scale=2.)
    
    xy_vals_seq: Sequence[Tuple[Tuple[float, float, float], float]] = \
        [(x, superv_func(x) + n.rvs(size=1)[0]) for x in pts]

    ag = AdamGradient(
        learning_rate=0.5,
        decay1=0.9,
        decay2=0.999
    )
    ffs = [
        lambda _: 1.,
        lambda x: x[0],
        lambda x: x[1],
        lambda x: x[2]
    ]

    lfa = LinearFunctionApprox.create(
         feature_functions=ffs,
         adam_gradient=ag,
         regularization_coeff=0.001,
         direct_solve=True
    )

    lfa_ds = lfa.solve(xy_vals_seq)
    print("Direct Solve")
    pprint(lfa_ds.weights)
    errors: np.ndarray = lfa_ds.evaluate(pts) - \
        np.array([y for _, y in xy_vals_seq])
    print("Mean Squared Error")
    pprint(np.mean(errors * errors))
    print()

    print("Linear Gradient Solve")
    for _ in range(100):
        print("Weights")
        pprint(lfa.weights)
        errors: np.ndarray = lfa.evaluate(pts) - \
            np.array([y for _, y in xy_vals_seq])
        print("Mean Squared Error")
        pprint(np.mean(errors * errors))
        lfa = lfa.update(xy_vals_seq)
        print()

    ds = DNNSpec(
        neurons=[2],
        bias=True,
        hidden_activation=lambda x: x,
        hidden_activation_deriv=lambda x: np.ones_like(x),
        output_activation=lambda x: x,
        output_activation_deriv=lambda x: np.ones_like(x)
    )

    dnna = DNNApprox.create(
        feature_functions=ffs,
        dnn_spec=ds,
        adam_gradient=ag,
        regularization_coeff=0.01
    )
    print("DNN Gradient Solve")
    for _ in range(100):
        print("Weights")
        pprint(dnna.weights)
        errors: np.ndarray = dnna.evaluate(pts) - \
            np.array([y for _, y in xy_vals_seq])
        print("Mean Squared Error")
        pprint(np.mean(errors * errors))
        dnna = dnna.update(xy_vals_seq)
        print()


Direct Solve
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=0, weights=array([ 1.99359097, 10.00293406,  3.99871943, -6.00280085]), adam_cache1=array([0., 0., 0., 0.]), adam_cache2=array([0., 0., 0., 0.]))
Mean Squared Error
3.950023558895854

Linear Gradient Solve
Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=0, weights=array([0., 0., 0., 0.]), adam_cache1=array([0., 0., 0., 0.]), adam_cache2=array([0., 0., 0., 0.]))
Mean Squared Error
5331.1088139477515

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=1, weights=array([ 0.49999975,  0.5       ,  0.5       , -0.5       ]), adam_cache1=array([ -0.19955846, -35.01126951, -13.99591788,  21.01040324]), adam_cache2=array([3.98235776e-03, 1.22578899e+02, 1.95885717e+01, 4.41437044e+01]))
Mean Squared Error
4655.437326573884

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999

Mean Squared Error
116.95852749746362

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=22, weights=array([ 1.55084209,  9.49072976,  4.67313681, -7.40946521]), adam_cache1=array([-5.39414687e-02, -1.06455621e+02,  1.66414800e+01,  5.58745119e+00]), adam_cache2=array([1.26881917e-02, 8.96019991e+02, 7.82036456e+01, 2.09711530e+02]))
Mean Squared Error
98.49216481543766

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=23, weights=array([ 1.61857358,  9.76062592,  4.51139963, -7.41006341]), adam_cache1=array([-9.28664850e-02, -9.76028250e+01,  1.73378603e+01,  1.05240133e-01]), adam_cache2=array([1.28719224e-02, 8.95445372e+02, 7.86826514e+01, 2.11925870e+02]))
Mean Squared Error
84.643983959659

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=24, weights=array([ 1.70747836, 10.00879198,  4.34749685, -7.38243484]), adam_cache1=array([ -0.12111908, -88.6


Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=45, weights=array([ 1.87921018, 10.81076661,  4.14459769, -5.52916564]), adam_cache1=array([-0.03138377, 23.45650463, -4.49184692,  0.06324589]), adam_cache2=array([1.33802267e-02, 9.03344752e+02, 8.02687426e+01, 2.23142884e+02]))
Mean Squared Error
35.38906796932144

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=46, weights=array([ 1.91588867, 10.72563377,  4.18674869, -5.54143294]), adam_cache1=array([-0.03969491, 23.93834887, -3.53207374,  1.7146919 ]), adam_cache2=array([1.33799556e-02, 9.03240880e+02, 8.02145439e+01, 2.23194562e+02]))
Mean Squared Error
30.913805530712473

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=47, weights=array([ 1.95648974, 10.63917009,  4.21713133, -5.56424501]), adam_cache1=array([-0.04350342, 24.07403522, -2.52074515,  3.15805652]), adam_cache2=array([1.33726253e-0


Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=68, weights=array([ 2.00758157,  9.72260172,  3.92794172, -6.14282719]), adam_cache1=array([-0.01374219, -1.93035067, -1.18286649, -1.81736957]), adam_cache2=array([1.31602275e-02, 8.86481059e+02, 7.90611988e+01, 2.19790471e+02]))
Mean Squared Error
7.56676466796805

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=69, weights=array([ 2.01994226,  9.73440609,  3.94702303, -6.12428964]), adam_cache1=array([-0.01096752, -2.71850683, -1.31230892, -2.12573883]), adam_cache2=array([1.31472634e-02, 8.85690852e+02, 7.89882745e+01, 2.19594700e+02]))
Mean Squared Error
7.088816003756412

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=70, weights=array([ 2.02815421,  9.74921932,  3.96697339, -6.10374752]), adam_cache1=array([-0.007233  , -3.3865309 , -1.36202061, -2.33838787]), adam_cache2=array([1.31348119e-02,

Mean Squared Error
4.190994294404467

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=91, weights=array([ 2.00472924, 10.08051436,  3.9692904 , -5.99437433]), adam_cache1=array([ 0.00234931, -0.3536547 ,  0.01538927,  0.92212156]), adam_cache2=array([1.28706916e-02, 8.66885034e+02, 7.73029304e+01, 2.14909689e+02]))
Mean Squared Error
4.1923398450044544

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=92, weights=array([ 2.00050638, 10.08074992,  3.97079472, -6.00307131]), adam_cache1=array([ 0.00322932, -0.04675044, -0.0891542 ,  0.85940304]), adam_cache2=array([1.28579452e-02, 8.66025523e+02, 7.72266884e+01, 2.14694866e+02]))
Mean Squared Error
4.187938072014309

Weights
Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=93, weights=array([ 1.99577393, 10.079583  ,  3.97381479, -6.01093333]), adam_cache1=array([ 0.00359862,  0.23028789, -0.17797805,  0.7725161


Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=6, weights=array([[ 0.66626954,  2.63156117,  1.18306754, -1.78302354],
       [ 0.57123334,  3.50963383,  1.77852027, -2.25594051]]), adam_cache1=array([[  0.73675805, -30.640608  ,  27.19126693,  26.72570997],
       [  1.54404913, -47.44511613,  55.87071503,  43.09606749]]), adam_cache2=array([[4.49418008e-02, 1.61946373e+02, 8.91366377e+01, 8.85899581e+01],
       [1.30104765e-01, 3.65496818e+02, 2.50984014e+02, 2.08906410e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=6, weights=array([[0.36686089, 1.51736461, 2.82937523]]), adam_cache1=array([[  0.42586469,  14.21803441, -30.1406272 ]]), adam_cache2=array([[3.26392354e-02, 5.63648918e+02, 1.47855681e+03]]))]
Mean Squared Error
1155.9325999279995

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=7, weights=array([[ 0.36586329,  2.66932997,  0.


Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=16, weights=array([[ 0.15046192,  2.51487728, -0.18689232, -1.5670811 ],
       [ 0.71953202,  3.6807606 ,  1.13897471, -2.17118561]]), adam_cache1=array([[ -0.10316561,  -5.00821259,  -0.1469685 ,   1.18223791],
       [ -1.01288284, -54.29275874, -63.86731577,  28.14595063]]), adam_cache2=array([[5.80520430e-02, 2.30385509e+02, 1.21979063e+02, 1.33756031e+02],
       [2.61807236e-01, 7.71049254e+02, 5.80671295e+02, 4.39449550e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=16, weights=array([[1.11759627, 0.48093015, 2.43925746]]), adam_cache1=array([[  -0.56123646,  -51.26563325, -115.85415914]]), adam_cache2=array([[5.78939004e-02, 1.70121840e+03, 3.71369743e+03]]))]
Mean Squared Error
66.22657680828276

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=17, weights=array([[ 0.16565461,  2.53642107


Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=26, weights=array([[-0.42605345,  2.20675394, -0.30015195, -1.28554294],
       [ 0.05961392,  3.5196025 ,  1.99465137, -2.05983121]]), adam_cache1=array([[ 0.15346605,  5.26695108,  4.82793563, -2.85127252],
       [ 0.31001933, -1.37670865, 14.36300661,  6.87004021]]), adam_cache2=array([[6.10398602e-02, 2.34386913e+02, 1.22841175e+02, 1.34540863e+02],
       [3.41921324e-01, 9.25300201e+02, 6.44032266e+02, 4.94837737e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=26, weights=array([[0.88160862, 0.1693132 , 2.25063681]]), adam_cache1=array([[ 0.03991885, -2.9988356 , 17.17959795]]), adam_cache2=array([[6.95965462e-02, 1.95665185e+03, 4.40713378e+03]]))]
Mean Squared Error
160.19040746281476

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=27, weights=array([[-0.46811645,  2.18563188, -0.33618052

26.76742462802575

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=36, weights=array([[-0.5532824 ,  2.0048976 , -0.53415031, -1.07366735],
       [ 0.58890029,  3.87099601,  1.57343007, -2.3455921 ]]), adam_cache1=array([[ 7.55217107e-02,  6.93847920e+00, -3.44223233e-01,
        -5.93819743e+00],
       [-4.23756253e-03,  1.34353215e+01, -1.59519753e+00,
        -1.44467678e+01]]), adam_cache2=array([[6.10572336e-02, 2.33827386e+02, 1.21860194e+02, 1.34399669e+02],
       [3.66541749e-01, 9.83483343e+02, 6.44221712e+02, 5.27097265e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=36, weights=array([[1.56983312, 0.49837515, 2.33565629]]), adam_cache1=array([[-3.48218527e-02,  1.86097173e+01,  4.24392259e+01]]), adam_cache2=array([[7.40401665e-02, 2.05999019e+03, 4.65806493e+03]]))]
Mean Squared Error
16.64239023911288

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, 

Mean Squared Error
26.55798509491582

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=46, weights=array([[-0.86768786,  1.85263927, -0.36866233, -0.92927432],
       [ 0.30537196,  3.92551371,  1.92899567, -2.40443135]]), adam_cache1=array([[-0.00958856,  1.88415457,  0.7090228 , -0.70185318],
       [-0.16247214,  1.22521272,  3.65501579,  2.60733958]]), adam_cache2=array([[6.08299365e-02, 2.31941157e+02, 1.21169372e+02, 1.33360927e+02],
       [3.72299799e-01, 9.86026853e+02, 6.51057113e+02, 5.30393678e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=46, weights=array([[1.34051738, 0.48882   , 2.43069388]]), adam_cache1=array([[-0.06697795, -0.38310837,  4.85721502]]), adam_cache2=array([[7.50275517e-02, 2.05292102e+03, 4.72792992e+03]]))]
Mean Squared Error
23.12205281940091

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=47, weights=array([[-

Mean Squared Error
6.0579297989330945

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=56, weights=array([[-0.84798731,  1.78728002, -0.45573136, -0.86657715],
       [ 0.43328097,  3.94001007,  1.73665341, -2.42288156]]), adam_cache1=array([[ 0.0307192 , -0.5086754 , -0.91266157, -0.11359302],
       [ 0.13881216, -5.4565451 , -4.2505153 ,  1.63486818]]), adam_cache2=array([[6.03285962e-02, 2.29860729e+02, 1.20196452e+02, 1.32105357e+02],
       [3.71304656e-01, 9.81909758e+02, 6.50266214e+02, 5.26898870e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=56, weights=array([[1.42626962, 0.51164186, 2.3707078 ]]), adam_cache1=array([[  0.0708615 ,  -3.8046287 , -13.12154662]]), adam_cache2=array([[7.48018530e-02, 2.03648463e+03, 4.72826421e+03]]))]
Mean Squared Error
7.821997152085809

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=57, weights=array

Mean Squared Error
5.686863805735459

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=66, weights=array([[-0.81988911,  1.73455909, -0.40905758, -0.81691543],
       [ 0.46062353,  3.89678945,  1.82540017, -2.38214748]]), adam_cache1=array([[-0.01715576, -0.30018916,  0.17527708,  0.46614288],
       [-0.05722672, -2.96393307,  1.05298874,  3.38282942]]), adam_cache2=array([[5.97683819e-02, 2.27722831e+02, 1.19013395e+02, 1.30865538e+02],
       [3.68363007e-01, 9.75589032e+02, 6.44115341e+02, 5.23434582e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=66, weights=array([[1.39791418, 0.4692962 , 2.34109751]]), adam_cache1=array([[-0.01228589, -3.48255541, -7.31510943]]), adam_cache2=array([[7.41753584e-02, 2.01985288e+03, 4.69885239e+03]]))]
Mean Squared Error
4.416684889369899

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=67, weights=array([[-

Mean Squared Error
4.949193833981243

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=76, weights=array([[-0.79687425,  1.72144589, -0.42950109, -0.80066991],
       [ 0.43203805,  3.9052922 ,  1.77472165, -2.38254623]]), adam_cache1=array([[-0.01286655, -0.0458387 , -0.42704991, -0.12716204],
       [-0.02464359, -0.81292903, -1.95525275, -0.17843618]]), adam_cache2=array([[5.91860112e-02, 2.25483340e+02, 1.17848262e+02, 1.29585784e+02],
       [3.65013397e-01, 9.66513612e+02, 6.38164780e+02, 5.18754078e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=76, weights=array([[1.31353218, 0.47816897, 2.33058326]]), adam_cache1=array([[ 1.01335835e-03, -1.51992696e-01, -2.61854784e+00]]), adam_cache2=array([[7.35003229e-02, 2.00036828e+03, 4.65744294e+03]]))]
Mean Squared Error
4.489832121476954

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=77, weigh

 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=85, weights=array([[1.28028221, 0.47540734, 2.33546619]]), adam_cache1=array([[0.01915799, 0.20306715, 2.002706  ]]), adam_cache2=array([[7.28573376e-02, 1.98256628e+03, 4.61754374e+03]]))]
Mean Squared Error
4.223265904635736

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=86, weights=array([[-0.74114816,  1.71299629, -0.42148129, -0.7979203 ],
       [ 0.4490433 ,  3.90187943,  1.78655301, -2.39215816]]), adam_cache1=array([[-0.00535482,  0.03878116,  0.21443735,  0.04408874],
       [ 0.01430674, -0.05519202,  1.11611676,  0.38403436]]), adam_cache2=array([[5.85994067e-02, 2.23249712e+02, 1.16684492e+02, 1.28299244e+02],
       [3.61446210e-01, 9.57155016e+02, 6.32037616e+02, 5.13673474e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=86, weights=array([[1.27108718, 0.47662984, 2.33475842]]), adam_cache1

Mean Squared Error
4.01616522554949

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=96, weights=array([[-0.679121  ,  1.71192653, -0.41600091, -0.79221496],
       [ 0.47165248,  3.9075333 ,  1.79446   , -2.38566438]]), adam_cache1=array([[-0.00740731,  0.09731927, -0.094291  , -0.07723793],
       [ 0.0035282 ,  0.34667835, -0.40547788, -0.29624813]]), adam_cache2=array([[5.80189526e-02, 2.21032759e+02, 1.15524296e+02, 1.27024910e+02],
       [3.57885704e-01, 9.47756315e+02, 6.25777747e+02, 5.08627463e+02]])),
 Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=96, weights=array([[1.23247475, 0.47794662, 2.3403175 ]]), adam_cache1=array([[0.01226263, 0.407236  , 0.50767598]]), adam_cache2=array([[7.20661545e-02, 1.96101660e+03, 4.56797854e+03]]))]
Mean Squared Error
4.0659963719708765

Weights
[Weights(adam_gradient=AdamGradient(learning_rate=0.5, decay1=0.9, decay2=0.999), time=97, weights=array([[-0.6