# GWAS integration: TWAS and MR

## Introduction

This module provides software implementations for transcriptome-wide association analysis (TWAS), and performs variant selection for providing sparse signals for cTWAS (causal TWAS) analysis as described in Qian et al (2024+) the multi-group cTWAS method. It will additionally perform Mendelian Randomization using fine-mapping instrumental variables (IV) as described in Zhang et al 2020 for "causal" effects estimation and model validation, with the unit of analysis being a single gene-trait pair.

This procedure is a continuation of the SuSiE-TWAS workflow --- it assumes that xQTL fine-mapping has been performed and moleuclar traits prediction weights pre-computed (to be used for TWAS). Cross validation for TWAS weights is optional but highly recommended.

GWAS data required are GWAS summary statistics and LD matrix for the region of interest.

### Step 1: TWAS 

1. Extract GWAS z-score for region of interest and corresponding LD matrix.
2. (Optional) perform allele matching QC for the LD matrix with summary stats.
3. Process weights: for a number of methods such as LASSO, Elastic Net and mr.ash we have to take the weights as is for QTL variants overlapping with GWAS variants. For SuSiE weights it can be adjusted to exactly match GWAS variants.
4. Perofrm TWAS test for multiple sets of weights. 
5. For each gene, filter TWAS results by keeping the best model selected by CV. Drop the genes that don't show good evidence of TWAS prediction weights.

### Step 2: Variant Selection for Imputable Genes via the Best Prediction Methods
1. Determine if the gene is imputable at each context based on the twas_cv performance by adjusted $r^2$ (>=0.01) and p-values (<0.05).
2. The imputable gene-context pair will go through variant selection step. Maximum 10 variants with top pip selected from either `top_loci` table or SuSiE CS set. 
3. Harmonize weights against LD reference and udpate SuSiE weight. 
4. Extract weights by best model for the context then by the variant names were selected from the previous step

### Step 3: cTWAS analysis

**FIXME: add more documentation here**

### Step 4: MR for candidate genes

1. Limit MR only to those showing some evidence of cTWAS significance AND have strong instrumental variable (fine-mapping PIP or CS). 
2. Use fine-mapped xQTL with GWAS data to perform MR. 
3. For multiple IV, aggregate individual IV estimates using a fixed-effect meta-analysis procedure.
4. Identify and exclude results with severe violations of the exclusion restriction (ER) assumption.

## Input

### GWAS Data Input Interface (Similar to `susie_rss`)

I. **GWAS Summary Statistics Files**
- **Input**: Vector of files for one or more GWAS studies.
- **Format**: 
  - Tab-delimited files.
  - First 4 columns: `chr`, `pos`, `a0`, `a1`
  - Additional columns can be loaded using column mapping file see below  
- **Column Mapping files (optional)**:
  - Optional YAML file for custom column mapping.
  - Required columns: `chr`, `pos`, `a0`, `a1`, either `z` or (`betahat` and `sebetahat`).
  - Optional columns: `n`, `var_y` (relevant to fine-mapping).

II. **GWAS Summary Statistics Meta-File**: this is optional and helpful when there are lots of GWAS data to process via the same command
- **Columns**: `study_id`, chromosome number, path to summary statistics file, optional path to column mapping file.
- **Note**: Chromosome number `0` indicates a genome-wide file.

eg: `gwas_meta.tsv`

```
study_id    chrom    file_path                 column_mapping_file
study1      1        gwas1.tsv.gz         column_mapping.yml
study1      2        gwas2.tsv.gz         column_mapping.yml
study2      0        gwas3.tsv.gz         column_mapping.yml
```

If both summary stats file (I) and meta data file (II) are specified we will take the union of the two.


III. **LD Reference Metadata File**
- **Format**: Single TSV file.
- **Contents**:
  - Columns: `chr`, `start`, `end`, path to the LD matrix, genomic build.
  - LD matrix path format: comma-separated, first entry is the LD matrix, second is the bim file.
