# A multivariate EBNM approach for mixture multivariate distribution estimate

An earlier version of the approach is outlined in Urbut et al 2019. This workflow implements a few improvements including using additional EBMF methods as well as the new `udr` package to fit the mixture model.

## Overview of approach

1. A workflow step is provided to merge PLINK univariate association analysis results to RDS files for extracting effect estimate samples
2. Estimated effects are analyzed by FLASH and PCA to extract patterns of sharing
3. Estimate the weights for patterns extracted from previous step

## Minimal working example

To see the input requirements and output data formats, please [download a minimal working example here](https://drive.google.com/file/d/1838xUOQuWTszQ0WJGXNiJMszY05cw3RS/view?usp=sharing), and run the following:

### Merge univariate results

```
sos run mixture_prior.ipynb merge \
    --analysis-units <FIXME> \
    --plink-sumstats <FIXME> \
    --name gtex_mixture
```

### Select and merge univariate effects

```
m=/path/to/data
cd $m && ls *.rds | sed 's/\.rds//g' > analysis_units.txt && cd -
sos run mixture_prior.ipynb extract_effects \
        --analysis-units $m/analysis_units.txt \
        --datadir $m --name `basename $m`
```

### Perform mixture model fitting

```
sos run mixture_prior.ipynb ud \
    --datadir $m --name `basename $m` &> ed_$m.log
sos run mixture_prior.ipynb ud --ud-method ted \
    --datadir $m --name `basename $m` &> ted_$m.log
sos run mixture_prior.ipynb ed \
    --datadir $m --name `basename $m` &> bovy_$m.log
```

### Plot results

```
sos run mixture_prior.ipynb plot_U --model-data /path/to/mixture_model.rds --cwd output
```

Notice that for production use, each `sos run` command should be submitted to the cluster as a job.

## Global parameters

In [None]:
[global]
import os
# Work directory & output directory
parameter: cwd = path('./output')
# The filename prefix for output data
parameter: name = str
parameter: mixture_components = ['flash', 'flash_nonneg', 'pca', 'canonical']
parameter: job_size = 1# Residual correlatoin file
parameter: resid_cor = path(".")
fail_if(not (resid_cor.is_file() or resid_cor == path('.')), msg = f'Cannot find ``{resid_cor}``')

## Merge PLINK univariate association summary statistic to RDS format

In [None]:
[merge]
parameter: molecular_pheno = path
# Analysis units file. For RDS files it can be generated by `ls *.rds | sed 's/\.rds//g' > analysis_units.txt`
parameter: analysis_units = path
regions = [x.strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
input:  molecular_pheno, for_each = "regions"
output: f'{cwd:a}/RDS/{_regions[0]}'

task: trunk_workers = 1, trunk_size = job_size, walltime = '4h',  mem = '6G', tags = f'{step_name}_{_output:bn}'  

R: expand = "$[ ]", stderr = f'{_output}.stderr', stdout = f'{_output}.stdout'
    library("dplyr")
    library("tibble")
    library("purrr")
    library("readr")
    molecular_pheno = read_delim($[molecular_pheno:r], delim = "\t")
    molecular_pheno = molecular_pheno%>%mutate(dir = map_chr(`#molc_pheno`,~paste(c(`.x`,"$[_regions[0]]"),collapse = "")))
    n = nrow(molecular_pheno)
    # For every condition read rds and extract the bhat and sbhat.
    genos = tibble( i = 1:n)
    genos = genos%>%mutate(bhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$bhat%>%as.data.frame%>%rownames_to_column),
                           sbhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$sbhat%>%as.data.frame%>%rownames_to_column))
                      
    # Join first two conditions
    genos_join_bhat = full_join((genos%>%pull(bhat))[[1]],(genos%>%pull(bhat))[[2]],by = "rowname")
    genos_join_sbhat = full_join((genos%>%pull(sbhat))[[1]],(genos%>%pull(sbhat))[[2]],by = "rowname")
    
    # If there are more conditions, join the rest
    if(n > 2){
        for(j in 3:n){
            genos_join_bhat = full_join(genos_join_bhat,(genos%>%pull(bhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
            genos_join_sbhat = full_join(genos_join_sbhat,(genos%>%pull(sbhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
        }
    }
    
    name = molecular_pheno%>%mutate(name = map(`#molc_pheno`, ~read.table(text = .x,sep = "/")),
                                    name = map_chr(name, ~.x[,ncol(.x)-2]%>%as.character) )%>%pull(name)
    colnames(genos_join_bhat) = name
    colnames(genos_join_sbhat) = name
    
    
    # save the rds file
    saveRDS(file = "$[_output]", list(bhat=genos_join_bhat, sbhat=genos_join_sbhat))

## Get top, random and null effects per analysis unit

In [None]:
# extract data for MASH from summary stats
[extract_effects_1]
parameter: table_name = ""
parameter: bhat = "bhat"
parameter: sbhat = "sbhat"
parameter: expected_ncondition = 0
parameter: datadir = path
parameter: seed = 999
parameter: n_random = 4
parameter: n_null = 4
parameter: z_only = True
# Analysis units file. For RDS files it can be generated by `ls *.rds | sed 's/\.rds//g' > analysis_units.txt`
parameter: analysis_units = path
# handle N = per_chunk data-set in one job
parameter: per_chunk = 1000
regions = [x.strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
input: [f'{datadir}/{x[0]}.rds' for x in regions], group_by = per_chunk
output: f"{cwd}/{name}/cache/{name}_{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    set.seed(${seed})
    matxMax <- function(mtx) {
      return(arrayInd(which.max(mtx), dim(mtx)))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    handle_nan_etc = function(x) {
      x$bhat[which(is.nan(x$bhat))] = 0
      x$sbhat[which(is.nan(x$sbhat) | is.infinite(x$sbhat))] = 1E3
      return(x)
    }
    extract_one_data = function(dat, n_random, n_null, infile) {
        if (is.null(dat)) return(NULL)
        z = abs(dat$${bhat}/dat$${sbhat})
        max_idx = matxMax(z)
        # strong effect samples
        strong = list(bhat = dat$${bhat}[max_idx[1],,drop=F], sbhat = dat$${sbhat}[max_idx[1],,drop=F])
        # random samples excluding the top one
        if (max_idx[1] == 1) {
            sample_idx = 2:nrow(z)
        } else if (max_idx[1] == nrow(z)) {
            sample_idx = 1:(max_idx[1]-1)
        } else {
            sample_idx = c(1:(max_idx[1]-1), (max_idx[1]+1):nrow(z))
        }
        random_idx = sample(sample_idx, min(n_random, length(sample_idx)), replace = F)
        random = list(bhat = dat$${bhat}[random_idx,,drop=F], sbhat = dat$${sbhat}[random_idx,,drop=F])
        # null samples defined as |z| < 2
        null.id = which(apply(abs(z), 1, max) < 2)
        if (length(null.id) == 0) {
          warning(paste("Null data is empty for input file", infile))
          null = list()
        } else {
          null_idx = sample(null.id, min(n_null, length(null.id)), replace = F)
          null = list(bhat = dat$${bhat}[null_idx,,drop=F], sbhat = dat$${sbhat}[null_idx,,drop=F])
        }
        dat = (list(random = remove_rownames(random), null = remove_rownames(null), strong = remove_rownames(strong)))
        dat$random = handle_nan_etc(dat$random)
        dat$null = handle_nan_etc(dat$null)
        dat$strong = handle_nan_etc(dat$strong)
        return(dat)
    }
    reformat_data = function(dat, z_only = TRUE) {
        # make output consistent in format with 
        # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb      
        res = list(random.z = dat$random$bhat/dat$random$sbhat, 
                  strong.z = dat$strong$bhat/dat$strong$sbhat,  
                  null.z = dat$null$bhat/dat$null$sbhat)
        if (!z_only) {
          res = c(res, list(random.b = dat$random$bhat,
           strong.b = dat$strong$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat,
           strong.s = dat$strong$sbhat))
      }
      return(res)
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
            if (is.null(one_data[[d]])) {
              next
            } else {
                res[[d]] = rbind(res[[d]], one_data[[d]])
            }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
      dat = tryCatch(readRDS(f), error = function(e) return(NULL))${("$"+table_name) if table_name != "" else ""}
      if (is.null(dat)) {
          message(paste("Skip loading file", f, "due to load failure."))
          next
      }
      if (${expected_ncondition} > 0 && (ncol(dat$${bhat}) != ${expected_ncondition} || ncol(dat$${sbhat}) != ${expected_ncondition})) {
          message(paste("Skip loading file", f, "because it has", ncol(dat$${bhat}), "columns different from required", ${expected_ncondition}))
          next
      }
      res = merge_data(res, reformat_data(extract_one_data(dat, ${n_random}, ${n_null}, f), ${"TRUE" if z_only else "FALSE"}))
    }
    saveRDS(res, ${_output:r})

In [None]:
[extract_effects_2]
input: group_by = "all"
output: f"{cwd}/{name}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
            res[[d]] = rbind(res[[d]], one_data[[d]])
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    # compute empirical covariance XtX
    X = dat$strong.z
    X[is.na(X)] = 0
    dat$XtX = t(as.matrix(X)) %*% as.matrix(X) / nrow(X)
    saveRDS(dat, ${_output:r})

## Factor analyses

In [None]:
[flash]
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.flash.rds"
task: trunk_workers = 1, walltime = '6h', trunk_size = 1, mem = '8G', cores = 2, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout'
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_flash(dat, factors="default", remove_singleton=${"TRUE" if "canonical" in mixture_components else "FALSE"}, output_model="${_output:n}.model.rds")
    saveRDS(res, ${_output:r})

In [None]:
[flash_nonneg]
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.flash_nonneg.rds"
task: trunk_workers = 1, walltime = '6h', trunk_size = 1, mem = '8G', cores = 2, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout'
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_flash(dat, factors="nonneg", remove_singleton=${"TRUE" if "canonical" in mixture_components else "FALSE"}, output_model="${_output:n}.model.rds")
    saveRDS(res, ${_output:r})

In [None]:
[pca]
parameter: npc = 3
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.pca.rds"
task: trunk_workers = 1, walltime = '2h', trunk_size = 1, mem = '8G', cores = 2, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout'
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_pca(dat, ${npc})
    saveRDS(res, ${_output:r})

In [None]:
[canonical]
input: f"{cwd}/{name}.rds"
output: f"{cwd}/{name}.canonical.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '8G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout'
    library("mashr")
    bhat = readRDS(${_input:r})$strong.z
    sbhat = bhat
    sbhat[!is.na(sbhat)] = 1
    dat = mashr::mash_set_data(bhat,sbhat)
    res = mashr::cov_canonical(dat)
    saveRDS(res, ${_output:r})

## Fit mixture model

In [None]:
# Installed commit d6d4c0e
[ud]
# Method is `ed` or `ted`
parameter: ud_method = "ed"
# A list of models where we only update the scales and not the matrices
# A typical choice is to estimate scales only for canonical components
parameter: scale_only = []
# Tolerance for change in likelihood
parameter: ud_tol_lik = 1e-3
input: [f"{cwd}/{name}.rds"] + [f"{cwd}/{name}.{m}.rds" for m in mixture_components]
output: f'{cwd}/{name}.{ud_method}{"_unconstrained" if len(scale_only) == 0 else ""}{("." + os.path.basename(resid_cor)[:-4]) if resid_cor.is_file() else ""}.rds'
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '10G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout"
    library(stringr)
    rds_files = c(${_input:r,})
    dat = readRDS(rds_files[1])
    U = list(XtX = dat$XtX)
    U_scaled = list()
    mixture_components =  c(${paths(mixture_components):r,})
    scale_only =  c(${paths(scale_only):r,})
    scale_idx = which(mixture_components %in% scale_only )
    for (f in 2:length(rds_files) ) {
        if ((f - 1) %in% scale_idx ) {
          U_scaled = c(U_scaled, readRDS(rds_files[f]))
        } else {
          U = c(U, readRDS(rds_files[f]))
        }
    }
    #
    if (${"TRUE" if resid_cor.is_file() else "FALSE"}) {
      V = readRDS(${resid_cor:r})
    } else {
      V = cor(dat$null.z)
    }
    # Fit mixture model using udr package
    library(udr)
    message(paste("Running ${ud_method.upper()} via udr package for", length(U), "mixture components"))
    f0 = ud_init(X = as.matrix(dat$strong.z), V = V, U_scaled = U_scaled, U_unconstrained = U, n_rank1=0)
    res = ud_fit(f0, X = na.omit(f0$X), control = list(unconstrained.update = "${ud_method}", resid.update = 'none', scaled.update = "fa", maxiter=5000, tol.lik = ${ud_tol_lik}), verbose=TRUE)
    saveRDS(list(U=res$U, w=res$w, loglik=res$loglik), ${_output:r})

In [None]:
[ed]
parameter: ed_tol = 1e-6
input: [f"{cwd}/{name}.rds"] + [f"{cwd}/{name}.{m}.rds" for m in mixture_components]
output: f'{cwd}/{name}.ed_bovy{("." + os.path.basename(resid_cor)[:-4]) if resid_cor.is_file() else ""}.rds'
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '10G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout"
    rds_files = c(${_input:r,})
    dat = readRDS(rds_files[1])
    U = list(XtX = dat$XtX)
    for (f in rds_files[2:length(rds_files)]) U = c(U, readRDS(f))
    if (${"TRUE" if resid_cor.is_file() else "FALSE"}) {
      V = readRDS(${resid_cor:r})
    } else {
      V = cor(dat$null.z)
    }
    # Fit mixture model using ED code by J. Bovy
    mash_data = mashr::mash_set_data(dat$strong.z, V=V)
    message(paste("Running ED via J. Bovy's code for", length(U), "mixture components"))
    res = mashr:::bovy_wrapper(mash_data, U, logfile=${_output:nr}, tol = ${ed_tol})
    saveRDS(list(U=res$Ulist, w=res$pi, loglik=scan("${_output:n}_loglike.log")), ${_output:r})

## Plot patterns of sharing

This is a simple utility function that takes the output from the pipeline above and make some heatmap to show major patterns of multivariate effects. The plots will be ordered by their mixture weights.

In [None]:
[plot_U]
parameter: model_data = path
# number of components to show
parameter: max_comp = -1
# whether or not to convert to correlation
parameter: to_cor = False
parameter: tol = "1E-6"
parameter: remove_label = False
parameter: name = ""
input: model_data
output: f'{cwd:a}/{_input:bn}{("_" + name.replace("$", "_")) if name != "" else ""}.pdf'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout'
    library(reshape2)
    library(ggplot2)
    plot_sharing = function(X, col = 'black', to_cor=FALSE, title="", remove_names=F) {
        clrs <- colorRampPalette(rev(c("#D73027","#FC8D59","#FEE090","#FFFFBF",
            "#E0F3F8","#91BFDB","#4575B4")))(128)
        if (to_cor) lat <- cov2cor(X)
        else lat = X/max(diag(X))
        lat[lower.tri(lat)] <- NA
        n <- nrow(lat)
        if (remove_names) {
            colnames(lat) = paste('t',1:n, sep = '')
            rownames(lat) = paste('t',1:n, sep = '')
        }
        melted_cormat <- melt(lat[n:1,], na.rm = TRUE)
        p = ggplot(data = melted_cormat, aes(Var2, Var1, fill = value))+
            geom_tile(color = "white")+ggtitle(title) + 
            scale_fill_gradientn(colors = clrs, limit = c(-1,1), space = "Lab") +
            theme_minimal()+ 
            coord_fixed() +
            theme(axis.title.x = element_blank(),
                  axis.title.y = element_blank(),
                  axis.text.x = element_text(color=col, size=8,angle=45,hjust=1),
                  axis.text.y = element_text(color=rev(col), size=8),
                  title =element_text(size=10),
                  # panel.grid.major = element_blank(),
                  panel.border = element_blank(),
                  panel.background = element_blank(),
                  axis.ticks = element_blank(),
                  legend.justification = c(1, 0),
                  legend.position = c(0.6, 0),
                  legend.direction = "horizontal")+
            guides(fill = guide_colorbar(title="", barwidth = 7, barheight = 1,
                   title.position = "top", title.hjust = 0.5))
        if(remove_names){
            p = p + scale_x_discrete(labels= 1:n) + scale_y_discrete(labels= n:1)
        }
        return(p)
    }
  
    dat = readRDS(${_input:r})
    name = "${name}"
    if (name != "") {
      if (is.null(dat[[name]])) stop("Cannot find data ${name} in ${_input}")
        dat = dat[[name]]
    }
    if (is.null(names(dat$U))) names(dat$U) = paste0("Comp_", 1:length(dat$U))
    meta = data.frame(names(dat$U), dat$w, stringsAsFactors=F)
    colnames(meta) = c("U", "w")
    tol = ${tol}
    n_comp = length(meta$U[which(meta$w>tol)])
    meta = head(meta[order(meta[,2], decreasing = T),], ${max_comp if max_comp > 1 else "nrow(meta)"})
    message(paste(n_comp, "components out of", length(dat$w), "total components have weight greater than", tol))
    res = list()
    for (i in 1:n_comp) {
        title = paste(meta$U[i], "w =", round(meta$w[i], 6))
        ##Handle updated udr data structure
        if(is.list(dat$U[[meta$U[i]]])){
          res[[i]] = plot_sharing(dat$U[[meta$U[i]]]$mat, to_cor = ${"T" if to_cor else "F"}, title=title, remove_names = ${"TRUE" if remove_label else "FALSE"})
        } else if(is.matrix(dat$U[[meta$U[i]]])){
          res[[i]] = plot_sharing(dat$U[[meta$U[i]]], to_cor = ${"T" if to_cor else "F"}, title=title, remove_names = ${"TRUE" if remove_label else "FALSE"})
        }
    }
    unit = 4
    n_col = 5
    n_row = ceiling(n_comp / n_col)
    pdf(${_output:r}, width = unit * n_col, height = unit * n_row)
    do.call(gridExtra::grid.arrange, c(res, list(ncol = n_col, nrow = n_row, bottom = "Data source: readRDS(${_input:br})${('$'+name) if name else ''}")))
    dev.off()