In [None]:
# !pip install statsmodels
# !pip install openpyxl
# !pip install forestplot
# !pip install --upgrade seaborn

In [None]:
import numpy as np
import pandas as pd
import json
from scipy.stats import pearsonr, ttest_ind
import re
import itertools as it
from functools import reduce
import statsmodels.api as sm
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 7, 'axes.linewidth': 1, 'xtick.major.width': 1, 'xtick.major.size': 5, 'ytick.major.width': 1, 'ytick.major.size': 5})
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
from matplotlib.backends.backend_pdf import PdfPages
import os


def save_pdf(save_file, fig):
    os.makedirs(os.path.dirname(save_file), exist_ok=True)
    pdf = PdfPages(save_file)
    pdf.savefig(fig, bbox_inches='tight',dpi=300)
    pdf.close()
    return

In [None]:

def create_gene_burden_table_helper(burden_df, annotations, maf, lf_samples_df):
    masked_burden_df = burden_df.loc[(burden_df.annotation.isin(annotations))&(burden_df.maf<=maf)].groupby("gene").agg({"samples": lambda x: set(",".join(x).split(","))}).reset_index()
    masked_burden_df = pd.concat([masked_burden_df, lf_samples_df])
    return masked_burden_df

def create_gene_burden_tables(burden_df, maf, lf_samples_df):
    masks = ["PTV", "PTV_Missense_strict", "PTV_Missense_lenient"]
    annot_terms = [["lof"], ["lof", "missense_strict"], ["lof", "missense_strict", "missense_lenient"]]
    gene_burden_dict = dict(zip(masks, [create_gene_burden_table_helper(burden_df, at, maf, lf_samples_df) for at in annot_terms]))
    return gene_burden_dict


def get_samples_helper(combos, genotype_df, cohort_samples):
    if len(set(combos).intersection(set(genotype_df.gene.values))) == len(combos):
        samples_per_gene = genotype_df.loc[genotype_df.gene.isin(combos)].samples.values
        samples_per_combo = reduce(lambda a,b: set(a).intersection(set(b)), samples_per_gene)
        samples_per_combo = cohort_samples.intersection(samples_per_combo)
    else:
        samples_per_combo = []
    return samples_per_combo


def get_samples(ser, gene_burden_dict, pop_samples):
    pattern = re.compile("(.+)\.(PTV.*)\.0\.001")
    m = re.match(pattern, ser.ID)
    if not m:
        print(ser.ID)
    gene = m.group(1)
    mask = m.group(2)
    gene_samples_df = gene_burden_dict[mask]
    
    combos = [gene]
    if "lf" in ser.index:
        lf = ser.lf
        combos.append(lf)
    
    samples = get_samples_helper(combos, gene_samples_df, pop_samples)
    return gene, samples

def get_bmi_pgs_info(ser, gene_burden_dict, pop_samples, pheno_df):
    gene, sample_names = get_samples(ser, gene_burden_dict, pop_samples)
    bmi = pheno_df.loc[pheno_df.sample_names.isin(sample_names), "bmi"].values
    pgs = pheno_df.loc[pheno_df.sample_names.isin(sample_names), "bmi_prs"].values
    bmi_pgs = list(zip(bmi, pgs))
    return pd.Series({"ID": ser.ID, "gene": gene, "beta": ser.beta, "bmi_pgs": bmi_pgs})
    

In [None]:
monogenic_meta_df = pd.read_excel("./monogenic_meta.xlsx")
monogenic_meta_pgs_df = pd.read_excel("./monogenic_pgs_int_meta.xlsx")
gene_burden_df = pd.read_csv("/mnt/project/notebooks/regenie/data/gene_burden.csv.gz")
pheno_df = pd.read_csv("/mnt/project/notebooks/regenie/data/pheno.csv.gz", dtype={"sample_names": str})

gene_burden_dict = create_gene_burden_tables(gene_burden_df, 0.001, pd.DataFrame())
pop_samples = set(pheno_df.sample_names.astype(str))


