# Everything is on the cloud now and I feel numb
TODO: Replace CTE notation with someone mildly useable that makes sense for the project you're working on. 

TODO: Send politely worded email to Guillaume, get and configure for PGF data

TODO: Indexes?

In [1]:
from google.cloud import bigquery
from google.oauth2 import service_account
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import time
import matplotlib.colors as mcolors

In [2]:
# helper function for plotting
def color_to_rgba(color_name, value, alpha):
    # Get the RGB components of the color
    rgb = mcolors.to_rgb(color_name)
    
    # Calculate the gray value based on the input value
    gray = (1 - value) * 0.8  # This adjusts how gray the color will be

    # Compute the final RGBA values
    r = rgb[0] * value + gray
    g = rgb[1] * value + gray
    b = rgb[2] * value + gray
    
    # Return the RGBA string
#     print(f'rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {alpha})')
    return f'rgba({int(r * 255)}, {int(g * 255)}, {int(b * 255)}, {alpha})'

In [3]:
# the location of our credentials json and name of bigquery project
credentials = service_account.Credentials.from_service_account_file('C:\\Users\\elija\\Documents\\24f-coop\\credentials.json')
project = 'net-data-viz-handbook'

# Initialize a GCS client
client = bigquery.Client(credentials=credentials, project=project)

In [31]:
# this is the query I'm working with for testing porpoises
query_pivot = """
    SELECT 
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM `net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily`
    WHERE country_id IN (218)
      AND run_id BETWEEN 1 AND 100
    GROUP BY date, country_id, run_id
    ORDER BY run_id, date
"""
df = client.query(query_pivot).result().to_dataframe()  # Execute the query to create the table
print("Data pulled successfully.")


BigQuery Storage module not found, fetch data with the REST endpoint instead.



Data pulled successfully.


## KMeans and a prayer (for creating grouped curve-based statistics)
- Uses the first num_features columns to simplify the KMeans algorithms, which is O(n<sup>2</sup>)?
    - Might just not need PCA or similar
- Procedural once distance matrix is computed (eg have to query main table manually)

In [32]:
# Step 1: Create the MSE matrix table
query_mse_matrix = """
CREATE OR REPLACE TABLE sri_data.mse_matrix AS
WITH infectious_data AS (
    SELECT
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM
        net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily
    WHERE
        country_id IN (218) AND run_id BETWEEN 1 AND 100
    GROUP BY
        date,
        country_id,
        run_id
)
SELECT
    a.run_id AS run_id_a,
    b.run_id AS run_id_b,
    AVG(POW(a.total_infectious - b.total_infectious, 2)) AS mse
FROM
    infectious_data a
JOIN
    infectious_data b
ON
    a.date = b.date
GROUP BY
    run_id_a, run_id_b
"""
client.query(query_mse_matrix).result()  # Execute the query to create the table
print("MSE matrix table created successfully.")

# Step 2: Create the ABC matrix table
query_abc_matrix = """
CREATE OR REPLACE TABLE sri_data.abc_matrix AS
WITH infectious_data AS (
    SELECT
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM
        net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily
    WHERE
        country_id IN (218) AND run_id BETWEEN 1 AND 100
    GROUP BY
        date,
        country_id,
        run_id
)
SELECT
    a.run_id AS run_id_a,
    b.run_id AS run_id_b,
    SUM(ABS(a.total_infectious - b.total_infectious)) AS abc
FROM
    infectious_data a
JOIN
    infectious_data b
ON
    a.date = b.date
GROUP BY
    run_id_a, run_id_b
"""
client.query(query_abc_matrix).result()  # Execute the query to create the table
print("ABC matrix table created successfully.")

MSE matrix table created successfully.
ABC matrix table created successfully.


In [33]:
# Step 3: Create the distance matrix
method = 'mse' # 'mse' or 'abc'

