# Curve Boxplots

TODO: 

- <s>Roll `query_base` so my eyes work again.</s>

- <s>Consolidate step 1.</s>

- <s>Seperate grouping and centrality metrics.</s> 

- <s>Band depth/MBD implementation: O(n<sup>3</sup>)? Centrality only, not distance.</s> 

- <s>Allow for efficiency without grouping.</s>

- Optimization: do the math on runtime reductions from clustering/partitioning. Partitioning by date might be useless? `PERCENTILE_CONT` &#8594; `APPROX_QUANTILE`?

- Write graphing package to layer graphs.

- Grouping by derivative? Is this stupid?

In [1]:
from google.cloud import bigquery
from google.oauth2 import service_account
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import matplotlib.colors as mcolors
from bin_builder import build_country_query

ModuleNotFoundError: No module named 'bin_builder'

In [None]:
# Set parameters for grouping
num_clusters = 2  # Example number of clusters
num_features = 5   # Set the number of features to select
grouping_method = 'mse' # 'mse' or 'abc'
centrality_method = 'mbd' # 'mse' for mean squared error,'abc' for area between the curves or 'mbd' for modified band depth

In [None]:
# helper function for plotting
def color_to_rgba(color_name, value, alpha):
    # Get the RGB components of the color
    rgb = mcolors.to_rgb(color_name)
    
    # Calculate the gray value based on the input value
    gray = (1 - value) * 0.8  # This adjusts how gray the color will be

    # Compute the final RGBA values
    r = rgb[0] * value + gray
    g = rgb[1] * value + gray
    b = rgb[2] * value + gray
    
    # Return the RGBA string
#     print(f'rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {alpha})')
    return f'rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {alpha})'

In [None]:
# the location of our credentials json and name of bigquery project
credentials = service_account.Credentials.from_service_account_file('C:\\Users\\elija\\Documents\\24f-coop\\credentials.json')
project = 'net-data-viz-handbook'

# Initialize a GCS client
client = bigquery.Client(credentials=credentials, project=project)

## Procedurally generating queries

In [None]:
table = 'net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily'
country_ids = [215] 
run_ids = 'all'
min_age = 0
max_age = 17
categories = ['Infectious']
q = build_country_query(table, country_ids, run_ids, min_age, max_age, categories, grouped=True)
q

In [None]:
# this is the query
query_base =  """
CREATE OR REPLACE TABLE sri_data.infectious_data
PARTITION BY date
CLUSTER BY country_id, run_id AS
SELECT 
    date,
    country_id,
    run_id,
    -- Calculate 7-day rolling average of total_infectious
    AVG(SUM(Infectious_13_17 + Infectious_18_23)) OVER (
        PARTITION BY country_id, run_id 
        ORDER BY date 
        ROWS BETWEEN 6 PRECEDING AND CURRENT ROW
    ) AS total_infectious
FROM `net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily`
WHERE country_id IN (218)
  AND run_id BETWEEN 1 AND 100
GROUP BY date, country_id, run_id;
"""

df = client.query(query_base).result()  # Execute the query to create the table
print("Data sliced successfully.")

## KMeans and a prayer (for creating grouped curve-based statistics)
- Uses the first num_features columns to simplify the KMeans algorithms, which is O(n<sup>2</sup>)?
    - Might just not need PCA or similar
- Procedural once distance matrix is computed (eg have to query main table manually)

In [None]:
# Step 1: Create the curve distance table
query_distances = """
CREATE OR REPLACE TABLE sri_data.curve_distances AS
SELECT
    a.run_id AS run_id_a,
    b.run_id AS run_id_b,
    AVG(POW(a.total_infectious - b.total_infectious, 2)) AS mse,
    SUM(ABS(a.total_infectious - b.total_infectious)) AS abc
FROM
    sri_data.infectious_data a
JOIN
    sri_data.infectious_data b
ON
    a.date = b.date
GROUP BY
    run_id_a, run_id_b
"""
client.query(query_distances).result()  # Execute the query to create the table
print("Curve distance table created successfully.")

In [None]:
# Step 3: Create the distance matrix
query_distance_matrix = f"""
CREATE OR REPLACE TABLE sri_data.distance_matrix
CLUSTER BY run_id AS --optimizations
SELECT
    run_id_a AS run_id,
        ARRAY_AGG(STRUCT(run_id_b, {grouping_method}) ORDER BY run_id_b ASC) AS distances
FROM
    sri_data.curve_distances
GROUP BY
    run_id_a;
"""
client.query(query_distance_matrix).result()  # Execute the query to create the table
print("Distance matrix table created successfully.")

