# Whole genome prior computation

This notebook implements a TWAS analysis workflow using multivariate susie.

## Aim

TBD

## Overview (TBD)

__Objective__: 
    To Compute the association between expression and SNP in prepare for TWAS analysis via genotype and multiple molecular phenotype

__Background__:
    SNP can modulate the functional phenotypes both directly and by modulating the expression levels of genes. 
Therefore, the integration of expression measurements and a larger scale GWAS summary association statistics will help identify the genes associated with the targeted complex traits. 

__Significance__:
    By applying this method, new candidate genes whose expression level is significantly associated with complex traits can be used in prediction without actually going through the expensive gene expression measurement process. As a relatively small set of gene expression and genotyping, data can be used to impute the expression for a much larger set of phenotyped individuals from their SNP genotype data. 

__Method__:
    The imputed expression can then be viewed as a linear model of genotypes with _weights based on the correlation between SNPs and gene expression__ in the training data while accounting for linkage disequilibrium (LD) SNPs. We then correlated the imputed gene expression to the trait to perform a transcriptome-wide association study (TWAS) and identify significant expression-trait associations. 
 
The weights are computed via multivariate susies, the accuracy of such weights are computed via by default 100 times five fold cross validation.


## Pre-requisites

We provide a container image `docker://gaow/twas` that contains all software needed to run the pipeline. If you would like to configure it by yourself, please make sure you install the following software before running this notebook:
- tidyverse
- PLINK
- R package mashr
- R package mmbr
- Output from the following univatiate analysis pipeline: twas_fusion_susie.ipynb

# Input and Output(TBD)
## Input


# Command interface (TBD)

# Working example 
A minimal working example (MWE) dataset that can be downloaded from the following link, which required access:https://drive.google.com/drive/u/0/folders/1N3PbH9hfv5eAikGHI58sXjVsssFkekWj

To test the command, please download the mwe folder and run the first among the following command within the mwe/data folder.

The time it take to run this MWE shall be around 5 minutes.

In [3]:
# Test the pipeline with MWE


nohup sos run /home/hs3163/GIT/neuro-twas/Workflow/wg_prior_genome.ipynb mm_prior \
--molecular_pheno_dir ./molc_dir/    \
--rds_list ./rds_list_mwe  \
--wd   ./ \
--name "geneTpmResidualsAgeGenderAdj_rename" \
--constraint "Both" \
--container /mnt/mfs/statgen/containers/twas_latest.sif &





## Actual pipeline running.

nohup sos run ~/GIT/neuro-twas/Workflow/wg_prior_genome.ipynb mm_prior \
--molecular_pheno_dir /home/hs3163/Project/Genome_prior/data/molc_dir    \
--rds_list /home/hs3163/Project/Genome_prior/data/rds_list_test  \
--wd   /home/hs3163/Project/Genome_prior/test \
--name "geneTpmResidualsAgeGenderAdj_rename" \
--container /mnt/mfs/statgen/containers/twas_latest.sif \
--constraint "Both" -s build &



## Actual pipeline running.

nohup sos run ~/GIT/neuro-twas/Workflow/wg_prior_genome.ipynb flash \
--molecular_pheno_dir /home/hs3163/Project/Genome_prior/data/molc_dir    \
--rds_list /home/hs3163/Project/Genome_prior/data/rds_list  \
--wd   /home/hs3163/Project/Genome_prior/merge \
--name "geneTpmResidualsAgeGenderAdj_rename" \
--container /mnt/mfs/statgen/containers/twas_latest.sif \
--constraint "Both" \
-J 10 -q csg -c /mnt/mfs/statgen/pbs_template/csg.yml -s build &


## Actual pipeline running.

