In [1]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torch.nn.functional as F

from sklearn.model_selection import train_test_split, KFold

import matplotlib.pyplot as plt

import plotly.graph_objects as go
import plotly.express as px

In [2]:
from ipywidgets import interact, interactive, fixed, interact_manual, interactive_output
import ipywidgets as widgets
from IPython.display import display

In [3]:
debug = False

if torch.backends.mps.is_available():
    device = torch.device('mps')  # Mac Metal Performance Shaders
    if debug:
        print(f"Training on device: {device} (Mac GPU)")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    if debug:
        print(f"Training on device: {device} (NVIDIA GPU)")
else:
    device = torch.device('cpu')
    if debug:
        print(f"Training on device: {device} (CPU)")

if debug:
    print(f"PyTorch version: {torch.__version__}")
if device.type == 'mps':
    if debug:
        print("Using Mac Metal Performance Shaders for GPU acceleration")

if debug:
    print(f"PyTorch version: {torch.__version__}")
if device.type == 'mps':
    if debug:
        print("Using Mac Metal Performance Shaders for GPU acceleration")

In [4]:

class AttentionBlock(nn.Module):
    """Transformer layer for metabolic modeling"""
    def __init__(self,vocab_size=115,dim=6,num_heads=2):
        super(AttentionBlock, self).__init__()

        assert dim%num_heads==0, "model dimension must be divisible by number of heads"
        self.vocab_size = vocab_size
        self.d_model = dim
        
        self.layer_norm = nn.LayerNorm(dim)
        self.num_heads = num_heads
        self.k = dim//num_heads

        self.W_k = nn.Linear(dim,self.k,bias=False)
        self.W_q = nn.Linear(dim,self.k,bias=False)
        self.W_v = nn.Linear(dim,self.k,bias=False)
        self.W_o = nn.Linear(self.k,dim,bias=False)
        #self.W_c = nn.Linear(vocab_size,vocab_size,bias=False)

    def scaled_dot_product_attention(self,keys,queries,values):
        # Find the product if K and Q transpose and divide by the square root of the model dimension (d_model)

        pre_softmax_attention_matrix = torch.einsum('bij,bkj->bik', keys,queries)/np.sqrt(self.d_model)
        attention_matrix = torch.softmax(pre_softmax_attention_matrix,dim=-1)
        attention_output = torch.einsum( 'bij,bjk->bik' , attention_matrix, values)

        return attention_output, attention_matrix

    def forward(self,x,c):
        norm_x = self.layer_norm(x)
        #modified_c = self.W_c(c.transpose(-2,-1)).transpose(-2,-1)
        modified_c = c

        Q = self.W_k(norm_x)
        K = self.W_q(norm_x)
        V = self.W_v(norm_x)

        attention_output, attention_matrix = self.scaled_dot_product_attention(Q,K,V)

        #print(attention_matrix.size(),modified_c.size())

        attended_c = torch.einsum('bij,bjk->bik',attention_matrix,modified_c)
        
        #print(c.size(),attended_c.size())

        output_x = self.W_o(attention_output) + x*(1/self.num_heads)
        output_c = (attended_c + c)*(1/self.num_heads)

        return output_x, output_c


In [5]:

class MultiHeadAttentionBlock(nn.Module):
    """Multi-Head Attention layer for metabolic modeling"""
    def __init__(self,vocab_size=115,dim=6,num_heads=2):
        super(MultiHeadAttentionBlock, self).__init__()

        self.attention_blocks = nn.ModuleList([AttentionBlock(vocab_size,dim,num_heads) for _ in range(num_heads)])

    def forward(self,x,c):

        output_x = torch.zeros_like(x)
        output_c = torch.zeros_like(c)

        for attention_block in self.attention_blocks:
            o_x, o_c = attention_block(x,c)
            output_x += o_x
            output_c += o_c

        return output_x, output_c

        

In [6]:
class FeedForwardBlock(nn.Module):

    def __init__(self,d_model,inner_dim_multiplier,dropout=0.1):
        super(FeedForwardBlock, self).__init__()

        self.d_model = d_model+1
        self.inner_dim = inner_dim_multiplier*(d_model+1)

        self.layer_norm = nn.LayerNorm(self.d_model)

        self.linear_layer_1 = nn.Linear(self.d_model,self.inner_dim)
        self.linear_layer_2 = nn.Linear(self.inner_dim,self.d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self,x,c):

        y = torch.cat((x,c),2)
        
        norm_y = self.layer_norm(y)
        
        norm_y = self.linear_layer_1(norm_y)
        norm_y = F.relu(norm_y)
        norm_y = self.linear_layer_2(norm_y)

        return norm_y + y

