In [None]:
import re
import numpy as np
import pandas as pd
from scipy import stats
from functools import reduce

In [None]:
monogenic_meta_df = pd.read_excel("./monogenic_meta.xlsx")
gene_burden_df = pd.read_csv("/mnt/project/notebooks/regenie/data/gene_burden.csv.gz")
pheno_df = pd.read_csv("/mnt/project/notebooks/regenie/data/pheno.csv.gz", dtype={"sample_names": str})


In [None]:
def create_gene_burden_table_helper(burden_df, annotations, maf, lf_samples_df):
    masked_burden_df = burden_df.loc[(burden_df.annotation.isin(annotations))&(burden_df.maf<=maf)].groupby("gene").agg({"samples": lambda x: set(",".join(x).split(","))}).reset_index()
    masked_burden_df = pd.concat([masked_burden_df, lf_samples_df])
    return masked_burden_df

def create_gene_burden_tables(burden_df, maf, lf_samples_df):
    masks = ["PTV", "PTV_Missense_strict", "PTV_Missense_lenient"]
    annot_terms = [["lof"], ["lof", "missense_strict"], ["lof", "missense_strict", "missense_lenient"]]
    gene_burden_dict = dict(zip(masks, [create_gene_burden_table_helper(burden_df, at, maf, lf_samples_df) for at in annot_terms]))
    return gene_burden_dict


def get_samples_helper(combos, genotype_df, cohort_samples):
    if len(set(combos).intersection(set(genotype_df.gene.values))) == len(combos):
        samples_per_gene = genotype_df.loc[genotype_df.gene.isin(combos)].samples.values
        samples_per_combo = reduce(lambda a,b: set(a).intersection(set(b)), samples_per_gene)
        samples_per_combo = cohort_samples.intersection(samples_per_combo)
    else:
        samples_per_combo = []
    return samples_per_combo


def get_samples(ser, gene_burden_dict, pop_samples):
    pattern = re.compile("(.+)\.(PTV.*)\.0\.001")
    m = re.match(pattern, ser.ID)
    if not m:
        print(ser.ID)
    gene = m.group(1)
    mask = m.group(2)
    gene_samples_df = gene_burden_dict[mask]
    
    combos = [gene]
    if "lf" in ser.index:
        lf = ser.lf
        combos.append(lf)
    
    samples = get_samples_helper(combos, gene_samples_df, pop_samples)
    return samples

def get_bmi_categories(bmi):
    """
    under 18.5 - This is described as underweight.
    between 18.5 and 24.9 - This is described as the 'healthy range'.
    between 25 and 29.9 - This is described as overweight.
    between 30 and 39.9 - This is described as obesity.
    40 or over - This is described as severe obesity.
    """
    cat = ""
    if bmi<18.5:
        cat = "underweight"
    elif bmi<24.9:
        cat = "normal"
    elif bmi<29.9:
        cat = "overweight"
    elif bmi<39.9:
        cat = "obese"
    else:
        cat = "severely obese"
    assert cat!=""
    return cat


def get_table_cat(gene_samples, nongene_samples, cat_samples, field):
    table = [
        [len(gene_samples.intersection(cat_samples)), len(gene_samples.difference(cat_samples))],
        [len(nongene_samples.intersection(cat_samples)), len(nongene_samples.difference(cat_samples))]
    ]
    df = pd.DataFrame(table, columns=[f"{field}", f"No {field}"], index=["Combo", "Non Combo"])
    return df


def get_bmi_cat_data_helper(gene_samples, pop_samples, pheno_df, category, category_values):
    cat_samples = set(pheno_df.loc[pheno_df["bmi_category"].isin(category_values)].sample_names)
    nongene_samples = pop_samples.difference(gene_samples)
    df = get_table_cat(gene_samples, nongene_samples, cat_samples, category)
    data_dict = dict(zip(
        [f"gene_{category}", f"gene_non{category}", f"nongene_{category}", f"nongene_non{category}"],
        [df.iloc[0,0], df.iloc[0,1], df.iloc[1,0], df.iloc[1,1]]
    ))
    return data_dict


def get_bmi_cat_data(ser, gene_burden_dict, pop_samples, category_dict):
    gene_samples = get_samples(ser, gene_burden_dict, pop_samples)
    cat_dict = dict()
    for cat, cat_val in category_dict.items():
        cat_val = cat_val[0]
        cdict = get_bmi_cat_data_helper(gene_samples, pop_samples, pheno_df, cat, cat_val)
        cat_dict.update(cdict)
    cat_dict["ID"] = ser.ID
    return pd.Series(cat_dict)

In [None]:
gene_burden_dict = create_gene_burden_tables(gene_burden_df, 0.001, pd.DataFrame())
pop_samples = set(pheno_df.sample_names.astype(str))


In [None]:
pheno_df["bmi_category"] = pheno_df.bmi.apply(get_bmi_categories)

In [None]:
categories = {
    "nu": [["normal", "underweight"], ["overweight", "obese", "severely obese"]],
    "ob": [["obese", "severely obese"], ["overweight", "normal", "underweight"]],
    "sob": [["severely obese"], ["obese", "overweight", "normal", "underweight"]],
    "ovw": [["overweight"], ["severely obese", "obese", "normal", "underweight"]]
}

In [None]:
monogenic_ukb_obesity_cat_df = monogenic_meta_df.apply(get_bmi_cat_data, axis=1 , args=(gene_burden_dict, pop_samples, categories,))

In [None]:
monogenic_ukb_obesity_cat_df.set_index("ID").to_csv("./monogenic_ukb_bmi_cat.csv.gz")