- **Documentation**: Refer to [our LD reference preparation document](https://cumc.github.io/xqtl-protocol/code/reference_data/ld_reference_generation.html) for detailed information.

### Output of Fine-Mapping & TWAS Pipeline

**xQTL Weight Database Metadata File**: 
- **Essential columns**: `chrom`, `start`, `end`, `region_id`, `original_data`, `contexts`
- **Structure of the weight database**: 
  - RDS format.
  - Organized hierarchically: region → context → weight matrix.
  - Each column represents a different method.

eg: `xqtl_meta.tsv`

```
#chr start end region_id TSS original_data combined_data combined_data_sumstats contexts contexts_top_loci
chr1 0 6480000 ENSG00000008128 1724356 "KNIGHT_pQTL.ENSG00000008128.univariate_susie_twas_weights.rds, MiGA_eQTL.ENSG00000008128.univariate_susie_twas_weights.rds, MSBB_eQTL.ENSG00000008128.univariate_susie_twas_weights.rds, ROSMAP_Bennett_Klein_pQTL.ENSG00000008128.univariate_susie_twas_weights.rds, ROSMAP_DeJager_eQTL.ENSG00000008128.univariate_susie_twas_weights.rds, ROSMAP_Kellis_eQTL.ENSG00000008128.univariate_susie_twas_weights.rds, ROSMAP_mega_eQTL.ENSG00000008128.univariate_susie_twas_weights.rds, STARNET_eQTL.ENSG00000008128.univariate_susie_twas_weights.rds" Fungen_xQTL.ENSG00000008128.cis_results_db.export.rds Fungen_xQTL.ENSG00000008128.cis_results_db.export_sumstats.rds Knight_eQTL_brain,MiGA_GFM_eQTL,MiGA_GTS_eQTL,MiGA_SVZ_eQTL,MiGA_THA_eQTL,BM_10_MSBB_eQTL,BM_22_MSBB_eQTL,BM_36_MSBB_eQTL,BM_44_MSBB_eQTL,monocyte_ROSMAP_eQTL,Mic_DeJager_eQTL,Ast_DeJager_eQTL,Oli_DeJager_eQTL,Exc_DeJager_eQTL,Inh_DeJager_eQTL,DLPFC_DeJager_eQTL,PCC_DeJager_eQTL,AC_DeJager_eQTL,Mic_Kellis_eQTL,Ast_Kellis_eQTL,Oli_Kellis_eQTL,OPC_Kellis_eQTL,Exc_Kellis_eQTL,Inh_Kellis_eQTL,Ast_mega_eQTL,Exc_mega_eQTL,Inh_mega_eQTL,Oli_mega_eQTL,STARNET_eQTL_Mac Knight_eQTL_brain,MiGA_GFM_eQTL,MiGA_GTS_eQTL,MiGA_SVZ_eQTL,MiGA_THA_eQTL,BM_10_MSBB_eQTL,BM_22_MSBB_eQTL,BM_36_MSBB_eQTL,BM_44_MSBB_eQTL,monocyte_ROSMAP_eQTL,Mic_DeJager_eQTL,Ast_DeJager_eQTL,Oli_DeJager_eQTL,Exc_DeJager_eQTL,Inh_DeJager_eQTL,DLPFC_DeJager_eQTL,PCC_DeJager_eQTL,AC_DeJager_eQTL,Mic_Kellis_eQTL,Ast_Kellis_eQTL,Oli_Kellis_eQTL,OPC_Kellis_eQTL,Exc_Kellis_eQTL,Inh_Kellis_eQTL,Ast_mega_eQTL,Exc_mega_eQTL,Inh_mega_eQTL,Oli_mega_eQTL,STARNET_eQTL_Mac
```

This file is automatically generated as part of the FunGen-xQTL protocol, although only the essential columns are relevant to our application here.


### TWAS region information

This is required for cTWAS analysis, where multiple TWAS and SNP data within each region are combined for joint inference to select the variables, either genes or SNPs, to figure out which variables are likely to be directly associated with the phenotype of interest, rather than being associated through correlations with true causal variables.

```
chrom    start    end    block_id  
1        1000     5000   block1    
2        2000     6000   block2
3        3000     7000   block3
```

## Output

I. A table with the following contents

```
gwas_study, chrom, block, gene, context, method, rsq_adj_cv, pval_cv, is_selected_method, twas_z
```

where

- if `twas_z` is `NA` it means the context is not imputable for the method of choice

II. a list of  `refined_twas_db` per block, in RDS format, of this structure:

```
$ region_id
    $ context
        $ selected_method
        $ selected_method_weights
        $ selected_top_variants
```

This will only contain imputatable contexts. It should come with a meta-data file like this:

```
chrom    start    end    block_id  refined_twas_db
1        1000     5000   block1    block1.rds
2        2000     6000   block2    block2.rds
3        3000     7000   block3    block3.rds
```

III. cTWAS and MR results

TBD

## Example
```
sos run /home/cl4215/githubrepo/xqtl-protocol/code/pecotmr_integration/twas.ipynb twas \
   --cwd /mnt/vast/hpc/csg/cl4215/mrmash/workflow/ \
   --gwas_meta_data /mnt/vast/hpc/csg/cl4215/mrmash/workflow/twas_ctwas/gwas/gwas_meta.tsv \
   --ld_meta_data /mnt/vast/hpc/csg/data_public/20240409_ADSP_LD_matrix/ld_meta_file.tsv \
   --regions /mnt/vast/hpc/csg/cl4215/mrmash/workflow/twas_mr/pipeline/EUR_LD_blocks_CLU.bed \
   --xqtl_meta_data /mnt/vast/hpc/csg/cl4215/mrmash/workflow/susie_twas/Fungen_xQTL.cis_results_db_TSS.test.tsv \
   --max_var_select 10 --p_value_cutoff 0.05 --rsq_threshold 0.01 -s force
```

In [None]:
[global]
parameter: cwd = path("output/")
parameter: gwas_meta_data = path()
parameter: xqtl_meta_data = path()
parameter: ld_meta_data = path()
parameter: gwas_name = []
parameter: gwas_data = []
parameter: column_mapping = []
parameter: regions = path()
parameter: name = f"{xqtl_meta_data:bn}.{gwas_meta_data:bn}"
parameter: container = ''
import re
parameter: entrypoint= ('micromamba run -a "" -n' + ' ' + re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])) if container else ""
parameter: job_size = 100
parameter: walltime = "5m"
parameter: mem = "8G"
parameter: numThreads = 1

