In [None]:
import numpy as np

def branin(x):
    a = 1.0
    b = 5.1 / (4.0 * np.pi**2)
    c = 5.0 / np.pi
    r = 6.0
    s = 10.0
    t = 1.0 / (8.0 * np.pi)
    
    x1 = x[..., 0]  # First coordinate
    x2 = x[..., 1]  # Second coordinate
    
    return (a * (x2 - b * x1**2 + c * x1 - r)**2 + 
            s * (1.0 - t) * np.cos(x1) + s)


def make_branin_surface(x=(-5, 15), y=(-5, 15), resolution=10):
    x1 = np.linspace(x[0], x[1], resolution)
    x2 = np.linspace(y[0], y[1], resolution)
    X, Y = np.meshgrid(x1, x2)
    Z = branin(np.stack([X, Y], axis=-1))
    return X, Y, Z


In [2]:
import plotly.graph_objects as go

def plot_branin_surface(fig, surface, position):
    """Plot the Branin function surface."""
    X, Y, Z = surface
    fig.add_trace(go.Surface(x=X, y=Y, z=Z,
        colorscale='Viridis',
        name='Branin Function',
        hovertemplate='<b>x₁:</b> %{x:.3f}<br><b>x₂:</b> %{y:.3f}<br><b>f(x):</b> %{z:.3f}<extra></extra>',
        showlegend=False,
        showscale=False
    ), row=position[0], col=position[1])

    fig.add_trace(go.Scatter3d(
        x=[-np.pi, np.pi, 9.42478],
        y=[12.275, 2.275, 2.475],
        z=[0.4, 0.4, 0.4],
        mode='markers',
        marker=dict(size=5, color='red', symbol='diamond', line=dict(width=1, color='white')),
        name='Global Minima',
        hovertemplate='<b>Global Minimum</b><br>x₁: %{x:.5f}<br>x₂: %{y:.5f}<br>f(x): %{z:.6f}<extra></extra>'
    ), row=position[0], col=position[1])

In [3]:
from matplotlib.pyplot import scatter
import plotly.graph_objects as go

def plot_model_mean_surface(fig, position):
    surface = go.Surface(x=[], y=[], z=[],
        colorscale='Viridis',
        name='model-mean',
        hovertemplate='<b>x₁:</b> %{x:.3f}<br><b>x₂:</b> %{y:.3f}<br><b>f(x):</b> %{z:.3f}<extra></extra>',
        showlegend=False,
        showscale=False
    )
    scatter = go.Scatter3d(x=[], y=[], z=[],
        mode='markers',
        marker=dict(size=5, color='red', symbol='circle', line=dict(width=1, color='white')),
        name='model-mean-point',
        hovertemplate='<b>New Point</b><br>x₁: %{x:.5f}<br>x₂: %{y:.5f}<br>f(x): %{z:.6f}<extra></extra>'
    )
    fig.add_trace(surface, row=position[0], col=position[1])
    fig.add_trace(scatter, row=position[0], col=position[1])
    def update_traces(entry):
        nx, ny = len(np.unique(entry['m_x'])), len(np.unique(entry['m_y']))
        m_x = np.array(entry['m_x']).reshape((nx, ny))
        m_y = np.array(entry['m_y']).reshape((nx, ny))
        m_z = np.array(entry['m_z']).reshape((nx, ny))
        s_x = entry['s_x']
        s_y = entry['s_y']
        s_z = entry['s_z']
        for trace in fig.select_traces(selector=lambda t: t.name == "model-mean"):
            trace.x = m_x
            trace.y = m_y
            trace.z = m_z
        for scatter in fig.select_traces(selector=lambda t: t.name == "model-mean-point"):
            scatter.x = [s_x]
            scatter.y = [s_y]
            scatter.z = [s_z]
    return update_traces

def plot_model_sigma_surface(fig, position):
    surface = go.Surface(x=[], y=[], z=[],
        colorscale='Viridis',
        name='model-sigma',
        hovertemplate='<b>x₁:</b> %{x:.3f}<br><b>x₂:</b> %{y:.3f}<br><b>f(x):</b> %{z:.3f}<extra></extra>',
        showlegend=False,
        showscale=False
    )
    scatter = go.Scatter3d(x=[], y=[], z=[],
        mode='markers',
        marker=dict(size=5, color='red', symbol='circle', line=dict(width=1, color='white')),
        name='model-sigma-point',
        hovertemplate='<b>New Point</b><br>x₁: %{x:.5f}<br>x₂: %{y:.5f}<br>f(x): %{z:.6f}<extra></extra>'
    )
    fig.add_trace(surface, row=position[0], col=position[1])
    fig.add_trace(scatter, row=position[0], col=position[1])
    def update_traces(entry):
        nx, ny = len(np.unique(entry['m_x'])), len(np.unique(entry['m_y']))
        m_x = np.array(entry['m_x']).reshape((nx, ny))
        m_y = np.array(entry['m_y']).reshape((nx, ny))
        m_z = np.sqrt(np.array(entry['m_s']).reshape((nx, ny)))
        s_x = entry['s_x']
        s_y = entry['s_y']
        for trace in fig.select_traces(selector=lambda t: t.name == "model-sigma"):
            trace.x = m_x
            trace.y = m_y
            trace.z = m_z
        for scatter in fig.select_traces(selector=lambda t: t.name == "model-sigma-point"):
            scatter.x = [s_x]
            scatter.y = [s_y]
            scatter.z = [0]
    return update_traces

