# Fine-mapping with SuSiE RSS model

This notebook take a list of LD file and a list of sumstat file and do salmon QC and susie RSS for each overlap LD block.

In [None]:
[global]

parameter: container = ""
# For cluster jobs, number commands to run per job
parameter: job_size = 1
# Wall clock time expected
parameter: walltime = "5h"
# Memory expected
parameter: mem = "16G"
# Number of threads
parameter: numThreads = 20

# getting the overlapped input
parameter: LD_list = path
parameter: sumstat_list = path
import pandas as pd
LD_list = pd.read_csv(LD_list,"\t")
sumstat_list = pd.read_csv(sumstat_list,"\t")
LD_list["#chr"] = [x[0].replace("chr", "") for x in  LD_list["#id"].str.split("_") ]
sumstat_list["#chr"] = [str(x).replace("chr", "") for x in  sumstat_list["#chr"] ]
input_inv = LD_list.merge(sumstat_list)
input_list = input_inv.iloc[:,[1,3]].values.tolist()
parameter: lead_idx_choice = "pvalue"
parameter: abf_prior_variance = 0.4
parameter: nlog10p_dentist_s_threshold = 4
parameter: r2_threshold = 0.6
parameter: n = 0


[SuSiE_RSS_1]
input: input_list, group_by = 2
output: f'{cwd:a}/{_input[1]:bn}.{_input[0].split(".")[-3]}.unisusie_rss.fit.rds',f'{cwd:a}/{_input[1]:bn}.{_input[0].split(".")[-3]}.unisusie_rss.ss_qced.tsv'
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output[0]:bn}'
python: expand = '${ }', stdout = f"{_output:nn}.stdout", stderr = f"{_output:nn}.stderr", container = container
    import pandas as pd
    import numpy as np
    import rpy2
    from backports import zoneinfo
    import rpy2.robjects.numpy2ri as numpy2ri
    from rpy2.robjects import pandas2ri
    from rpy2.robjects.packages import SignatureTranslatedAnonymousPackage
    import rpy2.robjects as ro
    numpy2ri.activate()
    pandas2ri.activate()
    def load_npz_ld(path):
        np_ld_loaded = np.load(path,allow_pickle=True)
        # sort by start position
        snp_id = [x.replace(":","_") for x in np_ld_loaded.get("arr_1")]
        np_ld_loaded = np_ld_loaded.get("arr_0")
        new = np_ld_loaded + np_ld_loaded.T
        np.fill_diagonal(new, np.diag(new)/2)
        return new,snp_id

    def get_bcor_meta(bcor_obj):
        df_ld_snps = bcor_obj.getMeta()
        df_ld_snps.rename(columns={'rsid':'SNP', 'position':'BP', 'chromosome':'CHR', 'allele1':'A1', 'allele2':'A2'}, inplace=True, errors='raise')
        ###df_ld_snps['CHR'] = df_ld_snps['CHR'].astype(np.int64)
        df_ld_snps['BP'] = df_ld_snps['BP'].astype(np.int64)
        return df_ld_snps

    def load_ld_bcor(ld_prefix):
        bcor_file = ld_prefix+'.bcor'
        import os
        import time
        from ldstore.bcor import bcor
        if not os.path.exists(bcor_file):
            raise IOError('%s not found'%(bcor_file))
        t0 = time.time()
        bcor_obj = bcor(bcor_file)
        df_ld_snps = get_bcor_meta(bcor_obj)
        ld_arr = bcor_obj.readCorr([])
        assert np.all(~np.isnan(ld_arr))
        return ld_arr, df_ld_snps

    def abf(beta, se, W=0.04):
        from scipy import special 
        z = beta / se
        V = se ** 2
        r = W / (W + V)
        lbf = 0.5 * (np.log(1 - r) + (r * z ** 2))
        denom = special.logsumexp(lbf)
        prob = np.exp(lbf - denom)
        return lbf, prob
    
    def get_cs(variant, prob, coverage=0.95):
        ordering = np.argsort(prob)[::-1]
        idx = np.where(np.cumsum(prob[ordering]) > coverage)[0][0]
        cs = variant[ordering][: (idx + 1)]
        return cs
    def slalom(df,LD,abf_prior_variance = 0.4 ,nlog10p_dentist_s_threshold = 4, r2_threshold = 0.6  ):
        from scipy import stats
        lbf, prob = abf(df.beta, df.se, W=abf_prior_variance)
        cs = get_cs(df.variant, prob, coverage=0.95)
        cs_99 = get_cs(df.variant, prob, coverage=0.99)
        df["lbf"] = lbf
        df["prob"] = prob
        df["cs"] = df.variant.isin(cs)
        df["cs_99"] = df.variant.isin(cs_99)
        
        if ${lead_idx_choice} == "pvalue":
            lead_idx_snp = df.pvalue.idxmin()
        else:
            lead_idx_snp = df.prob.idxmax()
            
        lead_variant = df.variant[lead_idx_snp]
        df["lead_variant"] = False
        df["lead_variant"].iloc[lead_idx_snp] = True
        # annotate LD     
        ## This is to identify the R for each snp vs the lead snp
        df["r"] = [LD[np.where(np.in1d(df.variant,lead_variant))][:,np.where(np.in1d(df.variant,x))][0][0][0] for x in df.variant]
        lead_z = (df.beta / df.se).iloc[lead_idx_snp]
        df["t_dentist_s"] = ((df.beta / df.se) - df.r * lead_z) ** 2 / (1 - df.r ** 2)
        df["t_dentist_s"] = np.where(df["t_dentist_s"] < 0, np.inf, df["t_dentist_s"])
        df["t_dentist_s"].iloc[lead_idx_snp] = np.nan
        df["nlog10p_dentist_s"] = stats.chi2.logsf(df["t_dentist_s"], df=1) / -np.log(10)
        df["r2"] = df.r ** 2
        df["outliers"] = (df.r2 > r2_threshold) & (df.nlog10p_dentist_s > nlog10p_dentist_s_threshold)
        df_output = df
        n_r2 = np.sum(df.r2 > r2_threshold)
        n_dentist_s_outlier = np.sum(
            (df.r2 > r2_threshold) & (df.nlog10p_dentist_s > nlog10p_dentist_s_threshold)
        )
        max_pip_idx = df.prob.idxmax()
        df_summary = pd.DataFrame(
            {
                "lead_pip_variant": [df.variant.iloc[max_pip_idx]],
                "n_total": [len(df.index)],
                "n_r2": [n_r2],
                "n_dentist_s_outlier": [n_dentist_s_outlier],
                "fraction": [n_dentist_s_outlier / n_r2 if n_r2 > 0 else 0],
                "max_pip": [np.max(df.prob)]
            }
            )
        return df, df_summary
    
    ## Load LD
    if "${_input[0]}".endswith("npz"):
        LD,snp_id = load_npz_ld(${_input[0]:r}) 
    if "${_input[0]}".endswith("bcor"):
        LD,snp_id = load_ld_bcor(${_input[0]:nr}) 
        
    sumstat = pd.read_csv(${_input[1]:r}, "\t")
    ## Get only intersect snp
    intersct = np.intersect1d(sumstat.variant.to_numpy(),snp_id)
    sumstat = sumstat.query("variant in @intersct").reset_index()
    indice = np.where(np.in1d(snp_id, intersct))
    LD = LD[np.ix_(indice[0].tolist(), indice[0].tolist())]    
    ## slalom
    ss_qc,ss_qc_sum = slalom(sumstat,LD,${abf_prior_variance},${nlog10p_dentist_s_threshold},${r2_threshold})
    print(ss_qc_sum)
    ss_qc.to_csv(${_output[1]:r})
    ## Filter out outlier
    LD = LD[np.ix_(ss_qc[~ss_qc.outliers].index,ss_qc[~ss_qc.outliers].index)]  
    ss_qc = ss_qc[~ss_qc.outliers]
    ss_qc = ss_qc.assign(z = ss_qc.beta/ss_qc.se)
    ## SuSiERSS
    string="""
    susie_rss_analysis=function(ss_df, R, n, var_y, z_ld_weight = 0, estimate_residual_variance = FALSE, 
    prior_variance = 50, check_prior = TRUE, output_path ){
    res = susieR:::susie_rss(as.matrix(ss_df$z),as.matrix(R),n,var_y, z_ld_weight = 0, estimate_residual_variance = FALSE, 
    prior_variance = 50, check_prior = TRUE)
    res$variants = ss_df$variant
    res$z = ss_df$z
    res$LD = R
    res$corr = susieR:::get_cs_correlation(res, X = NULL, Xcorr = R, max = FALSE)
    rownames(res$corr) <- names(res$cs)
    colnames(res$corr) <- names(res$cs)
    if (length(res$sets$cs) > 1) {
        index_combos = expand.grid(1:length(res$sets$cs),1:length(res$sets$cs))
        in_common = apply(index_combos, 1, function(x) intersect(res$sets$cs[[x[1]]], res$sets$cs[[x[2]]]))
        counts = unlist(lapply(in_common, length))
        ovlp_mat = matrix(counts, ncol = length(res$sets$cs), byrow = T)
        ovlp_mat[lower.tri(ovlp_mat)] = NA
        rownames(ovlp_mat) = names(res$sets$cs)
        colnames(ovlp_mat) = names(res$sets$cs)
        print(ovlp_mat)
        res$sets[["ovlp_mat"]] = ovlp_mat
    }
    saveRDS(res,output_path)
    return(res)
    }
    """
    susie_analysis = SignatureTranslatedAnonymousPackage(string, "susie_analysis")
    analysis_result = susie_analysis.susie_rss_analysis(ss_df = ss_qc,R = LD,output_path=${_output[0]:r} ${f', n ={n}' if n > 2 else ""})
    