import os
import pandas as pd

def adapt_file_path(file_path, reference_file):
    """
    Adapt a single file path based on its existence and a reference file's path.

    Args:
    - file_path (str): The file path to adapt.
    - reference_file (str): File path to use as a reference for adaptation.

    Returns:
    - str: Adapted file path.

    Raises:
    - FileNotFoundError: If no valid file path is found.
    """
    reference_path = os.path.dirname(reference_file)

    # Check if the file exists
    if os.path.isfile(file_path):
        return file_path

    # Check file name without path
    file_name = os.path.basename(file_path)
    if os.path.isfile(file_name):
        return file_name

    # Check file name in reference file's directory
    file_in_ref_dir = os.path.join(reference_path, file_name)
    if os.path.isfile(file_in_ref_dir):
        return file_in_ref_dir

    # Check original file path prefixed with reference file's directory
    file_prefixed = os.path.join(reference_path, file_path)
    if os.path.isfile(file_prefixed):
        return file_prefixed

    # If all checks fail, raise an error
    raise FileNotFoundError(f"No valid path found for file: {file_path}")

def group_by_region(lst, partition):
    # from itertools import accumulate
    # partition = [len(x) for x in partition]
    # Compute the cumulative sums once
    # cumsum_vector = list(accumulate(partition))
    # Use slicing based on the cumulative sums
    # return [lst[(cumsum_vector[i-1] if i > 0 else 0):cumsum_vector[i]] for i in range(len(partition))]
    return partition

In [None]:
[get_analysis_regions: shared = "regional_data"]
from collections import OrderedDict

def check_required_columns(df, required_columns):
    """Check if the required columns are present in the dataframe."""
    missing_columns = [col for col in required_columns if col not in list(df.columns)]
    if missing_columns:
        raise ValueError(f"Missing required columns: {', '.join(missing_columns)}")