In [7]:

class TransformerBlock(nn.Module):
    """Embedding layer + Attention Block + FeedForward Layer"""
    def __init__(self,vocab_size=115,dim=6,num_heads=2,inner_dim_multiplier=5):
        super(TransformerBlock, self).__init__()

        self.d_model = dim
        self.vocab_size = vocab_size

        self.inp_embedding = nn.Embedding(vocab_size,dim)

        self.attention_block = MultiHeadAttentionBlock(vocab_size,dim,num_heads)

        self.feedforward_block = FeedForwardBlock(dim,inner_dim_multiplier)

        self.linear_layer_1 = nn.Linear(vocab_size,vocab_size)

        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.32)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.32)
    
    
    def forward(self,c):

        batch_size, vocab_size, _ = c.size()

        # y = torch.randint(0, vocab_size, (batch_size, vocab_size))
        # for k in range(vocab_size):
        #     y[:,k] = k

        y = torch.arange(vocab_size,device=device).unsqueeze(0).expand(batch_size, -1)
        
        x = self.inp_embedding(y)
        # print(x.size())
        
        output_x, output_c = self.attention_block(x,c)

        output_y = self.feedforward_block(output_x,output_c)

        return output_y[:,:,-1].unsqueeze(-1)

        #return output_c


In [8]:
class TransformersSeries(nn.Module):
    def __init__(self,vocab_size=115,dim=6,num_heads=2,inner_dim_multiplier=5,num_transformers=2):
        super(TransformersSeries, self).__init__()

        self.d_model = dim
        self.vocab_size = vocab_size
        self.num_transformers = num_transformers

        self.inp_embedding = nn.Embedding(vocab_size,dim)

        self.attention_blocks = nn.ModuleList([MultiHeadAttentionBlock(vocab_size,dim,num_heads) for _ in range(num_transformers)])

        self.feedforward_blocks = nn.ModuleList([FeedForwardBlock(dim,inner_dim_multiplier) for _ in range(num_transformers)])

        self.linear_layer_1 = nn.Linear(vocab_size,vocab_size)

        self.apply(self._init_weights)
    
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    
    
    def forward(self,c):

        batch_size, vocab_size, _ = c.size()

        y = torch.arange(vocab_size,device=device).unsqueeze(0).expand(batch_size, -1)
        
        x = self.inp_embedding(y)

        
        # The series of transformer layers
        for i in range(self.num_transformers):

            output_x, output_c = self.attention_blocks[i](x,c)

            output_y = self.feedforward_blocks[i](output_x,output_c)

            x = output_y[:,:,:-1]
            c = output_y[:,:,-1].unsqueeze(-1)

        return output_y[:,:,-1].unsqueeze(-1)



In [9]:
inps_test_ = torch.load('inputs_test.pt')
outs_test_ = torch.load('outputs_test.pt')


In [10]:
input_cols = [
        'EX_glc__D_e', 'EX_fru_e', 'EX_lac__D_e', 'EX_pyr_e', 'EX_ac_e',
        'EX_akg_e', 'EX_succ_e', 'EX_fum_e', 'EX_mal__L_e', 'EX_etoh_e',
        'EX_acald_e', 'EX_for_e', 'EX_gln__L_e', 'EX_glu__L_e',
        'EX_co2_e', 'EX_h_e', 'EX_h2o_e', 'EX_nh4_e', 'EX_o2_e', 'EX_pi_e'
    ]

