In [1]:
import torch
import numpy as np
import plotly.graph_objects as go

from sgpr.gp_regression import GP, SoR, DTC, FITC

torch.set_default_dtype(torch.float64)

# 実験（1）

* 作成したコードが正常に動作するか


## 準備

In [2]:
class Latent_Function():
    def __init__(self, noise_level=0.1):
        self.noise_level = noise_level
        self.noise = None

    def _f(self, X):
        # ここを好きにカスタマイズ
        tmp = np.sin(X) + np.cos(X)
        return tmp
    
    def f(self, X, observed=False):
        tmp = self._f(X)

        if observed is True:
            if isinstance(X, np.ndarray):
                noise = np.random.normal(loc=tmp, scale=self.noise_level)
                return noise
            elif isinstance(X, torch.Tensor):
                noise = torch.normal(tmp, self.noise_level)
                return noise
        
        else:
            if isinstance(X, np.ndarray):
                return np.array(tmp)
            elif isinstance(X, torch.Tensor):
                return tmp


def make_data(f, X=None, X_pred=None):
    if X is None:
        X_normal = torch.randn(50) * 1.5 + 4 
        X_normal = torch.clip(X_normal, 0, 10) 
        X_uniform = torch.rand(50) * 6 + 2  
        X_combined = torch.cat([X_normal, X_uniform])
        X, _ = torch.sort(X_combined)
        X = X.reshape(-1, 1)

    y = f(X, observed=True)

    if X_pred is None:
        X_max, X_min = X.max(), X.min()
        interval = torch.abs(X_max - X_min)
        size = int(torch.ceil(interval / 0.1))
        X_pred_max, X_pred_min = X_max + interval * 0.2, X_min - interval * 0.2
        X_pred = torch.linspace(X_pred_min, X_pred_max, size).reshape(-1, 1)

    return X, y, X_pred


def visualize(X, y, X_pred, f, mean, cov, p_inputs=None, title=""):
    X = X.ravel().detach().numpy()
    y = y.ravel().detach().numpy()
    X_pred = X_pred.ravel().detach().numpy()
    mean = mean.ravel().detach().numpy()
    var = torch.diagonal(cov).detach().numpy()

    if p_inputs is not None:
        p_inputs = p_inputs.ravel().detach().numpy()

    credible_interval = 1.96 * np.sqrt(var) # 95%

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=X_pred, y=f(X_pred), line_color="#00CC00", name="Latent Function"))
    fig.add_trace(go.Scatter(x=X_pred, y=mean, line_color="#FE73FF", name="Mean"))
    fig.add_trace(go.Scatter(x=X_pred, y=mean-credible_interval, mode='lines', line=dict(color='lightgray'), showlegend=False))
    fig.add_trace(go.Scatter(x=X_pred, y=mean+credible_interval, mode='lines', line=dict(color='lightgray'), fill='tonexty', showlegend=False))
    
    if p_inputs is not None:
        fig.add_trace(go.Scatter(x=p_inputs, y=f(p_inputs), mode='markers', marker=dict(size=6, color="#FF0000", opacity=1), name="Pseudo-inputs"))

    fig.add_trace(go.Scatter(x=X, y=y, mode='markers', marker=dict(size=4, color="#0000FF", opacity=0.3), name="Observations"))

    fig.update_layout(title=title, xaxis_title="X", yaxis_title="f*")
    fig.show()

In [3]:
f = Latent_Function(noise_level=0.2).f
X, y, X_pred = make_data(f)

## GP regression

In [4]:
model1 = GP(X, y)
mean1, cov1 = model1.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean1, cov1, title="GP_regression")

In [5]:
model1.optimize(iteration=100, learning_rate=0.01)
mean1_opt, cov1_opt = model1.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean1_opt, cov1_opt, title="Optimized GP_regression")