nohup sos run ~/GIT/neuro-twas/Workflow/wg_prior_genome.ipynb flash \
--molecular_pheno_dir /home/hs3163/Project/Genome_prior/data/molc_dir    \
--rds_list /home/hs3163/Project/Genome_prior/data/rds_list  \
--wd   /home/hs3163/Project/Genome_prior/merge \
--name "geneTpmResidualsAgeGenderAdj_rename" \
--container /mnt/mfs/statgen/containers/twas_latest.sif \
--constraint "Both" \
-s build &




ERROR: Error in parse(text = x, srcfile = src): <text>:3:7: unexpected symbol
2: 
3: nohup sos
         ^


# Global parameter settings
The section outlined the parameters that can be set in the command interface.

In [5]:
[global]
# Path to a list of folders in which the rds are to be analysised
parameter: molecular_pheno_dir = path

# Path to a file that lists all the rds to be combined and analysis
parameter: rds_list = path
# Path to the work directory of this pipeline,where the output will be stored.
parameter: wd = path
# Specify the number of jobs per run.
parameter: job_size = 2
# Container option for software to run the analysis: docker or singularity
parameter: container = 'gaow/twas'

# Input data directory 
parameter: datadir = path('{wd:a}/input')
# Work directory & output directory
parameter: cwd = path('{wd:a}/output')
# The filename prefix for output data
parameter: name = str
# handle N = per_chunk data-set in one job
parameter: per_chunk = 1000


# Whether run flash with constriant or not,default both, ('Constraint', 'No_Constraint' , 'Both')
parameter: constraint = 'Both'


import glob
# Get rds of interest to focus on.
regions = [x.strip().split() for x in open(rds_list).readlines() if x.strip() and not x.strip().startswith('#')]
molecular_pheno = [x.strip().split() for x in open(molecular_pheno_dir).readlines() if x.strip() and not x.strip().startswith('#')]

In [1]:
[mm_prior_1,merge_1]
input:  molecular_pheno_dir, for_each = "regions"
output: f'{wd:a}/input/{_regions[0]}'

task: trunk_workers = 1, trunk_size = job_size, walltime = '4h',  mem = '6G', tags = f'{step_name}_{_output[0]:bn}'  