def extract_regional_data(gwas_meta_data, xqtl_meta_data, regions, gwas_name, gwas_data, column_mapping):
    """
    Extracts data from GWAS and xQTL metadata files and additional GWAS data provided. 

    Args:
    - gwas_meta_data (str): File path to the GWAS metadata file.
    - xqtl_meta_data (str): File path to the xQTL weight metadata file.
    - gwas_name (list): vector of GWAS study names.
    - gwas_data (list): vector of GWAS data.
    - column_mapping (list, optional): vector of column mapping files.

    Returns:
    - Tuple of two dictionaries:
        - GWAS Dictionary: Maps study IDs to a list containing chromosome number, 
          GWAS file path, and optional column mapping file path.
        - xQTL Dictionary: Nested dictionary with region IDs as keys.

    Raises:
    - FileNotFoundError: If any specified file path does not exist.
    - ValueError: If required columns are missing in the input files or vector lengths mismatch.
    """
    # Check vector lengths
    if len(gwas_name) != len(gwas_data):
        raise ValueError("gwas_name and gwas_data must be of equal length")
    
    if len(column_mapping)>0 and len(column_mapping) != len(gwas_name):
        raise ValueError("If column_mapping is provided, it must be of the same length as gwas_name and gwas_data")

    # Required columns for each file type
    required_gwas_columns = ['study_id', 'chrom', 'file_path']
    required_xqtl_columns = ['region_id', '#chr', 'start', 'end', "TSS", 'original_data'] #region_id here is gene name
    required_ld_columns = ['chr', 'start', 'stop']
    
    # Reading the GWAS metadata file
    gwas_df = pd.read_csv(gwas_meta_data, sep="\t")
    check_required_columns(gwas_df, required_gwas_columns)
    gwas_dict = OrderedDict()
    
    # Reading LD regions info
    regions_df = pd.read_csv(regions, sep="\t",skipinitialspace=True)
    regions_df.columns = [col.strip() for col in regions_df.columns]  # Strip spaces from column names before
    regions_df['chr'] = regions_df['chr'].str.strip()
    #regions_df['chr'] = regions_df['chr'].str.strip().str.replace('chr', '').astype(int) # remove space in the 'chr' column and conver to integer
    check_required_columns(regions_df, required_ld_columns)
    regions_dict = OrderedDict()
    
    # Reading the xQTL weight metadata file
    xqtl_df = pd.read_csv(xqtl_meta_data, sep=" ")
    check_required_columns(xqtl_df, required_xqtl_columns)
    xqtl_dict = OrderedDict()

    # Process additional GWAS data from R vectors
    for name, data, mapping in zip(gwas_name, gwas_data, column_mapping or [None]*len(gwas_name)):
        gwas_dict[name] = {0: [data, mapping]}

    for _, row in gwas_df.iterrows():
        file_path = row['file_path']
        mapping_file = row.get('column_mapping_file')
        
        # Adjust paths if necessary
        file_path = adapt_file_path(file_path, gwas_meta_data)
        if mapping_file:
            mapping_file = adapt_file_path(mapping_file,  gwas_meta_data)

       # Create or update the entry for the study_id
        if row['study_id'] not in gwas_dict:
            gwas_dict[row['study_id']] = {}

        # Expand chrom 0 to chrom 1-22 or use the specified chrom
        chrom_range = range(1, 23) if row['chrom'] == 0 else [row['chrom']]
        for chrom in chrom_range:
            if chrom in gwas_dict[row['study_id']]:
                existing_entry = gwas_dict[row['study_id']][f'chr{chrom}']
                raise ValueError(f"Duplicate chromosome specification for study_id {row['study_id']}, chrom {chrom}. "
                                 f"Conflicting entries: {existing_entry} and {[file_path, mapping_file]}")
            gwas_dict[row['study_id']][f'chr{chrom}'] = [file_path, mapping_file]
            
    for _, row in regions_df.iterrows():
        LD_region_id = f"{row['chr']}_{row['start']}_{row['stop']}"
        overlapping_xqtls = xqtl_df[(xqtl_df['#chr'] == row['chr']) & 
                                     (xqtl_df['TSS'] <= row['stop']) & 
                                     (xqtl_df['TSS'] >= (row['start']))]
        file_paths = []
        mapped_genes = []
        # Collect file paths for xQTLs overlapping this region
        for _, xqtl_row in overlapping_xqtls.iterrows():
            original_data = xqtl_row['original_data']
            file_list = original_data.split(',') if ',' in original_data else [original_data]
            file_paths.extend([adapt_file_path(fp.strip(), xqtl_meta_data) for fp in file_list])
            mapped_genes.append(xqtl_row['region_id'])

        # Store metadata and files in the dictionary
        regions_dict[LD_region_id] = {
            "meta_info": [row['chr'], row['start'], row['stop'], LD_region_id, mapped_genes],
            "files": file_paths
        }
        
    for _, row in xqtl_df.iterrows():
        file_paths = [adapt_file_path(fp.strip(), xqtl_meta_data) for fp in row['original_data'].split(',')]  # Splitting and stripping file paths
        xqtl_dict[row['region_id']] = {"meta_info": [row['#chr'], row['start'], row['end'], row['region_id'], row['contexts']],
                                       "files": file_paths}
    return gwas_dict, xqtl_dict, regions_dict


