In [19]:
import dash
from dash import dcc
from dash import html
from dash.dependencies import Input, Output
import plotly.express as px
import numpy as np
import torch
import pandas as pd
import os
import pathlib
import matplotlib.pyplot as plt
import seaborn as sns

In [20]:
%reload_ext autoreload
%autoreload 2

In [21]:
import sys
sys.path.append("../../branching_model/")
# sys.path.append("../branching_model/")


In [24]:
# from branching_model.Agent import * 
from branching_model.Agent import Agent, N_TREATMENTS 
from branching_model.test_plasticity import test_adaptation, test_treatment

In [28]:
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import torch
import numpy as np
import plotly.graph_objs as go

app = dash.Dash(__name__)

app.layout = html.Div([
    html.H1("Cell Plasticity Dashboard"),
    html.Button("Run No Treatment", id="btn-no-treatment"),
    html.Button("Run Treatment 1", id="btn-treatment-1"),
    html.Button("Run Treatment 2", id="btn-treatment-2"),
    html.Button("Run Dual Treatments", id="btn-dual-treatments"),
    html.Button("Run Sequential Treatment", id="btn-sequential-treatment"),
    dcc.Graph(id='plot'),
])

@app.callback(Output('plot', 'figure'), Input('btn-no-treatment', 'n_clicks'),
              Input('btn-treatment-1', 'n_clicks'),
              Input('btn-treatment-2', 'n_clicks'),
              Input('btn-dual-treatments', 'n_clicks'),
              Input('btn-sequential-treatment', 'n_clicks'))

def update_plot(btn_no_treatment, btn_treatment_1, btn_treatment_2, btn_dual_treatments, btn_sequential_treatment):
    ctx = dash.callback_context
    if not ctx.triggered:
        button_id = 'btn-no-treatment'
    else:
        button_id = ctx.triggered[0]['prop_id'].split('.')[0]

    plot_title = ""
    figures = []

    if button_id == 'btn-no-treatment':
        no_treatment_title = "No treatment"
        no_treatment_doses = torch.zeros(N_TREATMENTS).reshape(1, -1)
        figures.append(generate_plot(no_treatment_doses, no_treatment_title))

    elif button_id == 'btn-treatment-1':
        treatment1_title = "Treatment 1"
        treatment1_doses = torch.from_numpy(np.array([1.0, 0.0], dtype=np.float32)).reshape(1, -1)
        figures.append(generate_plot(treatment1_doses, treatment1_title))

    elif button_id == 'btn-treatment-2':
        treatment2_title = "Treatment 2"
        treatment2_doses = torch.from_numpy(np.array([0.0, 1.0], dtype=np.float32)).reshape(1, -1)
        figures.append(generate_plot(treatment2_doses, treatment2_title))

    elif button_id == 'btn-dual-treatments':
        dual_treatment_title = "Dual treatments"
        dual_doses = torch.from_numpy(np.array([1.0, 1.0], dtype=np.float32)).reshape(1, -1)
        figures.append(generate_plot(dual_doses, dual_treatment_title))

    elif button_id == 'btn-sequential-treatment':
        plot_title = "Sequential Treatment"
        # Generate a plot for sequential treatment (implement this function)

    if figures:
        return {'data': [fig.data for fig in figures], 'layout': {'title': plot_title}}
    else:
        return {}

def generate_plot(doses, title):
    # Create a line plot for each trait
    x_values = [1, 2, 3]
    y_values = [4, 1, 2]

    data = [go.Scatter(x=x_values, y=y_values, mode='lines', name='Example')]

    # Add S trait as a separate trace
    data.append(go.Scatter(x=x_values, y=y_values, mode='lines', name='S'))

    for i in range(1, N_TREATMENTS + 1):
        trait_name = f"R{i}"
        # Replace x_values and y_values with the appropriate values for the trait
        data.append(go.Scatter(x=x_values, y=y_values, mode='lines', name=trait_name))

    layout = go.Layout(title=title)
    return {'data': data, 'layout': layout}