# Fine-mapping with SuSiE model

This notebook take a list of regions of interest, one genotype file and a list of phenotype files to perform fine-mapping for individual level data.

## Input

1. A list of regions to be analyzed. If it is more than 3 columns then the last column of this file will be treated as region name.
2. Either a list of per chromosome genotype files, or one file for genotype data of the entire genome. Genotype data has to be in PLINK `bed` format. 
3. Phenotype file: first column must be sample ID and other columns each for a phenotype ID. It should have a header.
4. Covariate file: first column must be sample ID and other columns are covariates. It should have a header. **All covariates in this file will be analyzed as adjustment in linear regression.**

### Example genotype data

```
#chr        path
chr21 /mnt/mfs/statgen/xqtl_workflow_testing/protocol_example.genotype.chr21.bed
chr22 /mnt/mfs/statgen/xqtl_workflow_testing/protocol_example.genotype.chr22.bed
```

Alternatively, simply use `protocol_example.genotype.chr21_22.bed` if all chromosomes are in the same file.

### Example region list file

It should have 3 or 4 columns, with the header a commented out line:

```
#chr    start    end    region_id
chr10   0    6480000    ENSG00000008128
chr1    0    6480000    ENSG00000008130
chr1    0    6480000    ENSG00000067606
chr1    0    7101193    ENSG00000069424
chr1    0    7960000    ENSG00000069812
chr1    0    6480000    ENSG00000078369
chr1    0    6480000    ENSG00000078808
```

Or, simply

```
#chr    start    end    
chr10   0    6480000    
chr1    0    6480000    
chr1    0    6480000    
chr1    0    7101193    
chr1    0    7960000    
chr1    0    6480000    
chr1    0    6480000    
```

### About indels

Option `--no-indel` will remove indel from analysis.

## Output

1. A RDS file containing SuSiE output object.
2. Some visualizations.

## Minimal Working Example

In [None]:
sos run pipeline/trans_finemapping.ipynb SuSiE  \
    --name dmd_analysis  \
    --genoFile protocol_example.genotype.chr21_22.bed   \
    --phenoFile study_1_phenotypes.tsv \
    --covFile study_1_covariates.tsv \
    --region-list regions_of_interest.tsv \
    --container oras://ghcr.io/cumc/pecotmr_apptainer:latest

## Workflow implementation

In [None]:
[global]
parameter: cwd = path("output")
# A list of file paths for genotype data, or the genotype data itself. 
parameter: genoFile = path
# One or multiple lists of file paths for phenotype data.
parameter: phenoFile = path
# Covariate file path
parameter: covFile = path
parameter: region_list = path
# Only focus on a subset of samples
parameter: keep_samples = path()
# It is required to input the name of the analysis
parameter: name = str
parameter: container = ""
import re
parameter: entrypoint= ('micromamba run -a "" -n' + ' ' + re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])) if container else ""
# For cluster jobs, number commands to run per job
parameter: job_size = 50
# Wall clock time expected
parameter: walltime = "10m"
# Memory expected
parameter: mem = "20G"
# Number of threads
parameter: numThreads = 1

def group_by_region(lst, partition):
    # from itertools import accumulate
    # partition = [len(x) for x in partition]
    # Compute the cumulative sums once
    # cumsum_vector = list(accumulate(partition))
    # Use slicing based on the cumulative sums
    # return [lst[(cumsum_vector[i-1] if i > 0 else 0):cumsum_vector[i]] for i in range(len(partition))]
    return partition

import os
import pandas as pd

def adapt_file_path(file_path, reference_file):
    """
    Adapt a single file path based on its existence and a reference file's path.

    Args:
    - file_path (str): The file path to adapt.
    - reference_file (str): File path to use as a reference for adaptation.

    Returns:
    - str: Adapted file path.

    Raises:
    - FileNotFoundError: If no valid file path is found.
    """
    reference_path = os.path.dirname(reference_file)

    # Check if the file exists
    if os.path.isfile(file_path):
        return file_path

    # Check file name without path
    file_name = os.path.basename(file_path)
    if os.path.isfile(file_name):
        return file_name

    # Check file name in reference file's directory
    file_in_ref_dir = os.path.join(reference_path, file_name)
    if os.path.isfile(file_in_ref_dir):
        return file_in_ref_dir

    # Check original file path prefixed with reference file's directory
    file_prefixed = os.path.join(reference_path, file_path)
    if os.path.isfile(file_prefixed):
        return file_prefixed

    # If all checks fail, raise an error
    raise FileNotFoundError(f"No valid path found for file: {file_path}")

