In [9]:
import plotly.express as px
import plotly.graph_objects as go

## Data

In [10]:
# Define your data as a dictionary
data_exp = {
    "setting": ["FV0_BS0_SEP0", "FV0_BS0_SEP1", "FV0_BS1_SEP0", "FV0_BS1_SEP1", "FV1_BS0_SEP0", "FV1_BS0_SEP1", "FV1_BS1_SEP0", "FV1_BS1_SEP1"],
    "roc_auc_mean": [0.92232, 0.93690, 0.94124, 0.94496, 0.91700, 0.92321, 0.92467, 0.94721],
    "roc_auc_std": [0.01331, 0.00482, 0.00358, 0.00275, 0.01902, 0.02542, 0.01182, 0.00480],
    "pr_auc_mean": [0.54481, 0.56390, 0.57731, 0.59218, 0.54345, 0.57156, 0.58224, 0.62291],
    "pr_auc_std": [0.01991, 0.01602, 0.01480, 0.01698, 0.01427, 0.01813, 0.00634, 0.01843],
}

model_exp = {
    "setting": ["CE0T1", "CE1T0", "CE1T1", "CE0T0"],
    "roc_auc_mean": [0.91701, 0.97576, 0.97666, 0.94721],
    "roc_auc_std": [0.00596, 0.00109, 0.00147, 0.00480],
    "pr_auc_mean": [0.51997, 0.78891, 0.79207, 0.62291],
    "pr_auc_std": [0.00933, 0.00761, 0.00972, 0.01843]
}

training_exp  = {
    "setting": ["10MRR", "10", "15", "18MRR", "18", "20MRR", "20", "25MRR", "25", "30MRR", "30", "15MRR"],
    "roc_auc_mean": [0.97775, 0.97017, 0.97272, 0.97513, 0.97425, 0.97768, 0.97867, 0.97651, 0.97314, 0.97513, 0.97576, 0.97666],
    "roc_auc_std": [0.00191, 0.00252, 0.00333, 0.00101, 0.00161, 0.00139, 0.00182, 0.00153, 0.00311, 0.00270, 0.00280, 0.00147],
    "pr_auc_mean": [0.79378, 0.77004, 0.77254, 0.78494, 0.77882, 0.80206, 0.80108, 0.78837, 0.78018, 0.78903, 0.78529, 0.79207],
    "pr_auc_std": [0.00807, 0.01266, 0.01103, 0.00834, 0.01153, 0.01119, 0.01065, 0.00897, 0.01235, 0.01277, 0.01531, 0.00972]
}

# Chart orders
data_chart_order = [
    "FV0_BS0_SEP1", 
    "FV0_BS0_SEP0", 
    "FV1_BS0_SEP1", 
    "FV1_BS0_SEP0", 
    "FV0_BS1_SEP1", 
    "FV0_BS1_SEP0", 
    "FV1_BS1_SEP1", 
    "FV1_BS1_SEP0"
]
data_setting_order = [data_exp["setting"].index(c) for c in data_chart_order]
data_rocs = [data_exp["roc_auc_mean"][i] for i in data_setting_order]
data_prs = [data_exp["pr_auc_mean"][i] for i in data_setting_order]
data_diff_roc = [r - data_rocs[0] for r in data_rocs]
data_diff_pr = [r - data_prs[0] for r in data_prs]

model_chart_order = [
    "CE0T0",
    "CE0T1",
    "CE1T0",
    "CE1T1",
]
model_setting_order = [model_exp["setting"].index(c) for c in model_chart_order]
model_rocs = [model_exp["roc_auc_mean"][i] for i in model_setting_order]
model_prs = [model_exp["pr_auc_mean"][i] for i in model_setting_order]
model_diff_roc = [r - model_rocs[0] for r in model_rocs]
model_diff_pr = [r - model_prs[0] for r in model_prs]

training_order = ["10MRR", "15MRR", "20MRR", "25MRR", "30MRR", "10", "15", "20", "25", "30"]
training_setting_order = [training_exp["setting"].index(c) for c in training_order]
training_rocs = [training_exp["roc_auc_mean"][i] for i in training_setting_order]
training_prs = [training_exp["pr_auc_mean"][i] for i in training_setting_order]
training_diff_roc = [r - training_rocs[1] for r in training_rocs]
training_diff_pr = [r - training_prs[1] for r in training_prs]




## Functions

