# Fine-mapping with SuSiE RSS model

This notebook take a list of LD reference files and a list of sumstat files from various association studies ...

## Input

1. **FIXME we need to make input as a bed file with chrom, start and end** A tab delimated table describing the path where LD per region stored, can be generated using the ld_per_region_plink step of the genotype processing module.

```
#id     dir
chr17_60570445_65149278 /mnt/vast/hpc/csg/molecular_phenotype_calling/LD/output_npz_2/1300_hg38_EUR_LD_blocks_npz_files/ROSMAP_NIA_WGS.leftnorm.filtered.filtered.chr17_60570445_65149278.flt16.npz
```

2. A tab delimated table describing path where summary stat per chromosome stored, can be generated using the yml_generator module before the qced sumstat are generated. **FIXME: If the chrom name is zero that means the data is genome-wide**
```
hs3163@csglogin:/mnt/vast/hpc/csg/xqtl_workflow_testing/susie_rss$ cat /mnt/vast/hpc/csg/xqtl_workflow_testing/ADGWAS/data_intergration/ADGWAS2022/qced_sumstat_list.txt
#chr    ADGWAS_Bellenguez_2022
1       /mnt/vast/hpc/csg/xqtl_workflow_testing/ADGWAS/data_intergration/ADGWAS2022/ADGWAS_Bellenguez_2022.1/ADGWAS2022.chr1.sumstat.tsv
2       /mnt/vast/hpc/csg/xqtl_workflow_testing/ADGWAS/data_intergration/ADGWAS2022/ADGWAS_Bellenguez_2022.2/ADGWAS2022.chr2.sumstat.tsv
3       /mnt/vast/hpc/csg/xqtl_workflow_testing/ADGWAS/data_intergration/ADGWAS2022/ADGWAS_Bellenguez_2022.3/ADGWAS2022.chr3.sumstat.tsv
4       /mnt/vast/hpc/csg/xqtl_workflow_testing/ADGWAS/data_intergration/ADGWAS2022/ADGWAS_Bellenguez_2022.4/ADGWAS2022.chr4.sumstat.tsv
5       /mnt/vast/hpc/csg/xqtl_workflow_testing/ADGWAS/data_intergration/ADGWAS2022/ADGWAS_Bellenguez_2022.5/ADGWAS2022.chr5.sumstat.tsv
```

3. Regions we want to analyze in the format `chr:start-end`. Can be multiple of these. If not specified we will use the regions in the LD data list

## Output

1. A RDS file containing the output susie object, the name of all variants that went through the analysis, the z score , and the LD used for the analysis.
2. A sumstat file with additional column containing the slalom results.

## MWE

In [None]:
sos run pipeline/SuSiE_RSS.ipynb SuSiE_RSS \
    --ld-data test.ld.list \
    --sumstats /mnt/vast/hpc/csg/xqtl_workflow_testing/ADGWAS/data_intergration/ADGWAS2022/qced_sumstat_list.txt \
    --container containers/stephenslab.sif --impute --cwd output_impute_2

In [None]:
[global]
parameter: cwd = path("output")
# getting the overlapped input
parameter: ld_data = path
parameter: sumstats = paths
import pandas as pd
parameter: container = ""
import re
parameter: entrypoint= ('micromamba run -a "" -n' + ' ' + re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])) if container else ""
# For cluster jobs, number commands to run per job
parameter: job_size = 1
# Wall clock time expected
parameter: walltime = "5h"
# Memory expected
parameter: mem = "16G"
# Number of threads
parameter: numThreads = 3

parameter: lead_idx_choice = "pvalue"
parameter: abf_prior_variance = 0.4
parameter: nlog10p_dentist_s_threshold = 4
parameter: r2_threshold = 0.6
parameter: n = 0
parameter: max_iter = 1000
parameter: impute = True # Whether to impute the sumstat for all the snp in LD but not in sumstat.

In [None]:
[get_analysis_regions: shared = "regional_data"]
# This will pair the LD matrix blocks with each of the input summary stats

LD_list = pd.read_csv(LD_list,sep="\t")
sumstat_list = pd.read_csv(sumstats,sep="\t")
LD_list["#chr"] = [x[0].replace("chr", "") for x in  LD_list["#id"].str.split("_") ]
sumstat_list["#chr"] = [str(x).replace("chr", "") for x in  sumstat_list["#chr"] ]
input_inv = LD_list.merge(sumstat_list)
input_list = input_inv.iloc[:,[1,3]].values.tolist()

