# Investigating GSEA results

Before this notebook is run, the bash script `03_run_all_gsea.sh` needs to have been run. You can verify this by checking for a folder named `gsea_res/` inside of the `intermediate_files` folder.

# TL;DR

This notebook looked for a fancy way to filter the GSEA results, but didn't find one. So in the end we just took the top n pathways from both the positive and negative ends of the list of sum of (not absolute value) NES across all days for each pathway. The filtered and unfiltered heatmaps are in the `output_charts` directory.

## Check distribution of normalized enrichment scores (NES)

This will help us decide which pathways to show in our heatmaps.

In [1]:
import altair as alt
import glob
import pandas as pd

alt.data_transformers.disable_max_rows()
gsea_res_dir = "intermediate_files/gsea_res"
output_dir = "output_charts"

In [2]:
def get_set_name(set_abbr):    
    set_abbrs = {
        "cp": "c2.cp.v2023.1.Hs.symbols.gmt",
        "go": "c5.go.v2023.1.Hs.symbols.gmt",
        "h": "h.all.v2023.1.Hs.symbols.gmt",
    }
    return set_abbrs[set_abbr]

    
def nes_dist(set_abbr, alpha):

    # Get the GMT name from the abbreviation
    set_name = get_set_name(set_abbr) 
    
    # Load GSEA output
    paths = glob.glob(f"{gsea_res_dir}/{set_name}/*")
    reports = []
    for path in paths:
        reports.extend(glob.glob(path + "/gsea_report_for_na_*.tsv"))

    pcharts = []
    nescharts = []

    repdays = []
    repposneg = []

    for report in reports:
        repdays.append(int(report.split("_")[3]))
        repposneg.append(report.split("_")[-2])

    repdf = pd.DataFrame({
        "report": reports,
        "day": repdays,
        "posneg": repposneg,
    }).\
    sort_values(by=["posneg", "day"])

    ness = []
    for report in repdf.report.values:

        day = int(report.split("_")[3])
        posneg = report.split("_")[-2]

        df = pd.read_csv(report, sep="\t")
        df_sel = df[df["FDR q-val"] <= alpha]
        if df_sel.NES.dtype != float:
            ness.extend(df_sel.NES[df_sel.NES.str.contains(r"[0-9]")].astype(float).values)
        else:
            ness.extend(df_sel.NES.values)

    neschart = alt.Chart(pd.DataFrame({"nes": ness})).mark_bar().encode(
        x=alt.X(
            "nes",
            bin=alt.Bin(step=0.05),
        ),
        y="count()",
    ).properties(
        width=1000,
        height=300,
    )

    return neschart

In [3]:
nes_dist("cp", alpha=1)

In [4]:
nes_dist("cp", alpha=0.05)

In [5]:
nes_dist("go", alpha=1)

In [6]:
nes_dist("go", alpha=0.05)

In [7]:
nes_dist("h", alpha=1)

In [8]:
nes_dist("h", alpha=0.05)

### NES distribution conclusion

It looks like filtering by p-value basically just drops reports with NES close to zero---in other words, it looks like we'd get similar results if we just filtered by absolute NES. Furthermore, there doesn't appear to be any consistent breaks in the distribution. So should we just arbitrarily pick some cutoff that gives us a manageable number of pathways on the heatmap? Hmm.

### Check distribution of the sum and mean of normalized enrichment scores (NES) across multiple days

Maybe this will be more informative.