query_distance_matrix = f"""
CREATE OR REPLACE TABLE sri_data.distance_matrix AS
SELECT
    run_id_a AS run_id,
        ARRAY_AGG(STRUCT(run_id_b, {method}) ORDER BY run_id_b ASC) AS distances
FROM
    sri_data.{method}_matrix
GROUP BY
    run_id_a;
"""
client.query(query_distance_matrix).result()  # Execute the query to create the table
print("Distance matrix table created successfully.")

Distance matrix table created successfully.


In [53]:
# Set parameters for the K-means model
num_clusters = 2  # Example number of clusters
num_features = 15   # Set the number of features to select

In [54]:
# Step 4: Create the K-means model by selecting the first num_features features based on actual distances
query_create_model = f"""
CREATE OR REPLACE MODEL sri_data.kmeans_model
OPTIONS(model_type='kmeans', num_clusters={num_clusters}) AS
SELECT
    run_id,
    ARRAY(
        SELECT distance.{method} 
        FROM UNNEST(distances) AS distance 
        WHERE distance.run_id_b <= {num_features}  -- Select only the first num_features
    ) AS features
FROM
    sri_data.distance_matrix;
"""
s = time.time()
client.query(query_create_model).result()  # Execute the model creation
print(f"K-means model created successfully in {round(time.time() - s, 3)} seconds.")

# Step 5: Apply K-means clustering and save results in a table
query_kmeans = f"""
CREATE OR REPLACE TABLE sri_data.kmeans_results AS
SELECT
    *
FROM
    ML.PREDICT(MODEL sri_data.kmeans_model,
        (SELECT
            run_id,
            ARRAY(
                SELECT distance.{method} 
                FROM UNNEST(distances) AS distance 
                WHERE distance.run_id_b <= {num_features}  
            ) AS features
         FROM
            sri_data.distance_matrix)
    ) AS predictions
"""
s = time.time()
client.query(query_kmeans).result()  # Execute the model creation
print(f"K-means clustering results saved successfully in {round(time.time() - s, 3)} seconds.")

K-means model created successfully in 29.224 seconds.
K-means clustering results saved successfully in 2.873 seconds.


In [55]:
# Step 6: Sum distances within each cluster, second variation
# 
percentile = 1 # not implemented currently, will let you select only most central subset of data

query_sum_distances = f"""
CREATE OR REPLACE TABLE `sri_data.total_distances_table` AS
WITH a AS (
    SELECT 
        kr.CENTROID_ID,
        kr.run_id, 
        run_id_b, 
        {method}
    FROM 
        `sri_data.kmeans_results` AS kr
    JOIN 
        `sri_data.distance_matrix` AS dm
    ON 
        kr.run_id = dm.run_id
    CROSS JOIN 
        UNNEST(dm.distances) AS dm_dist  -- Unnest the distances array here 
),
b AS (
    SELECT
        run_id AS run_id_b, 
        CENTROID_ID AS CENTROID_ID_B
    FROM
        `sri_data.kmeans_results`
)
SELECT 
    a.run_id,
    a.CENTROID_ID,
    AVG({method}) as total_distance
FROM 
    a
JOIN 
    b
ON 
    a.run_id_b = b.run_id_b
WHERE
    a.CENTROID_ID = b.CENTROID_ID_B
GROUP BY
    a.CENTROID_ID,
    a.run_id;

"""
total_distances = client.query(query_sum_distances).to_dataframe()  # Execute and fetch results
print("Total distances summed within each cluster successfully.")

Total distances summed within each cluster successfully.


In [56]:
# revised step 6
save_sum_distances = f"""CREATE OR REPLACE TABLE `sri_data.total_distances_table` AS
    WITH a AS (
        SELECT 
            kr.CENTROID_ID,
            kr.run_id, 
            run_id_b, 
            {method}  
        FROM 
            `sri_data.kmeans_results` AS kr
        JOIN 
            `sri_data.distance_matrix` AS dm
        ON 
            kr.run_id = dm.run_id
        CROSS JOIN 
            UNNEST(dm.distances) AS dm_dist  -- Unnest the distances array here 
    ),
    b AS (
        SELECT
            run_id AS run_id_b, 
            CENTROID_ID AS CENTROID_ID_B
        FROM
            `sri_data.kmeans_results`
    )
    SELECT 
        a.run_id,
        a.CENTROID_ID,
        AVG({method}) AS total_distance  
    FROM 
        a
    JOIN 
        b
    ON 
        a.run_id_b = b.run_id_b
    WHERE
        a.CENTROID_ID = b.CENTROID_ID_B
    GROUP BY
        a.CENTROID_ID,
        a.run_id;
    """
