# GWAS integration: TWAS and MR

## Introduction

This module provides software implementations for transcriptome-wide association analysis (TWAS), and Mendelian Randomization using fine-mapping instrumental variables (IV). The procedures implements the MR procedure described in Zhang et al 2020 for "causal" effects estimation and model validation, with the unit of analysis being a single gene-trait pair.

This procedure is based on the SuSiE-TWAS workflow --- it assumes that xQTL fine-mapping has been performed (to be used for both TWAS and MR) and moleuclar traits prediction weights pre-computed (to be used for TWAS). Cross validation for TWAS weights is optional but highly recommended.

GWAS data required are GWAS summary statistics and LD matrix for the region of interest.

### Step 1: TWAS 

1. Extract GWAS z-score for region of interest and corresponding LD matrix.
2. (Optional) perform allele matching QC for the LD matrix with summary stats.
3. Process weights: for LASSO, Elastic Net and mr.ash we have to take the weights as is for QTL variants overlapping with GWAS variants. For SuSiE weights it can be adjusted to exactly match GWAS variants.
4. Perofrm TWAS test for multiple sets of weights. 
5. For each gene, filter TWAS results by keeping the best model selected by CV. Drop the genes that don't show good evidence of TWAS prediction weights.

### Step 2: MR for candidate genes

1. Limit MR only to those showing some evidence of TWAS significance AND have strong instrumental variable (fine-mapping PIP or CS). 
2. Use fine-mapped xQTL with GWAS data to perform MR. 
3. For multiple IV, aggregate individual IV estimates using a fixed-effect meta-analysis procedure.
4. Identify and exclude results with severe violations of the exclusion restriction (ER) assumption.

### Input

#### GWAS Data Input Interface (Similar to `susie_rss`)

**GWAS Summary Statistics Files**
- **Input**: Vector of files for one or more GWAS studies.
- **Format**: 
  - Tab-delimited files.
  - First 4 columns: `chr`, `pos`, `a0`, `a1`
  - Additional columns can be loaded using column mapping file see below  
- **Column Mapping files (optional)**:
  - Optional YAML file for custom column mapping.
  - Required columns: `chr`, `pos`, `a0`, `a1`, either `z` or (`betahat` and `sebetahat`).
  - Optional columns: `n`, `var_y`.

**GWAS Summary Statistics Meta-File**: this is optional and helpful when there are lots of GWAS data to process via the same command
- **Columns**: `study_id`, chromosome number, path to summary statistics file, optional path to column mapping file.
- **Note**: Chromosome number `0` indicates a genome-wide file.

eg: `gwas_meta.tsv`

```
study_id    chrom    file_path                 column_mapping_file
study1      1        gwas1.tsv.gz         column_mapping.yml
study2      2        gwas2.tsv.gz         column_mapping.yml
```

**LD Reference Metadata File**
- **Format**: Single TSV file.
- **Contents**:
  - Columns: `chr`, `start`, `end`, path to the LD matrix, genomic build.
  - LD matrix path format: comma-separated, first entry is the LD matrix, second is the bim file.
- **Documentation**: Refer to our LD reference preparation document for detailed information (Tosin pending update).

#### Output of Fine-Mapping & TWAS Pipeline

**xQTL Weight Database files**
- path to various weight DB files, comma delimited.

**xQTL Weight Database Metadata File**: this is optional and helpful when TWAS is done genome-wide for many regions via the same command
- **Types**: Gene-based or TAD-based.
- **Structure**: 
  - RDS format.
  - Organized hierarchically: region → condition → weight matrix.
  - Each column represents a different method.
- **Format**: `chrom`, `start`, `end`, `region_id`, `condition` (e.g., tissue type, QTL), path to various weight DB files, comma delimited.

eg: `xqtl_meta.tsv`

```
chrom    start    end    region_id    condition    file_path
1        1000     5000   region1      cohor1:tissue1:eQTL      weight1.rds, weight2.rds
2        2000     6000   region2      cohor1:tissue1:eQTL      weight3.rds
3        3000     7000   region3      cohor1:tissue1:eQTL      weight4.rds, weight5.rds
```

## Output

TWAS FIXME this is incorrect for now.

- Each row corresponds to a single SNP inferred as a member of a signal cluster, with columns including:
   - `snp`: SNP name.
   - `beta_eQTL`: eQTL effect.
   - `se_eQTL`: Standard error of estimated eQTL effect.
   - `beta_GWAS`: GWAS effect.
   - `se_GWAS`: Standard error of GWAS effect.
   - `cluster`: Signal cluster ID (credible sets index).
   - `pip`: SNP posterior inclusion probability (PIP).
   - `gene_id`: Gene name.


MR

-  The output includes the following columns for each gene:
    - `gene_id`: Gene name.
    - `num_cluster`: Number of credible sets of the gene.
    - `num_instruments`: Number of instruments included in the gene.
    - `spip`: Sum of PIP for credible sets of each gene.
    - `grp_beta`: Signal-level estimates, combining SNP-level estimates from member SNPs weighted by their PIPs.
    - `grp_se`: Standard error of signal-level estimates.
    - `meta`: Gene-level estimate of the causal effect, aggregating signal-level estimates using a fixed-effect meta-analysis model.
    - `se_meta`: Standard error of the gene-level estimate of the causal effect.
    - `Q`: Cochran’s Q statistic.
    - `I2`: $I^2$ statistics