opt_iter: 1/100
opt_iter: 2/100
opt_iter: 3/100
opt_iter: 4/100
opt_iter: 5/100
opt_iter: 6/100
opt_iter: 7/100
opt_iter: 8/100
opt_iter: 9/100
opt_iter: 10/100
opt_iter: 11/100
opt_iter: 12/100
opt_iter: 13/100
opt_iter: 14/100
opt_iter: 15/100
opt_iter: 16/100
opt_iter: 17/100
opt_iter: 18/100
opt_iter: 19/100
opt_iter: 20/100
opt_iter: 21/100
opt_iter: 22/100
opt_iter: 23/100
opt_iter: 24/100
opt_iter: 25/100
opt_iter: 26/100
opt_iter: 27/100
opt_iter: 28/100
opt_iter: 29/100
opt_iter: 30/100
opt_iter: 31/100
opt_iter: 32/100
opt_iter: 33/100
opt_iter: 34/100
opt_iter: 35/100
opt_iter: 36/100
opt_iter: 37/100
opt_iter: 38/100
opt_iter: 39/100
opt_iter: 40/100
opt_iter: 41/100
opt_iter: 42/100
opt_iter: 43/100
opt_iter: 44/100
opt_iter: 45/100
opt_iter: 46/100
opt_iter: 47/100
opt_iter: 48/100
opt_iter: 49/100
opt_iter: 50/100
opt_iter: 51/100
opt_iter: 52/100
opt_iter: 53/100
opt_iter: 54/100
opt_iter: 55/100
opt_iter: 56/100
opt_iter: 57/100
opt_iter: 58/100
opt_iter: 59/100
opt_it

In [6]:
model1.make_params_df()

Unnamed: 0,noise,variance,lengthscale
0,0.000000,0.000000,-0.693147
1,-0.010000,0.010000,-0.683147
2,-0.020003,0.019997,-0.673154
3,-0.030009,0.029989,-0.663171
4,-0.040022,0.039973,-0.653203
...,...,...,...
96,-1.126159,0.830016,-0.083513
97,-1.140628,0.837299,-0.081209
98,-1.155167,0.844561,-0.078985
99,-1.169775,0.851804,-0.076844


## SoR

In [7]:
model2 = SoR(X, y, p_optimized=True)
mean2, cov2 = model2.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean2, cov2, p_inputs=model2.pseudo_inputs, title="SoR")

In [8]:
model2.optimize(iteration=300, learning_rate=0.01)
mean2_opt, cov2_opt = model2.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean2_opt, cov2_opt, p_inputs=model2.pseudo_inputs, title="Optimized SoR")

opt_iter: 1/300
opt_iter: 2/300
opt_iter: 3/300
opt_iter: 4/300
opt_iter: 5/300
opt_iter: 6/300
opt_iter: 7/300
opt_iter: 8/300
opt_iter: 9/300
opt_iter: 10/300
opt_iter: 11/300
opt_iter: 12/300
opt_iter: 13/300
opt_iter: 14/300
opt_iter: 15/300
opt_iter: 16/300
opt_iter: 17/300
opt_iter: 18/300
opt_iter: 19/300
opt_iter: 20/300
opt_iter: 21/300
opt_iter: 22/300
opt_iter: 23/300
opt_iter: 24/300
opt_iter: 25/300
opt_iter: 26/300
opt_iter: 27/300
opt_iter: 28/300
opt_iter: 29/300
opt_iter: 30/300
opt_iter: 31/300
opt_iter: 32/300
opt_iter: 33/300
opt_iter: 34/300
opt_iter: 35/300
opt_iter: 36/300
opt_iter: 37/300
opt_iter: 38/300
opt_iter: 39/300
opt_iter: 40/300
opt_iter: 41/300
opt_iter: 42/300
opt_iter: 43/300
opt_iter: 44/300
opt_iter: 45/300
opt_iter: 46/300
opt_iter: 47/300
opt_iter: 48/300
opt_iter: 49/300
opt_iter: 50/300
opt_iter: 51/300
opt_iter: 52/300
opt_iter: 53/300
opt_iter: 54/300
opt_iter: 55/300
opt_iter: 56/300
opt_iter: 57/300
opt_iter: 58/300
opt_iter: 59/300
opt_it

In [9]:
model2.make_params_df()

