In [69]:
import importlib

import torch
import numpy as np
import plotly.graph_objects as go

import sgpr.gp_regression as tmp
importlib.reload(tmp)
from sgpr.gp_regression import GP, SoR, DTC, FITC

torch.set_default_dtype(torch.float64)

# 実験（1）

* 作成したコードが正常に動作するか


## 準備

In [70]:
class Latent_Function():
    def __init__(self, noise_level=0.1):
        self.noise_level = noise_level
        self.noise = None

    def _f(self, X):
        # ここを好きにカスタマイズ
        tmp = np.sin(X) + np.cos(X)
        return tmp
    
    def f(self, X, observed=False):
        tmp = self._f(X)

        if observed is True:
            if isinstance(X, np.ndarray):
                noise = np.random.normal(loc=tmp, scale=self.noise_level)
                return noise
            elif isinstance(X, torch.Tensor):
                noise = torch.normal(tmp, self.noise_level)
                return noise
        
        else:
            if isinstance(X, np.ndarray):
                return np.array(tmp)
            elif isinstance(X, torch.Tensor):
                return tmp


def make_data(f, X=None, X_pred=None):
    if X is None:
        X_normal = torch.randn(50) * 1.5 + 4 
        X_normal = torch.clip(X_normal, 0, 10) 
        X_uniform = torch.rand(50) * 6 + 2  
        X_combined = torch.cat([X_normal, X_uniform])
        X, _ = torch.sort(X_combined)
        X = X.reshape(-1, 1)

    y = f(X, observed=True)

    if X_pred is None:
        X_max, X_min = X.max(), X.min()
        interval = torch.abs(X_max - X_min)
        size = int(torch.ceil(interval / 0.1))
        X_pred_max, X_pred_min = X_max + interval * 0.2, X_min - interval * 0.2
        X_pred = torch.linspace(X_pred_min, X_pred_max, size).reshape(-1, 1)

    return X, y, X_pred


def visualize(X, y, X_pred, f, mean, cov, p_inputs=None, title=""):
    X = X.ravel().detach().numpy()
    y = y.ravel().detach().numpy()
    X_pred = X_pred.ravel().detach().numpy()
    mean = mean.ravel().detach().numpy()
    var = torch.diagonal(cov).detach().numpy()

    if p_inputs is not None:
        p_inputs = p_inputs.ravel().detach().numpy()

    credible_interval = 1.96 * np.sqrt(var) # 95%

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=X_pred, y=f(X_pred), line_color="#00CC00", name="Latent Function"))
    fig.add_trace(go.Scatter(x=X_pred, y=mean, line_color="#FE73FF", name="Mean"))
    fig.add_trace(go.Scatter(x=X_pred, y=mean-credible_interval, mode='lines', line=dict(color='lightgray'), showlegend=False))
    fig.add_trace(go.Scatter(x=X_pred, y=mean+credible_interval, mode='lines', line=dict(color='lightgray'), fill='tonexty', showlegend=False))
    
    if p_inputs is not None:
        fig.add_trace(go.Scatter(x=p_inputs, y=f(p_inputs), mode='markers', marker=dict(size=6, color="#FF0000", opacity=1), name="Pseudo-inputs"))

    fig.add_trace(go.Scatter(x=X, y=y, mode='markers', marker=dict(size=4, color="#0000FF", opacity=0.3), name="Observations"))

    fig.update_layout(title=title, xaxis_title="X", yaxis_title="f*")
    fig.show()

In [71]:
f = Latent_Function(noise_level=0.2).f
X, y, X_pred = make_data(f)

## GP regression

In [72]:
model1 = GP(X, y)
mean1, cov1 = model1.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean1, cov1, title="GP_regression")

In [73]:
model1.optimize(iteration=1000, learning_rate=0.01)
mean1_opt, cov1_opt = model1.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean1_opt, cov1_opt, title="Optimized GP_regression")