In [9]:
def nes_sum_mean_dist(set_abbr, abs, alpha):

    # Get the GMT name from the abbreviation
    set_name = get_set_name(set_abbr) 
    
    paths = glob.glob(f"{gsea_res_dir}/{set_name}/*")
    reports = []
    for path in paths:
        reports.extend(glob.glob(path + "/gsea_report_for_na_*.tsv"))

    pcharts = []
    nescharts = []

    repdays = []
    repposneg = []

    for report in reports:
        repdays.append(int(report.split("_")[3]))
        repposneg.append(report.split("_")[-2])

    repdf = pd.DataFrame({
        "report": reports,
        "day": repdays,
        "posneg": repposneg,
    }).\
    sort_values(by=["posneg", "day"])

    names = []
    ness = []
    days = []
    for report in repdf.report.values:

        day = int(report.split("_")[3])
        posneg = report.split("_")[-2]

        df = pd.read_csv(report, sep="\t")
        df_sel = df[df["FDR q-val"] <= alpha]
        if df_sel.NES.dtype != float:
            row_mask = df_sel.NES.str.contains(r"[0-9]")

            names.extend(df_sel.NAME[row_mask].values)
            ness.extend(df_sel.NES[row_mask].astype(float).values)
            days.extend([day] * row_mask.sum())
        else:
            names.extend(df_sel.NAME.values)
            ness.extend(df_sel.NES.values)
            days.extend([day] * df_sel.shape[0])
            
    nesdf = pd.DataFrame({
        "name": names,
        "nes": ness,
        "day": days,
    })

    if abs:
        nesdf = nesdf.assign(nes=nesdf.nes.abs())
        
    nes_sums = nesdf.groupby(["name"])["nes"].sum().rename("sum")
    nes_means = nesdf.groupby(["name"])["nes"].mean().rename("mean")

    df = pd.concat([nes_sums, nes_means], axis=1)

    chart = alt.hconcat(*[
        alt.Chart(df).mark_bar().encode(
            x=alt.X(
                col,
                bin=alt.Bin(step=1 if col == "sum" else 0.05),
            ),
            y="count()",
        ).properties(
            title=col,
            width=700,
            height=300,
        )
        for col in df.columns])

    return chart

### Curated pathways:

In [10]:
nes_sum_mean_dist("cp", abs=False, alpha=0.05)

In [11]:
nes_sum_mean_dist("cp", abs=True, alpha=0.05)

### GO:

In [12]:
nes_sum_mean_dist("go", abs=False, alpha=0.05)

In [13]:
nes_sum_mean_dist("go", abs=True, alpha=0.05)

### Hallmarks:

In [14]:
nes_sum_mean_dist("h", abs=False, alpha=0.05)

In [15]:
nes_sum_mean_dist("h", abs=True, alpha=0.05)

## Conclusion after looking at distributions of summary statistics (sum and mean) of NES grouped by day

Looks like there might be something interesting if we filter by a sum for each day of the absolute value of NES > 9, so let's try that. 

## Make filtered time series heatmaps of selected pathways

Not sure the best way to filter. Use the `daily_sum_abs_nes_cutoff` parameter to filter by the sum of the absolute value of NES across all days for each pathway; use `top_n_nes_sum_posneg` to just take the top n pathways from both the positive and negative ends of the list of sum of (not absolute value) NES across all days for each pathway. 

In the end, we just decided to take the top n from the top and bottom.

