# Extract genome-wide data for multivariate analysis

## Description

This notebook prepares input data for Utimate Decomposition to generate mixture prior (for mvSuSiE) or to use for MASH analysis. It outputs 3 sets of data: $Z_s$, $Z_n$ and $Z_r$ (strong, null and random)

* $Z_s$: **this is now extracted from genome-wide cis analysis fine-mapping results**. We extract the top loci data frame of each condition, where the CS threshold is set to be 0.7. Then we merge the z-scores of them into one data frame.
* $Z_n$: (null $Z$-scores): we first extract up to $M$ candidate SNPs from each region which satisify $|z| \le 2$, then overlap it with the list of independent SNPs to keep only independent variants, then finally take the union of the extracted.
* $Z_r$: we randomly extract variants based on input independent list of variants.

**FIXME: We need to apply the independent list of variants Anqi developed and use it here to filter and get $Z_n$ and $Z_r$. This logic shoud be added to `processing_1`. Also, it might be a good idea we take some of these utility functions into pecotmr package for better maintenance. For example `processing_1` the function to load regional summary stats from tensorQTL into a matrix should be packed into pecotmr; plus this one function `handle_nan_etc`. processing_2 can stay as is; the `susie_signal` step can also go into `pecotmr` as a way for users to summarize signals from SuSiE for other purposes**

## Input
1. **Marginal summary statistics files**: Bgzipped summary statistics for chromosomes 1-22, generated by tensorQTL cis-analysis and indexed by `tabix`.
2. **Fine-mapping results file index**: Path to lists of fine-mapped RDS files from finemapping output.
2. **Genome region partition** (optional): Defines genomic regions for each gene as enhanced cis regions where we should extract $Z_n$ and $Z_r$ from. This list is used for fine-mapping, so if the complete list of fine-mapping RDS (rather than a handful of it) is already avaiable (#2 above) then there is no need to provide this file. Otherwise, it's going to be limited to only certaion regions, which is also good for testing purpose.

## Output
A list of 10 elements:

```
List of 10
 $ random.z: num [1:36, 1:2] -0.785 -0.785 -0.785 -0.785 -0.785 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:36] "1:97960:A:G" "1:138565:G:A" "1:15112:C:T" "1:189947:G:A" ...
  .. ..$ : chr [1:2] "A" "B"
 $ null.z  : num [1:36, 1:2] -0.785 -0.785 -0.785 -0.785 -0.785 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:36] "1:93692:C:T" "1:273645:A:G" "1:10442:CCTA:." "1:198942:A:C" ...
  .. ..$ : chr [1:2] "A" "B"
 $ random.b: num [1:36, 1:2] -0.123 -0.123 -0.123 -0.123 -0.123 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:36] "1:97960:A:G" "1:138565:G:A" "1:15112:C:T" "1:189947:G:A" ...
  .. ..$ : chr [1:2] "A" "B"
 $ null.b  : num [1:36, 1:2] -0.123 -0.123 -0.123 -0.123 -0.123 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:36] "1:93692:C:T" "1:273645:A:G" "1:10442:CCTA:." "1:198942:A:C" ...
  .. ..$ : chr [1:2] "A" "B"
 $ null.s  : num [1:36, 1:2] 0.157 0.157 0.157 0.157 0.157 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:36] "1:93692:C:T" "1:273645:A:G" "1:10442:CCTA:." "1:198942:A:C" ...
  .. ..$ : chr [1:2] "A" "B"
 $ random.s: num [1:36, 1:2] 0.157 0.157 0.157 0.157 0.157 ...
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:36] "1:97960:A:G" "1:138565:G:A" "1:15112:C:T" "1:189947:G:A" ...
  .. ..$ : chr [1:2] "A" "B"
 $ strong.b:Classes ‘data.table’ and 'data.frame':	1 obs. of  2 variables:
  ..$ A: num -0.217
  ..$ B: num -0.217
  ..- attr(*, ".internal.selfref")=<externalptr> 
 $ strong.s:Classes ‘data.table’ and 'data.frame':	1 obs. of  2 variables:
  ..$ A: num 0.0481
  ..$ B: num 0.0481
  ..- attr(*, ".internal.selfref")=<externalptr> 
 $ strong.z:Classes ‘data.table’ and 'data.frame':	1 obs. of  2 variables:
  ..$ A: num -4.5
  ..$ B: num -4.5
  ..- attr(*, ".internal.selfref")=<externalptr> 
 $ XtX     : num [1:2, 1:2] 20.3 20.3 20.3 20.3
  ..- attr(*, "dimnames")=List of 2
  .. ..$ : chr [1:2] "A" "B"
  .. ..$ : chr [1:2] "A" "B"
  ```

### Example

In [None]:
# generate random and null only
sos run pipeline/mash_preprocessing.ipynb processing \
    --name protocol_example_protein \
    --sum_files test_pQTL_asso_list \
               test_pQTL_asso_list \
    --region_file test.region \
    --pheno_names A B 


In [None]:
# generate strong only
sos run pipeline/mash_preprocessing.ipynb susie_signal \
    --name protocol_example_protein \
    --susie_list protocol_example_protein.susie_output.txt \
    --pheno_names A B 


In [None]:
# generate mashr input directly
sos run pipeline/mash_preprocessing.ipynb mash_input \
    --name protocol_example_protein \
    --sum_files test_pQTL_asso_list \
               test_pQTL_asso_list \
    --region_file test.region \
    --susie_list protocol_example_protein.susie_output.txt \
    --pheno_names A B


In [1]:
[global]
import glob
parameter: name = str
parameter: pheno_names = paths
# Path to work directory where output locates
parameter: cwd = path("./output")
# Containers that contains the necessary packages
parameter: container = ""
import re
parameter: entrypoint= ('micromamba run -a "" -n' + ' ' + re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])) if container else ""
# For cluster jobs, number commands to run per job
parameter: job_size = 1
# Wall clock time expected
parameter: walltime = "5h"
# Memory expected
parameter: mem = "16G"
# Number of threads
parameter: numThreads = 8

