# Reformat the top signals from SuSiE
This notebook is to extract strong and random signals from the SuSie extracted signals so that they can be fed into flashr for mixture prior and mashr.

At th momenet, 
1. the NA would be removed by default:na_remove = TRUE, can be kept with na_remove = FALSE
2. nan beta will be set to 0 

In [None]:
nohup sos run ~/xqtl-pipeline-changed/pipeline/Signal_Extraction_from_SuSiE.V2.ipynb extract_effects_fromsusie    \
#--na_remove FALSE \
--name  Ast_Exc_Inh_Mic_OPC_Oli.with.na  \
--sum_stat '/mnt/vast/hpc/csg/rf2872/Work/MASH_test_csg/output/ALL_Ast_End_Exc_Inh_Mic_OPC_Oli.merged_rds.list'   \
--susie '/mnt/vast/hpc/csg/rf2872/Work/SuSiE_MASH/Extra_signals/SuSiE_output_signals_table.new'  \
--gene_ref '/mnt/vast/hpc/csg/molecular_phenotype_calling/reference_data/Homo_sapiens.GRCh38.103.chr.reformatted.collapse_only.gene.region_list'  \
--exclude_condition 1,3    \
-s force -J 20 -q csg2 -l t_pri -c ~/test/csg2.yml   &> mash_extr_effects.with.na.log &


## Global parameters

In [None]:
[global]
import os
# Work directory & output directory
parameter: cwd = path('./output')
# The filename prefix for output data
parameter: name = str
parameter: job_size = 1
parameter: container = ''
parameter: table_name = ""
parameter: bhat = "bhat"
parameter: sbhat = "sbhat"
parameter: expected_ncondition = 0
##  conditions can be excluded if needs arise. If nothing to exclude keep the default 0
parameter: exclude_condition = ["1","3"]
parameter: datadir = ""
parameter: seed = 999
parameter: n_random = 4
parameter: n_null = 4
parameter: z_only = False
parameter: na_remove = "TRUE"
# Analysis units file with 1 column being the cell types
import pandas as pd
#parameter: analysis_units = path
# handle N = per_chunk data-set in one job
parameter: per_chunk = 1000
#regions = [x.replace("\"","").strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
parameter: susie = path
parameter: sum_stat = path
parameter: gene_ref = ''
parameter: gene_ref_ncondition = 0

## Get top, random and null effects per analysis unit