In [None]:
def get_bmi_categories(bmi):
    """
    under 18.5 - This is described as underweight.
    between 18.5 and 24.9 - This is described as the 'healthy range'.
    between 25 and 29.9 - This is described as overweight.
    between 30 and 39.9 - This is described as obesity.
    40 or over - This is described as severe obesity.
    """
    cat = ""
    if bmi<18.5:
        cat = "underweight"
    elif bmi<24.9:
        cat = "normal"
    elif bmi<29.9:
        cat = "overweight"
    elif bmi<39.9:
        cat = "obese"
    else:
        cat = "severely obese"
    assert cat!=""
    return cat

In [None]:
pheno_df["bmi_cat"] = pheno_df.bmi.apply(get_bmi_categories)

In [None]:
pheno_df["prs_cat"] = pd.qcut(pheno_df.bmi_prs, 5, labels=False)

# Interaction with PGS distribution plots

In [None]:
monogenic_meta_pgs_df["bmi_beta"] = monogenic_meta_pgs_df.ID.map(monogenic_meta_df.set_index("ID").beta.to_dict())

In [None]:
def get_bmi_pgs_cat(ser):
    if ser.bmi_beta>0:
        gene_type="risk"
    else:
        gene_type="protective"
    
    if ser.beta>0:
        interaction_type="risk"
    else:
        interaction_type="protective"
    
    if ser.p_value<5e-2:
        significance="significant"
    else:
        significance="nonsignificant"
    
    return "_".join([gene_type, interaction_type, significance])


In [None]:
monogenic_meta_pgs_df["category"] = monogenic_meta_pgs_df.apply(get_bmi_pgs_cat, axis=1)


In [None]:
monogenic_meta_pgs_min_pval_df = monogenic_meta_pgs_df.groupby('gene', group_keys=False).apply(lambda x: x.loc[x.p_value.idxmin()]).reset_index(drop=True)

In [None]:
monogenic_meta_pgs_sig_df = monogenic_meta_pgs_min_pval_df.loc[monogenic_meta_pgs_min_pval_df.p_value<0.05]

In [None]:
monogenic_meta_pgs_sig_df.sort_values(["bmi_beta"]).category.unique()

In [None]:
def create_distribution_plot_per_id(monogenic_meta_pgs_df, gene_ids, gene_burden_dict, pheno_df):
    
    fig, ax = plt.subplots(2,2,figsize=(3.5,3),sharex=True,sharey=True)
    
    for iax in range(2):
        for jax in range(2):
            ser = monogenic_meta_pgs_df.iloc[gene_ids[iax,jax]]

            pop_samples = set(pheno_df.sample_names.astype(str))
            gene,samples = get_samples(ser, gene_burden_dict, pop_samples)
            gene_pheno_df = pheno_df.copy()
            gene_pheno_df["gene_carrier"] = gene_pheno_df.sample_names.isin(samples)

            sns_ax = sns.boxplot(
                data=gene_pheno_df, x="prs_cat", y="bmi", hue="gene_carrier",
                hue_order=[False, True],
                palette = ["lightgrey", "royalblue"],
                dodge=True, width=0.75, linewidth=0.5, fliersize=0, capprops={'color':'none'}, 
                boxprops={ 'edgecolor':'k'},  # 'facecolor':'none', 
                whiskerprops={'color':'k'}, medianprops={'color':'k'}, ax=ax[iax][jax]
            )

            text_pos = 0.0
            for i, (res_cat) in enumerate(range(5)):
                psd = gene_pheno_df.loc[gene_pheno_df.prs_cat==res_cat]
                ttest_res = ttest_ind(psd.loc[psd.gene_carrier==True, "bmi"].dropna(), psd.loc[psd.gene_carrier==False, "bmi"].dropna(), alternative="two-sided")
                ttest_pval = ttest_res.pvalue
                if ttest_pval<0.05:
                    pval_text = f"P={round(ttest_pval, 2)}"
                    if ttest_pval<0.01:
                        pval_text = "P<0.01"
                else:
                    pval_text = "ns"

                ax[iax][jax].annotate(
                    pval_text, xy=(text_pos, 47.5), xytext=(text_pos, 49), ha="center", va="bottom", fontsize=5,
                    arrowprops=dict(arrowstyle=f'-[, widthB=0.55, lengthB=0.25', lw=0.5, color='k')
                )
                text_pos+=1

            ax[iax][jax].set_ylim(10, 50)
            ax[iax][jax].set_title(gene, fontsize=7)
            ax[iax][jax].get_legend().remove()
            ax[iax][jax].spines[["right", "top"]].set_visible(False)
            ax[iax][jax].set_xlabel("PGS quintiles")
    
    fig.tight_layout()
    return fig

    

