In [None]:
def visualize(X, y, X_pred, f, mean, cov, p_inputs=None, title=""):
    X = X.ravel().detach().numpy()
    y = y.ravel().detach().numpy()
    X_pred = X_pred.ravel().detach().numpy()
    mean = mean.ravel().detach().numpy()
    var = torch.diagonal(cov).detach().numpy()

    if p_inputs is not None:
        p_inputs = p_inputs.ravel().detach().numpy()

    credible_interval = 1.96 * np.sqrt(var) # 95%

    fig = go.Figure()
    fig.add_trace(go.Scatter(x=X_pred, y=f(X_pred), line_color="#00CC00", name="Latent Function"))
    fig.add_trace(go.Scatter(x=X_pred, y=mean, line_color="#FE73FF", name="Mean"))
    fig.add_trace(go.Scatter(x=X_pred, y=mean-credible_interval, mode='lines', line=dict(color='lightgray'), showlegend=False))
    fig.add_trace(go.Scatter(x=X_pred, y=mean+credible_interval, mode='lines', line=dict(color='lightgray'), fill='tonexty', showlegend=False))
    
    if p_inputs is not None:
        fig.add_trace(go.Scatter(x=p_inputs, y=f(p_inputs), mode='markers', marker=dict(size=6, color="#FF0000", opacity=1), name="Pseudo-inputs"))

    fig.add_trace(go.Scatter(x=X, y=y, mode='markers', marker=dict(size=4, color="#0000FF", opacity=0.3), name="Observations"))

    fig.update_layout(title=title, xaxis_title="X", yaxis_title="f*")
    fig.show()

model1 = GP(X, y)
mean1, cov1 = model1.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean1, cov1, title="GP_regression")

model2 = SoR(X, y)
mean2, cov2 = model2.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean2, cov2, p_inputs=model2.pseudo_inputs, title="SoR")

model3 = DTC(X, y)
mean3, cov3 = model3.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean3, cov3, p_inputs=model3.pseudo_inputs, title="DTC")

model4 = FITC(X, y)
mean4, cov4 = model4.predict(X_pred.clone())

visualize(X, y, X_pred, f, mean4, cov4, p_inputs=model4.pseudo_inputs, title="FITC")

In [None]:
from plotly.subplots import make_subplots

# ... (前のコードの定義部分はそのまま)

model1 = GP(X, y)
mean1, cov1 = model1.predict(X_pred.clone())

model2 = SoR(X, y)
mean2, cov2 = model2.predict(X_pred.clone())

model3 = DTC(X, y)
mean3, cov3 = model3.predict(X_pred.clone())

model4 = FITC(X, y)
mean4, cov4 = model4.predict(X_pred.clone())

fig = make_subplots(rows=4, cols=1, shared_xaxes=True, vertical_spacing=0.05)

fig.add_trace(go.Scatter(x=X_pred, y=f(X_pred), line_color="#00CC00", name="Latent Function"), row=1, col=1)
fig.add_trace(go.Scatter(x=X_pred, y=mean1, line_color="#FE73FF", name="Mean"), row=1, col=1)
fig.add_trace(go.Scatter(x=X_pred, y=mean1-credible_interval, mode='lines', line=dict(color='lightgray'), showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=X_pred, y=mean1+credible_interval, mode='lines', line=dict(color='lightgray'), fill='tonexty', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=p_inputs, y=f(p_inputs), mode='markers', marker=dict(size=6, color="#FF0000", opacity=1), name="Pseudo-inputs"), row=1, col=1)
fig.add_trace(go.Scatter(x=X, y=y, mode='markers', marker=dict(size=4, color="#0000FF", opacity=0.3), name="Observations"), row=1, col=1)

fig.add_trace(go.Scatter(x=X_pred, y=mean2, line_color="#FE73FF", name="Mean"), row=2, col=1)
# ... 同様に model2 の可視化を追加

fig.add_trace(go.Scatter(x=X_pred, y=mean3, line_color="#FE73FF", name="Mean"), row=3, col=1)
# ... 同様に model3 の可視化を追加

fig.add_trace(go.Scatter(x=X_pred, y=mean4, line_color="#FE73FF", name="Mean"), row=4, col=1)
# ... 同様に model4 の可視化を追加

fig.update_layout(title="GP Regression Visualizations", xaxis_title="X", yaxis_title="f*")
fig.show()