Unnamed: 0,noise,variance,lengthscale,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,u_10
0,0.000000,0.000000,-0.693147,0.000000e+00,0.851586,1.703172,2.554758,3.406344,4.257930,5.109516,5.961102,6.812688,7.664274
1,-0.010000,0.010000,-0.683147,9.999997e-03,0.861586,1.713172,2.564758,3.396344,4.267930,5.119516,5.971102,6.802688,7.654274
2,-0.020004,0.019999,-0.673191,1.695120e-02,0.871585,1.723143,2.574692,3.387594,4.265366,5.129219,5.980265,6.792956,7.644296
3,-0.030016,0.029996,-0.663308,1.810992e-02,0.881583,1.733074,2.584548,3.380446,4.260127,5.132165,5.985834,6.783956,7.634355
4,-0.040037,0.039991,-0.653531,1.572532e-02,0.891585,1.742951,2.594315,3.375200,4.254304,5.129548,5.986758,6.775886,7.624471
...,...,...,...,...,...,...,...,...,...,...,...,...,...
296,-5.847041,2.831814,-1.137715,-1.728423e-06,1.975925,1.975162,3.060448,3.702352,4.353528,5.022202,6.042645,6.706739,7.429472
297,-5.877528,2.840764,-1.135817,-5.344069e-07,1.979935,1.974665,3.057996,3.704273,4.351815,5.023568,6.041097,6.709103,7.426957
298,-5.908079,2.849712,-1.138548,8.161305e-07,1.978854,1.969215,3.061061,3.701090,4.354846,5.021352,6.042436,6.706938,7.429245
299,-5.938695,2.858660,-1.137478,-2.448398e-06,1.982240,1.969535,3.058544,3.703184,4.352940,5.023110,6.042219,6.707899,7.428036


## DTC

In [10]:
model3 = DTC(X, y, p_optimized=True)
mean3, cov3 = model3.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean3, cov3, p_inputs=model3.pseudo_inputs, title="DTC")

In [11]:
model3.optimize(iteration=300, learning_rate=0.01)
mean3_opt, cov3_opt = model3.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean3_opt, cov3_opt, p_inputs=model3.pseudo_inputs, title="Optimized DTC")

opt_iter: 1/300
opt_iter: 2/300
opt_iter: 3/300
opt_iter: 4/300
opt_iter: 5/300
opt_iter: 6/300
opt_iter: 7/300
opt_iter: 8/300
opt_iter: 9/300
opt_iter: 10/300
opt_iter: 11/300
opt_iter: 12/300
opt_iter: 13/300
opt_iter: 14/300
opt_iter: 15/300
opt_iter: 16/300
opt_iter: 17/300
opt_iter: 18/300
opt_iter: 19/300
opt_iter: 20/300
opt_iter: 21/300
opt_iter: 22/300
opt_iter: 23/300
opt_iter: 24/300
opt_iter: 25/300
opt_iter: 26/300
opt_iter: 27/300
opt_iter: 28/300
opt_iter: 29/300
opt_iter: 30/300
opt_iter: 31/300
opt_iter: 32/300
opt_iter: 33/300
opt_iter: 34/300
opt_iter: 35/300
opt_iter: 36/300
opt_iter: 37/300
opt_iter: 38/300
opt_iter: 39/300
opt_iter: 40/300
opt_iter: 41/300
opt_iter: 42/300
opt_iter: 43/300
opt_iter: 44/300
opt_iter: 45/300
opt_iter: 46/300
opt_iter: 47/300
opt_iter: 48/300
opt_iter: 49/300
opt_iter: 50/300
opt_iter: 51/300
opt_iter: 52/300
opt_iter: 53/300
opt_iter: 54/300
opt_iter: 55/300
opt_iter: 56/300
opt_iter: 57/300
opt_iter: 58/300
opt_iter: 59/300
opt_it

In [12]:
model3.make_params_df()

Unnamed: 0,noise,variance,lengthscale,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,u_10
0,0.000000,0.000000,-0.693147,0.000000e+00,0.851586,1.703172,2.554758,3.406344,4.257930,5.109516,5.961102,6.812688,7.664274
1,-0.010000,0.010000,-0.683147,9.999997e-03,0.861586,1.713172,2.564758,3.396344,4.267930,5.119516,5.971102,6.802688,7.654274
2,-0.020004,0.019999,-0.673191,1.695120e-02,0.871585,1.723143,2.574692,3.387594,4.265366,5.129219,5.980265,6.792956,7.644296
3,-0.030016,0.029996,-0.663308,1.810992e-02,0.881583,1.733074,2.584548,3.380446,4.260127,5.132165,5.985834,6.783956,7.634355
4,-0.040037,0.039991,-0.653531,1.572532e-02,0.891585,1.742951,2.594315,3.375200,4.254304,5.129548,5.986758,6.775886,7.624471
...,...,...,...,...,...,...,...,...,...,...,...,...,...
296,-5.847041,2.831814,-1.137715,-1.728423e-06,1.975925,1.975162,3.060448,3.702352,4.353528,5.022202,6.042645,6.706739,7.429472
297,-5.877528,2.840764,-1.135817,-5.344069e-07,1.979935,1.974665,3.057996,3.704273,4.351815,5.023568,6.041097,6.709103,7.426957
298,-5.908079,2.849712,-1.138548,8.161305e-07,1.978854,1.969215,3.061061,3.701090,4.354846,5.021352,6.042436,6.706938,7.429245
299,-5.938695,2.858660,-1.137478,-2.448398e-06,1.982240,1.969535,3.058544,3.703184,4.352940,5.023110,6.042219,6.707899,7.428036