opt_iter: 1/1000
opt_iter: 2/1000
opt_iter: 3/1000
opt_iter: 4/1000
opt_iter: 5/1000
opt_iter: 6/1000
opt_iter: 7/1000
opt_iter: 8/1000
opt_iter: 9/1000
opt_iter: 10/1000
opt_iter: 11/1000
opt_iter: 12/1000
opt_iter: 13/1000
opt_iter: 14/1000
opt_iter: 15/1000
opt_iter: 16/1000
opt_iter: 17/1000
opt_iter: 18/1000
opt_iter: 19/1000
opt_iter: 20/1000
opt_iter: 21/1000
opt_iter: 22/1000
opt_iter: 23/1000
opt_iter: 24/1000
opt_iter: 25/1000
opt_iter: 26/1000
opt_iter: 27/1000
opt_iter: 28/1000
opt_iter: 29/1000
opt_iter: 30/1000
opt_iter: 31/1000
opt_iter: 32/1000
opt_iter: 33/1000
opt_iter: 34/1000
opt_iter: 35/1000
opt_iter: 36/1000
opt_iter: 37/1000
opt_iter: 38/1000
opt_iter: 39/1000
opt_iter: 40/1000
opt_iter: 41/1000
opt_iter: 42/1000
opt_iter: 43/1000
opt_iter: 44/1000
opt_iter: 45/1000
opt_iter: 46/1000
opt_iter: 47/1000
opt_iter: 48/1000
opt_iter: 49/1000
opt_iter: 50/1000
opt_iter: 51/1000
opt_iter: 52/1000
opt_iter: 53/1000
opt_iter: 54/1000
opt_iter: 55/1000
opt_iter: 56/1000
o

In [74]:
model1.make_params_df()

Unnamed: 0,noise,variance,lengthscale
0,1.000000,1.000000,0.500000
1,0.990050,1.010050,0.505025
2,0.980196,1.020198,0.510098
3,0.970437,1.030443,0.515217
4,0.960768,1.040783,0.520381
...,...,...,...
238,0.018550,6.564816,0.182836
239,0.018070,6.661681,0.177874
240,0.017601,6.760986,0.173020
241,0.017143,6.862980,0.168270


In [75]:
model1.params

[tensor(-4.0661, requires_grad=True),
 tensor(1.9261, requires_grad=True),
 tensor(-1.7822, requires_grad=True)]

## SoR

In [76]:
model2 = SoR(X, y, p_optimized=True)
mean2, cov2 = model2.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean2, cov2, p_inputs=model2.pseudo_inputs, title="SoR")

In [77]:
model2.optimize(iteration=1000, learning_rate=0.01)
mean2_opt, cov2_opt = model2.predict(X_pred.clone())


visualize(X, y, X_pred, f, mean2_opt, cov2_opt, p_inputs=model2.pseudo_inputs, title="Optimized SoR")

opt_iter: 1/1000
opt_iter: 2/1000
opt_iter: 3/1000
opt_iter: 4/1000
opt_iter: 5/1000
opt_iter: 6/1000
opt_iter: 7/1000
opt_iter: 8/1000
opt_iter: 9/1000
opt_iter: 10/1000
opt_iter: 11/1000
opt_iter: 12/1000
opt_iter: 13/1000
opt_iter: 14/1000
opt_iter: 15/1000
opt_iter: 16/1000
opt_iter: 17/1000
opt_iter: 18/1000
opt_iter: 19/1000
opt_iter: 20/1000
opt_iter: 21/1000
opt_iter: 22/1000
opt_iter: 23/1000
opt_iter: 24/1000
opt_iter: 25/1000
opt_iter: 26/1000
opt_iter: 27/1000
opt_iter: 28/1000
opt_iter: 29/1000
opt_iter: 30/1000
opt_iter: 31/1000
opt_iter: 32/1000
opt_iter: 33/1000
opt_iter: 34/1000
opt_iter: 35/1000
opt_iter: 36/1000
opt_iter: 37/1000
opt_iter: 38/1000
opt_iter: 39/1000
opt_iter: 40/1000
opt_iter: 41/1000
opt_iter: 42/1000
opt_iter: 43/1000
opt_iter: 44/1000
opt_iter: 45/1000
opt_iter: 46/1000
opt_iter: 47/1000