gwas_dict, xqtl_dict, regions_dict = extract_regional_data(gwas_meta_data, xqtl_meta_data,regions,gwas_name, gwas_data, column_mapping)
regional_data = dict([("GWAS", gwas_dict), ("xQTL", xqtl_dict), ("Regions", regions_dict)])

**FIXME: please add documentation for each paramter in this format, **

```
# docunmentation for this parameter
paramter: p_value_cutoff
```

**FIXME** what we need to do for this pipeline:

1. make it work for the input and output formats i design --- follow it 100% but if you feel it awkward reach out to me to discuss rather than making decisions on your own.
2. change all "context" to "context" (unify the termiology)
3. for the new `[twas]` pipeline, we are adding an extra layer of loop to the logic below, which now should look like this:

```
for each block:
-- refined_twas_weights_data = list()
-- for each gene:
---- imputable_contexts = list() # $context: method1, method2, etc
---- for each context:
------ update_imputable_contexts()
---- if len(imputable_contexts) > 0:
------ twas_weights_data = load_twas_weights( ... ) # this is pecotmr function, which is part of my original design in twas_mr pipeline
------ update_refined_twas_weights() # see section "Output" above for what I expect of the refined database for TWAS, involving extracting information from the best method, and also select top variants for cTWAS
------ for each gwas_study:
-------- load_and_handle_gwas_data()
-------- twas_analysis_for_all_contexts_and_methods(raw_twas_weights_data) # For contexts that are not imputable simply put twas_z to NA
```

4. for the new `[ctwas]` step, we combine it with MR as `[ctwas_mr]`

```
---- for each gwas_study:
------ ctwas_wrapper(refined_twas_weights_meta_data, ...)
------ get causal gene context from cTWAS output
------ perform MR on these causal gene in respective contexts
```

In [None]:
[twas_1]
depends: sos_variable("regional_data")
parameter: allele_qc = True
parameter: coverage = "cs_coverage_0.95"
parameter: max_var_select = 10
parameter: p_value_cutoff = 0.05
parameter: rsq_threshold = 0.01

region_info = [x["meta_info"] for x in regional_data['Regions'].values()]
regional_xqtl_files = [x["files"] for x in regional_data['Regions'].values()]

# Filter out empty xQTL file paths
filtered_region_info = []
filtered_regional_xqtl_files = []
skipped_regions =[]

for region, files in zip(region_info, regional_xqtl_files):
    if files:
        filtered_region_info.append(region)
        filtered_regional_xqtl_files.append(files)
    else:
        skipped_regions.append(region)

print(f"Skipping {len(skipped_regions)} out of {len(regional_xqtl_files)} regions, no overlapping xQTL weights found. ")

