In [None]:
import random

import dash_bootstrap_components as dbc
import numpy as np
import pandas as pd
import plotly.express as px

from rubicon_ml import Rubicon
from rubicon_ml.viz.base import VizBase
from rubicon_ml.viz.dataframe_plot import DataframePlot, _register_callbacks as _df_register_callbacks
from rubicon_ml.viz.metric_lists_comparison import CompareMetricLists, _register_callbacks as _m_register_callbacks

In [None]:
rubicon = Rubicon(persistence="memory", auto_git_enabled=True)
project = rubicon.get_or_create_project("list metric comparison")

num_experiments_to_log = 5

data_ranges = [(random.randint(0, 15000), random.randint(0, 15000)) for _ in range(num_experiments_to_log)]
dates = pd.date_range(start="1/1/2010", end="12/1/2020", freq="MS")

for i in range(num_experiments_to_log):
    experiment = project.log_experiment()
    experiment.log_metric(name="coefficients", value=[random.random() for _ in range(0, 25)])
    experiment.log_metric(name="p-values", value=[random.random() for _ in range(0, 25)])
    experiment.log_metric(name="stderr", value=[random.random() for _ in range(0, 25)])
    
    start, stop = data_ranges[i]
    data = np.array([list(dates), np.linspace(start, stop, len(dates))])
    data_df = pd.DataFrame.from_records(data.T, columns=["calendar month", "open accounts"])

    dataframe = experiment.log_dataframe(data_df, tags=["open accounts"])

In [None]:
dashboard_container = VizBase(dash_title="rubicon-ml dashboard")

compare_metric_lists = CompareMetricLists(
    project.experiments(),
    selected_metric="coefficients",
    column_names=["intercept"] + [f"var_{i:03}" for i in range(1, 25)],
)
_m_register_callbacks(dashboard_container.app)

dataframe_plot = DataframePlot(
    project.experiments(),
    px.line,
    ["open accounts"],
    "calendar month",
    "open accounts",
)
_df_register_callbacks(dashboard_container.app)

In [None]:
dashboard_container.app = compare_metric_lists.app
dashboard_container.app.layout = dashboard_container._build_frame(
    [
        dbc.Col(dataframe_plot._build_layout(), width=6),
        dbc.Col(compare_metric_lists._build_layout(), width=6),
    ],
)
dashboard_container.run_server_inline(i_frame_kwargs={"height": "600px"})