# Summary statistics formatting
This notebook takes in more than one collections of sumstat text file,  to produce a collections of merged.rds per gene files that can served as the input of both MASH and MVSuSiE analysis.

## Input
1. a sumstat list with columns: "#chr", theme1, theme2, theme3, each cells not under #chr represent the path to 1 sumstat file(generated by yml generator)
2. region_list:a table with columns: chr, start, end, gene_ID for partition
## Output
1. 23 merged sumstat file in txt format, 1 for each chrom
2. merged sumstat file in rds format, 1 for each gene
3. 2 file documenting 1 and 2

### Example

In [None]:
sos run ~/codes/xqtl-pipeline/pipeline/sumstat_processing.ipynb processing \
    --name protocol_example_protein \
    --asso_files /mnt/vast/hpc/csg/rf2872/Work/test/mash_test/test_pQTL_asso_list \
               /mnt/vast/hpc/csg/rf2872/Work/test/mash_test/test_pQTL_asso_list \
    --region_file test.region


In [None]:
[global]
import glob
parameter: name = str
parameter: asso_files = paths
parameter: region_file = path
# Path to work directory where output locates
parameter: cwd = path("./output")
# Containers that contains the necessary packages
parameter: container = ''
# For cluster jobs, number commands to run per job
parameter: job_size = 1
# Wall clock time expected
parameter: walltime = "5h"
# Memory expected
parameter: mem = "16G"
# Number of threads
parameter: numThreads = 8
parameter: per_chunk =100
# Columns: "#chr", sumstat(merged.vcf.gz)
parameter: table_name = ""
parameter: bhat = "bhat"
parameter: sbhat = "sbhat"
parameter: expected_ncondition = 0
##  conditions can be excluded if needs arise. If nothing to exclude keep the default 0
parameter: exclude_condition = []
parameter: datadir = ""
parameter: seed = 999
parameter: n_random = 4
parameter: n_null = 4
parameter: z_only = False
parameter: na_remove = "TRUE"

## Get the random and null effects per analysis unit

In [None]:
[processing_1]
import re
import pandas as pd
def find_matching_files_for_region(chr_id):
    chr_number = chr_id[3:]  # subset 1 from chr1
    pattern_str = r"\.{chr_number}\."
    pattern = re.compile(pattern_str.format(chr_number=chr_number))
    paths = []
    for asso_file in asso_files:
        with open(asso_file, 'r') as af:
            for aline in af:
                if pattern.search(aline):
                    paths.append(aline.strip())
    return ",".join(paths)

updated_regions = []
with open(region_file, 'r') as regions:
    header = regions.readline().strip()
    updated_regions.append(header + "\tpath\tregion")
    for line in regions:
        parts = line.strip().split("\t")
        chr_id, start, end, gene_id = parts
        paths = find_matching_files_for_region(chr_id)
        updated_regions.append(f"{chr_id}\t{start}\t{end}\t{gene_id}\t{paths}\t{chr_id}:{start}-{end}")

meta_df = pd.DataFrame([line.split("\t") for line in updated_regions[1:]], columns=updated_regions[0].split("\t"))
meta = meta_df[['gene_id', 'path', 'region']].to_dict(orient='records')

input: for_each='meta'
output: f'{cwd:a}/{name}_cache/{name}.{_meta["gene_id"]}.rds'
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime,  mem = mem, tags = f'{step_name}_{_output:bn}'  
R: expand = "${ }", stderr = f'{_output}.stderr', stdout = f'{_output}.stdout', container = container
    # Extract and preprocess data from phenotype_path
    extract_data <- function(path, region) {
        tabix_region(path, region) %>%
            mutate(variant = paste(`#CHROM`, POS, REF, ALT, sep = ":")) %>%
            select(-c(3, 6:9)) %>%
            distinct(variant, .keep_all = TRUE) %>%
            as.matrix
    }
    # Extract bhat and sbhat
    extract_component <- function(df, component_index) {
        df %>%
            select(6:ncol(df)) %>%
            mutate(across(everything(), ~as.numeric(strsplit(as.character(.), ":")[[1]][component_index]))) %>%
            as.matrix
    }

    load_combined_matrix_data <- function(phenotype_path, region) {
        library(dplyr)   
        Y <- lapply(phenotype_path, extract_data, region)

        # Combine matrices
        combined_matrix <- Reduce(function(x, y) merge(x, y, by = c("variant", "#CHROM", "POS", "REF", "ALT")), Y) %>%
            distinct(variant, .keep_all = TRUE)

        dat <- list(
            bhat = extract_component(combined_matrix, 1),
            sbhat = extract_component(combined_matrix, 2)
        )

        rownames(dat$bhat) <- rownames(dat$sbhat) <- combined_matrix$variant

        return(dat)
    }
  
    tabix_region <- function(file, region){
        data.table::fread(cmd = paste0("tabix -h ", file, " ", region))%>%as_tibble() 
    }
  
    region <- "${_meta['region']}"
    phenotype_path <- unlist(strsplit("${_meta['path']}", ","))

    dat <- tryCatch(
      {
        # Try to run the function
         load_combined_matrix_data(phenotype_path = phenotype_path, region = region)
      },
      error = function(e) {
        message("gsub chr in region id...")
        # If an error occurs, modify the region and try again
         load_combined_matrix_data(phenotype_path = phenotype_path, region =  gsub("chr", "", region))
      }
    )
      saveRDS(dat, ${_output:r})