def plot_model_acquisition_surface(fig, position):
    surface = go.Surface(x=[], y=[], z=[],
        colorscale='Viridis',
        name='model-acquisition',
        hovertemplate='<b>x₁:</b> %{x:.3f}<br><b>x₂:</b> %{y:.3f}<br><b>f(x):</b> %{z:.3f}<extra></extra>',
        showlegend=False,
        showscale=False
    )
    scatter = go.Scatter3d(x=[], y=[], z=[],
        mode='markers',
        marker=dict(size=5, color='red', symbol='circle', line=dict(width=1, color='white')),
        name='model-acquisition-point',
        hovertemplate='<b>New Point</b><br>x₁: %{x:.5f}<br>x₂: %{y:.5f}<br>f(x): %{z:.6f}<extra></extra>'
    )
    fig.add_trace(surface, row=position[0], col=position[1])
    fig.add_trace(scatter, row=position[0], col=position[1])
    def update_traces(entry):
        ax = entry.get('a_x', [])
        ay = entry.get('a_y', [])
        az = entry.get('a_z', [])
        nx, ny = len(np.unique(ax)), len(np.unique(ay))
        ax = np.array(ax).reshape((nx, ny))
        ay = np.array(ay).reshape((nx, ny))
        az = np.array(az).reshape((nx, ny))
        sx = entry['s_x']
        sy = entry['s_y']
        for trace in fig.select_traces(selector=lambda t: t.name == "model-acquisition"):
            trace.x = ax
            trace.y = ay
            trace.z = az
        for scatter in fig.select_traces(selector=lambda t: t.name == "model-acquisition-point"):
            scatter.x = [sx]
            scatter.y = [sy]
            scatter.z = [0]
    return update_traces

In [12]:
from ipywidgets import interact
import json
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Load and parse the Bayesian optimization debug JSON
with open('/workspaces/library-b2o/bayesian_debug.json', 'r') as f:
    data = json.loads(f.read())

fig = go.FigureWidget(make_subplots(rows=2, cols=2,
    specs=[[{'type':'surface'}, {'type':'surface'}],
           [{'type':'surface'}, {'type':'surface'}]],
    subplot_titles=['Branin', 'Model Mean', 'Acquisition', 'Model Sigma'],
    horizontal_spacing=0.05,
    vertical_spacing=0.05))
fig.update_layout(
    width=1800,
    height=900,
    margin=dict(l=10, r=10, t=50, b=10),
    scene=dict(aspectmode='auto')
)
plot_branin_surface(fig, make_branin_surface(), position=(1, 1))
model_mean_updater = plot_model_mean_surface(fig, position=(1, 2))
model_sigma_updater = plot_model_sigma_surface(fig, position=(2, 2))
model_acquisition_updater = plot_model_acquisition_surface(fig, position=(2, 1))
fig  

FigureWidget({
    'data': [{'colorscale': [[0.0, '#440154'], [0.1111111111111111, '#482878'],
                             [0.2222222222222222, '#3e4989'], [0.3333333333333333,
                             '#31688e'], [0.4444444444444444, '#26828e'],
                             [0.5555555555555556, '#1f9e89'], [0.6666666666666666,
                             '#35b779'], [0.7777777777777778, '#6ece58'],
                             [0.8888888888888888, '#b5de2b'], [1.0, '#fde725']],
              'hovertemplate': ('<b>x₁:</b> %{x:.3f}<br><b>x₂:<' ... '):</b> %{z:.3f}<extra></extra>'),
              'name': 'Branin Function',
              'scene': 'scene',
              'showlegend': False,
              'showscale': False,
              'type': 'surface',
              'uid': '0361dee3-0dc1-48f7-9b53-dbecb6ccae2b',
              'x': {'bdata': ('AAAAAAAAFMCO4ziO4zgGwHAcx3Ecx+' ... 'zHcRwlQOQ4juM4jilAAAAAAAAALkA='),
                    'dtype': 'f8',
                    'shape': '10,

In [13]:
@interact(i=(0, len(data)-1))
def update(i):
    model_mean_updater(data[i])
    model_acquisition_updater(data[i])
    model_sigma_updater(data[i])


interactive(children=(IntSlider(value=40, description='i', max=80), Output()), _dom_classes=('widget-interact'…