**Tensor product spline implementation**

In [3]:
# convert jupyter notebook to python script
# !jupyter nbconvert --to script TensorProductSplines.ipynb

[NbConvertApp] Converting notebook TensorProductSplines.ipynb to script
[NbConvertApp] Writing 2084 bytes to TensorProductSplines.py


In [102]:
import numpy as np
import plotly.graph_objects as go

from scipy.sparse import kron

from ClassBSplines import BSpline
from PenaltyMatrices import PenaltyMatrices


class TensorProductSpline(BSpline):
    """Implementation of the tensor product spline according to Simon Wood, 2006"""
    
    def __init__(self, x1, x2):
        self.x1 = x1
        self.x2 = x2
        self.basis = None
        
    def tensor_product_spline_2d(self, k1=5, k2=5, print_shapes=False):
        """Calculate the TPS from two 1d B-splines.
        
        Parameters:
        -------------
        k1 : integer   - Number of knots for the first B-spline.
        k2 : integer   - Number of knots for the second B-Spline.
        print_shape : bool - prints the dimensions of the basis matrices.
        
        """
        self.k1 = k1
        self.k2 = k2
        BSpline_x1 = BSpline(self.x1)
        BSpline_x2 = BSpline(self.x2)
        BSpline_x1.b_spline_basis(k=self.k1)
        BSpline_x2.b_spline_basis(k=self.k2)
        self.X1 = BSpline_x1.basis
        self.X2 = BSpline_x2.basis
        self.basis = kron(self.X1, self.X2).toarray()

        if print_shapes:
            print("Shape of the first basis: ", self.X1.shape)
            print("Shape of the second basis: ", self.X2.shape)
            print("Shape of the tensor product basis: ", self.basis.shape)
        return
        
    def plot_tensor_product_spline_basis(self):
        """Plot the tensor product spline basis matrix for a 2d TPS."""
        fig = go.Figure()
        x1g, x2g = np.meshgrid(self.x1, self.x2)
        print("x1g: ", x1g.shape)
        print("x2g: ", x2g.shape)
        B = self.basis.reshape((-1, self.X2.shape[0], self.X1.shape[0]))
        print("Shape B: ", B.shape)
        for i in range(self.basis.shape[1]):
            fig.add_trace(
                go.Surface(
                    x=x1g, y=x2g, 
                    z=self.basis[:,i].reshape((self.X2.shape[0], self.X1.shape[0])),
                    name=f"TPS Basis {i+1}")
            )
            if i == 0: 
                print("\t Shape z: ", self.basis[:,i].reshape((self.X2.shape[0], self.X1.shape[0])).shape)
                print("\t Shape x1g: ", x1g.shape)
                print("\t Shape x2g: ", x2g.shape)
                
        fig.update_layout(
            scene=dict(
                xaxis_title="x1",
                yaxis_title="x2",
                zaxis_title=""
            ),
            title="Tensor product spline basis", 
        )
        fig.show()
        return


In [104]:
import numpy as np
x1 = np.linspace(0,10,25)
x2 = np.linspace(-3,-1,25)

TPS = TensorProductSpline(x1, x2)
TPS.tensor_product_spline_2d(print_shapes=True)
TPS.plot_tensor_product_spline_basis()

Shape of the first basis:  (25, 5)
Shape of the second basis:  (25, 5)
Shape of the tensor product basis:  (625, 25)
x1g:  (25, 25)
x2g:  (25, 25)
Shape B:  (25, 25, 25)
	 Shape z:  (25, 25)
	 Shape x1g:  (25, 25)
	 Shape x2g:  (25, 25)


In [101]:
import pandas as pd
z_data = pd.read_csv('https://raw.githubusercontent.com/plotly/datasets/master/api_docs/mt_bruno_elevation.csv')
z = z_data.values
sh_0, sh_1 = z.shape
x, y = np.linspace(0, 1, sh_0), np.linspace(0, 1, sh_1)
print(x.shape)
print(y.shape)
print(z.shape)
print("+++++++++++++")
print(x1.shape)
print(x2.shape)
K = TPS.basis.reshape((-1, len(x1), len(x2)))
fig = go.Figure()
for i in range(K.shape[0]):
    fig.add_trace(go.Surface(z=K[i]))
fig.show()

(25,)
(25,)
(25, 25)
+++++++++++++
(50,)
(25,)


In [88]:
fig.update_layout(title='Mt Bruno Elevation', autosize=False,
                  width=500, height=500,
                  margin=dict(l=65, r=50, b=65, t=90))
fig.show()

In [79]:
K = TPS.basis
KK = K.reshape((-1, 50, 25))
KK[1].shape

(50, 25)

In [41]:
x1g, x2g = np.meshgrid(x1, x2)
fig = go.Figure()
for i in range(TPS.basis.shape[1]):
    fig.add_trace(
        go.Surface(
            x=x1g, y=x2g, 
            z=TPS.basis[:,i].reshape((TPS.X1.shape[0], TPS.X2.shape[0])),
            name=f"TPS Basis {i+1}")
    )

print(x1g.shape)
TPS.basis[:,0].reshape((TPS.X1.shape[0], TPS.X2.shape[0])).shape

Shape of the first basis:  (100, 12)
Shape of the second basis:  (25, 3)
Shape of the tensor product basis:  (2500, 36)
(25, 100)


(100, 25)