# Fine-mapping with SuSiE model

This notebook take a list of regions of interest, one genotype file and a list of phenotype files to perform fine-mapping for individual level data.

## Input

1. A list of regions to be analyzed. If it is more than 3 columns then the last column of this file will be treated as region name.
2. Either a list of per chromosome genotype files, or one file for genotype data of the entire genome. Genotype data has to be in PLINK `bed` format. 
3. Phenotype file: first column must be sample ID and other columns each for a phenotype ID. It should have a header.
4. Covariate file: first column must be sample ID and other columns are covariates. It should have a header. **All covariates in this file will be analyzed as adjustment in linear regression.**

### Example genotype data

```
#chr        path
chr21 /mnt/mfs/statgen/xqtl_workflow_testing/protocol_example.genotype.chr21.bed
chr22 /mnt/mfs/statgen/xqtl_workflow_testing/protocol_example.genotype.chr22.bed
```

Alternatively, simply use `protocol_example.genotype.chr21_22.bed` if all chromosomes are in the same file.

### Example region list file

It should have 3 or 4 columns, with the header a commented out line:

```
#chr    start    end    region_id
chr10   0    6480000    ENSG00000008128
chr1    0    6480000    ENSG00000008130
chr1    0    6480000    ENSG00000067606
chr1    0    7101193    ENSG00000069424
chr1    0    7960000    ENSG00000069812
chr1    0    6480000    ENSG00000078369
chr1    0    6480000    ENSG00000078808
```

Or, simply

```
#chr    start    end    
chr10   0    6480000    
chr1    0    6480000    
chr1    0    7101193    
chr1    0    7960000    
```

in either case we need to make sure that each row is unique.

### About indels

Option `--no-indel` will remove indel from analysis.

## Output

1. A RDS file containing SuSiE output object.
2. Some visualizations.

## Minimal Working Example

In [None]:
sos run pipeline/trans_finemapping.ipynb SuSiE  \
    --name dmd_analysis  \
    --genoFile protocol_example.genotype.chr21_22.bed   \
    --phenoFile study_1_phenotypes.tsv \
    --covFile study_1_covariates.tsv \
    --region-list regions_of_interest.tsv \
    --container oras://ghcr.io/cumc/pecotmr_apptainer:latest

## Workflow implementation

In [None]:
[global]
parameter: cwd = path("output")
# A list of file paths for genotype data, or the genotype data itself. 
parameter: genoFile = path
# One or multiple lists of file paths for phenotype data.
parameter: phenoFile = path
# Covariate file path
parameter: covFile = path
parameter: region_list = path
# Only focus on a subset of samples
parameter: keep_samples = path()
# It is required to input the name of the analysis
parameter: name = str
parameter: container = ""
import re
parameter: entrypoint= ('micromamba run -a "" -n' + ' ' + re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])) if container else ""
# For cluster jobs, number commands to run per job
parameter: job_size = 50
# Wall clock time expected
parameter: walltime = "10m"
# Memory expected
parameter: mem = "20G"
# Number of threads
parameter: numThreads = 1

def group_by_region(lst, partition):
    # from itertools import accumulate
    # partition = [len(x) for x in partition]
    # Compute the cumulative sums once
    # cumsum_vector = list(accumulate(partition))
    # Use slicing based on the cumulative sums
    # return [lst[(cumsum_vector[i-1] if i > 0 else 0):cumsum_vector[i]] for i in range(len(partition))]
    return partition

import os
import pandas as pd

def adapt_file_path(file_path, reference_file):
    """
    Adapt a single file path based on its existence and a reference file's path.

    Args:
    - file_path (str): The file path to adapt.
    - reference_file (str): File path to use as a reference for adaptation.

    Returns:
    - str: Adapted file path.

    Raises:
    - FileNotFoundError: If no valid file path is found.
    """
    reference_path = os.path.dirname(reference_file)

    # Check if the file exists
    if os.path.isfile(file_path):
        return file_path

    # Check file name without path
    file_name = os.path.basename(file_path)
    if os.path.isfile(file_name):
        return file_name

    # Check file name in reference file's directory
    file_in_ref_dir = os.path.join(reference_path, file_name)
    if os.path.isfile(file_in_ref_dir):
        return file_in_ref_dir

    # Check original file path prefixed with reference file's directory
    file_prefixed = os.path.join(reference_path, file_path)
    if os.path.isfile(file_prefixed):
        return file_prefixed

    # If all checks fail, raise an error
    raise FileNotFoundError(f"No valid path found for file: {file_path}")