output_cols = [
        'ACALD_flux',
        'ACALDt_flux',
        'ACKr_flux',
        'ACONTa_flux',
        'ACONTb_flux',
        'ACt2r_flux',
        'ADK1_flux',
        'AKGDH_flux',
        'AKGt2r_flux',
        'ALCD2x_flux',
        'ATPM_flux',
        'ATPS4r_flux',
        'Biomass_Ecoli_core_flux',
        'CO2t_flux',
        'CS_flux',
        'CYTBD_flux',
        'D_LACt2_flux',
        'ENO_flux',
        'ETOHt2r_flux',
        'EX_ac_e_flux',
        'EX_acald_e_flux',
        'EX_akg_e_flux',
        'EX_co2_e_flux',
        'EX_etoh_e_flux',
        'EX_for_e_flux',
        'EX_fru_e_flux',
        'EX_fum_e_flux',
        'EX_glc__D_e_flux',
        'EX_gln__L_e_flux',
        'EX_glu__L_e_flux',
        'EX_h_e_flux',
        'EX_h2o_e_flux',
        'EX_lac__D_e_flux',
        'EX_mal__L_e_flux',
        'EX_nh4_e_flux',
        'EX_o2_e_flux',
        'EX_pi_e_flux',
        'EX_pyr_e_flux',
        'EX_succ_e_flux',
        'FBA_flux',
        'FBP_flux',
        'FORt2_flux',
        'FORti_flux',
        'FRD7_flux',
        'FRUpts2_flux',
        'FUM_flux',
        'FUMt2_2_flux',
        'G6PDH2r_flux',
        'GAPD_flux',
        'GLCpts_flux',
        'GLNS_flux',
        'GLNabc_flux',
        'GLUDy_flux',
        'GLUN_flux',
        'GLUSy_flux',
        'GLUt2r_flux',
        'GND_flux',
        'H2Ot_flux',
        'ICDHyr_flux',
        'ICL_flux',
        'LDH_D_flux',
        'MALS_flux',
        'MALt2_2_flux',
        'MDH_flux',
        'ME1_flux',
        'ME2_flux',
        'NADH16_flux',
        'NADTRHD_flux',
        'NH4t_flux',
        'O2t_flux',
        'PDH_flux',
        'PFK_flux',
        'PFL_flux',
        'PGI_flux',
        'PGK_flux',
        'PGL_flux',
        'PGM_flux',
        'PIt2r_flux',
        'PPC_flux',
        'PPCK_flux',
        'PPS_flux',
        'PTAr_flux',
        'PYK_flux',
        'PYRt2_flux',
        'RPE_flux',
        'RPI_flux',
        'SUCCt2_2_flux',
        'SUCCt3_flux',
        'SUCDi_flux',
        'SUCOAS_flux',
        'TALA_flux',
        'THD2_flux',
        'TKT1_flux',
        'TKT2_flux',
        'TPI_flux'
    ]

all_data = input_cols + output_cols

In [11]:
trained_model = TransformersSeries()  # Must recreate model structure
trained_model.load_state_dict(torch.load('trained_model.pth'))
trained_model.to(device);

In [12]:
trained_model_12_3_5_2 = TransformersSeries(115,12,3,5,2)
trained_model_12_3_5_2.load_state_dict(torch.load('trained_model_12_3_5_2.pth'))

trained_model_12_3_5_2.to(device);


In [13]:
models = [trained_model,trained_model_12_3_5_2]
model_configs = ['dmodel=6,num_heads=2,inner_dim_multiplier=5,num_transformers=2','dmodel=12,num_heads=3,inner_dim_multiplier=5,num_transformers=2']

In [14]:
def plot_data(simulation,model):
    plt.figure(figsize=(6*2, 4*2))
    j = simulation
    k = model
    pred_ops = models[k](inps_test_[j,:,:].unsqueeze(0))
    target_ops = outs_test_[j,:,:].unsqueeze(0)
    inputs = inps_test_[j,:,:].unsqueeze(0)
    plt.stem(inputs[0,:,0].cpu().detach().numpy(),'g',label='Input Concentrations')
    plt.plot(pred_ops[0,:,0].cpu().detach().numpy(),color='red',label='Predicted Concentrations')
    plt.plot(target_ops[0,:,0].cpu().detach().numpy(),color='black',label='Actual Concentrations')
    plt.ylim([-50,100])
    plt.legend()
    plt.title(f"Simulation {j}:{model_configs[k]}")
    plt.grid()

In [15]:
interactive_plot = interactive(plot_data, simulation=(0, inps_test_.shape[0] - 1),model=(0,1))
output = interactive_plot.children[-1]
interactive_plot