# This is in principle required; but in practice it can be optional if we are not exactly stringent about getting independent SNPs
parameter: independent_variant_list = path
ran_null = file_target(f"{cwd}/{name}.random.null.rds")
strong = file_target(f"{cwd}/{name}.strong.rds")

## Get the random and null effects per analysis unit

In [None]:
[extract_tensorqtl_1]
parameter: sum_files = paths
parameter: region_file = path
import re
import pandas as pd
def find_matching_files_for_region(chr_id):
    chr_number = chr_id[3:]  # subset 1 from chr1
    pattern_str = r"\.{chr_number}\."
    pattern = re.compile(pattern_str.format(chr_number=chr_number))
    paths = []
    for sum_file in sum_files:
        with open(sum_file, 'r') as af:
            for aline in af:
                if pattern.search(aline):
                    paths.append(aline.strip())
    return ",".join(paths)

updated_regions = []
with open(region_file, 'r') as regions:
    header = regions.readline().strip()
    updated_regions.append(header + "\tpath\tregion")
    for line in regions:
        parts = line.strip().split("\t")
        chr_id, start, end, gene_id = parts
        paths = find_matching_files_for_region(chr_id)
        updated_regions.append(f"{chr_id}\t{start}\t{end}\t{gene_id}\t{paths}\t{chr_id}:{start}-{end}")

meta_df = pd.DataFrame([line.split("\t") for line in updated_regions[1:]], columns=updated_regions[0].split("\t"))
meta = meta_df[['gene_id', 'path', 'region']].to_dict(orient='records')

input: for_each='meta'
output: f'{cwd:a}/{name}_cache/{name}.{_meta["gene_id"]}.rds'
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime,  mem = mem, tags = f'{step_name}_{_output:bn}'  
R: expand = "${ }", stderr = f'{_output}.stderr', stdout = f'{_output}.stdout', container = container, entrypoint=entrypoint
    region <- "${_meta['region']}"
    # FIXME I am sure there is a more elegant way to put together the path, via SoS
    phenotype_path <- unlist(strsplit("${_meta['path']}", ","))
    dat <- tryCatch(
      {
        # Try to run the function
         pecotmr::load_multitrait_tensorqtl_sumstat(phenotype_path = phenotype_path, region = region, 
          trait_names = c(${pheno_names:r,}), filter_file = NULL, remove_any_missing = TRUE, max_rows_selected = 300)
      },
      error = function(e) {
        warning("Attempt remove chr in region ID to load the data.")
        # If an error occurs, modify the region and try again
        pecotmr::load_multitrait_tensorqtl_sumstat(phenotype_path = phenotype_path, region =  gsub("chr", "", region), 
          trait_names = c(${pheno_names:r,}), filter_file = NULL, remove_any_missing = TRUE, max_rows_selected = 300)
      }
    )
    saveRDS(dat, ${_output:r}, compress="xz")

In [None]:
# extract data for MASH from summary stats
[extract_tensorqtl_2]
parameter: seed = 999
parameter: n_random = 50
parameter: n_null = 50
parameter: z_only = False
# Columns: "#chr", sumstat(merged.vcf.gz)
parameter: table_name = ""
parameter: bhat = "bhat"
parameter: sbhat = "sbhat"
parameter: expected_ncondition = 0
parameter: per_chunk =100

##  conditions can be excluded if needs arise. If nothing to exclude keep the default 0
parameter: exclude_condition = []
parameter: datadir = ""
parameter: na_remove = "TRUE"