In [None]:
[global]
# Workdir
parameter: gwas_meta_file = path()
parameter: xqtl_meta_file = path()
parameter: container = ''
import re
parameter: entrypoint= ('micromamba run -a "" -n' + ' ' + re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])) if container else ""
parameter: job_size = 100
parameter: walltime = "1h"
parameter: mem = "16G"
parameter: numThreads = 2

def group_by_region(lst, partition):
    # from itertools import accumulate
    # partition = [len(x) for x in partition]
    # Compute the cumulative sums once
    # cumsum_vector = list(accumulate(partition))
    # Use slicing based on the cumulative sums
    # return [lst[(cumsum_vector[i-1] if i > 0 else 0):cumsum_vector[i]] for i in range(len(partition))]
    return partition

In [None]:
[get_analysis_regions: shared = "regional_data"]
import pandas as pd
import os

def file_exists(file_path, relative_path=None):
    """Check if a file exists at the given path or relative to a specified path."""
    if os.path.exists(file_path) and os.path.isfile(file_path):
        return True
    elif relative_path:
        # Constructing the relative path
        relative_file_path = os.path.join(relative_path, file_path)
        return os.path.exists(relative_file_path) and os.path.isfile(relative_file_path)
    return False

def check_required_columns(df, required_columns):
    """Check if the required columns are present in the dataframe."""
    missing_columns = [col for col in required_columns if col not in list(df.columns)]
    if missing_columns:
        raise ValueError(f"Missing required columns: {', '.join(missing_columns)}")

def extract_regional_data(gwas_meta_file, xqtl_meta_file):
    """
    Extracts data from GWAS and xQTL metadata files. Checks if files exist at specified paths,
    or relative to the metadata file's directory if not found initially. Also checks for required columns in the input files.

    Args:
    - gwas_meta_file (str): File path to the GWAS metadata file.
    - xqtl_meta_file (str): File path to the xQTL weight metadata file.

    Returns:
    - Tuple of two dictionaries:
        - GWAS Dictionary: Maps study IDs to a list containing chromosome number, 
          GWAS file path, and optional column mapping file path.
        - xQTL Dictionary: Nested dictionary with region IDs as keys. Each key maps to a 
          dictionary containing 'meta_info' (list of chromosome, start, end, condition)
          and 'files' (list of file paths).

    Raises:
    - FileNotFoundError: If any specified file path does not exist, even when considered
      relative to the metadata file's directory.
    - ValueError: If required columns are missing in the input files.
    """
    from collections import OrderedDict
    # Required columns for each file type
    required_gwas_columns = ['study_id', 'chrom', 'file_path']
    required_xqtl_columns = ['region_id', 'chrom', 'start', 'end', 'condition', 'file_path']

    # Base directory of the metadata files
    gwas_base_dir = os.path.dirname(gwas_meta_file)
    xqtl_base_dir = os.path.dirname(xqtl_meta_file)
    
    # Reading the GWAS metadata file
    gwas_df = pd.read_csv(gwas_meta_file, sep="\t")
    check_required_columns(gwas_df, required_gwas_columns)
    gwas_dict = OrderedDict()
    for _, row in gwas_df.iterrows():
        file_path = row['file_path']
        mapping_file = row.get('column_mapping_file')

        # Check if the file and optional mapping file exist
        if not file_exists(file_path, gwas_base_dir) or (mapping_file and not file_exists(mapping_file, gwas_base_dir)):
            raise FileNotFoundError(f"File {file_path} not found for {row['study_id']}")
        
        # Adjust paths if necessary
        file_path = file_path if file_exists(file_path) else os.path.join(gwas_base_dir, file_path)
        if mapping_file:
            mapping_file = mapping_file if file_exists(mapping_file) else os.path.join(gwas_base_dir, mapping_file)

        gwas_dict[row['study_id']] = [row['chrom'], file_path, mapping_file]

    # Reading the xQTL weight metadata file
    xqtl_df = pd.read_csv(xqtl_meta_file, sep="\t")
    check_required_columns(xqtl_df, required_xqtl_columns)
    xqtl_dict = OrderedDict()
    for _, row in xqtl_df.iterrows():
        file_paths = [fp.strip() for fp in row['file_path'].split(',')]  # Splitting and stripping file paths

        # Check if each file path exists
        for i, fp in enumerate(file_paths):
            if not file_exists(fp, xqtl_base_dir):
                raise FileNotFoundError(f"File {fp} not found in {row['region_id']}")
            file_paths[i] = fp if file_exists(fp) else os.path.join(xqtl_base_dir, fp)

        xqtl_dict[row['region_id']] = {"meta_info": [row['chrom'], row['start'], row['end'], row['region_id'], row['condition']],
                                       "files": file_paths}

    return gwas_dict, xqtl_dict

# Example usage
gwas_dict, xqtl_dict = extract_regional_data(gwas_meta_file, xqtl_meta_file)
regional_data = dict([("GWAS", gwas_dict), ("xQTL", xqtl_dict)])

In [None]:
[twas_scan_1]
depends: sos_variable("regional_data")
regions = regional_data['xQTL'].keys()
meta_info = [x["meta_info"] for x in regional_data['xQTL'].values()]
xqtl_files = [x["files"] for x in regional_data['xQTL'].values()]
input: xqtl_files, group_by = lambda x: group_by_region(x, xqtl_files), group_with = "meta_info"
#output: f'{cwd:a}/{step_name[:-2]}/{condition.replace(":", "_")}.{_meta_info[3]}.twas.rds'
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
R: expand = '${ }', stdout = f"{_output:n}.stdout", stderr = f"{_output:n}.stderr", container = container, entrypoint = entrypoint
    print(c(${_input:r,}))