# Developer Guide

This notebook acts to serve as an interactive example of the [developer's guide](https://orbithunter.readthedocs.io/en/latest/guide.html). It's goal is to show how easy it is to implement your equations,
and the relatively small number of functions and methods required to yield access to all of the tools that orbithunter provides. To do so, we shall use a simple equation as a toy example.

Let's create an `Orbit` subclass which 
\begin{equation}
    x^2 - \log(x) - a = 0
\end{equation}

where x and a are vectors of arbitrary (but identical) lengths. 

We can now move onto the methods which actually evaluate the functions. I will
be working under the assumption that the default cost function is to be used:
$\frac{F^2}{2}$. We must implement the transforms first. Because we have parameters which are real-valued, it will benefit us to make the equations real valued.


In [None]:
class OrbitLZ(orb.Orbit):

    def transform(self, to=None):
        if self.state is None:
            raise ValueError(
                "Trying to transform an unpopulated {} instance.".format(str(self))
            )
        elif self.basis is None:
            raise ValueError(
                "Trying to transform state with unknown basis".format(str(self))
            )

        if to == "xyz":
            if self.basis == "modes":
                return self._inv_time_transform()
            else:
                return self
        elif to == "modes":
            if self.basis == "xyz":
                return self._time_transform()
            else:
                return self
        else:
            raise ValueError("Trying to transform to unrecognized basis.")


    def dimensions(self):
        """
        Dimensions of the spatiotemporal tile (configuration space).

        Returns
        -------
        tuple :
            Tuple of dimensions, typically this will take the form (t, x, y, z) for (3+1)-D spacetime

        Notes
        -----
        Because this is usually a subset of self.parameters, it does not use the property decorator. This method
        is purposed for readability and other reasons where only dimensions are required.

        """
        return self.t, 
    
    @staticmethod
    def bases_labels():
        return "xyz", "modes"
    
    @staticmethod
    def parameter_labels():
        return 't', 'sigma', 'rho', 'b'
    
    @staticmethod
    def discretization_labels():
        """
        Labels for time and spatial dimensions; space is treated as a vector of 3 
        coordinates

        """
        return 'n', 'ijk'
    
    @staticmethod
    def dimension_labels():
        return 't', 'ijk'
    
    @staticmethod
    def minimal_shapes():
        return 2, 3
    
    @staticmethod
    def minimal_shape_increments():
        return 1, 0
    
    @staticmethod
    def continuous_dimensions():
        return True, False

    @staticmethod
    def _default_shape():
        return 4, 3

    @staticmethod
    def _default_parameter_ranges():
        return {'t': (0, 100), 'sigma':10, 'rho':28, 'b':8./3.}
    
    @staticmethod
    def _default_constraints():
        return {'t': False, 'sigma':True, 'rho': True, 'b':True}
    
    @classmethod
    def _dimension_indexing_order(cls):
        return False, True
    
    def orbit_vector(self):
        return np.concatenate((self.state.ravel(), self.parameters)).reshape(-1, 1)
        
    def plot(self):
        from mpl_toolkits import mplot3d
        import numpy as np
        import matplotlib.pyplot as plt

        fig = plt.figure()
        ax = plt.axes(projection='3d')
        xyz = self.transform(to="xyz").state
        ax.plot3D(xyz[:,0], xyz[:, 1], xyz[:, 2], 'gray')
        ax.set_xlabel('x')
        ax.set_ylabel('y')
        ax.set_zlabel('z')
        plt.show()
        
        
    def eqn(self, **kwargs):
        # Because x,y,z, independent .dt() from OrbitKS can be used!
        modes = self.transform(to="modes").copy()
        xyz = self.transform(to="xyz").copy()
        eqn_orbit = modes.dt()
        #can exploit broadcasting here: gives x^2, xy, xz
        xyz.state[:, 1] = (xyz.state[:, 0]*xyz.state[:, 1])
        xyz.state[:, 2] = (xyz.state[:, 0]*xyz.state[:, 2])
        
        fxyz = xyz.transform(to='modes')
        xy = fxyz.state[:, 1]
        xz = fxyz.state[:, 2]
        
        x = modes.state[:, 0]
        y = modes.state[:, 1]
        z = modes.state[:, 2]
        
        xt, yt, zt = eqn_orbit.state.T
        
        fx = xt-self.sigma * y + self.sigma * x
        fy = yt-self.rho * x + y + xz
        fz = zt-xy + self.b * z


        return self.__class__(state=np.concatenate((fx[:,None],fy[:,None],fz[:,None]),axis=1), basis='modes', parameters=self.parameters)
        
    def matvec(self, other, **kwargs):
        # Because x,y,z, independent .dt() from OrbitKS can be used!

        return matvec_orbit
        
    def rmatvec(self, other, **kwargs):
        # Because x,y,z, independent .dt() from OrbitKS can be used!

        assert other.basis == 'modes' and self.basis=='modes'

        return rmatvec_orbit
        
    def jacobian(self, **kwargs):
        raise AttributeError(f"Jacobian is not yet defined for {str(self)}")
        
    def preconditioning_parameters(self):
        return self.t
    
    def costgrad(self, *args, **kwargs):
        """
        Derivative of $1/2 |F|^2$

        Parameters
        ----------
        eqn : OrbitKS
            Orbit instance whose state equals DAE evaluated with respect to current state, i.e. F(v)
        kwargs :
            Any keyword arguments relevant for rmatvec, eqn, or 'preconditioning'.

        Returns
        -------
        gradient :
            OrbitKS instance whose state contains $(dF/dv)^T * F  = J^T F$

        Notes
        -----
        In this case, "preconditioning" is numerical rescaling of the gradient used as a numerical tool in descent
        methods.

        """

        if args:
            eqn = args[0]
        else:
            eqn = self.eqn()

        grad = self.rmatvec(eqn, **kwargs)
        if kwargs.get('preconditioning'):
            grad = grad.precondition(
                **{"pmult": self.parameters, **kwargs}
            )
        return grad


    def _parse_state(self, state, basis, **kwargs):
        self.basis = 'physical'
        if isinstance(state, np.ndarray):
            if len(state.shape) != 2:
                raise ValueError('"state" array must be two-dimensional')
            self.state = state
        else:
            self.state = np.array([], dtype=float).reshape(0, 0)
        
        if self.size > 0:
            if n < self.minimal_shape()[0] or ijk < self.minimal_shape()[1]:
                warn_str = "\nminimum discretization requirements not met; methods may not work as intended."
                warnings.warn(warn_str, RuntimeWarning)
            self.basis = basis
            self.discretization = n, ijk
        else:
            self.discretization = None
            self.basis = None