## FITC

In [13]:
model4 = FITC(X, y, p_optimized=True)
mean4, cov4 = model4.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean4, cov4, p_inputs=model4.pseudo_inputs, title="FITC")

In [14]:
model4.optimize(iteration=300, learning_rate=0.01)
mean4_opt, cov4_opt = model4.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean4_opt, cov4_opt, p_inputs=model4.pseudo_inputs, title="Optimized FITC")

opt_iter: 1/300
opt_iter: 2/300
opt_iter: 3/300
opt_iter: 4/300
opt_iter: 5/300
opt_iter: 6/300
opt_iter: 7/300
opt_iter: 8/300
opt_iter: 9/300
opt_iter: 10/300
opt_iter: 11/300
opt_iter: 12/300
opt_iter: 13/300
opt_iter: 14/300
opt_iter: 15/300
opt_iter: 16/300
opt_iter: 17/300
opt_iter: 18/300
opt_iter: 19/300
opt_iter: 20/300
opt_iter: 21/300
opt_iter: 22/300
opt_iter: 23/300
opt_iter: 24/300
opt_iter: 25/300
opt_iter: 26/300
opt_iter: 27/300
opt_iter: 28/300
opt_iter: 29/300
opt_iter: 30/300
opt_iter: 31/300
opt_iter: 32/300
opt_iter: 33/300
opt_iter: 34/300
opt_iter: 35/300
opt_iter: 36/300
opt_iter: 37/300
opt_iter: 38/300
opt_iter: 39/300
opt_iter: 40/300
opt_iter: 41/300
opt_iter: 42/300
opt_iter: 43/300
opt_iter: 44/300
opt_iter: 45/300
opt_iter: 46/300
opt_iter: 47/300
opt_iter: 48/300
opt_iter: 49/300
opt_iter: 50/300
opt_iter: 51/300
opt_iter: 52/300
opt_iter: 53/300
opt_iter: 54/300
opt_iter: 55/300
opt_iter: 56/300
opt_iter: 57/300
opt_iter: 58/300
opt_iter: 59/300
opt_it

In [15]:
model4.make_params_df()

Unnamed: 0,noise,variance,lengthscale,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,u_10
0,0.000000,0.000000,-0.693147,0.000000e+00,0.851586,1.703172,2.554758,3.406344,4.257930,5.109516,5.961102,6.812688,7.664274
1,-0.010000,0.010000,-0.683147,9.999997e-03,0.861586,1.713172,2.564758,3.396344,4.267930,5.119516,5.971102,6.802688,7.654274
2,-0.020004,0.019999,-0.673191,1.695120e-02,0.871585,1.723143,2.574692,3.387594,4.265366,5.129219,5.980265,6.792956,7.644296
3,-0.030016,0.029996,-0.663308,1.810992e-02,0.881583,1.733074,2.584548,3.380446,4.260127,5.132165,5.985834,6.783956,7.634355
4,-0.040037,0.039991,-0.653531,1.572532e-02,0.891585,1.742951,2.594315,3.375200,4.254304,5.129548,5.986758,6.775886,7.624471
...,...,...,...,...,...,...,...,...,...,...,...,...,...
296,-5.847041,2.831814,-1.137715,-1.728423e-06,1.975925,1.975162,3.060448,3.702352,4.353528,5.022202,6.042645,6.706739,7.429472
297,-5.877528,2.840764,-1.135817,-5.344069e-07,1.979935,1.974665,3.057996,3.704273,4.351815,5.023568,6.041097,6.709103,7.426957
298,-5.908079,2.849712,-1.138548,8.161305e-07,1.978854,1.969215,3.061061,3.701090,4.354846,5.021352,6.042436,6.706938,7.429245
299,-5.938695,2.858660,-1.137478,-2.448398e-06,1.982240,1.969535,3.058544,3.703184,4.352940,5.023110,6.042219,6.707899,7.428036
