In [None]:
# !pip install openpyxl

In [None]:
import numpy as np
import pandas as pd
import json
from scipy.stats import pearsonr
import re
from functools import reduce
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rcParams.update({'font.size': 7, 'axes.linewidth': 1, 'xtick.major.width': 1, 'xtick.major.size': 5, 'ytick.major.width': 1, 'ytick.major.size': 5})
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
from matplotlib.backends.backend_pdf import PdfPages
import os

In [None]:

def create_gene_burden_table_helper(burden_df, annotations, maf, lf_samples_df):
    masked_burden_df = burden_df.loc[(burden_df.annotation.isin(annotations))&(burden_df.maf<=maf)].groupby("gene").agg({"samples": lambda x: set(",".join(x).split(","))}).reset_index()
    masked_burden_df = pd.concat([masked_burden_df, lf_samples_df])
    return masked_burden_df

def create_gene_burden_tables(burden_df, maf, lf_samples_df):
    masks = ["PTV", "PTV_Missense_strict", "PTV_Missense_lenient"]
    annot_terms = [["lof"], ["lof", "missense_strict"], ["lof", "missense_strict", "missense_lenient"]]
    gene_burden_dict = dict(zip(masks, [create_gene_burden_table_helper(burden_df, at, maf, lf_samples_df) for at in annot_terms]))
    return gene_burden_dict


def get_samples_helper(combos, genotype_df, cohort_samples):
    if len(set(combos).intersection(set(genotype_df.gene.values))) == len(combos):
        samples_per_gene = genotype_df.loc[genotype_df.gene.isin(combos)].samples.values
        samples_per_combo = reduce(lambda a,b: set(a).intersection(set(b)), samples_per_gene)
        samples_per_combo = cohort_samples.intersection(samples_per_combo)
    else:
        samples_per_combo = []
    return samples_per_combo


def get_samples(ser, gene_burden_dict, pop_samples):
    pattern = re.compile("(.+)\.(PTV.*)\.0\.001")
    m = re.match(pattern, ser.ID)
    if not m:
        print(ser.ID)
    gene = m.group(1)
    mask = m.group(2)
    gene_samples_df = gene_burden_dict[mask]
    
    combos = [gene]
    if "lf" in ser.index:
        lf = ser.lf
        combos.append(lf)
    
    samples = get_samples_helper(combos, gene_samples_df, pop_samples)
    return gene, samples

def get_bmi_pgs_info(ser, gene_burden_dict, pop_samples, pheno_df):
    gene, sample_names = get_samples(ser, gene_burden_dict, pop_samples)
    bmi = pheno_df.loc[pheno_df.sample_names.isin(sample_names), "bmi"].values
    pgs = pheno_df.loc[pheno_df.sample_names.isin(sample_names), "bmi_prs"].values
    bmi_pgs = list(zip(bmi, pgs))
    return pd.Series({"ID": ser.ID, "gene": gene, "beta": ser.beta, "bmi_pgs": bmi_pgs})
    

In [None]:
monogenic_meta_df = pd.read_excel("./monogenic_meta.xlsx")
gene_burden_df = pd.read_csv("/mnt/project/notebooks/regenie/data/gene_burden.csv.gz")
pheno_df = pd.read_csv("/mnt/project/notebooks/regenie/data/pheno.csv.gz", dtype={"sample_names": str})

In [None]:
gene_burden_dict = create_gene_burden_tables(gene_burden_df, 0.001, pd.DataFrame())
pop_samples = set(pheno_df.sample_names.astype(str))


In [None]:
bmi_pgs_df = monogenic_meta_df.apply(get_bmi_pgs_info, axis=1, args=(gene_burden_dict, pop_samples, pheno_df))
bmi_pgs_df = bmi_pgs_df.explode("bmi_pgs").reset_index(drop=True).drop_duplicates(["gene", "beta", "bmi_pgs"])
bmi_pgs_df[['bmi', 'pgs']] = pd.DataFrame(bmi_pgs_df['bmi_pgs'].tolist(), index=bmi_pgs_df.index)

In [None]:
plot_df = bmi_pgs_df.loc[~bmi_pgs_df.beta.between(-1.5,1.5)]

In [None]:
risk_gene_order = monogenic_meta_df.loc[monogenic_meta_df.beta>1.5].sort_values("beta", ascending=False).ID.str.split(".", expand=True).iloc[:, 0].unique()

protective_gene_order = monogenic_meta_df.loc[monogenic_meta_df.beta<-1.5].sort_values("beta", ascending=True).ID.str.split(".", expand=True).iloc[:, 0].unique()

In [None]:
plot_df["color"] = plot_df.beta.apply(lambda x: "lightgreen" if x<0 else "indianred")

In [None]:
fig, ax  = plt.subplots(1,1,figsize=(3.5,2))

cs = ["skyblue", "gold", "darkorange", "red"]
ax.axhspan(12, 25, facecolor=cs[0], alpha=0.2)
ax.axhspan(25, 30, facecolor=cs[1], alpha=0.2)
ax.axhspan(30, 40, facecolor=cs[2], alpha=0.2)
ax.axhspan(40, 60, facecolor=cs[3], alpha=0.2)

sns.boxplot(
    plot_df, x="gene", y="bmi", palette=dict(zip(plot_df.gene, plot_df.color)),
    order=np.concatenate((risk_gene_order, protective_gene_order)),
    fliersize=1, linewidth=0.75, ax=ax)





ax.set_ylim(12, 60)
ax.set_xticklabels(ax.get_xticklabels(), rotation=90, fontstyle="italic", fontsize=7);
ax.set_ylabel("BMI")
ax.set_xlabel("")
ax.spines[["right", "top"]].set_visible(False)

In [None]:
fig.tight_layout()

In [None]:
def save_pdf(save_file, fig):
    os.makedirs(os.path.dirname(save_file), exist_ok=True)
    pdf = PdfPages(save_file)
    pdf.savefig(fig, bbox_inches='tight',dpi=300)
    pdf.close()
    return

In [None]:
save_pdf("./bmi_dist_ukb_v2.pdf", fig)

In [None]:
fig.savefig("./bmi_dist_ukb.png")

In [None]:
fig, ax  = plt.subplots(1,1,figsize=(20,4))
sns.boxplot(plot_df, x="gene", y="pgs", palette=dict(zip(plot_df.gene, plot_df.color)), ax=ax)

ax.set_xticklabels(ax.get_xticklabels(), rotation=90);