input:  group_by = per_chunk
output: f"{cwd}/{name}_cache/{name}_batch{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }",stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', container = container, entrypoint=entrypoint
    library(dplyr)
    library(stringr)
    set.seed(${seed})
    #library(huiiy)

    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    extract_one_data = function(dat, n_random, n_null, filename) {
        if (is.null(dat)) return(NULL)
        abs_z = abs(dat$${bhat}/dat$${sbhat})
        # random samples can include the real signals 
        sample_idx = 1:nrow(abs_z)
        random_idx = sample(sample_idx, min(n_random, length(sample_idx)), replace = F)
        random = list(bhat = dat$${bhat}[random_idx,,drop=F], sbhat = dat$${sbhat}[random_idx,,drop=F])
        # null samples defined as |z| < 2
        null.id = which(apply(abs_z, 1, max) < 2)
        if (length(null.id) == 0) {
          warning(paste("Null data is empty for input file", filename))
          null = list()
        } else {
          null_idx = sample(null.id, min(n_null, length(null.id)), replace = F)
          null = list(bhat = dat$${bhat}[null_idx,,drop=F], sbhat = dat$${sbhat}[null_idx,,drop=F])
        }
        #dat = (list(random = remove_rownames(random), null = remove_rownames(null)))
        dat = (list(random = random, null = null))
        return(dat)
    }
    reformat_data = function(dat, z_only = FALSE) {
        # make output consistent in format with 
        # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb      
        res = list(random.z = dat$random$bhat/dat$random$sbhat, 
                  null.z = dat$null$bhat/dat$null$sbhat)
        if (!z_only) {
          res = c(res, list(random.b = dat$random$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat))
      }
      return(res)
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(is.null(res$random.b)|is.null(res$null.b))) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
            if (is.null(one_data[[d]])) {
              next
            } else {
                res[[d]] = as.matrix(rbind(res[[d]],as.data.frame(one_data[[d]])))
            }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
      dat = tryCatch(readRDS(f), error = function(e) return(NULL))${("$"+table_name) if table_name != "" else ""}
      if (is.null(dat)) {
          message(paste("Skip loading file", f, "due to load failure."))
          next
      }
      if (${expected_ncondition} > 0 && (ncol(dat$${bhat}) != ${expected_ncondition} || ncol(dat$${sbhat}) != ${expected_ncondition})) {
          message(paste("Skip loading file", f, "because it has", ncol(dat$${bhat}), "columns different from required", ${expected_ncondition}))
          next
      }
      if(length(c(${",".join([repr(x) for x in exclude_condition])})) > 0 ){
          message(paste("Excluding condition ${exclude_condition} from the analysis"))
          dat$bhat = dat$bhat[,-c(${",".join(exclude_condition)})]
          dat$sbhat = dat$sbhat[,-c(${",".join(exclude_condition)})]
          dat$Z = dat$Z[,-c(${",".join(exclude_condition)})]
      }

      dat<-tryCatch(extract_one_data(dat, ${n_random}, ${n_null}, f), error = function(e) return(NULL))
      res<-tryCatch(merge_data(res, reformat_data(dat , ${"TRUE" if z_only else "FALSE"})), error = function(e) message("Skipping gene due to lack of SNPs"))     
  }
  saveRDS(res, ${_output:r}, compress="xz")

In [None]:
[extract_tensorqtl_3]
input: group_by = "all"
output: random_null = f"{cwd}/{name}.random_null.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", container = container, stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', entrypoint=entrypoint
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
            res[[d]] = rbind(res[[d]], one_data[[d]])
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    saveRDS(dat, ${_output:r}, compress="xz")
 
bash: expand = "${ }", container = container, stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', entrypoint=entrypoint
    rm -rf ${cwd}/${name}_cache/

In [None]:
[extract_susie_top_loci]
parameter: susie_list = path
input: susie_list
output: strong = f"{cwd}/{name}.strong.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", container = container,stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', volumes = [f'{cwd:ad}:{cwd:ad}'], entrypoint=entrypoint
    out <- pecotmr::load_susie_top_loci(read.table("${_input}")$V1, c(${pheno_names:r,}))
    saveRDS(out, ${_output:r}, compress="xz")

In [None]:
[mash_input]
input: random_null, strong
output: f"{cwd}/{name}.mashr_input.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", container = container,stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', volumes = [f'{cwd:ad}:{cwd:ad}'], entrypoint=entrypoint
    out <- readRDS(${_input[0]:r})
    strong <- readRDS(${_input[1]:r})
    out$strong.b <- strong$bhat
    out$strong.s <- strong$sbhat
    X <- out$strong.z <- strong$z
    X[is.na(X)] = 0
    out$ZtZ = t(as.matrix(X)) %*% as.matrix(X) / nrow(X)
    saveRDS(out, ${_output:r}, compress="xz")