In [None]:
import pandas as pd
import altair as alt
import httpimport

_ = alt.data_transformers.disable_max_rows()


In [None]:
# Import custom altair theme from remote github using httpimport module
def import_theme_new():
    with httpimport.github_repo("bblarsen-sci", "altair_themes", "main"):
        import main_theme

        @alt.theme.register("custom_theme", enable=True)
        def custom_theme():
            return main_theme.main_theme()


import_theme_new()


In [None]:
mean_effect_df = pd.read_csv(
    "../../results/filtered_data/cell_entry/Nipah_F_func_effects_filtered_mean.csv"
)

pre_df = (
    pd.read_csv("../../results/residue_accessibility/5evm_accessibility.csv")
    .drop(columns=["accessibility_A", "accessibility_B", "accessibility_C"])
    .rename(columns={"mean_accessibility": "accessibility_pre"})
)
post_df = (
    pd.read_csv("../../results/residue_accessibility/NiV_F_postfusion_accessibility.csv")
    .drop(columns=["accessibility_A", "accessibility_B", "accessibility_C"])
    .rename(columns={"mean_accessibility": "accessibility_post"})
)

In [None]:
merged_df = pd.merge(pre_df, post_df, on='site', how='inner')


merged_acc_effect_df = pd.merge(merged_df, mean_effect_df, on='site', how='inner')
display(merged_acc_effect_df)

In [None]:

tmp_df = merged_acc_effect_df.query("accessibility_post > 20 and accessibility_pre < 10 and site != 332 & site != 313 and site != 349 and site != 364 and site != 71 and site != 259").sort_values('effect')
print(tmp_df.sort_values('effect').head(10)['site'].tolist())



In [None]:
melted_df = merged_acc_effect_df.query('accessibility_pre < 5 and accessibility_post > 30').melt(
    id_vars=["site", "effect"],
    value_vars=["accessibility_pre", "accessibility_post"],
    var_name="state",
    value_name="accessibility",
)
display(melted_df.sort_values('effect').head(20))
print(list(melted_df['site'].unique()))

line_chart = (
    alt.Chart(melted_df)
    .mark_line(point=True)
    .encode(
        x=alt.X("state", title=None, sort=["accessibility_pre", "accessibility_post"]),
        y=alt.Y("accessibility:Q", title=None),
        color=alt.Color(
            "effect:Q", legend=None, scale=alt.Scale(scheme="redblue", domainMid=0, domain=[-3.5,1])
        ),
        tooltip=["site", "state", "accessibility", "effect"],
    )
    .properties(
        width=alt.Step(100),
        height=200,
    )
)

display(line_chart)


In [None]:
display(merged_acc_effect_df.head(5))

In [None]:
melted_df = merged_acc_effect_df.melt(
    id_vars=["site", "effect"],
    value_vars=["accessibility_pre", "accessibility_post"],
    var_name="state",
    value_name="accessibility",
)
display(melted_df)
# find matched paris
valid_sites = (
    melted_df[melted_df["state"] == "accessibility_pre"]["accessibility"] < 20
)
valid_site_ids = melted_df.loc[valid_sites[valid_sites].index, "site"]
print(list(valid_site_ids))
filtered_df_test = melted_df[melted_df["site"].isin(valid_site_ids)]


In [None]:
melted_df = merged_acc_effect_df.melt(
    id_vars=["site", "effect"],
    value_vars=["accessibility_pre", "accessibility_post"],
    var_name="state",
    value_name="accessibility",
)

filtered_df = melted_df[
    ((melted_df["state"] == "accessibility_pre") & (melted_df["accessibility"] > 30)) |
    ((melted_df["state"] == "accessibility_post") & (melted_df["accessibility"] > 30))
]

display(filtered_df)

boxplot_chart = (
    alt.Chart(filtered_df)
    .mark_boxplot(extent="min-max", opacity=1)
    .encode(
        x=alt.X("state", title=None, sort=["accessibility_pre", "accessibility_post"]),
        y=alt.Y("effect:Q", title=None),
        # color=alt.Color("effect:Q", legend=None),
    )
    .properties(
        width=alt.Step(25),
        height=200,
    )
)

display(boxplot_chart)
boxplot_chart.save('../../niv_pre_post_effects.svg')