In [None]:
gene_ids = np.array([13,3,221,225]).reshape(2,2)

In [None]:
f= create_distribution_plot_per_id(monogenic_meta_pgs_df, gene_ids, gene_burden_dict, pheno_df)

In [None]:
save_pdf("./pgs_int_example.pdf", f)

In [None]:
def create_distribution_plot_helper(df, gene_burden_dict, pop_samples):
    all_category_samples = reduce(lambda a,b: a.union(b), df.apply(lambda x: get_samples(x, gene_burden_dict, pop_samples)[1], axis=1))
    return all_category_samples

def create_factor_plot_helper(df, gene_burden_dict, pop_samples):
    all_category_samples = reduce(lambda a,b: a.union(b), df.apply(lambda x: get_samples(x, gene_burden_dict, pop_samples)[1], axis=1))
    return all_category_samples

def create_factor_plot(monogenic_meta_pgs_sig_df, gene_burden_dict, pheno_df):
    categories = ['protective_protective_significant', 'protective_risk_significant', 'risk_protective_significant', 'risk_risk_significant']
    pop_samples = set(pheno_df.sample_names.astype(str))
    category_sample_dict = dict()
    for cat in categories:
        df = monogenic_meta_pgs_sig_df.loc[monogenic_meta_pgs_sig_df.category==cat]
        category_samples = create_distribution_plot_helper(df, gene_burden_dict, pop_samples)
        category_sample_dict[cat] = category_samples

    cat_pheno_df = pheno_df.copy()
    for cat, cat_samples in category_sample_dict.items():
        cat_pheno_df[cat] = cat_pheno_df.sample_names.isin(cat_samples)
    
    cat_pheno_df["prs_cat"] = cat_pheno_df.prs_cat.map({0: "lower", 1: "lower", 2: "lower", 3: "higher", 4: "higher"})
    
    fig, ax = plt.subplots(1,4,figsize=(3.5,1),sharey=True) # sharex=True,sharey=True
    
    ms = 5
    lw=1
    palette=["navy", "darkorange"]
    sns.pointplot(
        data=cat_pheno_df, x=categories[0], y="bmi", 
        hue="prs_cat", hue_order=["lower", "higher"],
        palette=palette,
        linestyles=["-", "--"],
        errorbar=("ci", 95), dodge=False, linewidth=lw, markersize=ms, 
        markers=["o", "d"], capsize=0.1, ax=ax[0]
        )

    
    sns.pointplot(
        data=cat_pheno_df, x=categories[1], y="bmi", 
        hue="prs_cat", hue_order=["lower", "higher"],
        palette=palette,
        linestyles=["-", "--"],
        errorbar=("ci", 95), dodge=False, linewidth=lw, markersize=ms, 
        markers=["o", "d"], capsize=0.1, ax=ax[1]
        )

    
    sns.pointplot(
        data=cat_pheno_df, x=categories[2], y="bmi", 
        hue="prs_cat", hue_order=["lower", "higher"],
        palette=palette,
        linestyles=["-", "--"],
        errorbar=("ci", 95), dodge=False, linewidth=lw, markersize=ms, 
        markers=["o", "d"], capsize=0.1, ax=ax[2]
        )

    
    sns.pointplot(
        data=cat_pheno_df, x=categories[3], y="bmi", 
        hue="prs_cat", hue_order=["lower", "higher"],
        palette=palette,
        linestyles=["-", "--"],
        errorbar=("ci", 95), dodge=False, linewidth=lw, markersize=ms, 
        markers=["o", "d"], capsize=0.1, ax=ax[3]
        )

    ax[0].set_ylim(24, 32)
    for i in range(4):
        ax[i].spines[["right", "top"]].set_visible(False)
        ax[i].get_legend().remove()
        ax[i].set_xlabel("")

    
    return fig
    

In [None]:
f = create_factor_plot(monogenic_meta_pgs_sig_df, gene_burden_dict, pheno_df)

In [None]:
save_pdf("./gene_pgs_interaction_modes.pdf", f)