def adapt_file_path_all(df, column_name, reference_file):
    return df[column_name].apply(lambda x: adapt_file_path(x, reference_file))

In [None]:
[get_analysis_regions: shared = "regional_data"]
# input is genoFile, phenoFile, covFile and optionally region_list. If region_list presents then we only analyze what's contained in the list.
# regional_data should be a dictionary like:
#{'data': [("genotype_1", "phenotype_1", "covariate_1"), ("genotype_1", "phenotype_1", "covariate_1"), ... ],
# 'meta_info': [("chr12:752578-752579", "region_id_1", "trait_1", "trait_2"), ("chr13:852580-852581", "region_id_2", "trait_1", "trait_2") ... ]}

def make_meta_data(region_list, region_ids, geno_meta_data, pheno_file, cov_file):
    '''
    Example output:
    #chr    start       end           ID  path     cov_path             cond             geno_path
    chr12   652578   852579  chr12_652578_852579  phenotype.tsv  covar.tsv  trait_A,trait_B   protocol_example.genotype.bed
    '''
    if len(region_ids) != len(region_list):
        raise ValueError("Length of region_ids does not match the number of rows in region_list.")

    # Read phenotype file
    pheno_df = pd.read_csv(pheno_file, sep='\t', header=0)
    # Create a comma-separated string of phenotype columns, excluding the first column
    phenotypes = ','.join(pheno_df.columns[1:])

    # Create the output DataFrame
    output_df = pd.DataFrame({
        '#chr': region_list['chr'],
        'start': region_list['start'],
        'end': region_list['end'],
        'ID': region_ids,
        'path': pheno_file,
        'cov_path': cov_file,
        'cond': phenotypes
    })

    def find_genotype_path(chrom):
        match = geno_meta_data[geno_meta_data['#chr'] == chrom]
        if not match.empty:
            return match.iloc[0]['genotype_file_path']
        else:
            raise ValueError(f"Chromosome {chrom} not found in geno_meta_data.")

    # Apply the function to find the genotype path for each row
    output_df['geno_path'] = output_df['#chr'].apply(find_genotype_path)
    return output_df    

# Load phenotype meta data
if len(phenoFile) != len(covFile):
    raise ValueError("Number of input phenotypes files must match that of covariates files")
if len(phenoFile) != len(phenotype_names):
    raise ValueError("Number of input phenotypes files must match the number of phenotype names")

# Load genotype meta data
if f"{genoFile:x}" == ".bed":
    geno_meta_data = pd.DataFrame([("chr"+str(x), f"{genoFile:a}") for x in range(1,23)] + [("chrX", f"{genoFile:a}")], columns=['#chr', 'geno_path'])
else:
    geno_meta_data = pd.read_csv(f"{genoFile:a}", sep = "\t", header=0)
    geno_meta_data.iloc[:, 1] = adapt_file_path_all(geno_meta_data, geno_meta_data.columns[1], f"{genoFile:a}")
    geno_meta_data.columns = ['#chr', 'geno_path']
    geno_meta_data['#chr'] = geno_meta_data['#chr'].apply(lambda x: str(x) if str(x).startswith('chr') else f'chr{x}')

# Checking the DataFrame
valid_chr_values = [f'chr{x}' for x in range(1, 23)] + ['chrX']
if not all(value in valid_chr_values for value in geno_meta_data['#chr']):
    raise ValueError("Invalid chromosome values found. Allowed values are chr1 to chr22 and chrX.")

if len(meta_data.index) == 0:
    raise ValueError("No region overlap between genotype and any of the phenotypes")

