**Tensor product spline implementation**

In [10]:
# convert jupyter notebook to python script
!jupyter nbconvert --to script TensorProductSplines.ipynb

[NbConvertApp] Converting notebook TensorProductSplines.ipynb to script
[NbConvertApp] Writing 2532 bytes to TensorProductSplines.py


In [1]:
import numpy as np
import plotly.graph_objects as go

from plotly.subplots import make_subplots
from scipy.sparse import kron

from ClassBSplines import BSpline
from PenaltyMatrices import PenaltyMatrix


class TensorProductSpline(BSpline):
    """Implementation of the tensor product spline according to Simon Wood, 2006."""
    
    def __init__(self, x1=None, x2=None):
        """It is important that len(x1) == len(x2)."""
        self.x1 = x1
        self.x2 = x2
        self.basis = None
        
    def tensor_product_spline_2d(self, k1=5, k2=5, print_shapes=False):
        """Calculate the TPS from two 1d B-splines.
        
        Parameters:
        -------------
        k1 : integer   - Number of knots for the first B-spline.
        k2 : integer   - Number of knots for the second B-Spline.
        print_shape : bool - prints the dimensions of the basis matrices.
        
        """
        self.k1 = k1
        self.k2 = k2
        BSpline_x1 = BSpline(self.x1)
        BSpline_x2 = BSpline(self.x2)
        BSpline_x1.b_spline_basis(k=self.k1)
        BSpline_x2.b_spline_basis(k=self.k2)
        
        BSpline_x1.plot_basis("1st B-Spline basis")
        BSpline_x2.plot_basis("2nd B-Spline basis")
        
        
        self.X1 = BSpline_x1.basis
        self.X2 = BSpline_x2.basis
        self.basis = kron(self.X1, self.X2).toarray()

        if print_shapes:
            print("Shape of the first basis: ", self.X1.shape)
            print("Shape of the second basis: ", self.X2.shape)
            print("Shape of the tensor product basis: ", self.basis.shape)
        return
        
    def plot_basis(self):
        """Plot the tensor product spline basis matrix for a 2d TPS."""
        fig = go.Figure()
        x1g, x2g = np.meshgrid(self.x1, self.x2)
        #print("x1g: ", x1g.shape)
        #print("x2g: ", x2g.shape)
        for i in range(self.basis.shape[1]):
            fig.add_trace(
                go.Surface(
                    x=x1g, y=x2g,
                    z=self.basis[:,i].reshape((self.X2.shape[0], self.X1.shape[0])),
                    name=f"TPS Basis {i+1}",
                    showscale=False
                )
            )
                
        fig.update_layout(
            scene=dict(
                xaxis_title="x1",
                yaxis_title="x2",
                zaxis_title=""
            ),
            title="Tensor product spline basis", 
        )
        fig.show()
        return

    
    def plot_basis_individuel(self):

        dim = self.basis.shape
        dim_resh_1 = int(np.sqrt(dim[0]))
        dim_resh_2 = int(np.sqrt(dim[1]))

        
        fig = make_subplots(rows=dim_resh_2, cols=dim_resh_2)

        for i in range(dim[1]):
            data = self.basis[:,i].reshape((dim_resh_1, dim_resh_1))
            if i < dim_resh_2:
                fig.add_trace(go.Heatmap(z=data), row=1, col=i+1)
            elif i >= dim_resh_2 and i < 2*dim_resh_2:
                fig.add_trace(go.Heatmap(z=data), row=2, col=i+1-dim_resh_2)
            elif i >= 2*dim_resh_2 and i < 3*dim_resh_2:
                fig.add_trace(go.Heatmap(z=data), row=3, col=i+1-2*dim_resh_2)
            elif i >= 3*dim_resh_2 and i < 4*dim_resh_2:
                fig.add_trace(go.Heatmap(z=data), row=4, col=i+1-3*dim_resh_2)
            elif i >= 4*dim_resh_2:
                fig.add_trace(go.Heatmap(z=data), row=5, col=i+1-4*dim_resh_2)

        fig.update_traces(showscale=False)
        fig.show()


In [3]:
x1 = np.linspace(0,3,6)
x2 = np.linspace(0,3,6)
x1g, x2g = np.meshgrid(x1, x2)

y = np.sin(np.cos(x1) * x2g)

import plotly.express as px

px.scatter_3d(x=x1g.ravel(), y=x2g.ravel(), z=y.ravel()).show()


In [4]:
T = TensorProductSpline(x1g.ravel(), x2g.ravel())
T.tensor_product_spline_2d(k1=5, k2=5)
T.plot_basis()
#T.plot_basis_individuel()

'x' from initialization is used for the spline basis!
'x' from initialization is used for the spline basis!


In [18]:
from numpy.linalg import lstsq

fit = lstsq(a=T.basis, b=y.ravel())

LinAlgError: Incompatible dimensions

In [19]:
T.basis.shape

(1296, 25)

In [22]:
y.ravel().shape

(36,)

In [24]:
px.imshow(T.basis[:,0].reshape((36,36)))