In [None]:
[SuSiE_RSS_1]
parameter: L = 10
parameter: max_L = 1000

depends: sos_variable("regional_data")

meta_info = regional_data['meta_info']
input: regional_data["data"], group_by = 2, group_with = "meta_info"
# name = f'{_input[0]:b}'.split(".")[-3]
output: f'{cwd:a}/{_input[1]:bn}.{name}.unisusie_rss.fit.rds',
        f'{cwd:a}/{_input[1]:bn}.{name}.unisusie_rss.ss_qced.tsv'    
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output[0]:bn}'
R: expand = '${ }', stdout = f"{_output[0]:nn}.stdout", stderr = f"{_output[0]:nn}.stderr", container = container, entrypoint = entrypoint
  
    ## Step 1: Load summary stats and LD data for a region, and match them, using the function in pecotmr::LD.R

    ## Step 2: basic QC between LD and summary stats --- to correct allele flipping mainly in pecotmr 
  
    ## Step 3: Perform SuSiE RSS with QC using my prototype
  
    ## Output are 1) RDS file of fine-mapping results and 2) summary stats file for the region after allele flipping QC as well as the SuSiE RSS based QC
  
    ## Ater that we repeat Step 1 and Step 3 with RSS QC (susie_rss as is). 

In [None]:
[SuSiE_RSS_2]
output: pip_plot = f"{cwd}/{_input:bn}.png"
task: trunk_workers = 1, trunk_size = job_size, walltime = '12h', mem = '20G', cores = numThreads, tags = f'{step_name}_{_output:bn}'
R: container=container, expand = "${ }", stderr = f'{_output[0]:n}.stderr', stdout = f'{_output[0]:n}.stdout', entrypoint = entrypoint
    res = readRDS(${_input:r})
    png(${_output[0]:r}, width = 14, height=6, unit='in', res=300)
    par(mfrow=c(1,2))
    susieR::susie_plot(res, y= "PIP", pos=list(attr='pos',start=res$pos[1],end=res$pos[length(res$pos)]), add_legend=T, xlab="position")
    susieR::susie_plot(res, y= "z", pos=list(attr='pos',start=res$pos[1],end=res$pos[length(res$pos)]), add_legend=T, xlab="position", ylab="-log10(p)")
    dev.off()