client.query(save_sum_distances).result()  # Execute the model creation
print(f"Distance sum results saved successfully.")

Distance sum results saved successfully.


In [57]:
# Step 7
query_non_outliers = """
CREATE OR REPLACE TABLE `sri_data.non_outliers_table` AS
WITH grouped_data AS (
  SELECT 
    CENTROID_ID,
    total_distance,
    ROW_NUMBER() OVER (PARTITION BY CENTROID_ID ORDER BY total_distance) AS rn,
    COUNT(*) OVER (PARTITION BY CENTROID_ID) AS total_count
  FROM `sri_data.total_distances_table`
),
iqr_calculation AS (
  SELECT 
    CENTROID_ID,
    MAX(CASE WHEN rn = CAST(total_count * 0.25 AS INT64) THEN total_distance END) AS lower_quartile,
    MAX(CASE WHEN rn = CAST(total_count * 0.75 AS INT64) THEN total_distance END) AS upper_quartile
  FROM grouped_data
  GROUP BY CENTROID_ID
),
outlier_bounds AS (
  SELECT 
    CENTROID_ID,
    lower_quartile - (1.5 * (upper_quartile - lower_quartile)) AS lower_bound,
    upper_quartile + (1.5 * (upper_quartile - lower_quartile)) AS upper_bound
  FROM iqr_calculation
),
outliers AS (
  SELECT 
    d.CENTROID_ID,
    d.run_id
  FROM `sri_data.total_distances_table` d
  JOIN outlier_bounds b
    ON d.CENTROID_ID = b.CENTROID_ID
  WHERE d.total_distance < b.lower_bound OR d.total_distance > b.upper_bound
),
non_outliers AS (
  SELECT 
    d.CENTROID_ID,
    d.run_id
  FROM `sri_data.total_distances_table` d
  LEFT JOIN outliers o
    ON d.run_id = o.run_id
  WHERE o.run_id IS NULL  -- Exclude outliers
)

SELECT 
    * 
FROM non_outliers;  -- Select all from the non_outliers CTE
    """
client.query(query_non_outliers).result()  # Execute the model creation
print(f"Non-outliers saved successfully.")

Non-outliers saved successfully.


In [58]:
query_middle_curves = """
CREATE OR REPLACE TABLE `sri_data.middle_curves` AS
WITH grouped_data AS (
  SELECT 
    CENTROID_ID,
    run_id,
    total_distance,
    ROW_NUMBER() OVER (PARTITION BY CENTROID_ID ORDER BY total_distance) AS rn,
    COUNT(*) OVER (PARTITION BY CENTROID_ID) AS total_count
  FROM `sri_data.total_distances_table`
),
top_half AS (
  SELECT 
    CENTROID_ID,
    run_id,
    total_distance,
    rn
  FROM grouped_data
  WHERE rn <= CAST(total_count * 0.5 AS INT64)  -- Select the top 50% based on total_distance
)

SELECT 
    CENTROID_ID,
    run_id,
FROM top_half;  -- Select the required fields for the middle_curves table

"""
client.query(query_middle_curves).result()  # Execute the model creation
print(f"Middle curves saved successfully.")

Middle curves saved successfully.


In [59]:
save_median = """
CREATE OR REPLACE TABLE `sri_data.median_curves` AS
WITH ranked_data AS (
  SELECT 
    CENTROID_ID,
    run_id,
    total_distance,
    ROW_NUMBER() OVER (PARTITION BY CENTROID_ID ORDER BY total_distance) AS rn
  FROM `sri_data.total_distances_table`
)

SELECT 
    CENTROID_ID,
    run_id,
FROM ranked_data
WHERE rn = 1;  -- Select the run_id with the lowest total_distance for each CENTROID_ID
"""