In [15]:
def tree_diagram(combinations, y_offset, x0=None):
    if x0 is None:
        x0 = 3 - len(combinations) - 1
    y0 = 0
    size = 2**(len(combinations)-2) // 2
    tree = go.Figure()

    for i, name in enumerate(combinations[:-1]):
        temp_y = y0 + y_offset
        for _ in range(2**i):
            tree.add_trace(go.Scatter(
                x=[x0+1, x0, x0, x0, x0+1],
                y=[temp_y-size, temp_y-size, temp_y, temp_y+size, temp_y+size],
            ))
            if _ % 2 == 1 or i == 0:
                tree.add_annotation(x=x0, y=temp_y, text=combinations[i])
            temp_y -= size*4
        x0 += 1
        y0 += size
        size /= 2

    y0 -= size*4
    x0 -= 0.2
    for i in range(2**(len(combinations)-1) // 2):
        tree.add_annotation(x=x0, y=y0+y_offset, text=combinations[-1])
        y0 -= size*8
        
    tree.update_annotations(
        showarrow=False,
        bgcolor="white",
        font=dict(size=15)
    )
    tree.update_traces(
        line=dict(color="black"),
        mode="lines",
    )
    return tree

def baseline_line(fig, base, ys, star_y=None):
    fig.add_trace(
        go.Scatter(
            x = [base]*2,
            y = [ys[0]+0.5, ys[-1]-0.5],
            xaxis = "x2",
            mode="lines",
            line = dict(
                color="#1E1E1E",
                dash="dot",
            ),
        )
    )
    fig.add_trace(
        go.Scatter(
            x=[base],
            y=[star_y],
            xaxis="x2",
            mode="markers",
            marker=dict(
                color="#1E1E1E",
                size=16,
                symbol="diamond",
            ),
        )
    )
    fig.add_annotation(x=base-0.0013, y=star_y, text="baseline", xref="x2", showarrow=False, font=dict(size=15), xanchor="right", bgcolor="white")

def diverging_bar(fig, diffs, ys, base, colors = ["rgb(220, 38, 127)", "rgb(100, 143, 255)"], label_best=True):
    if isinstance(base, list):
        marker_colors = [colors[b > base[0]] for b in base]
    else:
        marker_colors = [colors[r > 0] for r in diffs]
    fig.add_trace(
        go.Bar(
            x=diffs, y=ys,
            xaxis="x2",
            orientation="h",
            base=base,
            marker=dict(
                color=marker_colors,
                line=dict(
                    width=0,
                ),
            ),
        )
    )

    if label_best:
        idx = diffs.index(max(diffs))
        fig.add_annotation(
            x=base + diffs[idx] + 0.001, y=ys[idx], xref="x2", xanchor="left", bgcolor="white", showarrow=False,
            text=f"{max(diffs):.3f} gain", font=dict(size=14, family="Arial Black")
        )

def custom_tree(combinations, y_offset):
    tree = go.Figure()
    x0 = -1
    size = 2.5
    temp_y = 0 + y_offset
    tree.add_trace(go.Scatter(
        x=[x0+1, x0, x0, x0, x0+1],
        y=[temp_y-size, temp_y-size, temp_y, temp_y+size, temp_y+size],
    ))
    tree.add_annotation(x=x0, y=temp_y, text=combinations[0])

    tree.add_annotation(x=x0+1, y=temp_y-size, text=combinations[1])

    x0 += 1
    size = 0.5
    fixed_y = temp_y
    n = len(combinations[-1])

    for sign in [-1, 1]:
        temp_y = fixed_y + sign*(size*n)
        temp_y += size*(n + 1) / 2
        
        for i in range(n-1):
            tree.add_trace(go.Scatter(
                x=[x0+1, x0, x0, x0+1],
                y=[temp_y-size, temp_y-size, temp_y+size, temp_y+size],
            ))
            tree.add_annotation(x=x0+0.9, y=temp_y+size, text=combinations[-1][i])
            temp_y -= size*2
        tree.add_annotation(x=x0+0.9, y=temp_y+size, text=combinations[-1][-1])
        



    tree.update_annotations(showarrow=False, bgcolor="white", font=dict(size=15))
    tree.update_traces(line=dict(color="black"), mode="lines")

    return tree


## Plot

In [16]:
fig = go.Figure()
layout = go.Layout(
    xaxis=dict(
        domain=[0, 0.2],
        visible=False,
    ),
    xaxis2=dict(
        domain=[0.2, 1], 
        range=(0.9099, 0.99011),
        gridcolor="silver",
        griddash="solid",
        showline=True,
        linecolor="black",
        linewidth=3,
        tickfont=dict(size=15),
        title=dict(
            text="AUROC",
            font=dict(
                size=25,
            ),
        ),
    ),
    yaxis=dict(
        visible=False,
        range=(-20, 5),

    ),
    height=800,
    width=1100,
    plot_bgcolor="white",
    showlegend=False,
    margin=dict(l=50, r=50, b=50, t=50),
)

tree = tree_diagram(["BEHRT", "+BS", "+FV", "-SEP"], 0.5)
tree2 = tree_diagram(["+Data", "+CE", "+T"], -6.5)
tree3 = custom_tree(["+Embedding", "+FM", [10, 15, 20, 25, 30]], y_offset=-14.5)

ys = list(range(4, -4, -1))
base = data_rocs[0]
# diverging_bar(fig, data_diff_pr, ys, base, label_best=False, colors=["rgb(195, 63, 128)", "rgb(100, 143, 255)"])
diverging_bar(fig, data_diff_roc, ys, base)
baseline_line(fig, base, ys, star_y=4)

ys = list(range(-5, -9, -1))
base = model_rocs[0]
# diverging_bar(fig, model_diff_pr, ys, base, label_best=False, colors=["rgb(195, 63, 128)", "rgb(100, 143, 255)"])
diverging_bar(fig, model_diff_roc, ys, base)
baseline_line(fig, base, ys, star_y=-5)

ys = list(range(-10, -20, -1))
base = training_rocs[1]
# diverging_bar(fig, training_diff_pr, ys, base, label_best=False, colors=["rgb(195, 63, 128)", "rgb(100, 143, 255)"])
diverging_bar(fig, training_diff_roc, ys, base)
baseline_line(fig, base, ys, star_y=-11)

for t in tree.data + tree2.data + tree3.data:
    fig.add_trace(t)
for t in tree.layout.annotations + tree2.layout.annotations + tree3.layout.annotations:
    fig.add_annotation(t)


fig.update_layout(layout)

# Add false y-axis at 0.93
fig.add_shape(go.layout.Shape(type="line",xref="x2",x0=0.91,x1=0.91,y0=5,y1=-20, line=dict(color="black", width=3)))


fig.add_annotation(x=-3.3, y=0.5, text="Data", textangle=-90, showarrow=False, font=dict(size=25))
fig.add_annotation(x=-3.3, y=-6.5, text="Embedding", textangle=-90, showarrow=False, font=dict(size=25))
fig.add_annotation(x=-3.3, y=-15, text="Masking", textangle=-90, showarrow=False, font=dict(size=25))

fig.show()