In [None]:
# extract data for MASH from summary stats
[processing_2]
input:  group_by = per_chunk
output: f"{cwd}/{name}_cache/{name}_batch{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }",stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', container = container
    library(dplyr)
    library(stringr)
    set.seed(${seed})
    #library(huiiy)
    matxMax <- function(mtx) {
      return(arrayInd(which.max(mtx), dim(mtx)))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    handle_nan_etc = function(x) {
      x$bhat[which(is.nan(x$bhat))] = 0
      x$sbhat[which(is.nan(x$sbhat) | is.infinite(x$sbhat))] = 1E3
      return(x)
    }
    extract_one_data = function(dat, n_random, n_null, infile, na_remove = TRUE) {
        if (is.null(dat)) return(NULL)
        if(na_remove == TRUE){
          na.info = list()
          na.info$n_bhat_ori = nrow(dat$${bhat})
          dat$bhat = na.omit(dat$${bhat})
          na.info$n_bhat = nrow(dat$${bhat})
          na.info$n_sbhat_ori = nrow(dat$${sbhat})
          dat$sbhat = na.omit(dat$${sbhat})
          na.info$n_sbhat_ori = nrow(dat$${sbhat})
          msg = paste(c("Out of ",na.info$n_bhat_ori," SNP, ",na.info$n_bhat," was retained for analysis"), collapse = "")
          message(msg)
          if (na.info$n_bhat == 0){
            stop("None of the SNP was retained for analysis, skipping genes") }
        }
        z = abs(dat$${bhat}/dat$${sbhat})
        # random samples can include the real signals 
        sample_idx = 1:nrow(z)
        
        random_idx = sample(sample_idx, min(n_random, length(sample_idx)), replace = F)
        random = list(bhat = dat$${bhat}[random_idx,,drop=F], sbhat = dat$${sbhat}[random_idx,,drop=F])
        # null samples defined as |z| < 2
        null.id = which(apply(abs(z), 1, max) < 2)
        if (length(null.id) == 0) {
          warning(paste("Null data is empty for input file", infile))
          null = list()
        } else {
          null_idx = sample(null.id, min(n_null, length(null.id)), replace = F)
          null = list(bhat = dat$${bhat}[null_idx,,drop=F], sbhat = dat$${sbhat}[null_idx,,drop=F])
        }
        dat = (list(random = remove_rownames(random), null = remove_rownames(null)))
        dat$random = handle_nan_etc(dat$random)
        dat$null = handle_nan_etc(dat$null)
        return(dat)
    }
    reformat_data = function(dat, z_only = TRUE) {
        # make output consistent in format with 
        # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb      
        res = list(random.z = dat$random$bhat/dat$random$sbhat, 
                  null.z = dat$null$bhat/dat$null$sbhat)
        if (!z_only) {
          res = c(res, list(random.b = dat$random$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat))
      }
      return(res)
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(is.null(res$random.b)|is.null(res$null.b))) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
            if (is.null(one_data[[d]])) {
              next
            } else {
                res[[d]] = as.matrix(rbind(res[[d]],as.data.frame(one_data[[d]])))
            }
          }
          return(res)
      }
    }
    res = list()
    signals.df<-NULL
 
    
    for (f in c(${_input:r,})) {
      # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
      dat = tryCatch(readRDS(f), error = function(e) return(NULL))${("$"+table_name) if table_name != "" else ""}
      if (is.null(dat)) {
          message(paste("Skip loading file", f, "due to load failure."))
          next
      }
      if (${expected_ncondition} > 0 && (ncol(dat$${bhat}) != ${expected_ncondition} || ncol(dat$${sbhat}) != ${expected_ncondition})) {
          message(paste("Skip loading file", f, "because it has", ncol(dat$${bhat}), "columns different from required", ${expected_ncondition}))
          next
      }
      if(length(c(${",".join([repr(x) for x in exclude_condition])})) > 0 ){
      message(paste("Excluding condition ${exclude_condition} from the analysis"))
      dat$bhat = dat$bhat[,-c(${",".join(exclude_condition)})]
      dat$sbhat = dat$sbhat[,-c(${",".join(exclude_condition)})]
      dat$Z = dat$Z[,-c(${",".join(exclude_condition)})]
      }

      dat<-tryCatch(extract_one_data(dat, ${n_random}, ${n_null}, f, ${na_remove}), error = function(e) return(NULL))
      res = tryCatch(merge_data(res, reformat_data(dat , ${"TRUE" if z_only else "FALSE"})), error = function(e) message("Skipping gene due to lack of SNPs"))
      
    saveRDS(res, ${_output:r})}


In [None]:
[processing_3]
input: group_by = "all"
output: f"{cwd}/{name}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", container = container,stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', volumes = [f'{cwd:ad}:{cwd:ad}']
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
            res[[d]] = rbind(res[[d]], one_data[[d]])
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    saveRDS(dat, ${_output:r})
 
bash: expand = "${ }", container = container,stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', volumes = [f'{cwd:ad}:{cwd:ad}']
    rm -rf ${cwd}/${name}_cache/