opt_iter: 48/1000
opt_iter: 49/1000
opt_iter: 50/1000
opt_iter: 51/1000
opt_iter: 52/1000
opt_iter: 53/1000
opt_iter: 54/1000
opt_iter: 55/1000
opt_iter: 56/1000
opt_iter: 57/1000
opt_iter: 58/1000
opt_iter: 59/1000
opt_iter: 60/1000
opt_iter: 61/1000
opt_iter: 62/1000
opt_iter: 63/1000
opt_iter: 64/1000
opt_iter: 65/1000
opt_iter: 66/1000
opt_iter: 67/1000
opt_iter: 68/1000
opt_iter: 69/1000
opt_iter: 70/1000
opt_iter: 71/1000
opt_iter: 72/1000
opt_iter: 73/1000
opt_iter: 74/1000
opt_iter: 75/1000
opt_iter: 76/1000
opt_iter: 77/1000
opt_iter: 78/1000
opt_iter: 79/1000
opt_iter: 80/1000
opt_iter: 81/1000
opt_iter: 82/1000
opt_iter: 83/1000
opt_iter: 84/1000
opt_iter: 85/1000
opt_iter: 86/1000
opt_iter: 87/1000
opt_iter: 88/1000
opt_iter: 89/1000
opt_iter: 90/1000
opt_iter: 91/1000
opt_iter: 92/1000
opt_iter: 93/1000
opt_iter: 94/1000
opt_iter: 95/1000
opt_iter: 96/1000
opt_iter: 97/1000
opt_iter: 98/1000
opt_iter: 99/1000
opt_iter: 100/1000
opt_iter: 101/1000
opt_iter: 102/1000
opt_ite

In [78]:
model2.params

[tensor(-7.1179, requires_grad=True),
 tensor(3.1351, requires_grad=True),
 tensor(-1.0654, requires_grad=True),
 tensor([[0.7712],
         [1.5672],
         [2.9611],
         [2.9271],
         [3.7998],
         [4.4902],
         [5.4263],
         [5.4235],
         [6.5440],
         [7.3820]], requires_grad=True)]

In [79]:
model2.make_params_df()

Unnamed: 0,noise,variance,lengthscale,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,u_10
0,1.000000,1.000000,0.500000,0.818755,1.582818,2.346880,3.110942,3.875005,4.639067,5.403129,6.167192,6.931254,7.695316
1,0.990050,1.010050,0.505025,0.828755,1.592818,2.356880,3.120942,3.885005,4.629067,5.393129,6.157192,6.941254,7.685316
2,0.980194,1.020200,0.510064,0.838715,1.602806,2.366863,3.130932,3.894252,4.620280,5.383150,6.147219,6.950966,7.675382
3,0.970429,1.030451,0.515085,0.848597,1.612775,2.376815,3.140826,3.902194,4.613304,5.373218,6.137302,6.960023,7.665560
4,0.960750,1.040803,0.520051,0.858358,1.622721,2.386720,3.150376,3.908707,4.607579,5.363372,6.127480,6.967881,7.655897
...,...,...,...,...,...,...,...,...,...,...,...,...,...
332,0.000894,22.392234,0.345236,0.773005,1.566168,2.948921,2.948651,3.797673,4.492220,5.424622,5.421851,6.545869,7.379924
333,0.000865,22.590323,0.344282,0.770176,1.569285,2.944621,2.944590,3.802064,4.488048,5.428102,5.425364,6.545845,7.379025
334,0.000838,22.790165,0.345412,0.773349,1.565963,2.954808,2.936934,3.798755,4.491602,5.424490,5.421653,6.547589,7.378133
335,0.000810,22.991769,0.344590,0.771216,1.567224,2.961078,2.927108,3.799819,4.490247,5.426346,5.423545,6.543966,7.381963


## DTC

In [80]:
model3 = DTC(X, y, p_optimized=True)
mean3, cov3 = model3.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean3, cov3, p_inputs=model3.pseudo_inputs, title="DTC")

