In [1]:
import pandas as pd
import altair as alt
import numpy as np

In [4]:
all_data = {}

experiment = "ood_selection"

datasets = ["dataset1"]
splits = ["id", "ood"]

for dataset in datasets:
    all_data[dataset] = {}
    for split in ["id", "ood"]:
        data = pd.read_csv(f"./chem-selectivenet/logs/train/{dataset}/{experiment}/test_{split}.csv")[["coverage", "rejection rate", "accuracy", "precision", "recall"]]
        data["dataset"] = dataset
        data["split"] = split
        all_data[dataset][split] = data

combined_data = pd.concat([pd.concat([all_data[dataset][split] for split in splits], axis=0) for dataset in datasets], axis=0)

In [5]:
combined_data.to_csv("saved_results/2023.10.17_drugood_selectnet_results.csv")

In [6]:
dataset = "dataset1"
id_data = all_data[dataset]["id"]
ood_data = all_data[dataset]["ood"]

In [11]:
coverage_filters = [ 0.4, 0.6, 0.8, 0.9, 1.0]

show_tooltip = True
interactive = False

ood_selection = alt.selection_point(fields=['Coverage (OOD)'], on="click")
id_selection = alt.selection_point(fields=['Coverage (ID)'], on="click")

prec_chart_ood = alt.Chart(ood_data[ood_data["coverage"].isin(coverage_filters)].rename(columns={"coverage": "Coverage (OOD)"})).mark_line(point=True).encode(
        x=alt.X("rejection rate", scale=alt.Scale(domain=[0.0,1.0])), 
        y=alt.Y("precision", scale=alt.Scale(domain=[0.0,1.0])), 
        color=alt.Color("Coverage (OOD)", type="ordinal", scale={"scheme": "reds"}),
        opacity=alt.condition(ood_selection, alt.value(1), alt.value(0.1)),
        tooltip=(["Coverage (OOD)", "rejection rate", "precision"] if show_tooltip else [])
        ).add_selection(ood_selection)

prec_chart_id = alt.Chart(id_data[id_data["coverage"].isin(coverage_filters)].rename(columns={"coverage": "Coverage (ID)"})).mark_line(point=True).encode(
        x=alt.X("rejection rate", scale=alt.Scale(domain=[0.0,1.0])), 
        y=alt.Y("precision", scale=alt.Scale(domain=[0.0,1.0])), 
        color=alt.Color("Coverage (ID)", type="ordinal", scale={"scheme": "blues"}),
        opacity=alt.condition(id_selection, alt.value(1), alt.value(0.1)),
        tooltip=(["Coverage (ID)", "rejection rate", "precision"] if show_tooltip else [])
        ).add_selection(id_selection)

acc_chart_ood = alt.Chart(ood_data[ood_data["coverage"].isin(coverage_filters)].rename(columns={"coverage": "Coverage (OOD)"})).mark_line(point=True).encode(
        x=alt.X("rejection rate", scale=alt.Scale(domain=[0.0,1.0])), 
        y=alt.Y("accuracy", scale=alt.Scale(domain=[0.0,1.0])), 
        color=alt.Color("Coverage (OOD)", type="ordinal", scale={"scheme": "reds"}),
        opacity=alt.condition(ood_selection, alt.value(1), alt.value(0.1)),
        tooltip=(["Coverage (OOD)", "rejection rate", "accuracy"] if show_tooltip else [])
        ).add_selection(ood_selection)

acc_chart_id = alt.Chart(id_data[id_data["coverage"].isin(coverage_filters)].rename(columns={"coverage": "Coverage (ID)"})).mark_line(point=True).encode(
        x=alt.X("rejection rate", scale=alt.Scale(domain=[0.0,1.0])), 
        y=alt.Y("accuracy", scale=alt.Scale(domain=[0.0,1.0])), 
        color=alt.Color("Coverage (ID)", type="ordinal", scale={"scheme": "blues"}),
        opacity=alt.condition(id_selection, alt.value(1), alt.value(0.1)),
        tooltip=(["Coverage (ID)", "rejection rate", "accuracy"] if show_tooltip else [])
        ).add_selection(id_selection)

prec_chart = alt.layer(
    prec_chart_ood,
    prec_chart_id,
).resolve_scale(color="independent").properties(width=400, height=400) #.interactive()
acc_chart = alt.layer(
    acc_chart_ood,
    acc_chart_id,
).resolve_scale(color="independent").properties(width=400, height=400) #.interactive()

if interactive:
    prec_chart = prec_chart.interactive()
    acc_chart = acc_chart.interactive()

prec_chart | acc_chart



In [5]:
coverage_filters = [0.85, 1.0, 0.5, 0.7]

prec_chart_id = alt.Chart(id_data[id_data["coverage"].isin(coverage_filters)]).mark_line(point=True).encode(
    x=alt.X("rejection rate", scale=alt.Scale(domain=[0.0,1.0])), 
    y=alt.Y("precision", scale=alt.Scale(domain=[0.0,1.0])), 
    color=alt.Color("coverage", type="ordinal")).properties(width=400, height=400).interactive()

acc_chart_id = alt.Chart(id_data[id_data["coverage"].isin(coverage_filters)]).mark_line(point=True).encode(
    x=alt.X("rejection rate", scale=alt.Scale(domain=[0.0,1.0])), 
    y=alt.Y("accuracy", scale=alt.Scale(domain=[0.0,1.0])), 
    color=alt.Color("coverage", type="ordinal")).properties(width=400, height=400).interactive()

prec_chart_id | acc_chart_id