In [16]:
def make_heatmap(set_abbr, alpha, daily_sum_abs_nes_cutoff=False, top_n_nes_sum_posneg=False):

    # Get the GMT name from the abbreviation
    set_name = get_set_name(set_abbr) 
    
    paths = glob.glob(f"{gsea_res_dir}/{set_name}/*")
    reports = []
    for path in paths:
        reports.extend(glob.glob(path + "/gsea_report_for_na_*.tsv"))

    pathways_pos = pd.DataFrame(columns=["pathway"])
    pathways_neg = pd.DataFrame(columns=["pathway"])

    for report in reports:
        day = int(report.split("_")[3])
        posneg = report.split("_")[-2]

        df = pd.read_csv(report, sep="\t")
        df_sel = df[df["FDR q-val"] <= alpha]

        bin = pd.DataFrame({
            "pathway": df_sel.NAME.values,
            day: df_sel.NES.astype(float).values,
        })

        if posneg == "pos":
            pathways_pos = pathways_pos.merge(
                bin,
                on="pathway",
                how="outer",
            )
        elif posneg == "neg":
            pathways_neg = pathways_neg.merge(
                bin,
                on="pathway",
                how="outer",
            )
        else:
            raise ValueError("posneg parsed incorrectly")

    pathways_pos = pathways_pos.set_index("pathway").sort_index(axis="index").sort_index(axis="columns")
    pathways_neg = pathways_neg.set_index("pathway").sort_index(axis="index").sort_index(axis="columns")

    pathways = pathways_pos.\
    T.\
    join(
        pathways_neg.T,
        how="outer",
        lsuffix="_POS",
        rsuffix="_NEG",
    )

    dups = pathways.columns[pathways.columns.str.match(r".*_((NEG)|(POS))$")].to_series().str.rsplit("_", n=1, expand=True)
    if dups.size:    
        dups = dups[0].unique()
        for dup in dups:
            dup_pos = dup + "_POS"
            dup_neg = dup + "_NEG"
            assert not pathways.loc[:, dup_pos].eq(pathways.loc[:, dup_neg].abs()).any() # Make sure nothing was marked up and down at same time
            pathways = pathways.assign(**{dup: pathways.loc[:, dup_pos].fillna(0) + pathways.loc[:, dup_neg].fillna(0)})
            pathways = pathways.drop(columns=[dup_pos, dup_neg])

    pathways = pathways.\
    fillna(0).\
    T.\
    sort_index()

    sort = pathways.assign(sort=pathways.sum(axis="columns")).\
    sort_values(by=["sort", "pathway"], ascending=[False, True]).\
    index.\
    tolist()

    pathways = pathways.\
    reset_index(drop=False).\
    melt(
        id_vars="pathway",
        var_name="day",
        value_name="NES",
    )

    if daily_sum_abs_nes_cutoff:
        # Only keep pathways where the sum across all days of the absolute value of the NES is greater than our cutoff
        filter_df = pathways.assign(NES=pathways.NES.abs()).groupby("pathway")["NES"].sum().to_frame().reset_index(drop=False)
        keep_pathways = filter_df[filter_df.NES > daily_sum_abs_nes_cutoff].pathway
        pathways = pathways[pathways.pathway.isin(keep_pathways.values)]

    if top_n_nes_sum_posneg:
        # Keep the top n pathways with the greatest positive and negative NES sums
        filter_df = pathways.\
        groupby("pathway")["NES"].\
        sum().\
        to_frame().\
        reset_index(drop=False).\
        sort_values(by="NES")

        keep_pathways = pd.concat([filter_df.pathway.iloc[0:top_n_nes_sum_posneg], filter_df.pathway.iloc[-top_n_nes_sum_posneg:]], axis=0, ignore_index=True)
        pathways = pathways[pathways.pathway.isin(keep_pathways.values)]

    nes_abs_max = pathways.NES.abs().max()

    search_input = alt.param(
        value="",
        bind=alt.binding(
            input="search",
            placeholder="pathway name (regex ok)",
            name="Search: ",
        )
    )

    heatmap = alt.Chart(pathways).mark_rect().encode(
        x="day:O",
        y=alt.Y(
            "pathway:O",
            sort=sort,
            axis=alt.Axis(
                title=None,
            ),
        ),
        color=alt.Color(
            "NES:Q",
            scale=alt.Scale(
                scheme="redblue",
                reverse=True,
                domain=[-nes_abs_max, nes_abs_max]
            )
        ),
        opacity=alt.condition(
            alt.expr.test(alt.expr.regexp(search_input, "i"), alt.datum.pathway),
            alt.value(1),
            alt.value(0.2),
        ),
    ).add_params(
        search_input,
    ).configure_axis(
        labelLimit=400,
    )

    return heatmap

In [17]:
make_heatmap("cp", alpha=0.05, top_n_nes_sum_posneg=20).save(f"{output_dir}/heatmap_cp_filtered.html")
make_heatmap("cp", alpha=0.05).save(f"{output_dir}/heatmap_cp_unfiltered.html")

make_heatmap("go", alpha=0.05, top_n_nes_sum_posneg=20).save(f"{output_dir}/heatmap_go_filtered.html")
make_heatmap("go", alpha=0.05).save(f"{output_dir}/heatmap_go_unfiltered.html")

make_heatmap("h", alpha=0.05, top_n_nes_sum_posneg=20).save(f"{output_dir}/heatmap_h_filtered.html")
make_heatmap("h", alpha=0.05).save(f"{output_dir}/heatmap_h_unfiltered.html")