In [81]:
model3.optimize(iteration=300, learning_rate=0.01)
mean3_opt, cov3_opt = model3.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean3_opt, cov3_opt, p_inputs=model3.pseudo_inputs, title="Optimized DTC")

opt_iter: 1/300
opt_iter: 2/300
opt_iter: 3/300
opt_iter: 4/300
opt_iter: 5/300
opt_iter: 6/300
opt_iter: 7/300
opt_iter: 8/300
opt_iter: 9/300
opt_iter: 10/300
opt_iter: 11/300
opt_iter: 12/300
opt_iter: 13/300
opt_iter: 14/300
opt_iter: 15/300
opt_iter: 16/300
opt_iter: 17/300
opt_iter: 18/300
opt_iter: 19/300
opt_iter: 20/300
opt_iter: 21/300
opt_iter: 22/300
opt_iter: 23/300
opt_iter: 24/300
opt_iter: 25/300
opt_iter: 26/300
opt_iter: 27/300
opt_iter: 28/300
opt_iter: 29/300
opt_iter: 30/300
opt_iter: 31/300
opt_iter: 32/300
opt_iter: 33/300
opt_iter: 34/300
opt_iter: 35/300
opt_iter: 36/300
opt_iter: 37/300
opt_iter: 38/300
opt_iter: 39/300
opt_iter: 40/300
opt_iter: 41/300
opt_iter: 42/300
opt_iter: 43/300
opt_iter: 44/300
opt_iter: 45/300
opt_iter: 46/300
opt_iter: 47/300
opt_iter: 48/300
opt_iter: 49/300
opt_iter: 50/300
opt_iter: 51/300
opt_iter: 52/300
opt_iter: 53/300
opt_iter: 54/300
opt_iter: 55/300
opt_iter: 56/300
opt_iter: 57/300
opt_iter: 58/300
opt_iter: 59/300
opt_it

In [82]:
model3.make_params_df()

Unnamed: 0,noise,variance,lengthscale,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,u_10
0,1.000000,1.000000,0.500000,0.818755,1.582818,2.346880,3.110942,3.875005,4.639067,5.403129,6.167192,6.931254,7.695316
1,0.990050,1.010050,0.505025,0.828755,1.592818,2.356880,3.120942,3.885005,4.629067,5.393129,6.157192,6.941254,7.685316
2,0.980194,1.020200,0.510064,0.838715,1.602806,2.366863,3.130932,3.894252,4.620280,5.383150,6.147219,6.950966,7.675382
3,0.970429,1.030451,0.515085,0.848597,1.612775,2.376815,3.140826,3.902194,4.613304,5.373218,6.137302,6.960023,7.665560
4,0.960750,1.040803,0.520051,0.858358,1.622721,2.386720,3.150376,3.908707,4.607579,5.363372,6.127480,6.967881,7.655897
...,...,...,...,...,...,...,...,...,...,...,...,...,...
296,0.002788,16.302633,0.344490,0.775805,1.568645,2.943915,2.949024,3.800768,4.489482,5.393464,5.459475,6.544704,7.380591
297,0.002704,16.447229,0.345445,0.775098,1.569569,2.944007,2.950267,3.799202,4.491412,5.393869,5.454468,6.548051,7.376969
298,0.002623,16.593086,0.344437,0.774994,1.569772,2.941356,2.949180,3.801435,4.488987,5.398846,5.454960,6.544146,7.380928
299,0.002544,16.740220,0.345524,0.777534,1.566911,2.944170,2.951994,3.798213,4.492244,5.397947,5.449288,6.546427,7.379111


## FITC

In [83]:
model4 = FITC(X, y, p_optimized=True)
mean4, cov4 = model4.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean4, cov4, p_inputs=model4.pseudo_inputs, title="FITC")

In [84]:
model4.optimize(iteration=1000, learning_rate=0.01)
mean4_opt, cov4_opt = model4.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean4_opt, cov4_opt, p_inputs=model4.pseudo_inputs, title="Optimized FITC")