region_ids = []
if region_list.is_file():
    region_list_df = pd.read_csv(region_list, sep = "\t", header=None, comment = "#")
    if region_list_df.shape[1] > 3:
        # More than 3 columns, extract the last column as ID
        region_ids = region_list_df.iloc[:, -1].unique()
    elif region_list_df.shape[1] == 3:
        # Exactly 3 columns, concatenate the values from the first 3 columns to form the ID
        region_ids = region_list_df.astype(str).apply(lambda x: '_'.join(x), axis=1).unique()
    else:
        raise ValueError(f"Region file ``{region_list}`` has fewer than 3 columns.")
else:
    raise ValueError(f"Region file ``{region_list}`` is not found!")
    
print(region_list)
print(region_id)

# Create the final dictionary
regional_data = {
    'data': [(row['geno_path'], *row['path'].split(','), *row['cov_path'].split(',')) for _, row in meta_data.iterrows()],
    'meta_info': [(f"{row['#chr']}:{row['start']}-{row['end']}", # this is the phenotype region
                   f"{row['#chr']}:{row['start_cis']}-{row['end_cis']}", # this is the cis-window region
                   row['ID'], *row['cond'].split(',')) for _, row in meta_data.iterrows()]
}

In [None]:
[SuSiE_RSS_1]
parameter: L = 10
parameter: max_L = 100
# If available the column that indicates sample size within the sumstats
parameter: sample_size_col = []
# Sample size used to generate the sumstats
parameter: sample_size = 0
# filtering threshold for raiss imputation
parameter: rcond = 0.01
parameter: R2_threshold = 0.6
depends: sos_variable("regional_data")
regions = list(regional_data['regions'].keys())
studies = list(regional_data["GWAS"].keys())
input: for_each = ["regions", "studies"]
output: f'{cwd:a}/{step_name[:-2]}/{_studies}.{_regions.replace(":", "_")}.susie_rss.rds'
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
R: expand = '${ }', stdout = f"{_output:n}.stdout", stderr = f"{_output:n}.stderr", container = container, entrypoint = entrypoint
    library(pecotmr)
    library(dplyr)
    library(susieR)
    library(data.table)
    sumstats=fread("${regional_data['GWAS'][_studies][regional_data['regions'][_regions][0]][0]}")
  
    # rename the columns by yml file -- make the column names consistent
    column_file_path = "${regional_data['GWAS'][_studies][regional_data['regions'][_regions][0]][1]}"
    column_data <- read.table(column_file_path, header = FALSE, sep = ":", stringsAsFactors = FALSE)
    colnames(column_data) = c("standard", "original")
    count = 1
    for (name in colnames(sumstats)){
        if(name %in% column_data$original){
            index = which(column_data$original == name)
            colnames(sumstats)[count] = column_data$standard[index]
        }
        count = count + 1
    }
  
    ## if the data don't have z scores, derive by beta/se, so that allele flip function can run
    if(length(sumstats$z) == 0){
          sumstats$z = sumstats$beta / sumstats$se
    }
  
    ## if the data don't have beta, derive it by making beta = z and se =1, so that allele flip function can run
    if(length(sumstats$beta) == 0){
          sumstats$beta = sumstats$z
          sumstats$se = 1
    }
    
    ## load region infomation
    region=data.frame(chrom = ${regional_data['regions'][_regions][0]},start = ${regional_data['regions'][_regions][1]},end = ${regional_data['regions'][_regions][2]})
    LD_meta_file=read.table("${ld_meta_data}", sep=" ", header = FALSE, col.names = c("chrom", "start", "end", "path"))
    ## Step 1: Load summary stats and LD data for a region, and match them, using the function in pecotmr::LD.R
    LD_data = load_LD_matrix(LD_meta_file, region, sumstats)
    ## Step 2: basic QC between LD and summary stats --- to correct allele flipping mainly in pecotmr
    allele_flip = allele_qc(sumstats, LD_data[[1]]$variants_df, match.min.prop=0.2, remove_dups=FALSE, flip=TRUE, remove=TRUE)
    allele_flip = allele_flip %>% mutate(variant_allele_flip = paste(chrom,pos,A1.sumstats,A2.sumstats,sep=":"))
    LD_extract = LD_data[[1]]$LD[allele_flip$variant_allele_flip,allele_flip$variant_allele_flip]
    ## Step 3: Perform SuSiE RSS with QC using Gao's prototype
    cols_sample_size=c(${','.join(['"%s"' % x for x in sample_size_col if x is not None])})
    sample_size = ${sample_size}
    L = ${L}
    sample_size_col = c(${','.join(['"%s"' % x for x in sample_size_col if x is not None])})
    ## get sample size: better specified. If not specified, calculate from median "sample_size_col". If columns to compute sample size not specified
    ## make sample size = 0, so that susie_rss will run without n (not meaning n will = 0)
    if(sample_size > 0){
      n = sample_size
    }else if(length(cols_sample_size) >= 1){
      n_col_sum <- allele_flip$${sample_size_col[0]} + allele_flip$${sample_size_col[1]}
      n = median(n_col_sum)
    }else{
      n = 0
    }
  
    # if include QC step, then correct_zR_discrepancy = TRUE
    if(${"TRUE" if QC else "FALSE"}){

      if( n > 0){
      susie_rss_result = susie_rss(bhat = allele_flip$beta, shat = allele_flip$se,
                              R = LD_extract, n = n, L = L,
                              correct_zR_discrepancy = TRUE, track_fit = FALSE)
      }else{
      # run without n
      susie_rss_result = susie_rss(bhat = allele_flip$beta, shat = allele_flip$se,
                              R = LD_extract, L = L,
                              correct_zR_discrepancy = TRUE, track_fit = FALSE)
      }

      if(${"TRUE" if impute else "FALSE"}){
        outlier = susie_rss_result$zR_outliers
        if(length(outlier) == 0){
            # no outliers, no need to imputation directly report fit result
            result = susie_rss_result
        }else{
            # with outliers, raiss imputation
            ref_panel = allele_flip %>% select("chrom", "pos", "variant_allele_flip", "A1.ref", "A2.ref")
            colnames(ref_panel) = c("chr", "pos", "variant_id", "A0", "A1") 
            known_zscore =  allele_flip %>% select("chrom", "pos", "variant_allele_flip", "A1.ref", "A2.ref", "z")
            colnames(known_zscore) = c("chr", "pos", "variant_id", "A0", "A1", "Z")
            known_zscores = known_zscore[-outlier, ] %>% arrange(pos)
            imputation_result = raiss(ref_panel, known_zscores, LD_extract, rcond = ${rcond}, R2_threshold = ${R2_threshold})
            filtered_out_variant = setdiff(allele_flip$variant_allele_flip, imputation_result$variant_id)
            filtered_out_id = which(allele_flip$variant_allele_flip %in% filtered_out_variant)
            if(length(filtered_out_id) != 0){
                LD_extract_filtered = as.matrix(LD_extract)[-filtered_out_id,-filtered_out_id]
            }else{
                LD_extract_filtered = as.matrix(LD_extract)

            }
            ## repeat step: get same sample size, if n = 0, run without n parameter
            if(n > 0){
            impute_rss_fit = susie_rss(z = imputation_result$Z, R = LD_extract_filtered, 
                               n = n,
                               L = L, correct_zR_discrepancy = FALSE,
                               track_fit = FALSE)
            }else{
            impute_rss_fit = susie_rss(z = imputation_result$Z, R = LD_extract_filtered, 
                               L = L, correct_zR_discrepancy = FALSE,
                               track_fit = FALSE)        
            }
            result = impute_rss_fit
            result$z = imputation_result$Z
        }



      }else{
        ## no imputation
             result = susie_rss_result
  
  
          }
      }else{
        ## no QC
        if( n > 0){
          result = susie_rss(bhat = allele_flip$beta, shat = allele_flip$se,
                                  R = LD_extract, n = n, L = L,
                                  correct_zR_discrepancy = FALSE, track_fit = FALSE)
          }else{
          # run without n
          result = susie_rss(bhat = allele_flip$beta, shat = allele_flip$se,
                                  R = LD_extract, L = L,
                                  correct_zR_discrepancy = FALSE, track_fit = FALSE)
          }
          
      }

    saveRDS(result, file = "${_output}")
    #write.table(allele_flip, "${_output:n}.sumstats_qced", sep = "\t", col.names=TRUE, row.names=FALSE, quote=FALSE)