client.query(save_median).result()  # Execute the model creation
print(f"Median curves saved successfully.")

Median curves saved successfully.


In [60]:
# Step 8
get_median_curves = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    WITH c AS (
    SELECT 
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM `net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily`
    WHERE country_id IN (218)
      AND run_id BETWEEN 1 AND 100
    GROUP BY date, country_id, run_id
    )
    
    SELECT
        c.date,
        mc.CENTROID_ID,
        MAX(c.total_infectious) as median
    FROM
        c
    JOIN
        `sri_data.median_curves` as mc
    ON
        c.run_id = mc.run_id
    GROUP BY
        date, 
        CENTROID_ID
    ORDER BY
        CENTROID_ID, 
        date;
    """
plt_median = client.query(get_median_curves).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")


BigQuery Storage module not found, fetch data with the REST endpoint instead.



Curves extracted successfully.


In [61]:
# Step 8
get_curves = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    WITH c AS (
    SELECT 
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM `net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily`
    WHERE country_id IN (218)
      AND run_id BETWEEN 1 AND 100
    GROUP BY date, country_id, run_id
    )
    
    SELECT
        c.date,
        nout.CENTROID_ID,
        MAX(c.total_infectious) as curve_100,
        MIN(c.total_infectious) as curve_0
    FROM
        c
    JOIN
        `sri_data.non_outliers_table` as nout
    ON
        c.run_id = nout.run_id
    GROUP BY
        date, 
        CENTROID_ID
    ORDER BY
        CENTROID_ID, 
        date;
    """
plt_curves = client.query(get_curves).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")

Curves extracted successfully.



BigQuery Storage module not found, fetch data with the REST endpoint instead.