opt_iter: 1/1000
opt_iter: 2/1000
opt_iter: 3/1000
opt_iter: 4/1000
opt_iter: 5/1000
opt_iter: 6/1000
opt_iter: 7/1000
opt_iter: 8/1000
opt_iter: 9/1000
opt_iter: 10/1000
opt_iter: 11/1000
opt_iter: 12/1000


opt_iter: 13/1000
opt_iter: 14/1000
opt_iter: 15/1000
opt_iter: 16/1000
opt_iter: 17/1000
opt_iter: 18/1000
opt_iter: 19/1000
opt_iter: 20/1000
opt_iter: 21/1000
opt_iter: 22/1000
opt_iter: 23/1000
opt_iter: 24/1000
opt_iter: 25/1000
opt_iter: 26/1000
opt_iter: 27/1000
opt_iter: 28/1000
opt_iter: 29/1000
opt_iter: 30/1000
opt_iter: 31/1000
opt_iter: 32/1000
opt_iter: 33/1000
opt_iter: 34/1000
opt_iter: 35/1000
opt_iter: 36/1000
opt_iter: 37/1000
opt_iter: 38/1000
opt_iter: 39/1000
opt_iter: 40/1000
opt_iter: 41/1000
opt_iter: 42/1000
opt_iter: 43/1000
opt_iter: 44/1000
opt_iter: 45/1000
opt_iter: 46/1000
opt_iter: 47/1000
opt_iter: 48/1000
opt_iter: 49/1000
opt_iter: 50/1000
opt_iter: 51/1000
opt_iter: 52/1000
opt_iter: 53/1000
opt_iter: 54/1000
opt_iter: 55/1000
opt_iter: 56/1000
opt_iter: 57/1000
opt_iter: 58/1000
opt_iter: 59/1000
opt_iter: 60/1000
opt_iter: 61/1000
opt_iter: 62/1000
opt_iter: 63/1000
opt_iter: 64/1000
opt_iter: 65/1000
opt_iter: 66/1000
opt_iter: 67/1000
opt_iter: 

In [85]:
model4.params

[tensor(-7.1179, requires_grad=True),
 tensor(3.1351, requires_grad=True),
 tensor(-1.0654, requires_grad=True),
 tensor([[0.7712],
         [1.5672],
         [2.9611],
         [2.9271],
         [3.7998],
         [4.4902],
         [5.4263],
         [5.4235],
         [6.5440],
         [7.3820]], requires_grad=True)]

In [86]:
model4.make_params_df()

Unnamed: 0,noise,variance,lengthscale,u_1,u_2,u_3,u_4,u_5,u_6,u_7,u_8,u_9,u_10
0,1.000000,1.000000,0.500000,0.818755,1.582818,2.346880,3.110942,3.875005,4.639067,5.403129,6.167192,6.931254,7.695316
1,0.990050,1.010050,0.505025,0.828755,1.592818,2.356880,3.120942,3.885005,4.629067,5.393129,6.157192,6.941254,7.685316
2,0.980194,1.020200,0.510064,0.838715,1.602806,2.366863,3.130932,3.894252,4.620280,5.383150,6.147219,6.950966,7.675382
3,0.970429,1.030451,0.515085,0.848597,1.612775,2.376815,3.140826,3.902194,4.613304,5.373218,6.137302,6.960023,7.665560
4,0.960750,1.040803,0.520051,0.858358,1.622721,2.386720,3.150376,3.908707,4.607579,5.363372,6.127480,6.967881,7.655897
...,...,...,...,...,...,...,...,...,...,...,...,...,...
332,0.000894,22.392234,0.345236,0.773005,1.566168,2.948921,2.948651,3.797673,4.492220,5.424622,5.421851,6.545869,7.379924
333,0.000865,22.590323,0.344282,0.770176,1.569285,2.944621,2.944590,3.802064,4.488048,5.428102,5.425364,6.545845,7.379025
334,0.000838,22.790165,0.345412,0.773349,1.565963,2.954808,2.936934,3.798755,4.491602,5.424490,5.421653,6.547589,7.378133
335,0.000810,22.991769,0.344590,0.771216,1.567224,2.961078,2.927108,3.799819,4.490247,5.426346,5.423545,6.543966,7.381963
