# A multivariate EBNM approach for mixture multivariate distribution estimate
Steps for a workflow to generate mixture prior for multivariate susie analysis.

## Overview of approach

1. A workflow step is provided to merge PLINK univariate association analysis results to RDS files for extracting effect estimate samples
2. Estimated effects are analyzed by FLASH and PCA to extract patterns of sharing
3. Estimate the weights for patterns extracted from previous step

## Minimal working example

To see the input requirements and output data formats for the analysis steps, please [download a minimal working example here](https://www.synapse.org/#!Synapse:syn26071492), and run the following within the unzipped folder, the container used is gaow/twas:


nohup sos run mixture_prior.ipynb extract_effects \
        --analysis-units ./analysis_units_mwe.txt \
        --datadir ./RDS/ \
        --name 'geneTpmResidualsAgeGenderAdj_rename'  \
        --container gaow/twas \
        --cwd ./ &

nohup sos run mixture_prior.ipynb flash \
        --analysis-units ./analysis_units_mwe.txt \
        --datadir ./RDS/ \
        --name 'geneTpmResidualsAgeGenderAdj_rename'  \
        --container  gaow/twas \
        --cwd ./ &

## Global parameters

In [None]:
[global]
# Work directory & output directory
parameter: cwd = path('./output')
# The filename prefix for output data
parameter: name = str
parameter: mixture_components = ['flash', 'flash_nonneg', 'pca', 'canonical']
parameter: container = 'gaow/twas'
parameter: job_size = 1

## Merge univariate association summary statistic to RDS format

In [None]:
[merge]
parameter: molecular_pheno = path
# Analysis units file. For RDS files it can be generated by `ls *.rds | sed 's/\.rds//g' > analysis_units.txt`
parameter: analysis_units = path
regions = [x.strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
input:  molecular_pheno, for_each = "regions"
output: f'{cwd:a}/RDS/{_regions[0]}'

task: trunk_workers = 1, trunk_size = job_size, walltime = '4h',  mem = '6G', tags = f'{step_name}_{_output:bn}'  

R: expand = "$[ ]", stderr = f'{_output}.stderr', stdout = f'{_output}.stdout',container = container
    library("dplyr")
    library("tibble")
    library("purrr")
    library("readr")
    molecular_pheno = read_delim($[molecular_pheno:r], delim = "\t")
    molecular_pheno = molecular_pheno%>%mutate(dir = map_chr(`#molc_pheno`,~paste(c(`.x`,"$[_regions[0]]"),collapse = "")))
    n = nrow(molecular_pheno)
    # For every condition read rds and extract the bhat and sbhat.
    genos = tibble( i = 1:n)
    genos = genos%>%mutate(bhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$bhat%>%as.data.frame%>%rownames_to_column),
                           sbhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$sbhat%>%as.data.frame%>%rownames_to_column))
                      
    # Join first two conditions
    genos_join_bhat = full_join((genos%>%pull(bhat))[[1]],(genos%>%pull(bhat))[[2]],by = "rowname")
    genos_join_sbhat = full_join((genos%>%pull(sbhat))[[1]],(genos%>%pull(sbhat))[[2]],by = "rowname")
    
    # If there are more conditions, join the rest
    if(n > 2){
        for(j in 3:n){
            genos_join_bhat = full_join(genos_join_bhat,(genos%>%pull(bhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
            genos_join_sbhat = full_join(genos_join_sbhat,(genos%>%pull(sbhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
        }
    }
    name = molecular_pheno%>%mutate(name = map(`#molc_pheno`, ~read.table(text = .x,sep = "/")),
                                    name = map_chr(name, ~.x[,ncol(.x)-2]%>%as.character) )%>%pull(name)
    colnames(genos_join_bhat) = name
    colnames(genos_join_sbhat) = name
    
    # save the rds file
    saveRDS(file = "$[_output]", list(bhat=genos_join_bhat, sbhat=genos_join_sbhat))

## Get top, random and null effects per analysis unit

In [None]:
# extract data for MASH from summary stats
[extract_effects_1]
parameter: datadir = path
parameter: seed = 999
parameter: n_random = 4
parameter: n_null = 4
parameter: z_only = True
# Analysis units file. For RDS files it can be generated by `ls *.rds | sed 's/\.rds//g' > analysis_units.txt`
parameter: analysis_units = path
# handle N = per_chunk data-set in one job
parameter: per_chunk = 1000
regions = [x.strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
input: [f'{datadir}/{x[0]}.rds' for x in regions], group_by = per_chunk
output: f"{cwd}/{name}/cache/{name}_{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", container = container
    set.seed(${seed})
    matxMax <- function(mtx) {
      return(arrayInd(which.max(mtx), dim(mtx)))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    handle_nan_etc = function(x) {
      x$bhat[which(is.nan(x$bhat))] = 0
      x$sbhat[which(is.nan(x$sbhat) | is.infinite(x$sbhat))] = 1E3
      return(x)
    }
    extract_one_data = function(infile, n_random, n_null) {
        # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
        dat = tryCatch(readRDS(infile), error = function(e) return(NULL))
        # dat = readRDS(infile)
        if (is.null(dat)) return(NULL)
        z = abs(dat$bhat/dat$sbhat)
        max_idx = matxMax(z)
        # strong effect samples
        strong = list(bhat = dat$bhat[max_idx[1],,drop=F], sbhat = dat$sbhat[max_idx[1],,drop=F])
        # random samples excluding the top one
        if (max_idx[1] == 1) {
            sample_idx = 2:nrow(z)
        } else if (max_idx[1] == nrow(z)) {
            sample_idx = 1:(max_idx[1]-1)
        } else {
            sample_idx = c(1:(max_idx[1]-1), (max_idx[1]+1):nrow(z))
        }
        random_idx = sample(sample_idx, min(n_random, length(sample_idx)), replace = F)
        random = list(bhat = dat$bhat[random_idx,,drop=F], sbhat = dat$sbhat[random_idx,,drop=F])
        # null samples defined as |z| < 2
        null.id = which(apply(abs(z), 1, max) < 2)
        if (length(null.id) == 0) {
          warning(paste("Null data is empty for input file", infile))
          null = list()
        } else {
          null_idx = sample(null.id, min(n_null, length(null.id)), replace = F)
          null = list(bhat = dat$bhat[null_idx,,drop=F], sbhat = dat$sbhat[null_idx,,drop=F])
        }
        dat = (list(random = remove_rownames(random), null = remove_rownames(null), strong = remove_rownames(strong)))
        dat$random = handle_nan_etc(dat$random)
        dat$null = handle_nan_etc(dat$null)
        dat$strong = handle_nan_etc(dat$strong)
        return(dat)
    }
    reformat_data = function(dat, z_only = TRUE) {
        # make output consistent in format with 
        # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb      
        res = list(random.z = dat$random$bhat/dat$random$sbhat, 
                  strong.z = dat$strong$bhat/dat$strong$sbhat,  
                  null.z = dat$null$bhat/dat$null$sbhat)
        if (!z_only) {
          res = c(res, list(random.b = dat$random$bhat,
           strong.b = dat$strong$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat,
           strong.s = dat$strong$sbhat))
      }
      return(res)
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
            if (is.null(one_data[[d]])) {
              next
            } else {
                res[[d]] = rbind(res[[d]], one_data[[d]])
            }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      res = merge_data(res, reformat_data(extract_one_data(f, ${n_random}, ${n_null}), ${"TRUE" if z_only else "FALSE"}))
    }
    saveRDS(res, ${_output:r})

In [None]:
[extract_effects_2]
input: group_by = "all"
output: f"{cwd}/{name}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", container = container
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
            res[[d]] = rbind(res[[d]], one_data[[d]])
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    # compute null correlation matrix
    dat$null.cor = cor(dat$null.z)
    # compute empirical covariance XtX
    X = dat$strong.z
    X[is.na(X)] = 0
    dat$XtX = t(as.matrix(X)) %*% as.matrix(X) / nrow(X)
    saveRDS(dat, ${_output:r})

## Factor analyses

In [None]:
[flash]
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.flash.rds"
task: trunk_workers = 1, walltime = '6h', trunk_size = 1, mem = '8G', cores = 2, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout',container = container
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_flash(dat, factors="default", remove_singleton=${"TRUE" if "canonical" in mixture_components else "FALSE"}, output_model="${_output:n}.model.rds")
    saveRDS(res, ${_output:r})

In [None]:
[flash_nonneg]
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.flash_nonneg.rds"
task: trunk_workers = 1, walltime = '6h', trunk_size = 1, mem = '8G', cores = 2, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout',container = container
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_flash(dat, factors="nonneg", remove_singleton=${"TRUE" if "canonical" in mixture_components else "FALSE"}, output_model="${_output:n}.model.rds")
    saveRDS(res, ${_output:r})

In [None]:
[pca]
parameter: npc = 3
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.pca.rds"
task: trunk_workers = 1, walltime = '2h', trunk_size = 1, mem = '8G', cores = 2, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout',container = container
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_pca(dat, ${npc})
    saveRDS(res, ${_output:r})

In [None]:
[canonical]
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.canonical.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '8G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', container = container
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_canonical(dat)
    saveRDS(res, ${_output:r})

## Fit mixture model

In [None]:
# Installed commit d6d4c0e
[ud]
# Method is `ed` or `teem`
parameter: ud_method = "ed"
input: [f"{cwd}/{name}.rds"] + [f"{cwd}/{name}.{m}.rds" for m in mixture_components]
output: f"{cwd}/{name}.{ud_method}.rds"
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '10G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout",container = container
    rds_files = c(${_input:r,})
    dat = readRDS(rds_files[1])
    U = list(XtX = dat$XtX)
    for (f in rds_files[2:length(rds_files)]) U = c(U, readRDS(f))
    # Fit mixture model using udr package
    library(udr)
    message(paste("Running ${ud_method.upper()} via udr package for", length(U), "mixture components"))
    f0 = ud_init(X = as.matrix(dat$strong.z), V = dat$null.cor, U_scaled = list(), U_unconstrained = U, n_rank1=0)
    res = ud_fit(f0,X = na.omit(f0$X), control = list(unconstrained.update = "${ud_method}", resid.update = 'none', maxiter=5000, tol = 1e-06), verbose=TRUE)
    saveRDS(list(U=res$U, w=res$w, loglik=res$loglik), ${_output:r})

In [None]:
[ed]
input: [f"{cwd}/{name}.rds"] + [f"{cwd}/{name}.{m}.rds" for m in mixture_components]
output: f"{cwd}/{name}.ed_bovy.rds"
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '10G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout",container = container
    rds_files = c(${_input:r,})
    dat = readRDS(rds_files[1])
    U = list(XtX = dat$XtX)
    for (f in rds_files[2:length(rds_files)]) U = c(U, readRDS(f))
    # Fit mixture model using ED code by J. Bovy
    mash_data = mashr::mash_set_data(dat$strong.z, V=dat$null.cor)
    message(paste("Running ED via J. Bovy's code for", length(U), "mixture components"))
    res = mashr:::bovy_wrapper(mash_data, U, logfile=${_output:nr}, tol = 1e-06)
    saveRDS(list(U=res$Ulist, w=res$pi, loglik=scan("${_output:n}_loglike.log")), ${_output:r})

## Plot patterns of sharing

This is a simple utility function that takes the output from the pipeline above and make some heatmap to show major patterns of multivariate effects. The plots will be ordered by their mixture weights.

In [None]:
[plot_U]
parameter: model_data = path
# number of components to show
parameter: max_comp = -1
# whether or not to convert to correlation
parameter: to_cor = False
parameter: tol = "0.05"
parameter: remove_label = False
input: model_data
output: f'{cwd:a}/{_input:bn}{("_" + name.replace("$", "_")) if name != "" else ""}.pdf'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout',container = container
    plot_sharing = function(X, to_cor=FALSE, title="", remove_names=F) {
        clrs <- colorRampPalette(rev(c("#D73027","#FC8D59","#FEE090","#FFFFBF",
                                       "#E0F3F8","#91BFDB","#4575B4")))(128)
        if (to_cor) lat <- cov2cor(X)
        else lat = X/max(diag(X))
        lat[lower.tri(lat)] <- NA
        n <- nrow(lat)
        if (remove_names) {
          colnames(lat) = NULL
          rownames(lat) = NULL
        }
        return(lattice::levelplot(lat[n:1,],col.regions = clrs,
                                xlab = "",ylab = "", main= list(label= title,side=1,line=0.5, cex= 2),
                                colorkey = TRUE,at = seq(-1,1,length.out = 128),
                                scales = list(cex = 1.5 )))
    }
  
    dat = readRDS(${_input:r})
    name = "${name}"
    if (name != "") {
      if (is.null(dat[[name]])) stop("Cannot find data ${name} in ${_input}")
        dat = dat[[name]]
    }
    if (is.null(names(dat$U))) names(dat$U) = paste0("Comp_", 1:length(dat$U))
    meta = data.frame(names(dat$U), dat$w, stringsAsFactors=F)
    colnames(meta) = c("U", "w")
    tol = ${tol}
    n_comp = length(meta$U[which(dat$w>tol)])
    meta = head(meta[order(meta[,2], decreasing = T),], ${max_comp if max_comp > 1 else "nrow(meta)"})
    message(paste(n_comp, "components out of", length(dat$w), "total components have weight greater than", tol))
    res = list()
    factor
    for (i in 1:n_comp) {
        title = paste("Factor", i , "w =", round(meta$w[i], 6))
        res[[i]] = plot_sharing(dat$U[[meta$U[i]]], to_cor = ${"T" if to_cor else "F"}, title=title, remove_names = ${"TRUE" if remove_label else "FALSE"})
    }
    unit = 5
    n_col = 2
    n_row = ceiling(n_comp / n_col)
    pdf(${_output:r}, width = unit * n_col, height = unit * n_row)
    do.call(gridExtra::grid.arrange, c(res, list(ncol = n_col, nrow = n_row, bottom = "Data source: readRDS(${_input:br})${('$'+name) if name else ''}")))
    dev.off()