In [None]:
# Step 4: Handle case when num_clusters = 1
if num_clusters == 1:
    query_assign_all_to_one_cluster = """
    CREATE OR REPLACE TABLE sri_data.kmeans_results
    CLUSTER BY CENTROID_ID, run_id AS
    SELECT DISTINCT
        run_id,
        1 AS centroid_id  -- Assign all runs to centroid 1
    FROM 
        sri_data.distance_matrix
    """
    s = time.time()
    client.query(query_assign_all_to_one_cluster).result()  # Execute the query to assign all runs to centroid 1
    print(f"All runs assigned to centroid 1 successfully in {round(time.time() - s, 3)} seconds.")
    
else:
    # Step 4: Create the K-means model by selecting the first num_features features based on actual distances
    query_create_model = f"""
    CREATE OR REPLACE MODEL sri_data.kmeans_model
    OPTIONS(model_type='kmeans', num_clusters={num_clusters}) AS
    SELECT
        run_id,
        ARRAY(
            SELECT distance.{grouping_method} 
            FROM UNNEST(distances) AS distance 
            WHERE distance.run_id_b <= {num_features}  -- Select only the first num_features
        ) AS features
    FROM
        sri_data.distance_matrix;
    """
    s = time.time()
    client.query(query_create_model).result()  # Execute the model creation
    print(f"K-means model created successfully in {round(time.time() - s, 3)} seconds.")

    # Step 5: Apply K-means clustering and save results in a table
    query_kmeans = f"""
    CREATE OR REPLACE TABLE sri_data.kmeans_results
    CLUSTER BY CENTROID_ID, run_id AS
    SELECT
        *
    FROM
        ML.PREDICT(MODEL sri_data.kmeans_model,
            (SELECT
                run_id,
                ARRAY(
                    SELECT distance.{grouping_method} 
                    FROM UNNEST(distances) AS distance 
                    WHERE distance.run_id_b <= {num_features}  
                ) AS features
             FROM
                sri_data.distance_matrix)
        ) AS predictions
    """
    s = time.time()
    client.query(query_kmeans).result()  # Execute the model creation
    print(f"K-means clustering results saved successfully in {round(time.time() - s, 3)} seconds.")

In [None]:
# revised step 6, gets summed distances for abc and mse
s = time.time()
save_sum_distances = f"""CREATE OR REPLACE TABLE `sri_data.total_distances_table`
    CLUSTER BY CENTROID_ID, run_id AS
    WITH a AS (
        SELECT 
            kr.CENTROID_ID,
            kr.run_id, 
            run_id_b, 
            {centrality_method}  
        FROM 
            `sri_data.kmeans_results` AS kr
        JOIN 
            `sri_data.curve_distances` AS dm
        ON 
            kr.run_id = dm.run_id_a
--        CROSS JOIN 
--            UNNEST(dm.distances) AS dm_dist  -- Unnest the distances array here 
    ),
    b AS (
        SELECT
            run_id AS run_id_b, 
            CENTROID_ID AS CENTROID_ID_B
        FROM
            `sri_data.kmeans_results`
    )
    SELECT 
        a.run_id,
        a.CENTROID_ID,
        AVG({centrality_method}) AS total_distance  
    FROM 
        a
    JOIN 
        b
    ON 
        a.run_id_b = b.run_id_b
    WHERE
        a.CENTROID_ID = b.CENTROID_ID_B
    GROUP BY
        a.CENTROID_ID,
        a.run_id;
    """
if centrality_method in ['abc', 'mse']:
    client.query(save_sum_distances).result()  # Execute the model creation
    print(f"Distance sum results using {centrality_method.upper()} saved successfully in {round(time.time()-s, 3)}.")