input: filtered_regional_xqtl_files, group_by = lambda x: group_by_region(x, filtered_regional_xqtl_files), group_with = "filtered_region_info"
output: f"{cwd:a}/{step_name}/{name}.{_filtered_region_info[3]}.twas.rds"#, f"{cwd:a}/{step_name}/{name}.{_filtered_region_info[3]}.twas_table.tsv"
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
R: expand = '${ }', stdout = f"{_output:n}.stdout", stderr = f"{_output:n}.stderr", container = container, entrypoint = entrypoint

    #library(pecotmr)
    library(dplyr)
    devtools::load_all("/home/cl4215/githubrepo/pecotmr/R") #edit on file_utiles.R and ctwas_wrapper.R
    
    LD_meta_file_path = "${ld_meta_data}"
    region_block <- unlist(strsplit("${_filtered_region_info[3]}", "_\\s*"))
    chrom = as.integer(readr::parse_number(region_block[1]))
    start = as.integer(region_block[2])
    end = as.integer(region_block[3])
    region_of_interest <-  data.frame(chrom = chrom, start = start, end = end)

    # get xQTL weight information 
    xqtl_meta_df <- data.table::fread("${xqtl_meta_data}") # Get related gene information from the xqtl_meta data table
    gwas_studies = c(${paths(regional_data["GWAS"].keys()):r,})
    gwas_files = c(${paths([v[_filtered_region_info[0]] for k, v in regional_data["GWAS"].items()]):r,})
    gene_list <- c(${', '.join([f"'{gene}'" for gene in _filtered_region_info[4]])}) 
  
    # Expand region_of_interest for included genes to load corresponding GWAS regions
    region_xqtl_meta <- xqtl_meta_df[xqtl_meta_df$region_id %in% gene_list, ]
    min_start <- min(region_xqtl_meta$start)
    max_end <- max(region_xqtl_meta$end)
    expanded_region <- data.frame(chrom = chrom, start = min_start, end = max_end)
  
    # Process TWAS results
    refined_twas_weights_data = list() #data hierarchy: refined_twas_weights_data[[region_id]][[gene]][[context]]
    gwas_data = list()
    twas_weights_results <- list() # store imputability status information for all weight db 
    twas_results <- list()

    # Step 1: Assess Imputability
    # load weight db files, we have potentially multiple weight db RDS files for each region of interest
    weight_db_list <- c(${_input:r,})
    # remove weight db files that only contain messages: "No SNPs found in the specified region ..."
    weight_db_list_update <- do.call(c, lapply(weight_db_list, function(file) if (file.size(file) > 200) file))
    if(length(weight_db_list_update)<=0) stop(paste0("No weight information available from ", paste0(weight_db_list, collapse=","), ". "))
    for (weight_db in weight_db_list_update){
        # Step 2: Determine Imputability
        twas_weights_results[[weight_db]] = twas_top_signals(weight_db, contexts=NULL, variable_name_obj=c("preset_variants_result", "variant_names"), twas_weights_table = "twas_weights", 
                                                            max_var_selection=${max_var_select}, min_rsq_threshold = ${rsq_threshold}, p_val_cutoff = ${p_value_cutoff})
    }
  
    # Step 2: load and handle GWAS data for each GWAS study
    for (s in seq_along(gwas_studies)){
        # load gwas data file for this particular chrom
        gwas_sumstats <- load_rss_data(gwas_files[s], gwas_files[length(gwas_files)])$sumstats
        # Load LD list containing LD matrix and corresponding variants
        gwas_LD_list <- load_LD_matrix(LD_meta_file_path, expanded_region, gwas_sumstats)
        # remove duplicate variants
        dup_idx <- which(duplicated(gwas_LD_list$combined_LD_variants))
        if (length(dup_idx)>=1){
          gwas_LD_list$combined_LD_variants <- gwas_LD_list$combined_LD_variants[-dup_idx] 
          gwas_LD_list$combined_LD_matrix <- gwas_LD_list$combined_LD_matrix[-dup_idx, -dup_idx] 
          gwas_LD_list$ref_panel <- gwas_LD_list$ref_panel[-dup_idx,]
        }
        # Allele flip
        gwas_allele_flip <- allele_qc(gwas_sumstats[, c("chrom", "pos", "A2", "A1")], gwas_LD_list$combined_LD_variants, gwas_sumstats, c("beta", "se", "z"))
        # Load LD matrix and sumstats
        gwas_data[[gwas_studies[s]]][["LD"]] <- gwas_LD_list$combined_LD_matrix
        gwas_data[[gwas_studies[s]]][["sumstats"]] <- gwas_allele_flip$target_data_qced
        rm(gwas_sumstats, gwas_allele_flip)# free up memory
        gc()
        
        for (weight_db in weight_db_list_update) {
            gene <- twas_weights_results[[weight_db]]$gene
            xqtl_contexts <- names(twas_weights_results[[weight_db]]$model_selection)# per weight_db level contexts
            imputable_contexts <- do.call(c, lapply(xqtl_contexts, function(context) if (twas_weights_results[[weight_db]]$model_selection[[context]]$imputable) context))
  
            if(length(imputable_contexts)<=0){
               print(paste0("All contexts for gene ", gene, "at weight db file ", weight_db," is not imputable. "))
            } else {
               # extract corresponding gene-level region post-QC GWAS data 
               all_variants_gene <- gsub("chr", "", unique(do.call(c, lapply(xqtl_contexts, function(context) rownames(twas_weights_results[[weight_db]]$weights[[context]])))))
               gwas_indx <- na.omit(match(all_variants_gene, gwas_data[[gwas_studies[s]]][["sumstats"]]$variant_id))
               gwas_data_sumstat <- gwas_data[[gwas_studies[s]]][["sumstats"]][gwas_indx, ]
               gwas_data_LD <- gwas_data[[gwas_studies[s]]][["LD"]][gwas_indx, gwas_indx]
  
               for (context in xqtl_contexts) {
                  refined_twas_weights_data[["${_filtered_region_info[3]}"]][[gene]][[context]] <- list(selected_model=twas_weights_results[[weight_db]]$model_selection[[context]]$selected_model)
                  refined_twas_weights_data[["${_filtered_region_info[3]}"]][[gene]][[context]][["is_imputable"]] <- twas_weights_results[[weight_db]]$model_selection[[context]]$imputable

                  # Step 3: Intersect with gwas summary statistics and adjust susie weights 
                  adjusted_susie_weights <- adjust_susie_weights(twas_weights_results[[weight_db]], context,
                    keep_variants = gwas_data_sumstat$variant_id,
                    allele_qc = ${"TRUE" if allele_qc else "FALSE"}
                  )
                  # Step 4: Overlap weights of other methods with the variants name of adjusted_susie_weights, 
                  # then combine with adjusted susie weights to obtain the subsetted weight matrix
                  weights_matrix <- get_nested_element(twas_weights_results[[weight_db]], c("weights", context))
                  weights_matrix_subset <- cbind(
                    susie_weights = adjusted_susie_weights$adjusted_susie_weights,
                    weights_matrix[adjusted_susie_weights$remained_variants_ids, !colnames(weights_matrix) %in% "susie_weights"]
                  )

                  if(${"TRUE" if allele_qc else "FALSE"}){
                    weights_matrix_qced <- allele_qc(rownames(weights_matrix_subset), gwas_LD_list$combined_LD_variants,
                      weights_matrix_subset, 1:ncol(weights_matrix),
                      target_gwas = FALSE
                    )
                    weights_matrix_subset <- weights_matrix_qced$target_data_qced[, !colnames(weights_matrix_qced$target_data_qced) %in% c("chrom", "pos", "A2", "A1", "variant_id")]
                    rownames(weights_matrix_subset) <- get_nested_element(weights_matrix_qced, c("target_data_qced", "variant_id"))
                  }
                  # Step 4: TWAS analysis
                  twas_result <- twas_analysis(
                    weights_matrix_subset, gwas_data_sumstat, gwas_data_LD,
                    rownames(weights_matrix_subset)
                  )
                  twas_results[[gene]][[context]][[gwas_studies[s]]]<- twas_result

                  # Step 5: Output selected twas weights from best model and selected variants
                  model <- get_nested_element(twas_weights_results[[weight_db]], c("model_selection", context))$selected_model
                  if(!is.null(model)) {
                      variants_picked <- twas_weights_results[[weight_db]][["variant_selection"]][[context]]
                      selec_indx <- match(gsub("chr", "",variants_picked), rownames(weights_matrix_subset))
                      ctwas_weights <- weights_matrix_subset[selec_indx, paste0(model, "_weights")]
                      names(ctwas_weights) <- variants_picked
                  } else {
                      ctwas_weights <- rep(NA, ${max_var_select})
                  }

                  # Add selected variants and weights information to the weights 
                  refined_twas_weights_data[["${_filtered_region_info[3]}"]][[gene]][[context]][["selected_top_variants"]] <- get_nested_element(twas_weights_results[[weight_db]], 
                                                                                                                                              c("variant_selection", context))
                  refined_twas_weights_data[["${_filtered_region_info[3]}"]][[gene]][[context]][["selected_model_weights"]] <-  ctwas_weights
               }
                
            }
        }
    }
    saveRDS(refined_twas_weights_data, ${_output:r}, compress='xz')
    
    # Generate TWAS Table based on refined_twas_weights_data
    genes <- names(refined_twas_weights_data[["${_filtered_region_info[3]}"]])
    twas_table <- do.call(rbind, lapply(genes, function(gene){
        contexts <- names(refined_twas_weights_data[["${_filtered_region_info[3]}"]][[gene]])
        # merge twas_cv information for same gene across all weight db files
        cv_data <- do.call(c, lapply(unname(weight_db_list_update), function(file){ if(twas_weights_results[[file]]$gene == gene) twas_weights_results[[file]]$cv_performance}))
        TSS <- xqtl_meta_df$TSS[xqtl_meta_df$region_id==gene]
        start <- xqtl_meta_df$start[xqtl_meta_df$region_id==gene]
        end <- xqtl_meta_df$end[xqtl_meta_df$region_id==gene]
  
        twas_table_gene <- data.frame()
        for (context in contexts){
          context_table <- data.frame()
          methods <- gsub( "_.*$", "", names(cv_data[[context]]))
          is_imputable = refined_twas_weights_data[["${_filtered_region_info[3]}"]][[gene]][[context]][["is_imputable"]]
          selected_method=refined_twas_weights_data[["${_filtered_region_info[3]}"]][[gene]][[context]]$selected_model
          if(is.null(selected_method)) selected_method <- NA 

          context_table <- do.call(rbind, lapply(gwas_studies, function(study){
              twas_zs <- c();twas_pvals <-c(); cv_rsqs<-c();cv_pvals <-c();is_selected_method<- c()
              for (method in methods){
                 cv_rsqs <- c(cv_rsqs, cv_data[[context]][[paste0(method, "_performance")]][, "adj_rsq"])
                 cv_pvals <- c(cv_pvals, cv_data[[context]][[paste0(method, "_performance")]][, "adj_rsq_pval"])
                 is_selected_method <- c(is_selected_method, ifelse(method==selected_method, TRUE, FALSE))
                 if (!is.null(names(twas_results[[gene]][[context]]))){
                     twas_zs <- c(twas_zs, twas_results[[gene]][[context]][[study]][[paste0(method, "_weights")]]$z)
                     twas_pvals <- c(twas_pvals, twas_results[[gene]][[context]][[study]][[paste0(method, "_weights")]]$pval)
                 } else {
                     twas_zs <- twas_pvals <- NA
                 }
              }
              data.frame(study=study, method=methods, rsq_adj_cv=cv_rsqs, pval_cv=cv_pvals, twas_z = twas_zs, twas_pval=twas_pvals, 
                          is_selected_method= is_selected_method)
           }))
           context_table$context=context
           context_table$is_imputable=is_imputable
           twas_table_gene <- rbind(twas_table_gene, context_table)
        }
        twas_table_gene$gene=gene
        twas_table_gene$chr=chrom
        twas_table_gene$start=start
        twas_table_gene$end=end
        twas_table_gene$TSS=TSS
        twas_table_gene$block="${_filtered_region_info[3]}"
        return(twas_table_gene)
     }))
    colname_ordered <- c("chr", "start", "end", "gene","TSS", "context", "study", "method", "is_imputable", "is_selected_method", "rsq_adj_cv", "pval_cv", 
                      "twas_z", "twas_pval", "block")
    write.table(twas_table[, colname_ordered], file="${_output:ann}.twas_table.tsv", sep="\t", row.names=FALSE, col.names=TRUE, quote=FALSE)