interactive(children=(IntSlider(value=14932, description='simulation', max=29864), IntSlider(value=0, descript…

In [16]:


# def plot_data_interact(simulation,model):
#     j = simulation 
#     xx = np.array([i for i in range(20)])


#     # Create plotly figure
#     fig_plotly = go.Figure()
#     inputs = inps_test_[j,:20,:].unsqueeze(0)
#     inp = inputs[0,:,0].cpu().detach().numpy()

#     pred_ops = models[model](inps_test_[j,:,:].unsqueeze(0))
#     target_ops = outs_test_[j,:,:].unsqueeze(0)

#     fig_plotly.add_trace(go.Scatter(
#         x=xx,
#         y=inp,
#         mode='markers',
#         marker=dict(
#             size=12,
#             color='green',
#         ),
#         text=[f'{all_data[i]}' 
#                 for i in xx],
#         hovertemplate='%{text}<extra></extra>',
#         name='Input Concentrations'
#     ))

#     xxx = np.array([i for i in range(115)])

#     fig_plotly.add_trace(go.Scatter(
#         x=xxx, 
#         y = pred_ops[0,:,0].cpu().detach().numpy(), 
#         mode='lines+markers', 
#         text = [f'{all_data[i]}' 
#                 for i in xxx],
#         hovertemplate='%{text}<extra></extra>',
#         name='Predicted Concentrations'
#         ))

#     fig_plotly.add_trace(go.Scatter(
#         x=xxx, 
#         y = target_ops[0,:,0].cpu().detach().numpy(), 
#         mode='lines+markers', 
#         text = [f'{all_data[i]}' 
#                 for i in xxx],
#         hovertemplate='%{text}<extra></extra>',
#         name='Predicted Concentrations'
#         ))

#     fig_plotly.show()


In [17]:
# interactive_plot = interactive(plot_data_interact, simulation=(0, inps_test_.shape[0] - 1),model=(0,1))
# output = interactive_plot.children[-1]
# display(interactive_plot)

In [None]:
# import ipywidgets as widgets
# import plotly.graph_objects as go
# import numpy as np

# Create FigureWidget with increased size
fig_widget = go.FigureWidget(
    layout=go.Layout(
        width=1700,   # Increased width
        height=800,   # Increased height
        title='Interactive Plot',
        yaxis=dict(range=[-50, 100]),
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=10  # Tick every 10 units for larger spacing
        ),
        legend=dict(
            orientation="h",  # Horizontal orientation
            yanchor="bottom",
            y=1.02,          # Position above the plot
            xanchor="center",
            x=0.5            # Center horizontally
        )
    )
)

def update_plot(simulation, model):
    j = simulation
    xx = np.array([i for i in range(20)])
    
    inputs = inps_test_[j, :20, :].unsqueeze(0)
    inp = inputs[0, :, 0].cpu().detach().numpy()
    pred_ops = models[model](inps_test_[j, :, :].unsqueeze(0))
    target_ops = outs_test_[j, :, :].unsqueeze(0)
    
    xxx = np.array([i for i in range(115)])
    
    # Clear existing traces
    fig_widget.data = []
    
    # Add traces
    fig_widget.add_trace(go.Scatter(
        x=xx,
        y=inp,
        mode='markers',
        marker=dict(size=12, color='purple'),
        text=[f'{all_data[i]}' for i in xx],
        hovertemplate='%{text}<extra></extra>',
        name='Input Concentrations'
    ))
    
    fig_widget.add_trace(go.Scatter(
        x=xxx,
        y=pred_ops[0, :, 0].cpu().detach().numpy(),
        mode='lines+markers',
        text=[f'{all_data[i]}' for i in xxx],
        hovertemplate='%{text}<extra></extra>',
        name='Predicted Concentrations'
    ))
    
    fig_widget.add_trace(go.Scatter(
        x=xxx,
        y=target_ops[0, :, 0].cpu().detach().numpy(),
        mode='lines+markers',
        text=[f'{all_data[i]}' for i in xxx],
        hovertemplate='%{text}<extra></extra>',
        name='Target Concentrations'
    ))
    
    # Update layout
    fig_widget.update_layout(
        title=f'Simulation {simulation} - Model {model}',
        xaxis_title='Time Steps',
        yaxis_title='Concentrations'
    )

# Create interactive widget
interactive_plot = widgets.interactive(
    update_plot,
    simulation=widgets.IntSlider(
        min=0, 
        max=len(inps_test_) - 1,
        step=1, 
        value=0,
        description='Simulation:'
    ),
    model=widgets.Dropdown(
            options=[
        ('d_model: 6, num_heads: 2, inner_dim_multiplier: 5, num_transformers: 2', 0),
        ('d_model: 12, num_heads: 3, inner_dim_multiplier: 5, num_transformers: 2', 1)],
        value=0,
        description='Model:'
    )
)

# Display both the controls and the figure
display(interactive_plot, fig_widget)

interactive(children=(IntSlider(value=0, description='Simulation:', max=29864), Dropdown(description='Model:',…

FigureWidget({
    'data': [{'hovertemplate': '%{text}<extra></extra>',
              'marker': {'color': 'purple', 'size': 12},
              'mode': 'markers',
              'name': 'Input Concentrations',
              'text': [EX_glc__D_e, EX_fru_e, EX_lac__D_e, EX_pyr_e, EX_ac_e,
                       EX_akg_e, EX_succ_e, EX_fum_e, EX_mal__L_e, EX_etoh_e,
                       EX_acald_e, EX_for_e, EX_gln__L_e, EX_glu__L_e, EX_co2_e,
                       EX_h_e, EX_h2o_e, EX_nh4_e, EX_o2_e, EX_pi_e],
              'type': 'scatter',
              'uid': 'a9a48613-af52-4c76-8f70-c4fb287b2a40',
              'x': {'bdata': 'AAECAwQFBgcICQoLDA0ODxAREhM=', 'dtype': 'i1'},
              'y': {'bdata': ('AAAAAAAAAAAAAAAAAAAAAAAAAAAAAA' ... 'hCAABIQgAASEK4HiVBZmYxQkjh6kA='),
                    'dtype': 'f4'}},
             {'hovertemplate': '%{text}<extra></extra>',
              'mode': 'lines+markers',
              'name': 'Predicted Concentrations',
              'text': [EX_glc

In [19]:
# pred_outs_ = models[0](inps_test_)
# pred_outs_.size()

In [20]:
# outs_test_.size()

In [21]:
def plot_pred_actual(metabolite,simulation,model):  
    pred_outs_ = models[model](inps_test_)
    plt.scatter(pred_outs_[:,metabolite,:].cpu().detach().numpy(),outs_test_[:,metabolite,:].cpu().detach().numpy(),color='r',alpha=0.1)
    plt.scatter(pred_outs_[simulation,metabolite,:].cpu().detach().numpy(),outs_test_[simulation,metabolite,:].cpu().detach().numpy(),color='k',alpha=0.9)
    plt.xlim([-50,100])
    plt.ylim([-50,100])
    plt.grid()
    plt.xlabel('Predicted Output')
    plt.ylabel('Actual Output')
    plt.title('Actual Concentrations vs Predicted Concentrations')
    plt.show()

In [22]:
interactive_plot = interactive(plot_pred_actual, metabolite=(0, inps_test_.shape[1] - 1), simulation = (0, inps_test_.shape[0] - 1), model = (0,1))
output = interactive_plot.children[-1]
interactive_plot

interactive(children=(IntSlider(value=57, description='metabolite', max=114), IntSlider(value=14932, descripti…

In [23]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import ipywidgets as widgets
from IPython.display import display

def create_plot_data(metabolite, simulation, model):
    pred_outs_ = models[model](inps_test_)
    
    # Convert to numpy arrays
    pred_all = pred_outs_[:, metabolite, :].cpu().detach().numpy()
    actual_all = outs_test_[:, metabolite, :].cpu().detach().numpy()
    pred_sim = pred_outs_[simulation, metabolite, :].cpu().detach().numpy()
    actual_sim = outs_test_[simulation, metabolite, :].cpu().detach().numpy()
    
    # Create simulation indices for hover text
    sim_indices_all = []
    for i in range(pred_all.shape[0]):
        sim_indices_all.extend([i] * pred_all.shape[1])
    
    sim_indices_sim = [simulation] * len(pred_sim.flatten())
    
    return pred_all, actual_all, pred_sim, actual_sim, sim_indices_all, sim_indices_sim

# Create initial figure with FigureWidget
fig = go.FigureWidget()

# Initialize with default values
pred_all, actual_all, pred_sim, actual_sim, sim_indices_all, sim_indices_sim = create_plot_data(0, 0, 0)

# Add all data points (red, low opacity)
fig.add_trace(go.Scatter(
    x=pred_all.flatten(),
    y=actual_all.flatten(),
    mode='markers',
    marker=dict(
        color='red',
        opacity=0.1,
        size=6
    ),
    name='Metabolite 0',
    customdata=sim_indices_all,
    hovertemplate='<b>Metabolite 0</b><br>' +
                  'Simulation: %{customdata}<br>' +
                  'Predicted: %{x:.2f}<br>' +
                  'Actual: %{y:.2f}<br>' +
                  '<extra></extra>'
))

# Add specific simulation data points (black, high opacity)
fig.add_trace(go.Scatter(
    x=pred_sim.flatten(),
    y=actual_sim.flatten(),
    mode='markers',
    marker=dict(
        color='black',
        opacity=0.9,
        size=8
    ),
    name='Simulation 0',
    customdata=sim_indices_sim,
    hovertemplate='<b>Simulation 0</b><br>' +
                  'Simulation: %{customdata}<br>' +
                  'Predicted: %{x:.2f}<br>' +
                  'Actual: %{y:.2f}<br>' +
                  '<extra></extra>'
))

# Update layout
fig.update_layout(
    title='Input Output Characteristics',
    xaxis_title='Predicted Output',
    yaxis_title='Actual Output',
    xaxis=dict(range=[-50, 100], showgrid=True),
    yaxis=dict(range=[-50, 100], showgrid=True),
    showlegend=True,
    hovermode='closest',
    width=800,
    height=700
)

# Create interactive widgets
metabolite_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=inps_test_.shape[1] - 1,
    step=1,
    description='Metabolite:'
)

simulation_slider = widgets.IntSlider(
    value=0,
    min=0,
    max=inps_test_.shape[0] - 1,
    step=1,
    description='Simulation:'
)

model_dropdown = widgets.Dropdown(
    options=[
        ('d_model: 6, num_heads: 2, inner_dim_multiplier: 5, num_transformers: 2', 0),
        ('d_model: 12, num_heads: 3, inner_dim_multiplier: 5, num_transformers: 2', 1)
    ],
    value=0,
    description='Model:'
)

def update_plot(*args):
    pred_all, actual_all, pred_sim, actual_sim, sim_indices_all, sim_indices_sim = create_plot_data(
        metabolite_slider.value,
        simulation_slider.value,
        model_dropdown.value
    )
    
    # Update the existing traces instead of creating new ones
    with fig.batch_update():
        # Update all data points
        fig.data[0].x = pred_all.flatten()
        fig.data[0].y = actual_all.flatten()
        fig.data[0].customdata = sim_indices_all
        fig.data[0].name = f'Metabolite {metabolite_slider.value}'
        fig.data[0].hovertemplate = f'<b>Metabolite {metabolite_slider.value}</b><br>' + \
                                   'Simulation: %{customdata}<br>' + \
                                   'Predicted: %{x:.2f}<br>' + \
                                   'Actual: %{y:.2f}<br>' + \
                                   '<extra></extra>'
        
        # Update specific simulation data points
        fig.data[1].x = pred_sim.flatten()
        fig.data[1].y = actual_sim.flatten()
        fig.data[1].customdata = sim_indices_sim
        fig.data[1].name = f'Simulation {simulation_slider.value}'
        fig.data[1].hovertemplate = f'<b>Simulation {simulation_slider.value}</b><br>' + \
                                   'Simulation: %{customdata}<br>' + \
                                   'Predicted: %{x:.2f}<br>' + \
                                   'Actual: %{y:.2f}<br>' + \
                                   '<extra></extra>'

# Connect widgets to update function
metabolite_slider.observe(update_plot, names='value')
simulation_slider.observe(update_plot, names='value')
model_dropdown.observe(update_plot, names='value')

# Create and display the interactive plot
interactive_plot = widgets.VBox([
    metabolite_slider,
    simulation_slider,
    model_dropdown,
    fig
])

display(interactive_plot)

VBox(children=(IntSlider(value=0, description='Metabolite:', max=114), IntSlider(value=0, description='Simulat…