In [None]:
s = time.time()
mbd = """
CREATE OR REPLACE TABLE `sri_data.total_distances_table`
CLUSTER BY CENTROID_ID, run_id AS
WITH curve_data AS (
    SELECT DISTINCT
        a.date AS date,
        a.run_id AS run_id,
        kra.CENTROID_ID as CENTROID_ID,
        b.run_id AS boundary_1_id,
        c.run_id AS boundary_2_id, 
        MAX(a.total_infectious) AS curve,
        MAX(b.total_infectious) AS boundary_1,
        MAX(c.total_infectious) AS boundary_2
    FROM
        sri_data.infectious_data AS a
    JOIN
        sri_data.infectious_data AS b ON a.date = b.date
    JOIN
        sri_data.infectious_data AS c ON a.date = c.date
    JOIN 
        sri_data.kmeans_results AS kra ON a.run_id = kra.run_id
    JOIN 
        sri_data.kmeans_results AS krb ON b.run_id = krb.run_id
    JOIN
        sri_data.kmeans_results AS krc ON c.run_id = krc.run_id
    
    WHERE
        b.run_id < c.run_id
        AND a.run_id != b.run_id
        AND a.run_id != c.run_id
        AND kra.CENTROID_ID = krb.CENTROID_ID
        AND kra.CENTROID_ID = krc.CENTROID_ID

    GROUP BY
      a.date, a.run_id, CENTROID_ID, b.run_id, c.run_id
    ORDER BY
      a.run_id, b.run_id, c.run_id, a.date
)


SELECT
    run_id,
    CENTROID_ID,
    1 / COUNT(*) as total_distance
FROM curve_data
WHERE
  (curve_data.boundary_1 <= curve AND curve <= boundary_2)
  OR (curve_data.boundary_2 <= curve AND curve <= boundary_1)
GROUP BY run_id, CENTROID_ID
"""
if centrality_method == 'mbd':
    df = client.query(mbd).result().to_dataframe()  # Execute the query to create the table
    print(f"MBD calculated in {round(time.time()-s, 3)} seconds.")

In [None]:
# Step 7
query_non_outliers = """
CREATE OR REPLACE TABLE `sri_data.non_outliers_table` AS
WITH iqr_bounds AS (
  SELECT 
    CENTROID_ID,
    --Using approximate quantiles to save some time and space
    APPROX_QUANTILES(total_distance, 100)[OFFSET(25)] AS lower_quartile,
    APPROX_QUANTILES(total_distance, 100)[OFFSET(75)] AS upper_quartile
  FROM `sri_data.total_distances_table`
  GROUP BY CENTROID_ID
),
non_outliers AS (
  SELECT 
    d.CENTROID_ID,
    d.run_id
  FROM `sri_data.total_distances_table` d
  JOIN iqr_bounds b
    ON d.CENTROID_ID = b.CENTROID_ID
  WHERE d.total_distance BETWEEN 
        (b.lower_quartile - 1.5 * (b.upper_quartile - b.lower_quartile)) 
        AND (b.upper_quartile + 1.5 * (b.upper_quartile - b.lower_quartile))
)

SELECT * FROM non_outliers;

    """
client.query(query_non_outliers).result()  # Execute the model creation
print(f"Non-outliers saved successfully.")

In [None]:
query_middle_curves = """
CREATE OR REPLACE TABLE `sri_data.middle_curves` AS
WITH grouped_data AS (
  SELECT 
    CENTROID_ID,
    run_id,
    total_distance,
    ROW_NUMBER() OVER (PARTITION BY CENTROID_ID ORDER BY total_distance) AS rn,
    COUNT(*) OVER (PARTITION BY CENTROID_ID) AS total_count
  FROM `sri_data.total_distances_table`
),
top_half AS (
  SELECT 
    CENTROID_ID,
    run_id,
    total_distance,
    rn
  FROM grouped_data
  WHERE rn <= CAST(total_count * 0.5 AS INT64)  -- Select the top 50% based on total_distance
)

SELECT 
    CENTROID_ID,
    run_id,
FROM top_half;  -- Select the required fields for the middle_curves table

"""
client.query(query_middle_curves).result()  # Execute the model creation
print(f"Middle curves saved successfully.")

In [None]:
save_median = """
CREATE OR REPLACE TABLE `sri_data.median_curves` AS
WITH ranked_data AS (
  SELECT 
    CENTROID_ID,
    run_id,
    total_distance,
    ROW_NUMBER() OVER (PARTITION BY CENTROID_ID ORDER BY total_distance) AS rn
  FROM `sri_data.total_distances_table`
)

SELECT 
    CENTROID_ID,
    run_id,
FROM ranked_data
WHERE rn = 1;  -- Select the run_id with the lowest total_distance for each CENTROID_ID
"""

client.query(save_median).result()  # Execute the model creation
print(f"Median curves saved successfully.")

