In [None]:
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import numpy as np
np.random.seed(42)

import numpy as np
import pandas as pd
from numpy.linalg import lstsq, cholesky
from scipy.linalg import sqrtm, schur, block_diag
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import MinMaxScaler, StandardScaler

from ClassBSplines import BSpline
from TensorProductSplines import TensorProductSpline
from PenaltyMatrices import PenaltyMatrix
from Smooth import Smooths as s

from ClassBSplines import BSpline as b
from TensorProductSplines import TensorProductSpline as t

class Model():
    
    possible_penalties = { "smooth": PenaltyMatrix().D2_difference_matrix, 
                           "inc": PenaltyMatrix().D1_difference_matrix,
                           "dec": PenaltyMatrix().D1_difference_matrix,
                           "conc": PenaltyMatrix().D2_difference_matrix, 
                           "conv": PenaltyMatrix().D2_difference_matrix,
                           "peak": None }
    
    def __init__(self, descr, n_param=20):
        """
        descr : tuple - ever entry describens one part of 
                        the model, e.g.
                        descr =( ("s(0)"", "smooth", 10),
                                 ("s(1)"", "inc", 10), 
                                (t(0,1), 5) 
                               )
                        with the scheme: (type of smooth, number of knots)
        
        !!! currently only smooths with BSpline basis !!!
        
        TODO:
            [ ] incorporate tensor product splines
        """
        self.description_str = descr
        self.coef_ = None
        self.description_dict = { t: (p, n) for t, p, n  in self.description_str}
        self.smooths = None

    def create_basis(self, X):
        """Create the unpenalized BSpline basis for the data X.
        
        Parameters:
        ------------
        X : np.ndarray - data
        n_param : int  - number of parameters for the spline basis 
        
        """
        assert (len(self.description_str) == X.shape[1]),"Nr of smooths must match Nr of predictors!"
       
        self.smooths = [ 
            s(x_data=X_norm[:, int(k[2])-1], 
            n_param=int(v[1]), 
            penalty=v[0]) for k, v in self.description_dict.items()
        ]    
        self.basis_without_penalty = np.concatenate([smooth.basis for smooth in self.smooths], axis=1) 
        
        return 
    
    def create_penalty_block_matrix(self, X, beta_test=None):
        """Create the penalty block matrix specified in self.description_str.
        
        Looks like: ------------
                    |p1 0  0  0|  
                    |0 p2  0  0|
                    |0  0 p3  0|
                    |0  0  0 p4|
                    ------------
        where p_i is a a matrix according to the specified penalty.

        Parameters:
        ---------------
        X : np.ndarray  - data
        
        TODO:
            [x]  include the weights !!! 
        
        """
        assert (self.smooths is not None), "Run Model.create_basis() first!"
        
        if beta_test is None:
            beta_test = np.zeros(self.basis_without_penalty.shape[1])
        
        idx = 0      
        basis_with_penalty = []
        for smooth in self.smooths:
            
            n = smooth.basis.shape[1]
            b = beta_test[idx:idx+n]
            
            D = smooth.penalty_matrix
            V = check_constraint(beta=b, constraint=smooth.penalty)

            basis_with_penalty.append(D.T @ V @ D )
            idx += n
            
        #self.basis_with_penalty_and_weight = np.concatenate(basis_with_penalty, axis=1)
        self.penalty_matrix_list = basis_with_penalty
        self.complete_penalty_matrix_with_weigths = block_diag(*basis_with_penalty)

        return
        
    def fit(self, X, y, plot_=True):
        """Lstsq fit try 2 using Smooths.
        
        Parameters:
        -------------
        X : pd.DataFrame or np.ndarray
        y : pd.DataFrame or np.array
        plot_ : boolean
        
        TODO:
            [ ] check the iterative fit
        """
        
        # create the basis for the initial fit without penalties
        self.create_basis(X)    
        print("Initial Least Squares fit to get beta_0!")
        fitting = lstsq(a=self.basis_without_penalty, b=y, rcond=None)
        beta_0 = fitting[0].ravel()

        # pad y for the penalized fit iterations
        #self.create_penalty_block_matrix(X, beta_test=None)
        #length_diff = self.basis_with_penalty_and_weight.shape[0] - y.shape[0]
        #if length_diff == 0:
        #    y_fit = y
        #else:
        #    y_fit = np.append(y, np.zeros((length_diff, 1)))

        beta = beta_0
        for i in range(3):
            print("Create basis with penalty and weight")
            self.create_penalty_block_matrix(X, beta_test=beta)
            
            print("Least squares fit iteration ", i)
            B = self.basis_without_penalty
            D_c = self.complete_penalty_matrix
        
            BB = B.T @ B
            DVD = D_c.T @ D_c
            By = B.T @ y
            
            beta_new = np.inv(BB + DVD) @ By
            #fitting = lstsq(a=self.basis_with_penalty_and_weight, b=y_fit, rcond=None)
            
            sse_beta = np.sum( (beta - beta_new)**2 )
            print("SSE_beta = ", np.round(sse_beta, 8))
            if sse_beta < 1e-3:
                print("Iteration converged!")
                break
            else:
                print("Iteration continous!")
                b = fitting[0].ravel()
            
        
        self.coef_ = fitting[0].ravel()
        fitting = (*fitting, self.basis_without_penalty @ self.coef_)
    
        self.mse = mean_squared_error(y, fitting[-1])
        print(f"Mean squared error on the data: {np.round(self.mse, 4)}")
        
        if plot_:
            fig = self.plot_xy(x=X[:,0], y=y.reshape((-1,)), name="Data")
            fig.add_trace(go.Scatter(x=X[:,0], y=fitting[-1], name="Fit", mode="markers"))
            fig.show()
        return fitting
    
    # not trusted
    def predict(self, X, y=None, plot_=False):
        """Prediction of the trained model on the data in X."""
        if self.coef_ is None:
            print("Model untrained!")
            return
        
        self.create_basis(X, penalty=None)
        # prediction_basis, pad_y = self.multiple_smooths(X.values)
        #if pad_y is False:
        #    pass
        #else:
        #    prediction_basis = prediction_basis[:-len(pad_y)]
        
        print("Shape prediction basis: ", self.basis_without_penalty.shape)
        print("Shape coef_: ", self.coef_.shape)
        pred = self.basis_without_penalty @ self.coef_
        if plot_:
            fig = self.plot_xy(x=X[:,0], y=pred, name="Prediction")
            if type(y) is not None:
                print("shape y: ", y.shape)
                fig.add_trace(go.Scatter(x=X[:,0], y=y.reshape((-1,)), name="Data", mode="markers"))
            fig.show()
        return pred
    
    def plot_xy(self, x, y, title="Titel", name="Data", xlabel="xlabel", ylabel="ylabel"):
        """Basic plotting function."""
        fig = go.Figure()
        fig.add_trace(go.Scatter(x=x, y=y, name=name, mode="markers"))
        fig.update_layout(title=title)
        fig.update_xaxes(title=xlabel)
        fig.update_yaxes(title=ylabel)
        return fig
    
    def plot_basis(self, matrix):
        """Plot the matrix."""
        fig = go.Figure(go.Image(z=matrix)).show()
        return
                        
    
def check_constraint(beta, constraint, print_idx=False):
    """Checks if beta fits the constraint."""
    V = np.zeros((len(beta), len(beta)))
    b_diff = np.diff(beta)
    b_diff_diff = np.diff(b_diff)
    if constraint is "inc":
        v = [0 if i > 0 else 1 for i in b_diff] 
    elif constraint is "dec":
        v = [0 if i < 0 else 1 for i in b_diff] 
    elif constraint is "conv":
        v = [0 if i > 0 else 1 for i in b_diff_diff]
    elif constraint is "conc":
        v = [0 if i < 0 else 1 for i in b_diff_diff]
    elif constraint is "no":
        v = np.zeros(len(beta))
    elif constraint is "smooth":
        v = np.ones(len(beta))
    else:
        print(f"Constraint [{constraint}] not implemented!")
        return
    
    V = np.diag(v)
    if print_idx:
        print("Constraint violated at the followin indices: ")
        print([idx for idx, n in enumerate(v) if n == 1])
    return V
    
    
    