In [62]:
# Step 8
get_mid_curves = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    WITH c AS (
    SELECT 
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM `net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily`
    WHERE country_id IN (218)
      AND run_id BETWEEN 1 AND 100
    GROUP BY date, country_id, run_id
    )
    
    SELECT
        c.date,
        mc.CENTROID_ID,
        MAX(c.total_infectious) as curve_75,
        MIN(c.total_infectious) as curve_25
    FROM
        c
    JOIN
        `sri_data.middle_curves` as mc
    ON
        c.run_id = mc.run_id
    GROUP BY
        date, 
        CENTROID_ID
    ORDER BY
        CENTROID_ID, 
        date;
    """
plt_middle = client.query(get_mid_curves).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")

Curves extracted successfully.



BigQuery Storage module not found, fetch data with the REST endpoint instead.



In [63]:
get_outliers = """-- Step 3: Calculate min and max values at each time step using the non-outliers table
    WITH c AS (
    SELECT 
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM `net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily`
    WHERE country_id IN (218)
      AND run_id BETWEEN 1 AND 100
    GROUP BY date, country_id, run_id
    ),
    outliers AS (
        SELECT 
            tdt.run_id, 
            tdt.CENTROID_ID
        FROM 
            `sri_data.total_distances_table` as tdt
        WHERE 
            run_id NOT IN (SELECT run_id FROM `sri_data.non_outliers_table`)
    )
    
    
    SELECT
        c.date,
        outliers.CENTROID_ID,
        outliers.run_id,
        c.total_infectious
    FROM
        outliers
    JOIN
        c
    ON
        c.run_id = outliers.run_id
    ORDER BY
        run_id, 
        date;
    """
plt_outliers = client.query(get_outliers).to_dataframe()  # Execute and fetch results
print("Curves extracted successfully.")


BigQuery Storage module not found, fetch data with the REST endpoint instead.



Curves extracted successfully.


## Visualizing

In [64]:
import plotly.express as px
import plotly.graph_objects as go
# here are all of our outliers
px.line(plt_outliers, x='date', y='total_infectious', color='CENTROID_ID', line_group='run_id')

In [65]:
# merging plt_curves and middle into one dataframe because I can
merged_curves = pd.merge(plt_curves, plt_middle, on=['date', 'CENTROID_ID'], how='inner')
merged_curves

Unnamed: 0,date,CENTROID_ID,curve_100,curve_0,curve_75,curve_25
0,2009-02-17,1,0,0,0,0
1,2009-02-18,1,0,0,0,0
2,2009-02-19,1,0,0,0,0
3,2009-02-20,1,0,0,0,0
4,2009-02-21,1,0,0,0,0
...,...,...,...,...,...,...
727,2010-02-13,2,0,0,0,0
728,2010-02-14,2,0,0,0,0
729,2010-02-15,2,0,0,0,0
730,2010-02-16,2,0,0,0,0


In [66]:
# create and lay out graph
fig = go.Figure()
fig.update_layout(
    title={
        'text': f"<b>Functional Boxplot</b><br><span style='font-size: 12px;'>As seen in \
<a href=https://www.tandfonline.com/doi/pdf/10.1198/jcgs.2011.09224?casa_token=ID3IjHflKz4AAAAA:4i-zhPbXhDzg\
8pDuowEPWoNiUFzHFcADAHHsqPonc6ac4dIzuQ40g5VA_n4BlUU7v1JsW7OD7Hf2>Sun & Genton (2011)</a>.  \
Clusters: {num_clusters}, features: {num_features}, method: {method}.</span>",
        'x': 0.5,  
        'y': 0.9,  
    },
    xaxis_title="Date",
    yaxis_title="Incidence",
    
    # colors
    xaxis=dict(gridcolor='#EFEFEF'),  # Change x-axis grid color
    yaxis=dict(gridcolor='#EFEFEF'),   # Change y-axis grid color
    plot_bgcolor='#FFFFFF',
    paper_bgcolor='#FFFFFF',
)
# fig.update_xaxes(range=[pd.Timestamp("2009-09-01"), pd.Timestamp("2010-02-17")])
# fig.update_yaxes(range=[0, 35000])
colors = 'maroon', 'navy', 'green', 'purple'

# plot outliers
for run in plt_outliers['run_id'].unique():
        df_run = plt_outliers[plt_outliers['run_id'] == run]
        gr = df_run.iloc[0, 1] # careful not to change table formatting
    
        fig.add_trace(go.Scatter(
        name=f'Group {gr} Outlier',
        x=df_run['date'],
        y=df_run['total_infectious'],
        marker=dict(color=color_to_rgba(colors[gr-1], .5, alpha=.3)),
        line=dict(width=1, dash='solid'),
        mode='lines',
        showlegend=False,
        legendgroup=str(gr)  # Assign to legend group
    ))

for group in plt_median['CENTROID_ID'].unique():
    print(group)
    # actually graph
    # Lower
    fig.add_trace(go.Scatter(
        name=f'Group {group} Lower Bound',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_0'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    # Upper
    fig.add_trace(go.Scatter(
        name=f'Group {group} Upper Bound',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_100'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], 1, .3),
        fill='tonexty',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
        
    # Lower
    fig.add_trace(go.Scatter(
        name=f'Group {group} Lower Quartile',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_25'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    # Upper
    fig.add_trace(go.Scatter(
        name=f'Group {group}',
        x=merged_curves[merged_curves['CENTROID_ID'] == group]['date'],
        y=merged_curves[merged_curves['CENTROID_ID'] == group]['curve_75'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], 1, alpha=.3),
        fill='tonexty',
        showlegend=True,
        legendgroup=str(group)  # Assign to legend group
    ))
    
    
    fig.add_trace(go.Scatter(
        name=f'Group {group} Median',
        x=plt_median[plt_median['CENTROID_ID'] == group]['date'],
        y=plt_median[plt_median['CENTROID_ID'] == group]['median'],
        marker=dict(color=color_to_rgba(colors[group-1], 1, alpha=1)),
        line=dict(width=1),
        mode='lines',
        showlegend=False,
        legendgroup=str(group)  # Assign to legend group
    ))
    
fig.show()

1
2


## Fixed time quantiles

In [67]:
fixed_time_quantiles = """
WITH c AS (
    SELECT 
        date,
        country_id,
        run_id,
        SUM(Infectious_13_17 + Infectious_18_23) AS total_infectious
    FROM `net-data-viz-handbook.sri_data.SIR_0_countries_incidence_daily`
    WHERE country_id IN (218)
      AND run_id BETWEEN 1 AND 100
    GROUP BY date, country_id, run_id
    ORDER BY run_id, date
),