In [None]:
# Step 8
get_median_curves = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    SELECT
        data.date,
        mc.CENTROID_ID,
        MAX(data.total_infectious) as median
    FROM
        sri_data.infectious_data as data
    JOIN
        `sri_data.median_curves` as mc
    ON
        data.run_id = mc.run_id
    GROUP BY
        date, 
        CENTROID_ID
    ORDER BY
        CENTROID_ID, 
        date;
    """
plt_median = client.query(get_median_curves).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")

In [None]:
# Step 8
get_curves = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    
    SELECT
        data.date,
        nout.CENTROID_ID,
        MAX(data.total_infectious) as curve_100,
        MIN(data.total_infectious) as curve_0
    FROM
        sri_data.infectious_data as data
    JOIN
        `sri_data.non_outliers_table` as nout
    ON
        data.run_id = nout.run_id
    GROUP BY
        date, 
        CENTROID_ID
    ORDER BY
        CENTROID_ID, 
        date;
    """
plt_curves = client.query(get_curves).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")

In [None]:
# Step 8
get_mid_curves = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    SELECT
        data.date,
        mc.CENTROID_ID,
        MAX(data.total_infectious) as curve_75,
        MIN(data.total_infectious) as curve_25
    FROM
        sri_data.infectious_data as data
    JOIN
        `sri_data.middle_curves` as mc
    ON
        data.run_id = mc.run_id
    GROUP BY
        date, 
        CENTROID_ID
    ORDER BY
        CENTROID_ID, 
        date;
    """
plt_middle = client.query(get_mid_curves).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")

In [None]:
get_outliers = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    WITH outliers AS (
        SELECT 
            tdt.run_id, 
            tdt.CENTROID_ID
        FROM 
            `sri_data.total_distances_table` as tdt
        WHERE 
            run_id NOT IN (SELECT run_id FROM `sri_data.non_outliers_table`)
    )
    
    
    SELECT
        data.date,
        outliers.CENTROID_ID,
        outliers.run_id,
        data.total_infectious
    FROM
        outliers
    JOIN
        sri_data.infectious_data as data
    ON
        data.run_id = outliers.run_id
    ORDER BY
        run_id, 
        date;
    """
plt_outliers = client.query(get_outliers).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")

## Visualizing

In [None]:
import plotly.express as px
import plotly.graph_objects as go
# here are all of our outliers
px.line(plt_outliers, x='date', y='total_infectious', color='CENTROID_ID', line_group='run_id')

In [None]:
# merging plt_curves and plt_middle into one dataframe because I can
merged_curves = pd.merge(plt_curves, plt_middle, on=['date', 'CENTROID_ID'], how='inner')
merged_curves

In [None]:
# create and lay out graph
fig = go.Figure()
fig.update_layout(
    title={
        'text': f"<b>Functional Boxplot</b><br><span style='font-size: 12px;'>As seen in \
<a href=https://www.tandfonline.com/doi/pdf/10.1198/jcgs.2011.09224?casa_token=ID3IjHflKz4AAAAA:4i-zhPbXhDzg\
8pDuowEPWoNiUFzHFcADAHHsqPonc6ac4dIzuQ40g5VA_n4BlUU7v1JsW7OD7Hf2>Sun & Genton (2011)</a>.  \
Clusters: {num_clusters}, features: {num_features}, grouping: {grouping_method}, centrality: {centrality_method}.</span>",
        'x': 0.5,  
        'y': 0.9,  
    },
    xaxis_title="Date",
    yaxis_title="Incidence",
    
    # colors
    xaxis=dict(gridcolor='#EFEFEF'),  # Change x-axis grid color
    yaxis=dict(gridcolor='#EFEFEF'),   # Change y-axis grid color
    plot_bgcolor='#FFFFFF',
    paper_bgcolor='#FFFFFF',
)
# fig.update_xaxes(range=[pd.Timestamp("2009-09-01"), pd.Timestamp("2010-02-17")])
# fig.update_yaxes(range=[0, 35000])
colors = 'maroon', 'navy', 'green', 'purple'

# plot outliers
for run in plt_outliers['run_id'].unique():
        df_run = plt_outliers[plt_outliers['run_id'] == run]
        gr = df_run.iloc[0, 1] # careful not to change table formatting
    
        fig.add_trace(go.Scatter(
        name=f'Group {gr} Outlier',
        x=df_run['date'],
        y=df_run['total_infectious'],
        marker=dict(color=color_to_rgba(colors[gr-1], .5, alpha=.3)),
        line=dict(width=1, dash='solid'),
        mode='lines',
        showlegend=False,
        legendgroup=str(gr)  # Assign to legend group
    ))

