In [None]:
%load_ext autoreload
%autoreload 2
%aimport utils_1_0

import altair as alt
import pandas as pd

from utils_1_0 import get_visualization_subtitle
from web import for_website

# Data
### Input File Format
The input file format is identical to "Options 1: File" in UpSetR-shiny (https://github.com/hms-dbmi/UpSetR-shiny)

- Columns are `attribute 1, attribute 2, ... attribute N, set 1, set2, ..., set M` where `set` columns contain either `1` or `0`, `1` indicating the '⬤' representation in UpSet

#### Demo Data: COVID Symptom Tracker April 7 (via https://www.nature.com/articles/d41586-020-00154-w)

<img src="https://media.nature.com/lw800/magazine-assets/d41586-020-00154-w/d41586-020-00154-w_17880786.jpg" alt="demo_diagram" style="width:300px;"/>

In [None]:
df = pd.read_csv("../data/covid_symptoms_table.csv")

df

# Visualization
#### The UpSetAltair visualizations contain three main views: 

(1) **vertical bar chart** on the top showing the cardinality of each intersecting set;

(2) **matrix view** on the bottom-left showing the intersecting set;

(3) **horizontal bar chart** on the bottom-right showing the cardinality of each set.

#### Options:
1. Specify sets of interest (e.g., `["Comedy", "Action", "Adventure"]`)
2. ~~Show empty intersections or not~~ (not yet supported)
3. Sorting type: `Frequency` or `Degree`
4. Sorting order: `descending` or `ascending`

In [None]:
# Top-level altair configuration
def upsetaltair_top_level_configuration(
    base,
    legend_orient="top-left",
    legend_symbol_size=30
):
    return base.configure_view(
        stroke=None
    ).configure_title(
        fontSize=18,
        fontWeight=400,
        anchor="start"
    ).configure_axis(
        labelFontSize=14,
        labelFontWeight=300,
        titleFontSize=16,
        titleFontWeight=400,
        titlePadding=10
    ).configure_legend(
        titleFontSize=16,
        titleFontWeight=400,
        labelFontSize=14,
        labelFontWeight=300,
        padding=10,
        orient=legend_orient,
        symbolType="circle",
        symbolSize=legend_symbol_size,
    ).configure_concat(
        spacing=0
    )

In [None]:
def UpSetAltair(
    data=None,
    sets=None, # This reflects the order of sets to be shown in the plots as well
    abbre=None,
    sort_by="Frequency",
    sort_order="ascending",
    width=1200,
    height=700,
    height_ratio=0.6,
    horizontal_bar_chart_width=300
): 
    if (data is None) or (sets is None):
        print("No data and/or a list of sets are provided")
        return
    if (height_ratio < 0) or (1 < height_ratio):
        print("height_ratio set to 0.5")
        height_ratio = 0.5

    """
    Data Preprocessing
    """
    data["count"] = 0
    data = data[sets + ["count"]]
    data = data.groupby(sets).count().reset_index()
    
    data["intersection_id"] = data.index
    data["degree"] = data[sets].sum(axis=1)
    data = data.sort_values(by=["count"], ascending=True if sort_order == "ascending" else False)
    
    data = pd.melt(data, id_vars=[
        "intersection_id", "count", "degree" # wide to long
    ])
    data = data.rename(columns={"variable": "set", "value": "is_intersect"})
    
    set_to_abbre = pd.DataFrame([ [sets[i], abbre[i]] for i in range(len(sets)) ], columns=["set", "set_abbre"])
    set_to_order = pd.DataFrame([ [sets[i], 1 + sets.index(sets[i])] for i in range(len(sets)) ], columns=["set", "set_order"])

    degree_calculation = ""
    for s in sets:
        degree_calculation += f"(isDefined(datum['{s}']) ? datum['{s}'] : 0)"
        if sets[-1] != s:
            degree_calculation += "+"
    
    degree_calculation = ""
    for s in sets:
        degree_calculation += f"(isDefined(datum['{s}']) ? datum['{s}'] : 0)"
        if sets[-1] != s:
            degree_calculation += "+"
    
    """
    Styles
    """
    vertical_bar_chart_height = height * height_ratio
    matrix_height = height - vertical_bar_chart_height
    matrix_width = width - horizontal_bar_chart_width
    
    glyph_size = 200
    bar_padding = 20
    vertical_bar_size = min(30, width / len(data["intersection_id"].unique().tolist()) - bar_padding)
    horizontal_bar_size = 20
    set_label_bg_size = 1000
    line_connection_size = 2
    
    color_range = ["#55A8DB", "#3070B5", "#30363F", "#F1AD60", "#DF6234", "#BDC6CA"]
    main_color = "#3A3A3A"
    
    x_sort = alt.Sort(
        field="count" if sort_by == "Frequency" else "degree",
        order=sort_order
    )
    
    # Selections
    legend_selection = alt.selection_multi(fields=["set"], bind="legend")

    """
    Plots
    """
    # We are transforming data w/ Altair's core functions because 
    # we want to use interactive legends.
    base = alt.Chart(data).transform_filter(
        legend_selection
    ).transform_pivot(
        # Right before this operation, columns should be:
        # count, set, is_intersect, (intersection_id, degree, set_order, set_abbre)
        # where (fields with brackets) should be dropped and recalculated later.
        "set",
        op="max",
        groupby=["intersection_id", "count"],
        value="is_intersect"
    ).transform_aggregate(
        # count, set1, set2, ...
        count="sum(count)",
        groupby=sets
    ).transform_calculate(
        # count, set1, set2, ...
        degree=degree_calculation
    ).transform_filter(
        # count, set1, set2, ..., degree
        alt.datum["degree"] != 0
    ).transform_window(
        # count, set1, set2, ..., degree
        intersection_id="row_number()",
        frame=[None, None]
    ).transform_fold(
        # count, set1, set2, ..., degree, intersection_id
        sets, as_=["set", "is_intersect"]
    ).transform_lookup(
        # count, set, is_intersect, degree, intersection_id
        lookup="set",
        from_=alt.LookupData(set_to_abbre, "set", ["set_abbre"])
    ).transform_lookup(
        # count, set, is_intersect, degree, intersection_id, set_abbre
        lookup="set",
        from_=alt.LookupData(set_to_order, "set", ["set_order"])
    ).transform_filter(
        # Make sure to remove the filtered sets.
        legend_selection
    ).transform_window(
        # count, set, is_intersect, degree, intersection_id, set_abbre
        set_order="distinct(set)",
        frame=[None, 0],
        sort=[{"field": "set_order"}]
    )
    # Now, we have data in the following format:
    # count, set, is_intersect, degree, intersection_id, set_abbre

    # Cardinality by intersecting sets (vertical bar chart)
    vertical_bar = base.mark_bar(color=main_color, size=vertical_bar_size).encode(
        x=alt.X(
            "intersection_id:N", 
            axis=alt.Axis(grid=False, labels=False, ticks=False, domain=True), 
            sort=x_sort,
            title=None
        ),
        y=alt.Y(
            "max(count):Q",
            axis=alt.Axis(grid=False, tickCount=3, orient='right'),
            title="Intersection Size"
        )
    ).properties(
        width=matrix_width,
        height=vertical_bar_chart_height
    )
    vertical_bar_text = vertical_bar.mark_text(
        color=main_color, 
        dy=-10,
        size=16
    ).encode(
        text=alt.Text("count:Q", format=".0f")
    )
    vertical_bar_chart = (vertical_bar + vertical_bar_text)
    
    # UpSet glyph view (matrix view)
    circle_bg = vertical_bar.mark_circle(size=glyph_size, opacity=1).encode(
        x=alt.X(
            "intersection_id:N", 
            axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False), 
            sort=x_sort,
            title=None
        ),
        y=alt.Y(
            "set_order:N",
            axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False),
            title=None
        ),
        color=alt.value("#E6E6E6")
    ).properties(
        height=matrix_height
    )

    rect_bg = circle_bg.mark_rect().transform_filter(
        alt.datum["set_order"] % 2 == 1
    ).encode(
        color=alt.value("#F7F7F7")
    )

    circle = circle_bg.transform_filter(
        alt.datum["is_intersect"] == 1
    ).encode(
        color=alt.value(main_color)
    )

    line_connection = vertical_bar.mark_bar(size=line_connection_size, color=main_color).transform_filter(
        alt.datum["is_intersect"] == 1
    ).encode(
        y=alt.Y("min(set_order):N"),
        y2=alt.Y2("max(set_order):N")
    )

    matrix_view = (rect_bg + circle_bg + line_connection  + circle)

    # Cardinality by sets (horizontal bar chart)
    horizontal_bar_label_bg = base.mark_circle(size=set_label_bg_size).encode(
        y=alt.Y(
            "set_order:N",
            axis=alt.Axis(grid=False, labels=False, ticks=False, domain=False),
            title=None,
        ),
        color=alt.Color(
            "set:N",
            scale=alt.Scale(domain=sets, range=color_range),
            title=None
        ),
        opacity=alt.value(1)
    )
    horizontal_bar_label = horizontal_bar_label_bg.mark_text().encode(
        text=alt.Text("set_abbre:N"),
        color=alt.value("white")

    )

    horizontal_bar = horizontal_bar_label_bg.mark_bar(
        size=horizontal_bar_size
    ).transform_filter(
        alt.datum["is_intersect"] == 1
    ).encode(
        x=alt.X(
            "sum(count):Q",
            axis=alt.Axis(grid=False, tickCount=3),
            title="Set Size"
        )
    ).properties(
        width=horizontal_bar_chart_width
    )

    # Concat Plots
    upsetaltair = alt.vconcat(
        vertical_bar_chart,
        alt.hconcat(
            matrix_view, 
            (horizontal_bar_label_bg + horizontal_bar_label), horizontal_bar, # horizontal bar chart
            spacing=5
        ).resolve_scale(
            y="shared"
        ),
        spacing=20
    ).add_selection(
        legend_selection
    )
    
    # Apply top-level configuration
    upsetaltair = upsetaltair_top_level_configuration(
            upsetaltair, 
            legend_orient="top",
            legend_symbol_size=set_label_bg_size / 2.0
        ).properties(
            title="Symptoms reported by users of the COVID Symptom Tracker app"
        )
    
    return upsetaltair

## Examples w/ Different Options

In [None]:
UpSetAltair(
    data=df.copy(), 
    sets=["Shortness of Breath", "Diarrhea", "Fever", "Cough", "Anosmia", "Fatigue"],
    abbre=["B", "D", "Fe", "C", "A", "Fa"],
    sort_by="Frequency",
    sort_order="ascending"
)

In [None]:
UpSetAltair(
    data=df.copy(), 
    sets=["Shortness of Breath", "Diarrhea", "Fever", "Cough", "Anosmia", "Fatigue"],
    abbre=["B", "D", "Fe", "C", "A", "Fa"],
    sort_by="Degree",
    sort_order="ascending"
)

In [None]:
UpSetAltair(
    data=df.copy(), 
    sets=["Shortness of Breath", "Diarrhea", "Fever", "Cough", "Anosmia", "Fatigue"],
    abbre=["B", "D", "Fe", "C", "A", "Fa"],
    sort_by="Frequency",
    sort_order="ascending",
    width=1200,
    height=500,
    height_ratio=0.6,
    horizontal_bar_chart_width=300
)