In [None]:
# extract data for MASH with Extracted signals from SuSiE output 
[extract_effects_fromsusie]
input: susie,sum_stat
output: f"{cwd}/{name}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }",stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', container = container
    library(org.Hs.eg.db)
    library(dplyr)
    library(stringr)
    set.seed(${seed})
    matxMax <- function(mtx) {
      return(arrayInd(which.max(mtx), dim(mtx)))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    handle_nan_etc = function(x) {
      x$bhat[which(is.nan(x$bhat))] = 0
      x$sbhat[which(is.nan(x$sbhat) | is.infinite(x$sbhat))] = 1E3
      return(x)
    }
    
    reformat_data = function(dat, z_only = TRUE) {
        # make output consistent in format with 
        # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb      
        res = list(random.z = dat$random$bhat/dat$random$sbhat, 
                  strong.z = dat$strong$bhat/dat$strong$sbhat,  
                  null.z = dat$null$bhat/dat$null$sbhat)
        if (!z_only) {
          res = c(res, list(random.b = dat$random$bhat,
           strong.b = dat$strong$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat,
           strong.s = dat$strong$sbhat))
      }
      return(res)
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
            if (is.null(one_data[[d]])) {
              next
            } else {
                res[[d]] = rbind(res[[d]], one_data[[d]])
            }
          }
          return(res)
      }
    }
  
    extract_susiesignal = function(dat, sus.sum,n_random, n_null, infile, na_remove = TRUE) {
        if (is.null(dat)) return(NULL)
        if(na_remove == TRUE){
          dat$bhat = na.omit(dat$bhat)
          dat$sbhat = na.omit(dat$sbhat)
          }
        signals.su<- sus.sum%>%filter(path %in% infile)%>%.[,"signals"]
        signals<-intersect(signals.su, rownames(dat$bhat))
        outlier<-setdiff(signals.su, rownames(dat$bhat))     
        strong = list(bhat = dat$bhat[signals,,drop=F], sbhat = dat$sbhat[signals,,drop=F])
        sample_idx<-setdiff(rownames(dat$bhat), signals)
        random_idx = sample(sample_idx, min(n_random, length(sample_idx)), replace = F)
        random = list(bhat = dat$bhat[random_idx,,drop=F], sbhat = dat$sbhat[random_idx,,drop=F])
        z = abs(dat$${bhat}/dat$${sbhat})
        # null samples defined as |z| < 2
        null.id = which(apply(abs(z), 1, max) < 2)
        if (length(null.id) == 0) {
          warning(paste("Null data is empty for input file", infile))
          null = list()
        } else {
          null_idx = sample(null.id, min(n_null, length(null.id)), replace = F)
          null = list(bhat = dat$bhat[null_idx,,drop=F], sbhat = dat$sbhat[null_idx,,drop=F])
        }
        dat = (list(random = remove_rownames(random), null = remove_rownames(null), strong = remove_rownames(strong)))
        dat$random = handle_nan_etc(dat$random)
        dat$null = handle_nan_etc(dat$null)
        dat$strong = handle_nan_etc(dat$strong)
        com.dat<-list(dat=dat,signals=signals,outlier=outlier)
        return(com.dat)
    }
    
    # intersect the output from SuSiE with sumstat file. 
    sus.out<-read.table(${_input[0]:r,},header = T)    
    sum.out<-read.table(${_input[1]:r,},header = F)
    
    if(${gene_ref_ncondition}>0){
        
    #always get gene_symbol from ensembl ID, instead of convert
    colnames(sum.out)<-c("gene_symbol","path")
    gene.ref<-read.table("${gene_ref}")
    sum.out$gene <- gene.ref[match(sum.out$gene_symbol,gene.ref$V5),"V4"]} else {
    colnames(sum.out)<-c("gene","path")}
  
    sus.sum<-merge(sus.out,sum.out,by="gene")
    files<-unique(sus.sum$path)
    #sus.sum<-sus.sum[-which(duplicated(sus.sum$path)),] 
    #files<-sus.sum$path
    res = list()
    signals.df<-NULL
    outlier.df<-NULL
    for (f in files) {
      # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
      dat = tryCatch(readRDS(f), error = function(e) return(NULL))${("$"+table_name) if table_name != "" else ""}
      if (is.null(dat)) {
          message(paste("Skip loading file", f, "due to load failure."))
          next
      }
      if (${expected_ncondition} > 0 && (ncol(dat$${bhat}) != ${expected_ncondition} || ncol(dat$${sbhat}) != ${expected_ncondition})) {
          message(paste("Skip loading file", f, "because it has", ncol(dat$${bhat}), "columns different from required", ${expected_ncondition}))
          next
      }
      if(c(${",".join(exclude_condition)})[1] > 0 ){
      message(paste("Excluding condition ${exclude_condition} from the analysis"))
      dat$bhat = dat$bhat[,-c(${",".join(exclude_condition)})]
      dat$sbhat = dat$sbhat[,-c(${",".join(exclude_condition)})]
      dat$Z = dat$Z[,-c(${",".join(exclude_condition)})]
      }

      com.dat<-tryCatch( extract_susiesignal(dat, sus.sum, ${n_random}, ${n_null}, f, ${na_remove})  , error = function(e) return(NULL)) 
      res = tryCatch(merge_data(res, reformat_data(com.dat$dat, ${"TRUE" if z_only else "FALSE"})), error = function(e) message("Skipping gene due to lack of SNPs"))

      #genename<-str_split(f,"[.]")%>%unlist%>%.[(length(.)-1)]
      genename<-sum.out[sum.out$path==f,ifelse(${gene_ref_ncondition}>0,"gene_symbol","gene")]
      signals.df<-rbind(signals.df,data.frame(gene=rep(genename,length(com.dat$signals)),signals=com.dat$signals))
      outlier.df<-rbind(outlier.df,data.frame(gene=rep(genename,length(com.dat$outlier)),signals=com.dat$outlier))
    }
    # compute empirical covariance XtX
    X = res$strong.z
    X[is.na(X)] = 0
    res$XtX = t(as.matrix(X)) %*% as.matrix(X) / nrow(X)
    saveRDS(res, ${_output:r})
  
  
    res.sum<-list(signals=signals.df,outlier=outlier.df)  
    saveRDS(res, ${_output:r})
    saveRDS(res.sum,"${_output:n}.sum.rds")