daily_data AS (
    SELECT 
        date, 
        total_infectious,
        run_id,
        ROW_NUMBER() OVER (PARTITION BY date ORDER BY total_infectious) AS row_num,
        COUNT(*) OVER (PARTITION BY date) AS total_rows
    FROM c  
),

-- Joining with kmeans_results to attach CENTROID_ID
centroid_data AS (
    SELECT 
        d.date,
        d.total_infectious,
        d.run_id,
        k.CENTROID_ID
    FROM daily_data d
    JOIN `sri_data.kmeans_results` k ON d.run_id = k.run_id -- Adjust the join condition if necessary
)

SELECT 
    CENTROID_ID,
    date,
    PERCENTILE_CONT(total_infectious, 0) OVER (PARTITION BY CENTROID_ID, date) AS Min,
    PERCENTILE_CONT(total_infectious, 0.05) OVER (PARTITION BY CENTROID_ID, date) AS Perc5,
    PERCENTILE_CONT(total_infectious, 0.25) OVER (PARTITION BY CENTROID_ID, date) AS Q1,
    PERCENTILE_CONT(total_infectious, 0.50) OVER (PARTITION BY CENTROID_ID, date) AS Median,
    PERCENTILE_CONT(total_infectious, 0.75) OVER (PARTITION BY CENTROID_ID, date) AS Q3,
    PERCENTILE_CONT(total_infectious, 0.95) OVER (PARTITION BY CENTROID_ID, date) AS Perc95,
    PERCENTILE_CONT(total_infectious, 1) OVER (PARTITION BY CENTROID_ID, date) AS Max
FROM centroid_data
GROUP BY CENTROID_ID, date, total_infectious
ORDER BY CENTROID_ID, date;