In [None]:
# [ctwas_mr]
# Format extracted ctwas weights, apply LD loader for ctwas, perform ctwas multigroup analysis. 
# depends: sos_variable("regional_data")
# # update records on the same gene_study on the summary table
# # add new record for new gene_study
# meta_info = [x["meta_info"] for x in regional_data['xQTL'].values()]
# xqtl_files = [x["files"] for x in regional_data['xQTL'].values()]
# xqtl_name = os.path.splitext(name)[0].rsplit('_', 1)[0]
# step_name = step_name.rsplit('_', 1)[0]
# input: f"{cwd:a}/{step_name}/{xqtl_name}.*.meta_table", group_by='all'
# output: f"{cwd:a}/{step_name}/{xqtl_name}.summary_table.tsv" 
# task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
# R: expand = '${ }', stdout = f"{cwd:a}/{step_name}/{xqtl_name}.summary_table.stdout", stderr = f"{cwd:a}/{step_name}/{xqtl_name}.summary_table.stderr", container = container, entrypoint = entrypoint


#     library(stringr)
#     file_paths <- c(${_input:r,})
#     summary_list <- lapply(file_paths, function(file)read.table(file, sep="\t", header=TRUE))
#     summary_table <- data.table::rbindlist(summary_list, use.names=TRUE)
#     write.table(summary_table, ${_output:r}, sep="\t", col.names=TRUE, row.names=FALSE, quote=FALSE)
#     data type adapt to various type of context format spelling 
#     matches <- str_extract_all(contexts, paste("${data_type_list}", collapse = "|"), simplify = TRUE)
#     data_type <- unique(tolower(matches)[1,])