for group in plt_median['CENTROID_ID'].unique():
    print(group)
    # actually graph
    # Lower
    fig.add_trace(go.Scatter(
        name=f'Group {group} Lower Bound',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    # Upper
    fig.add_trace(go.Scatter(
        name=f'Group {group} Upper Bound',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_100'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], 1, .3),
        fill='tonexty',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
        
    # Lower
    fig.add_trace(go.Scatter(
        name=f'Group {group} Lower Quartile',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_25'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    # Upper
    fig.add_trace(go.Scatter(
        name=f'Group {group}',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_75'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], 1, alpha=.3),
        fill='tonexty',
        showlegend=True,
        legendgroup=str(group)  # Assign to legend group
    ))
    
    
    fig.add_trace(go.Scatter(
        name=f'Group {group} Median',
        x=plt_median[plt_median['CENTROID_ID'] == group]['date'],
        y=plt_median[plt_median['CENTROID_ID'] == group]['median'],
        marker=dict(color=color_to_rgba(colors[group-1], 1, alpha=1)),
        line=dict(width=1),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    
fig.show()

## Fixed time quantiles

In [None]:
fixed_time_quantiles = """
WITH daily_data AS (
    SELECT 
        date, 
        total_infectious,
        run_id,
        ROW_NUMBER() OVER (PARTITION BY date ORDER BY total_infectious) AS row_num,
        COUNT(*) OVER (PARTITION BY date) AS total_rows
    FROM sri_data.infectious_data
),

-- Joining with kmeans_results to attach CENTROID_ID
centroid_data AS (
    SELECT 
        d.date,
        d.total_infectious,
        d.run_id,
        k.CENTROID_ID
    FROM daily_data d
    JOIN `sri_data.kmeans_results` k ON d.run_id = k.run_id -- Adjust the join condition if necessary
)

SELECT 
    CENTROID_ID,
    date,
    PERCENTILE_CONT(total_infectious, 0) OVER (PARTITION BY CENTROID_ID, date) AS Min,
    PERCENTILE_CONT(total_infectious, 0.05) OVER (PARTITION BY CENTROID_ID, date) AS Perc5,
    PERCENTILE_CONT(total_infectious, 0.25) OVER (PARTITION BY CENTROID_ID, date) AS Q1,
    PERCENTILE_CONT(total_infectious, 0.50) OVER (PARTITION BY CENTROID_ID, date) AS Median,
    PERCENTILE_CONT(total_infectious, 0.75) OVER (PARTITION BY CENTROID_ID, date) AS Q3,
    PERCENTILE_CONT(total_infectious, 0.95) OVER (PARTITION BY CENTROID_ID, date) AS Perc95,
    PERCENTILE_CONT(total_infectious, 1) OVER (PARTITION BY CENTROID_ID, date) AS Max
FROM centroid_data
GROUP BY CENTROID_ID, date, total_infectious
ORDER BY CENTROID_ID, date;


"""
plt_ftq = client.query(fixed_time_quantiles).result().to_dataframe()  # Execute the query to create the table
print("Data pulled successfully.")

In [None]:
plt_ftq

In [None]:
# a monstrosity of a query
get_outlying_points = """
WITH daily_data AS (
    SELECT 
        date, 
        total_infectious,
        run_id
    FROM sri_data.infectious_data  
),

centroid_data AS (
    SELECT 
        d.date,
        d.total_infectious,
        d.run_id,
        k.CENTROID_ID
    FROM daily_data d
    JOIN `sri_data.kmeans_results` k ON d.run_id = k.run_id
),

percentile_data AS (
    SELECT 
        CENTROID_ID,
        date,
        PERCENTILE_CONT(total_infectious, 0.05) OVER (PARTITION BY CENTROID_ID, date) AS Perc5,
        PERCENTILE_CONT(total_infectious, 0.95) OVER (PARTITION BY CENTROID_ID, date) AS Perc95
    FROM centroid_data
    GROUP BY CENTROID_ID, date, total_infectious
)

-- Main query to filter points outside the 90% interval
SELECT DISTINCT
    cd.CENTROID_ID,
    cd.date,
    cd.total_infectious
FROM centroid_data cd
JOIN percentile_data pd
  ON cd.CENTROID_ID = pd.CENTROID_ID
  AND cd.date = pd.date
WHERE cd.total_infectious < pd.Perc5  -- Below 5th percentile
   OR cd.total_infectious > pd.Perc95  -- Above 95th percentile
ORDER BY cd.CENTROID_ID, cd.date;

"""
plt_outlying_points = client.query(get_outlying_points).result().to_dataframe()  # Execute the query to create the table
print("Data pulled successfully.")

In [None]:
px.scatter(plt_outlying_points, x='date', y='total_infectious', color=plt_outlying_points['CENTROID_ID'].astype(str))

In [None]:
full_range = False
outlying_points = True

fig = go.Figure()
fig.update_layout(
    title={
        'text': f"<b>Traditional Boxplot</b><br><span style='font-size: 12px;'>Uses fixed-time quantiles.</span>",
        'x': 0.5,  
        'y': 0.9,  
    },
    xaxis_title="Date",
    yaxis_title="Incidence",
    
    # colors
    xaxis=dict(gridcolor='#EFEFEF'),  # Change x-axis grid color
    yaxis=dict(gridcolor='#EFEFEF'),   # Change y-axis grid color
    plot_bgcolor='#FFFFFF',
    paper_bgcolor='#FFFFFF',
)
# fig.update_xaxes(range=[pd.Timestamp("2009-09-01"), pd.Timestamp("2010-02-17")])
# fig.update_yaxes(range=[0, 35000])

try:
    plt_ftq.set_index('date', inplace=True)
except Exception:
    pass

for group in plt_ftq['CENTROID_ID'].unique():
    df_group = plt_ftq[plt_ftq['CENTROID_ID'] == group]

    if full_range:
        # FULL RANGE
        fig.add_trace(go.Scatter(
            name=f'Minimum',
            x=df_group.index,
            y=df_group['Min'],
            marker=dict(color="#444"),
            line=dict(width=0),
            mode='lines',
            showlegend=False,
            legendgroup=str(group)  # Assign to legend group
        ))
        fig.add_trace(go.Scatter(
            name=f'Full Range',
            x=df_group.index,
            y=df_group['Max'],
            mode='lines',
            marker=dict(color="#444"),
            line=dict(width=0),
            fillcolor=color_to_rgba(colors[group-1], .4, .2),
            fill='tonexty',
            showlegend=False,
            legendgroup=str(group)  # Assign to legend group
        ))
    
    
    # MIDDLE 90%
    fig.add_trace(go.Scatter(
        name=f'Minimum',
        x=df_group.index,
        y=df_group['Perc5'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    fig.add_trace(go.Scatter(
        name=f'Middle 90%',
        x=df_group.index,
        y=df_group['Perc95'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], .6, .2),
        fill='tonexty',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    
    # MIDDLE 50%
    fig.add_trace(go.Scatter(
        name=f'Minimum',
        x=df_group.index,
        y=df_group['Q1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    fig.add_trace(go.Scatter(
        name=f'Group {group}',
        x=df_group.index,
        y=df_group['Q3'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], .8, .3),
        fill='tonexty',
        showlegend=True,
        legendgroup=str(group)  # Assign to legend group
    ))
    
    
    fig.add_trace(go.Scatter(
        name=f'Median',
        x=df_group.index,
        y=df_group['Median'],
        marker=dict(color=color_to_rgba(colors[group-1], 1, 1)),
        line=dict(width=1),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)
    ))
    
    if outlying_points:
        fig.add_trace(go.Scatter(
        name=f'Outlying Points',
        x=plt_outlying_points[plt_outlying_points['CENTROID_ID'] == group]['date'],
        y=plt_outlying_points[plt_outlying_points['CENTROID_ID'] == group]['total_infectious'],
        mode='markers',
        marker=dict(color=color_to_rgba(colors[group-1], .4, .1)),
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    
fig.show()

In [None]:
import plotly.graph_objects as go

# Create a new figure to overlay both graphs
fig = go.Figure()

# Add traces from the functional boxplot
for run in plt_outliers['run_id'].unique():
    df_run = plt_outliers[plt_outliers['run_id'] == run]
    gr = df_run.iloc[0, 1]  # careful not to change table formatting

    fig.add_trace(go.Scatter(
        name=f'Group {gr} Outlier (Functional)',
        x=df_run['date'],
        y=df_run['total_infectious'],
        marker=dict(color=color_to_rgba('maroon', .5, alpha=.3)),
        line=dict(width=1, dash='solid'),
        mode='lines',
        showlegend=False,
        legendgroup='Functional'  # Assign to functional legend group
    ))

for group in plt_median['CENTROID_ID'].unique():
    # Lower Bound
    fig.add_trace(go.Scatter(
        name=f'Group {group} Lower Bound (Functional)',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup='Functional'  # Assign to functional legend group
    ))
    # Upper Bound
    fig.add_trace(go.Scatter(
        name=f'Functional Boxplot',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_100'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba('maroon', 1, .3),
        fill='tonexty',
        showlegend=True,
        legendgroup='Functional'  # Assign to functional legend group
    ))

    # Other traces (lower quartile, upper quartile, median)
    # Lower Quartile
    fig.add_trace(go.Scatter(
        name=f'Group {group} Lower Quartile (Functional)',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_25'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup='Functional'  # Assign to functional legend group
    ))
    # Upper Quartile
    fig.add_trace(go.Scatter(
        name=f'Group {group} Upper Quartile (Functional)',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_75'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba('maroon', 1, alpha=.3),
        fill='tonexty',
        showlegend=False,
        legendgroup='Functional'  # Assign to functional legend group
    ))

    # Median
    fig.add_trace(go.Scatter(
        name=f'Group {group} Median (Functional)',
        x=plt_median[plt_median['CENTROID_ID'] == group]['date'],
        y=plt_median[plt_median['CENTROID_ID'] == group]['median'],
        marker=dict(color=color_to_rgba('maroon', 1, alpha=1)),
        line=dict(width=1),
        mode='lines',
        showlegend=False,
        legendgroup='Functional'  # Assign to functional legend group
    ))

# Add traces from the traditional boxplot
for group in plt_ftq['CENTROID_ID'].unique():
    df_group = plt_ftq[plt_ftq['CENTROID_ID'] == group]

    # Full Range
#     fig.add_trace(go.Scatter(
#         name=f'Traditional',
#         x=df_group.index,
#         y=df_group['Min'],
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         mode='lines',
#         showlegend=False,
#         legendgroup='Traditional'  # Assign to traditional legend group
#     ))
#     fig.add_trace(go.Scatter(
#         name=f'Traditional Boxplot',
#         x=df_group.index,
#         y=df_group['Max'],
#         mode='lines',
#         marker=dict(color="#444"),
#         line=dict(width=0),
#         fillcolor=color_to_rgba('navy', .4, .2),
#         fill='tonexty',
#         showlegend=True,
#         legendgroup='Traditional'  # Assign to traditional legend group
#     ))

    # Middle 90%
    fig.add_trace(go.Scatter(
        name=f'Traditional Boxplot',
        x=df_group.index,
        y=df_group['Perc5'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup='Traditional'  # Assign to traditional legend group
    ))
    fig.add_trace(go.Scatter(
        name=f'Traditional Boxplot',
        x=df_group.index,
        y=df_group['Perc95'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba('navy', .6, .2),
        fill='tonexty',
        showlegend=True,
        legendgroup='Traditional'  # Assign to traditional legend group
    ))

    # Middle 50%
    fig.add_trace(go.Scatter(
        name=f'Minimum (Traditional)',
        x=df_group.index,
        y=df_group['Q1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup='Traditional'  # Assign to traditional legend group
    ))
    fig.add_trace(go.Scatter(
        name=f'Middle 50% (Traditional)',
        x=df_group.index,
        y=df_group['Q3'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba('navy', .8, .3),
        fill='tonexty',
        showlegend=False,
        legendgroup='Traditional'  # Assign to traditional legend group
    ))

    # Median
    fig.add_trace(go.Scatter(
        name=f'Median (Traditional)',
        x=df_group.index,
        y=df_group['Median'],
        marker=dict(color=color_to_rgba('navy', 1, 1)),
        line=dict(width=1),
        mode='lines',
        showlegend=False,
        legendgroup='Traditional'  # Assign to traditional legend group
    ))

    # Outlying Points
    if outlying_points:
        fig.add_trace(go.Scatter(
            name=f'Outlying Points (Traditional)',
            x=plt_outlying_points[plt_outlying_points['CENTROID_ID'] == group]['date'],
            y=plt_outlying_points[plt_outlying_points['CENTROID_ID'] == group]['total_infectious'],
            mode='markers',
            marker=dict(color=color_to_rgba('navy', .4, .1)),
            showlegend=False,
            legendgroup='Traditional'  # Assign to traditional legend group
        ))

# Update layout
fig.update_layout(
    title={
        'text': "<b>Combined Boxplots</b><br><span style='font-size: 12px;'>Functional and Traditional Boxplots.</span>",
        'x': 0.5,
        'y': 0.9,
    },
    xaxis_title="Date",
    yaxis_title="Incidence",
    xaxis=dict(gridcolor='#EFEFEF'),
    yaxis=dict(gridcolor='#EFEFEF'),
    plot_bgcolor='#FFFFFF',
    paper_bgcolor='#FFFFFF',
)

# Show the combined figure
fig.show()

## Runtime testing

In [None]:
# import time
# import numpy as np
# import pandas as pd
# from tqdm.notebook import tqdm

# # Set the number of times to test each query
# k = 10
# num_clusters = 1
# num_features = 5
# centrality_method = 'mbd'
# grouping_method = 'mse'

# # List of queries to test with meaningful names
# queries = {
#     "Create Infectious Data Table": query_base,
#     "Create Curve Distances Table": query_distances,
#     "Create Distance Matrix Table": query_distance_matrix,
#     "Assign All Runs to One Cluster" if num_clusters == 1 else "Create KMeans Model": query_assign_all_to_one_cluster if num_clusters == 1 else query_create_model,
#     "KMeans Clustering" if num_clusters > 1 else None: query_kmeans if num_clusters > 1 else None,  # Only add KMeans clustering if num_clusters > 1
#     "Save Sum Distances (ABC or MSE)" if centrality_method in ['abc', 'mse'] else "Calculate MBD": save_sum_distances if centrality_method in ['abc', 'mse'] else mbd,
#     "Identify Non-Outliers": query_non_outliers,
#     "Save Middle Curves": query_middle_curves,
#     "Save Median Curves": save_median,
#     "Get Median Curves": get_median_curves,
#     "Get Overall Curves": get_curves,
#     "Get Middle Curves": get_mid_curves,
#     "Get Outliers": get_outliers
# }

In [None]:
# # Filter out any None queries (from conditional logic above)
# queries = {name: q for name, q in queries.items() if q is not None}

# # Dictionary to store the execution times
# execution_times = {name: [] for name in queries.keys()}

# # Test each query k times and record execution times
# for query_name, query in queries.items():
# #     print(f"\nTesting '{query_name}'...")

#     # Create a progress bar for the runs of the current query
#     for j in tqdm(range(k), desc=f"{query_name} runs", leave=False):
#         start_time = time.time()
#         client.query(query).result()  # Execute the query
#         elapsed_time = time.time() - start_time
        
#         # Store the elapsed time
#         execution_times[query_name].append(elapsed_time)
# #         print(f"Run {j + 1}/{k} for '{query_name}' took {round(elapsed_time, 3)} seconds.")

# #     print(f"Completed '{query_name}'")

# # Convert execution times to a DataFrame for easier analysis
# df_times = pd.DataFrame.from_dict(execution_times, orient='index').T

# # Summary statistics
# summary = df_times.describe().T[['mean', 'min', 'max']]

In [None]:
# import numpy as np
# import pandas as pd
# import plotly.graph_objects as go
# from scipy.stats import gaussian_kde
# # Melt the DataFrame to long format for Plotly
# df_melted = df_times.reset_index().melt(id_vars='index', var_name='Query', value_name='Execution Time')
# df_melted.rename(columns={'index': 'Run'}, inplace=True)

# # Calculate total execution time per run
# total_time_per_run = df_melted.groupby('Run')['Execution Time'].sum().reset_index()

# # Create the figure
# fig = go.Figure()

# # Add a KDE plot for each query
# for query in df_melted['Query'].unique():
#     query_data = df_melted[df_melted['Query'] == query]['Execution Time']
    
#     # Compute the KDE for each query
#     kde = gaussian_kde(query_data, bw_method='scott')  # Adjust the bandwidth method
#     x_values = np.linspace(min(query_data), max(query_data), 100)
#     kde_values = kde(x_values)

#     # Add the KDE trace with grouped legend
#     fig.add_trace(go.Scatter(
#         x=x_values,
#         y=kde_values,
#         mode='lines',
#         name=query,
#         legendgroup=query,  # Use the category map
#         line=dict(width=2),  # Adjust line width
#     ))

# # Add a trace for the total execution time per run
# kde_total = gaussian_kde(total_time_per_run['Execution Time'], bw_method='scott')  # KDE on total times
# total_x_values = np.linspace(min(total_time_per_run['Execution Time']), max(total_time_per_run['Execution Time']), 100)
# total_kde_values = kde_total(total_x_values)

# # # Add the total time KDE trace
# # fig.add_trace(go.Scatter(
# #     x=total_x_values,
# #     y=total_kde_values,
# #     mode='lines',
# #     name='Total Execution Time',
# #     legendgroup='Total Time',  # Grouped under a specific legend group
# #     line=dict(width=3, color='red'),  # Adjust color and width for distinction
# # ))

# # Update layout
# fig.update_layout(
#     title='Smoothed Probability Estimates of Query Execution Times',
#     xaxis_title='Execution Time (seconds)',
#     yaxis_title='Density',
#     legend_title='Query Types',
#     template='plotly_white',
# )

# # Show the plot
# fig.show()