R: expand = "$[ ]", stderr = f'{_output}.stderr', stdout = f'{_output}.stdout',container = container
    library("dplyr")
    library("tibble")
    library("purrr")
    library("readr")
    molecular_pheno = read_delim("$[molecular_pheno_dir]",delim = "\t")
    molecular_pheno = molecular_pheno%>%mutate(dir = map_chr(`#molc_pheno`,~paste(c(`.x`,"$[_regions[0]]"),collapse = "")))
    n = nrow(molecular_pheno)
    # For every condition read rds and extract the bhat and sbhat.
    genos = tibble( i = 1:n)
    genos = genos%>%mutate(bhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$bhat%>%as.data.frame%>%rownames_to_column),
                           sbhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$sbhat%>%as.data.frame%>%rownames_to_column))

                      
    # Join first two conditions
    genos_join_bhat = full_join((genos%>%pull(bhat))[[1]],(genos%>%pull(bhat))[[2]],by = "rowname")
    genos_join_sbhat = full_join((genos%>%pull(sbhat))[[1]],(genos%>%pull(sbhat))[[2]],by = "rowname")
    
    # If there are more conditions, join the rest
    if(n > 2){
    for(j in 3:n){
    genos_join_bhat = full_join(genos_join_bhat,(genos%>%pull(bhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
    genos_join_sbhat = full_join(genos_join_sbhat,(genos%>%pull(sbhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
    }
    }
    sumstats = list(bhat=genos_join_bhat, sbhat=genos_join_sbhat)
    # save the rds file
    saveRDS(file = "$[_output]", list(sumstats = sumstats))



In [None]:
[mm_prior_2,extract_1]

parameter: seed = 999
parameter: n_random = 4
parameter: n_null = 4

input: glob.glob(f'{wd:a}/input/*.rds'), group_by = per_chunk
output: f'{wd:a}/cache/{name}_{_index+1}.rds'

task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    set.seed(${seed})
    matxMax <- function(mtx) {
      return(arrayInd(which.max(mtx), dim(mtx)))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    extract_one_data = function(infile, n_random, n_null) {
        # If cannot read the input for some reason then we just skip it, assuming we have enough other data-sets to use.
        dat = tryCatch(readRDS(infile)$sumstats, error = function(e) return(NULL))
        if (is.null(dat)) return(NULL)
        z = abs(dat$bhat/dat$sbhat)
        max_idx = matxMax(z)
        # strong effect samples
        strong = list(bhat = dat$bhat[max_idx[1],,drop=F], sbhat = dat$sbhat[max_idx[1],,drop=F])
        # random samples excluding the top one
        if (max_idx[1] == 1) {
            sample_idx = 2:nrow(z)
        } else if (max_idx[1] == nrow(z)) {
            sample_idx = 1:(max_idx[1]-1)
        } else {
            sample_idx = c(1:(max_idx[1]-1), (max_idx[1]+1):nrow(z))
        }
        random_idx = sample(sample_idx, n_random, replace = T)
        random = list(bhat = dat$bhat[random_idx,,drop=F], sbhat = dat$sbhat[random_idx,,drop=F])
        # null samples defined as |z| < 2
        null.id = which(apply(abs(z), 1, max) < 2)
        null_idx = sample(null.id, n_null, replace = F)
        null = list(bhat = dat$bhat[null_idx,,drop=F], sbhat = dat$sbhat[null_idx,,drop=F])
        return(list(file = infile,random = remove_rownames(random), null = remove_rownames(null), strong = remove_rownames(strong)))
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
              for (s in names(one_data[[d]])) {
                  res[[d]][[s]] = rbind(res[[d]][[s]], one_data[[d]][[s]])
              }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      res = merge_data(res, extract_one_data(f, ${n_random}))
    }
    saveRDS(res, ${_output:r})

In [None]:
[mm_prior_3,extract_2]
input: group_by = "all"
output: f'{wd:a}/output/{name}.rds'
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '6G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
              for (s in names(one_data[[d]])) {
                  res[[d]][[s]] = rbind(res[[d]][[s]], one_data[[d]][[s]])
              }
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    # make output consistent in format with 
    # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb
    saveRDS(
          list(random.z = dat$random$bhat/dat$random$sbhat,
           strong.z = dat$strong$bhat/dat$strong$sbhat,
           null.z = dat$null$bhat/dat$null$sbhat,
           random.b = dat$random$bhat,
           strong.b = dat$strong$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat,
           strong.s = dat$strong$sbhat,
           file = c(${_input:r,})
              ),
          ${_output:r})

In [None]:
# Perform FLASH analysis (time estimate: 20min)
[mm_prior_4,flash]
# default method for convex optimization
parameter: optmethod = "mixSQP"
parameter: flash_optmethod = "mixSQP"

input: f'{wd:a}/output/{name}.rds'
output: f'{wd:a}/output/{name}.{constraint}.flash.rds'

task: trunk_workers = 1, walltime = '2h', trunk_size = 1, mem = '8G', cores = 2, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output}.stderr', stdout = f'{_output}.stdout',container = container
    library(flashr)
    library(mixsqp)
    library(mashr)
    my_init_fn <- function(Y, K = 1) {
      ret = flashr:::udv_si(Y, K)
      pos_sum = sum(ret$v[ret$v > 0])
      neg_sum = -sum(ret$v[ret$v < 0])
      if (neg_sum > pos_sum) {
        return(list(u = -ret$u, d = ret$d, v = -ret$v))
      } else
      return(ret)
    }

    flash_pipeline = function(data,init_fn = "my_init_fn", ...) {
      ## Non-negative FLASH suggested by Jason Willwerscheid for when the non-negative factor assumption is reasonable
      ## cf: discussion section of
      ## https://willwerscheid.github.io/MASHvFLASH/MASHvFLASHnn2.html
      ebnm_fn = "ebnm_ash"
      ebnm_param = list(l = list(mixcompdist = "normal",
                               optmethod = "${flash_optmethod}"),
                        f = list(mixcompdist = "+uniform",
                               optmethod = "${flash_optmethod}"))
      ##
      fl_g <- flashr:::flash_greedy_workhorse(data,
                    var_type = "constant",
                    ebnm_fn = ebnm_fn,
                    ebnm_param = ebnm_param,
                    init_fn = init_fn,
                    stopping_rule = "factors",
                    tol = 1e-3,
                    verbose_output = "odF")
      fl_b <- flashr:::flash_backfit_workhorse(data,
                    f_init = fl_g,
                    var_type = "constant",
                    ebnm_fn = ebnm_fn,
                    ebnm_param = ebnm_param,
                    stopping_rule = "factors",
                    tol = 1e-3,
                    verbose_output = "odF")
      return(fl_b)
    }

    r1cov=function(x){x %*% t(x)}
    cov_from_factors = function(f, name){
       Ulist = list()
       for(i in 1:nrow(f)){
         Ulist = c(Ulist,list(r1cov(f[i,])))
       }
       names(Ulist) = paste0(name,"_",(1:nrow(f)))
       return(Ulist)
     }

    ## HS, is this function consistant with the ez model? 
    cov_flash = function(data, subset = NULL, non_singleton = FALSE, save_model = NULL) {
      #if(is.null(subset)) subset = 1:mashr:::n_effects(data)
      #b.center = apply(data$Bhat[subset,], 2, function(x) x - mean(x))
      ## Only keep factors with at least two values greater than 1 / sqrt(n)
    b.center = data$Bhat
    find_nonunique_effects <- function(fl) {
        thresh <- 1/sqrt(ncol(fl$fitted_values))
        vals_above_avg <- colSums(fl$ldf$f > thresh)
        nonuniq_effects <- which(vals_above_avg > 1)
        return(fl$ldf$f[, nonuniq_effects, drop = FALSE])
      }

      fmodel = flash_pipeline(b.center)
      if (non_singleton)
          flash_f = find_nonunique_effects(fmodel)
      else 
          flash_f = fmodel$ldf$f
      ## row.names(flash_f) = colnames(b)
      if (!is.null(save_model)) saveRDS(list(model=fmodel, factors=flash_f), save_model)
      if(ncol(flash_f) == 0){
        U.flash = list("tFLASH" = t(fmodel$fitted_values) %*% fmodel$fitted_values / nrow(fmodel$fitted_values))
      } else{
        U.flash = c(cov_from_factors(t(as.matrix(flash_f)), "FLASH"),
                    list("tFLASH" = t(fmodel$fitted_values) %*% fmodel$fitted_values / nrow(fmodel$fitted_values)))
      }
      return(U.flash)
    }
    ##
    dat = readRDS("${_input}")
    if("${constraint}" == "Both" |"${constraint}" == "No_Constraint" ){
    f.d = flash_set_data(as.matrix(dat$strong.z))
    #ycenter = apply(f.d$Y, 2, function(x) x - mean(x))
    #f.d$Y = ycenter
    f = flash_pipeline(f.d,init_fn = "udv_si")
    }              
    if("${constraint}" == "Both"){
    dat = mash_set_data(dat$strong.b, dat$strong.s, alpha= 1, zero_Bhat_Shat_reset = 1E3)
    res = cov_flash(dat, non_singleton = TRUE, save_model = "${_output}.model.rds")
    res = c(res,mashr:::cov_from_factors(t(as.matrix(f$ldf$f)), "FLASH_NC"),
                list("tFLASH_NC" = t(f$fitted_values) %*% f$fitted_values / nrow(f$fitted_values)))                                                     
    }else if ( "${constraint}" == "Constraint" ){
    dat = mash_set_data(dat$strong.b, dat$strong.s, alpha= 1, zero_Bhat_Shat_reset = 1E3)
    res = cov_flash(dat, non_singleton = TRUE, save_model = "${_output}.model.rds")
    }else{
    res = c(mashr:::cov_from_factors(t(as.matrix(f$ldf$f)), "FLASH_NC"),
                list("tFLASH_NC" = t(f$fitted_values) %*% f$fitted_values / nrow(f$fitted_values))) 
    }                                                      
    saveRDS(res, "${_output}")

In [None]:
[mm_prior_5,mash_ed_1, udr_ed_1,teem_ed_1,ed_1]
parameter: npc = 3
input:  f'{wd:a}/output/{name}.rds', f'{wd:a}/output/{name}.{constraint}.flash.rds'
output: f'{_input[1]:n}.FL_PC{npc}.rds'

task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '8G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout",container = container
    library(mashr)
    dat = readRDS(${_input[0]:r})
    V = cor(dat$null.z)
    X = dat$strong.z
    X[is.na(X)] = 0
    mash_data = mash_set_data(dat$strong.b, Shat=dat$strong.s, alpha=1, V=V, zero_Bhat_Shat_reset = 1E3)
    # FLASH matrices
    U.flash = readRDS(${_input[1]:r})
    # SVD matrices
    U.pca = ${"cov_pca(mash_data, %s)" % npc if npc > 0 else "list()"}
    # Emperical cov matrix
    X.center = apply(X, 2, function(x) x - mean(x))
    Ulist = c(U.flash, U.pca, list("XX" = t(X.center) %*% X.center / nrow(X.center)))
    saveRDS(list(mash_data = mash_data, Ulist = Ulist,X = X, V = V), ${_output:r})

In [None]:
[mash_ed_2]
input:  output_from('mash_ed_1')
output: f"{_input:n}.ED.rds"
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '8G', cores = 14, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout",container = container
    library(mashr)
    dat = readRDS(${_input:r})
    # Denoised data-driven matrices
    res = mashr:::bovy_wrapper(dat$mash_data, dat$Ulist, logfile=${_output:nr}, tol = 1e-06)
    # format to input for simulation with DSC (current pipeline)
    saveRDS(list(U=res$Ulist, w=res$pi, loglik=scan("${_output:nn}.ED_loglike.log")), ${_output:r}) 

In [None]:
[udr_ed_2]
input:  output_from('udr_ed_1')
output: f"{_input:n}.UD_ED.rds"
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '8G', cores = 14, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout",container = container
    library(udr) # udr commit 5265079 with changes to set lower bound on the eigenvalues
    dat = readRDS(${_input:r})
    # Denoised data-driven matrices
    f0 = ud_init(X = dat$X, V = dat$V, U_scaled = list(), U_unconstrained = dat$Ulist, n_rank1=0)
    res = ud_fit(f0, control = list(unconstrained.update = "ed", resid.update = 'none', maxiter=5000),
    verbose=FALSE)
    # format to input for simulation with DSC (current pipeline)
    saveRDS(list(U=res$U, w=res$w, loglik=res$loglik), ${_output:r})

In [None]:
[mm_prior_6,ed_2]
method = ["ed","teem"]
input: for_each = "method"
output: f"{_input:n}.{_method}.UD_ED.rds"
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '8G', cores = 14, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout",container = container
    library(udr) # udr commit 5265079 with changes to set lower bound on the eigenvalues
    dat = readRDS(${_input:r})
    # Denoised data-driven matrices
    f0 = ud_init(X = dat$X, V = dat$V, U_scaled = list(), U_unconstrained = dat$Ulist, n_rank1=0)
    res = ud_fit(f0, control = list(unconstrained.update = "${_method}", resid.update = 'none', maxiter=5000),
    verbose=FALSE)
    # format to input for simulation with DSC (current pipeline)
    saveRDS(list(U=res$U, w=res$w, loglik=res$loglik), ${_output:r})