# Top signal extraction
This notebook is to extract strong and random signals from the sumstat rds of each of the gene so that they can be fed into flashr for mixture prior and mashr.

At th momenet,      

1. the NA sumstat  will be removed

2. nan beta will be set to 0 

3. nan and infinit sumstat will be set to 1E3 
      
4. nan treatment( 2 and 3 ). are based on the original design in bioworkflow/mixture_prior 

## Global parameters

In [None]:
[global]
import os
# Work directory & output directory
parameter: cwd = path('./output')
# The filename prefix for output data
parameter: name = str
parameter: job_size = 1
parameter: container = ''
parameter: table_name = ""
parameter: bhat = "bhat"
parameter: sbhat = "sbhat"
parameter: expected_ncondition = 0
##  conditions can be excluded if needs arise. If nothing to exclude keep the default 0
parameter: exclude_condition = ["1","3"]
parameter: datadir = ""
parameter: seed = 999
parameter: n_random = 4
parameter: n_null = 4
parameter: z_only = False
# Analysis units file with 1 column being the cell types
import pandas as pd
parameter: analysis_units = path
# handle N = per_chunk data-set in one job
parameter: per_chunk = 1000
regions = [x.replace("\"","").strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]

## Get top, random and null effects per analysis unit

In [None]:
# extract data for MASH from summary stats
[extract_effects_1]
input: regions , group_by = per_chunk
output: f"{cwd}/{name}/cache/{name}_{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }",stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', container = container
    set.seed(${seed})
    matxMax <- function(mtx) {
      return(arrayInd(which.max(mtx), dim(mtx)))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    handle_nan_etc = function(x) {
      x$bhat[which(is.nan(x$bhat))] = 0
      x$sbhat[which(is.nan(x$sbhat) | is.infinite(x$sbhat))] = 1E3
      return(x)
    }
    extract_one_data = function(dat, n_random, n_null, infile) {
        if (is.null(dat)) return(NULL)
        na.info = list()
        na.info$n_bhat_ori = nrow(dat$${bhat})
        dat$bhat = na.omit(dat$${bhat})
        na.info$n_bhat = nrow(dat$${bhat})
        na.info$n_sbhat_ori = nrow(dat$${sbhat})
        dat$sbhat = na.omit(dat$${sbhat})
        na.info$n_sbhat_ori = nrow(dat$${sbhat})
        msg = paste(c("Out of ",na.info$n_bhat_ori," SNP, ",na.info$n_bhat," was retained for analysis"), collapse = "")
        message(msg)
        if (na.info$n_bhat == 0){
          stop("None of the SNP was retained for analysis, skipping genes") }
        z = abs(dat$${bhat}/dat$${sbhat})
        max_idx = matxMax(z)
        if (is.null(max_idx)) return(NULL)
        # strong effect samples
        strong = list(bhat = dat$${bhat}[max_idx[1],,drop=F], sbhat = dat$${sbhat}[max_idx[1],,drop=F])
        # random samples excluding the top one
        if (max_idx[1] == 1) {
            sample_idx = 2:nrow(z)
        } else if (max_idx[1] == nrow(z)) {
            sample_idx = 1:(max_idx[1]-1)
        } else {
            sample_idx = c(1:(max_idx[1]-1), (max_idx[1]+1):nrow(z))
        }
        random_idx = sample(sample_idx, min(n_random, length(sample_idx)), replace = F)
        random = list(bhat = dat$${bhat}[random_idx,,drop=F], sbhat = dat$${sbhat}[random_idx,,drop=F])
        # null samples defined as |z| < 2
        null.id = which(apply(abs(z), 1, max) < 2)
        if (length(null.id) == 0) {
          warning(paste("Null data is empty for input file", infile))
          null = list()
        } else {
          null_idx = sample(null.id, min(n_null, length(null.id)), replace = F)
          null = list(bhat = dat$${bhat}[null_idx,,drop=F], sbhat = dat$${sbhat}[null_idx,,drop=F])
        }
        dat = (list(random = remove_rownames(random), null = remove_rownames(null), strong = remove_rownames(strong)))
        dat$random = handle_nan_etc(dat$random)
        dat$null = handle_nan_etc(dat$null)
        dat$strong = handle_nan_etc(dat$strong)
        return(dat)
    }
    reformat_data = function(dat, z_only = TRUE) {
        # make output consistent in format with 
        # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb      
        res = list(random.z = dat$random$bhat/dat$random$sbhat, 
                  strong.z = dat$strong$bhat/dat$strong$sbhat,  
                  null.z = dat$null$bhat/dat$null$sbhat)
        if (!z_only) {
          res = c(res, list(random.b = dat$random$bhat,
           strong.b = dat$strong$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat,
           strong.s = dat$strong$sbhat))
      }
      return(res)
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
            if (is.null(one_data[[d]])) {
              next
            } else {
                res[[d]] = rbind(res[[d]], one_data[[d]])
            }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
      dat = tryCatch(readRDS(f), error = function(e) return(NULL))${("$"+table_name) if table_name != "" else ""}
      if (is.null(dat)) {
          message(paste("Skip loading file", f, "due to load failure."))
          next
      }
      if (${expected_ncondition} > 0 && (ncol(dat$${bhat}) != ${expected_ncondition} || ncol(dat$${sbhat}) != ${expected_ncondition})) {
          message(paste("Skip loading file", f, "because it has", ncol(dat$${bhat}), "columns different from required", ${expected_ncondition}))
          next
      }
      if(c(${",".join(exclude_condition)})[1] > 0 ){
      message(paste("Excluding condition ${exclude_condition} from the analysis"))
      dat$bhat = dat$bhat[,-c(${",".join(exclude_condition)})]
      dat$sbhat = dat$sbhat[,-c(${",".join(exclude_condition)})]
      dat$Z = dat$Z[,-c(${",".join(exclude_condition)})]
      }

  
      res = tryCatch(merge_data(res, reformat_data(tryCatch( extract_one_data(dat, ${n_random}, ${n_null}, f)  , error = function(e) return(NULL)) , ${"TRUE" if z_only else "FALSE"})), error = function(e) message("Skipping gene due to lack of SNPs"))
    }
      
    saveRDS(res, ${_output:r})

In [None]:
[extract_effects_2]
input: group_by = "all"
output: f"{cwd}/{name}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", container = container,stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', volumes = [f'{cwd:ad}:{cwd:ad}']
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
            res[[d]] = rbind(res[[d]], one_data[[d]])
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    # compute empirical covariance XtX
    X = dat$strong.z
    X[is.na(X)] = 0
    dat$XtX = t(as.matrix(X)) %*% as.matrix(X) / nrow(X)
    saveRDS(dat, ${_output:r})