def adapt_file_path_all(df, column_name, reference_file):
    return df[column_name].apply(lambda x: adapt_file_path(x, reference_file))

In [None]:
[get_analysis_regions: shared = "regional_data"]
# input is genoFile, phenoFile, covFile and optionally region_list. If region_list presents then we only analyze what's contained in the list.
# regional_data should be a dictionary like:
#{'data': [("genotype_1", "phenotype_1", "covariate_1"), ("genotype_1", "phenotype_1", "covariate_1"), ... ],
# 'meta_info': [("chr12:752578-752579", "region_id_1", "trait_1", "trait_2"), ("chr13:852580-852581", "region_id_2", "trait_1", "trait_2") ... ]}

def make_meta_data(region_list, region_ids, geno_meta_data, pheno_file, cov_file):
    '''
    Example output:
    #chr    start       end           ID  path     cov_path             cond             geno_path
    chr12   652578   852579  chr12_652578_852579  phenotype.tsv  covar.tsv  trait_A,trait_B   protocol_example.genotype.bed
    '''
    if len(region_ids) != len(region_list):
        raise ValueError("Length of region_ids does not match the number of rows in region_list.")
    if not pheno_file.is_file():
        raise ValueError(f"Phenotype file ``{pheno_file}`` does not exist")
    pheno_file = str(pheno_file)
    if not cov_file.is_file():
        raise ValueError(f"Covariate file ``{cov_file}`` does not exist")
    cov_file = str(cov_file)
    # Read phenotype file
    pheno_df = pd.read_csv(pheno_file, delim_whitespace=True, header=0)
    # Create a comma-separated string of phenotype columns, excluding the first column
    # FIXME HERE!
    phenotypes = ','.join(pheno_df.columns[1:])

    # Create the output DataFrame
    output_df = pd.DataFrame({
        '#chr': region_list.iloc[:, 0],
        'start': region_list.iloc[:, 1],
        'end': region_list.iloc[:, 2],
        'ID': region_ids,
        'path': pheno_file,
        'cov_path': cov_file,
        'cond': phenotypes
    })

    def find_genotype_path(chrom):
        match = geno_meta_data[geno_meta_data['#chr'] == chrom]
        if not match.empty:
            return match.iloc[0]['geno_path']
        else:
            raise ValueError(f"Chromosome {chrom} not found in geno_meta_data.")

    # Apply the function to find the genotype path for each row
    output_df['geno_path'] = output_df['#chr'].apply(find_genotype_path)
    return output_df    

# Load genotype meta data
if f"{genoFile:x}" == ".bed":
    geno_meta_data = pd.DataFrame([("chr"+str(x), f"{genoFile:a}") for x in range(1,23)] + [("chrX", f"{genoFile:a}")], columns=['#chr', 'geno_path'])
else:
    geno_meta_data = pd.read_csv(f"{genoFile:a}", sep = "\t", header=0)
    geno_meta_data.iloc[:, 1] = adapt_file_path_all(geno_meta_data, geno_meta_data.columns[1], f"{genoFile:a}")
    geno_meta_data.columns = ['#chr', 'geno_path']
    geno_meta_data['#chr'] = geno_meta_data['#chr'].apply(lambda x: str(x) if str(x).startswith('chr') else f'chr{x}')

# Checking the DataFrame
valid_chr_values = [f'chr{x}' for x in range(1, 23)] + ['chrX']
if not all(value in valid_chr_values for value in geno_meta_data['#chr']):
    raise ValueError("Invalid chromosome values found. Allowed values are chr1 to chr22 and chrX.")

region_ids = []
if region_list.is_file():
    region_list_df = pd.read_csv(region_list, delim_whitespace=True, header=None, comment = "#")
    if region_list_df.shape[1] > 3:
        # More than 3 columns, extract the last column as ID
        region_ids = region_list_df.iloc[:, -1].unique()
    elif region_list_df.shape[1] == 3:
        # Exactly 3 columns, concatenate the values from the first 3 columns to form the ID
        region_ids = region_list_df.astype(str).apply(lambda x: '_'.join(x), axis=1).unique()
    else:
        raise ValueError(f"Region file ``{region_list}`` has fewer than 3 columns.")