"""
plt_ftq = client.query(fixed_time_quantiles).result().to_dataframe()  # Execute the query to create the table
print("Data pulled successfully.")


BigQuery Storage module not found, fetch data with the REST endpoint instead.



Data pulled successfully.


In [68]:
plt_ftq

Unnamed: 0,CENTROID_ID,date,Min,Perc5,Q1,Median,Q3,Perc95,Max
0,1,2009-02-17,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,1,2009-02-18,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,1,2009-02-19,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,1,2009-02-20,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,1,2009-02-21,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...
19940,2,2010-02-13,0.0,0.0,0.0,0.0,0.0,0.0,0.0
19941,2,2010-02-14,0.0,0.0,0.0,0.0,0.0,0.0,0.0
19942,2,2010-02-15,0.0,0.0,0.0,0.0,0.0,0.0,0.0
19943,2,2010-02-16,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [69]:
fig = go.Figure()
fig.update_layout(
    title={
        'text': f"<b>Traditional Boxplot</b><br><span style='font-size: 12px;'>Uses fixed-time quantiles.</span>",
        'x': 0.5,  
        'y': 0.9,  
    },
    xaxis_title="Date",
    yaxis_title="Incidence",
    
    # colors
    xaxis=dict(gridcolor='#EFEFEF'),  # Change x-axis grid color
    yaxis=dict(gridcolor='#EFEFEF'),   # Change y-axis grid color
    plot_bgcolor='#FFFFFF',
    paper_bgcolor='#FFFFFF',
)
# fig.update_xaxes(range=[pd.Timestamp("2009-09-01"), pd.Timestamp("2010-02-17")])
# fig.update_yaxes(range=[0, 35000])

try:
    plt_ftq.set_index('date', inplace=True)
except Exception:
    pass

for group in plt_ftq['CENTROID_ID'].unique():
    df_group = plt_ftq[plt_ftq['CENTROID_ID'] == group]

    # FULL RANGE
    fig.add_trace(go.Scatter(
        name=f'Minimum',
        x=df_group.index,
        y=df_group['Min'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup='Full Range'  # Assign to legend group
    ))
    fig.add_trace(go.Scatter(
        name=f'Full Range',
        x=df_group.index,
        y=df_group['Max'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], .4, .1),
        fill='tonexty',
        showlegend=True,
        legendgroup='Full Range'  # Assign to legend group
    ))
    
    
    # MIDDLE 90%
    fig.add_trace(go.Scatter(
        name=f'Minimum',
        x=df_group.index,
        y=df_group['Perc5'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup='Middle 90%'  # Assign to legend group
    ))
    fig.add_trace(go.Scatter(
        name=f'Middle 90%',
        x=df_group.index,
        y=df_group['Perc95'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], .6, .2),
        fill='tonexty',
        showlegend=True,
        legendgroup='Middle 90%'  # Assign to legend group
    ))
    
    # MIDDLE 50%
    fig.add_trace(go.Scatter(
        name=f'Minimum',
        x=df_group.index,
        y=df_group['Q1'],
        marker=dict(color="#444"),
        line=dict(width=0),
        mode='lines',
        showlegend=False,
        legendgroup='Middle 50%'  # Assign to legend group
    ))
    fig.add_trace(go.Scatter(
        name=f'Middle 50%',
        x=df_group.index,
        y=df_group['Q3'],
        mode='lines',
        marker=dict(color="#444"),
        line=dict(width=0),
        fillcolor=color_to_rgba(colors[group-1], .8, .3),
        fill='tonexty',
        showlegend=True,
        legendgroup='Middle 50%'  # Assign to legend group
    ))
    
    
    fig.add_trace(go.Scatter(
        name=f'Median',
        x=df_group.index,
        y=df_group['Median'],
        marker=dict(color=color_to_rgba(colors[group-1], 1, 1)),
        line=dict(width=1),
        mode='lines',
        showlegend=True,
        legendgroup='Median'
    ))
    
fig.show()

## KMeans runtime testing

In [None]:
# import time
# import plotly.express as px
# import pandas as pd
# from tqdm.notebook import tqdm

# # Sample variables
# ls_num_clusters = [1, 10, 50]
# ls_num_features = [1, 10, 100]
# i = 5  # Number of iterations

# # Dictionary to store runtimes
# runtime_data = {}

# # Loop over the number of iterations
# for num_clusters in ls_num_clusters:
#     for num_features in ls_num_features:
#         for _ in tqdm(range(i), desc=f'Clusters: {num_clusters} | Features: {num_features}', leave=False):
#             start_time = time.time()
#             # Simulate the query execution (replace this with your actual query)
#             client.query(query_create_model).result()  # Replace with actual query
#             end_time = time.time()
            
#             runtime = end_time - start_time
            
#             # Store runtime in the dictionary
#             key = (num_clusters, num_features)
#             if key in runtime_data:
#                 runtime_data[key].append(runtime)
#             else:
#                 runtime_data[key] = [runtime]

# # Calculate average runtimes
# average_runtimes = {
#     'num_clusters': [],
#     'num_features': [],
#     'average_runtime': []
# }

# for (num_clusters, num_features), runtimes in runtime_data.items():
#     average_runtimes['num_clusters'].append(num_clusters)
#     average_runtimes['num_features'].append(num_features)
#     average_runtimes['average_runtime'].append(sum(runtimes) / len(runtimes))

In [None]:
# px.imshow(pd.DataFrame(average_runtimes).pivot(index='num_clusters', columns='num_features', values='average_runtime'))