In [None]:
[SuSiE_RSS_3]
sep = "" #'\n\n---\n'
input: group_by = 'all'
output: analysis_summary = f'{cwd}/{sumstats_path:bnn}.analysis_summary.md', variants_csv = f'{cwd}/{sumstats_path:bnn}.variants.csv'
python: container=container, expand = "${ }", entrypoint = entrypoint

    theme = '''---
    theme: base-theme
    style: |
     p {
       font-size: 24px;
       height: 900px;
       margin-top:1cm;
      }
      img {
        height: 70%;
        display: block;
        margin-left: auto;
        margin-right: auto;
      }
      body {
       margin-top: auto;
       margin-bottom: auto;
       font-family: verdana;
      }
    ---    
    '''
    import numpy as np
    import pandas as pd
    
    # will load the rds file outputted in a previous step
    def load_rds(filename, types=None):
        import os
        import pandas as pd, numpy as np
        import rpy2.robjects as RO
        import rpy2.robjects.vectors as RV
        import rpy2.rinterface as RI
        from rpy2.robjects import numpy2ri
        numpy2ri.activate()
        from rpy2.robjects import pandas2ri
        pandas2ri.activate()
        def load(data, types, rpy2_version=3):
            if types is not None and not isinstance(data, types):
                return np.array([])
            # FIXME: I'm not sure if I should keep two versions here
            # rpy2_version 2.9.X is more tedious but it handles BoolVector better
            # rpy2 version 3.0.1 converts bool to integer directly without dealing with
            # NA properly. It gives something like (0,1,-234235).
            # Possibly the best thing to do is to open an issue for it to the developers.
            if rpy2_version == 2:
                # below works for rpy2 version 2.9.X
                if isinstance(data, RI.RNULLType):
                    res = None
                elif isinstance(data, RV.BoolVector):
                    data = RO.r['as.integer'](data)
                    res = np.array(data, dtype=int)
                    # Handle c(NA, NA) situation
                    if np.sum(np.logical_and(res != 0, res != 1)):
                        res = res.astype(float)
                        res[res < 0] = np.nan
                        res[res > 1] = np.nan
                elif isinstance(data, RV.FactorVector):
                    data = RO.r['as.character'](data)
                    res = np.array(data, dtype=str)
                elif isinstance(data, RV.IntVector):
                    res = np.array(data, dtype=int)
                elif isinstance(data, RV.FloatVector):
                    res = np.array(data, dtype=float)
                elif isinstance(data, RV.StrVector):
                    res = np.array(data, dtype=str)
                elif isinstance(data, RV.DataFrame):
                    res = pd.DataFrame(data)
                elif isinstance(data, RV.Matrix):
                    res = np.matrix(data)
                elif isinstance(data, RV.Array):
                    res = np.array(data)
                else:
                    # I do not know what to do for this
                    # But I do not want to throw an error either
                    res = str(data)
            else:
                if isinstance(data, RI.NULLType):
                    res = None
                else:
                    res = data
            if isinstance(res, np.ndarray) and res.shape == (1, ):
                res = res[0]
            return res
        def load_dict(res, data, types):
            '''load data to res'''
            names = data.names if not isinstance(data.names, RI.NULLType) else [
                i + 1 for i in range(len(data))
            ]
            for name, value in zip(names, list(data)):
                if isinstance(value, RV.ListVector):
                    res[name] = {}
                    res[name] = load_dict(res[name], value, types)
                else:
                    res[name] = load(value, types)
            return res
        #
        if not os.path.isfile(filename):
            raise IOError('Cannot find file ``{}``!'.format(filename))
        rds = RO.r['readRDS'](filename)
        if isinstance(rds, RV.ListVector):
            res = load_dict({}, rds, types)
        else:
            res = load(rds, types)
        return res
    
    def f7(seq):
        seen = set()
        seen_add = seen.add
        return [x for x in seq if not (x in seen or seen_add(x))]



    text = ""
    sep = '\n\n---\n'
    
    inp = "${_input:r}".split(" ")
    for i, each in enumerate(inp):
        inp[i] = ".".join(each.split(".")[:-1])

    r = f7("${_input:bn}".split(" "))
    
    num_csets = []
    region_info = []
    
    # this will be a 2d array that stores information about each variant of interest in the phenotype
    # this includes all the variants in a cs and all the variants past the cutoff
    variant_info = []

    for reg_i, each in enumerate(f7(inp)):
    
        rid = r[reg_i].split('.')[0]
        
        text_temp = ""
        text_temp += "#\n\n SuSiE RSS {region} \n".format(region=r[reg_i])
        text_temp += "![]({region}.png){sep} \n \n".format(region=r[reg_i], sep=sep)

        rd = load_rds(each[1:]+".rds")
        
        # find the number of cs in the current region
        if rd["sets"]["cs"] == None:
            num_csets.append(0)
        else:
            num_csets.append(len(rd["sets"]["cs"]))
        print(num_csets)
        
        # this will store the indicies of all variants that cross the threshold
        ind_p = []

        pval = ${pip_cutoff}

        for i, each in enumerate(rd["pip"]):
            if each >= pval:
                ind_p.append(i)
        sumvars = 0
        
        # if we have at least one cs in the current region
        if num_csets[reg_i] > 0:
            tbl_header = "| chr number | pos at highest pip | ref | alt | region id | cs | highest pip |  \n"
            tbl_header += "| --- | --- | --- | --- | --- | --- | --- |  \n"

            table = ""
            
            sumpips = 0
            
            for cset in rd["sets"]["cs"].keys():
                print(cset)
                
                # if we have many variants in the cs
                if isinstance(rd["sets"]["cs"][cset], np.ndarray):
                    highestpip = 0
                    poswhighestpip = -1
                    for i in rd["sets"]["cs"][cset]:
                        i = i.item() - 1
                        
                        # we make sure that ind_p only stores the variants that aren't in any cs
                        if i in ind_p: ind_p.remove(i) 
                        
                        # append variant info
                        variant_info.append( [rd["chr"][i], rd["pos"][i], rd["ref"][i], rd["alt"][i], rid, cset, rd["pip"][i]] )
                        
                        if rd["pip"][i] > highestpip:
                            highestpip = rd["pip"][i]
                            poswhighestpip = i
                            
                        sumpips += rd["pip"][i]
                        sumvars += 1
                        
                    if poswhighestpip > -1:
                        i = poswhighestpip
                        table += "| {chr} | {pos} | {ref} | {alt} | {rid} | {cs} | {pip:.2f} |  \n".format(chr=rd["chr"][i], pos=rd["pos"][i], ref=rd["ref"][i], alt=rd["alt"][i], rid=rid, cs=cset, pip=rd["pip"][i])
                
                else: # if we have only one variant in the cs
                    i =  rd["sets"]["cs"][cset]
                    i = i.item() - 1
                    
                    # we make sure that ind_p only stores the variants that aren't in any cs
                    if i in ind_p: ind_p.remove(i)
                    
                    # append variant info
                    variant_info.append( [rd["chr"][i], rd["pos"][i], rd["ref"][i], rd["alt"][i], rid, cset, rd["pip"][i]] )
                    
                    table += "| {chr} | {pos} | {ref} | {alt} | {rid} | {cs} | {pip:.2f} |  \n".format(chr=rd["chr"][i], pos=rd["pos"][i], ref=rd["ref"][i], alt=rd["alt"][i], rid=rid, cs=cset, pip=rd["pip"][i])
                    
                    sumpips += rd["pip"][i]
                    sumvars += 1
            

            text_temp += "- Total number of variants: {}\n".format(len(rd["pip"]))
            text_temp += "- Expected number of causal variants: {:.2f}\n".format(sumpips)
            text_temp += "- Number of variants with PIP > {} and not in any CS: {}\n\n".format(pval, len(ind_p))
            text_temp += tbl_header + table + sep
            
            if num_csets[reg_i] > 1:
                text_temp += "#### CORR: Correlation between CS | OLAP: Overlap between CS\n"
                
                cs = list(rd["sets"]["cs"].keys())

                corrheader = "|  |"
                corrbreak = "| --- |"

                for i in cs:
                    corrheader += " CORR {} |".format(i)
                    corrbreak += " --- |"
                    
                corrheader += "  |"
                corrbreak += " --- |"
                    
                for i in cs:
                    corrheader += " OLAP {} |".format(i)
                    corrbreak += " --- |"

                corrheader += "\n"
                corrbreak += "\n"

                body = ""

                for en, i in enumerate(cs):
                    body += "| {} |".format(i)
                    for j in rd["cscorr"][en]:
                        body += " {:.2f} |".format(j)
                    body += "  |"
                    for j in rd["sets"]["cs"]:
                        body += " {} |".format(len(np.intersect1d(rd["sets"]["cs"][i], rd["sets"]["cs"][j])))
                    body += "\n"
                
                text_temp += corrheader + corrbreak + body + sep
            
        region_info.append(text_temp)
            
    f = open(${_output["analysis_summary"]:r}, "w")
    
    cset_order = np.argsort(num_csets)
    cset_order = cset_order.tolist()
    cset_order.reverse()
    for c in cset_order:
        text += region_info[c]
    
    f.write(theme + text)
    
    f.close()
    
    for i in ind_p:
        # append variant info
        variant_info.append( [rd["chr"][i], rd["pos"][i], rd["ref"][i], rd["alt"][i], rid, "None", rd["pip"][i]] )
        
    df = pd.DataFrame(variant_info, columns=["chr", "pos", "ref", "alt", "rid", "cs", "pip"])
    df.to_csv(${_output["variants_csv"]:r}, sep = "\t", header = True, index = True)

In [None]:
# Generate analysis report: HTML file, and optionally PPTX file
[SuSiE_RSS_4]
output: f"{_input['analysis_summary']:n}.html"
sh: container=container_marp, expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', entrypoint = entrypoint
    node /opt/marp/.cli/marp-cli.js ${_input['analysis_summary']} -o ${_output:a} \
        --title '${region_file:bnn} fine mapping analysis' \
        --allow-local-files
    node /opt/marp/.cli/marp-cli.js ${_input['analysis_summary']} -o ${_output:an}.pptx \
        --title '${region_file:bnn} fine mapping analysis' \
        --allow-local-files