else:
    raise ValueError(f"Region file ``{region_list}`` is not found!")
    
meta_data = make_meta_data(region_list_df, region_ids, geno_meta_data, phenoFile, covFile)

# Create the final dictionary
regional_data = {
    'data': [(row['geno_path'], *row['path'].split(','), *row['cov_path'].split(',')) for _, row in meta_data.iterrows()],
    'meta_info': [(f"{row['#chr']}:{row['start']}-{row['end']}", # this is the phenotype region
                   row['ID'], *row['cond'].split(',')) for _, row in meta_data.iterrows()]
}

In [None]:
[SuSiE_1]
# initial number of single effects for SuSiE
parameter: init_L = 8
# maximum number of single effects to use for SuSiE
parameter: max_L = 30
# remove a variant if it has more than imiss missing individual level data
parameter: imiss = 1.0
# MAF cutoff
parameter: maf = 0.005
# MAC cutoff, on top of MAF cutoff
parameter: mac = 5
# Remove indels if indel = False
parameter: indel = True
parameter: pip_cutoff = 0.1
parameter: coverage = [0.95, 0.7, 0.5]
depends: sos_variable("regional_data")

meta_info = regional_data['meta_info']
input: regional_data["data"], group_by = lambda x: group_by_region(x, regional_data["data"]), group_with = "meta_info"
output: f'{cwd:a}/{step_name[:-2]}/{name}.{_meta_info[1]}.susie.rds'
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
R: expand = '${ }', stdout = f"{_output:n}.stdout", stderr = f"{_output:n}.stderr", container = container, entrypoint = entrypoint
    # extract subset of samples
    keep_samples = NULL
    if (${"TRUE" if keep_samples.is_file() else "FALSE"}) {
      keep_samples = unlist(strsplit(readLines(${keep_samples:ar}), "\\s+"))
      message(paste(length(keep_samples), "samples are selected to be loaded for analysis"))
    }
    print(${_input[0]:anr}) # this is genotype
    print(${_input[1]:ar}) # phenotype file
    print(${_input[2]:ar}) # covariate
    print("${_meta_info[0]}") # region coordinate
    print("${_meta_info[1]}") # region ID
    print (c(${",".join(['"%s"' % x for x in _meta_info[2:]])})) # phenotypes to analyze
    stop("stop here for now")
    
    library(pecotmr)
    # Load genotype data
    tryCatch({
    fdat = load_regional_univariate_data(genotype = ${_input[0]:anr},
                                          phenotype = ${_input[1]:ar},
                                          covariate = ${_input[2]:ar},
                                          region = NULL,
                                          cis_window = "${_meta_info[0]}",
                                          conditions = c(${",".join(['"%s"' % x for x in _meta_info[2:]])}),
                                          maf_cutoff = ${maf},
                                          mac_cutoff = ${mac},
                                          imiss_cutoff = ${imiss},
                                          keep_indel = ${"TRUE" if indel else "FALSE"},
                                          keep_samples = keep_samples,
                                          scale_residuals = FALSE)
    }, NoSNPsError = function(e) {
        message("Error: ", e$message)
        #saveRDS(NULL, ${_output:ar})
        saveRDS(list("${_meta_info[1]}" = e$message), ${_output:ar}, compress='xz')
        quit(save="no")
    })
    # Univeriate analysis suite
    fitted = list()
    for (r in 1:length(fdat$residual_Y)) {
      st = proc.time()
      fitted[[r]] = susie_wrapper(fdat$residual_X[[r]], fdat$residual_Y[[r]], ${init_L}, ${max_L}, ${coverage[0]})
      fitted[[r]] = susie_post_processor(fitted[[r]], fdat$residual_X[[r]], fdat$residual_Y[[r]], fdat$residual_X_scalar[[r]], fdat$residual_Y_scalar[[r]], 
                                       fdat$maf[[r]], secondary_coverage = c(${",".join([str(x) for x in coverage[1:]])}), signal_cutoff = ${pip_cutoff}, 
                                       other_quantities = list(dropped_samples = fdat$dropped_sample[[r]]))
      fitted[[r]]$total_time_elapsed = proc.time() - st
    }
    names(fitted) <- names(fdat$residual_Y)
    saveRDS(list("${_meta_info[1]}" = fitted), ${_output:ar}, compress='xz')