# A multivariate EBNM approach for mixture multivariate distribution estimate

This is a truncated version from the bioworkflow/mixture_prior analysis, the steps where mixture components were generated are removed to not duplicate with that of the mashr_flashr_workflow. 


An earlier version of the approach is outlined in Urbut et al 2019. This workflow implements a few improvements including using additional EBMF methods as well as the new `udr` package to fit the mixture model.

## Overview of approach

1. A workflow step is provided to merge PLINK univariate association analysis results to RDS files for extracting effect estimate samples
2. Estimated effects are analyzed by FLASH and PCA to extract patterns of sharing
3. Estimate the weights for patterns extracted from previous step

## Minimal working example

To see the input requirements and output data formats, please [download a minimal working example here](https://drive.google.com/file/d/1838xUOQuWTszQ0WJGXNiJMszY05cw3RS/view?usp=sharing), and run the following:

### Merge univariate results

```
sos run mixture_prior.ipynb merge \
    --analysis-units <FIXME> \
    --plink-sumstats <FIXME> \
    --name gtex_mixture
```

### Select and merge univariate effects

```
m=/path/to/data
cd $m && ls *.rds | sed 's/\.rds//g' > analysis_units.txt && cd -
sos run mixture_prior.ipynb extract_effects \
        --analysis-units $m/analysis_units.txt \
        --datadir $m --name `basename $m`
```

### Perform mixture model fitting

```
sos run mixture_prior.ipynb ud \
    --datadir $m --name `basename $m` &> ed_$m.log
sos run mixture_prior.ipynb ud --ud-method ted \
    --datadir $m --name `basename $m` &> ted_$m.log
sos run mixture_prior.ipynb ed \
    --datadir $m --name `basename $m` &> bovy_$m.log
```

### Plot results

```
sos run mixture_prior.ipynb plot_U --model-data /path/to/mixture_model.rds --cwd output
```

Notice that for production use, each `sos run` command should be submitted to the cluster as a job.

## Global parameters

In [None]:
[global]
import os
# Work directory & output directory
parameter: cwd = path('./output')
# The filename prefix for output data
parameter: mixture_components_dir = path
parameter: name = str
parameter: data = path
parameter: mixture_components = ['flash', 'flash_nonneg', 'pca', 'canonical']
parameter: job_size = 1# Residual correlatoin file
parameter: effect_model = "EZ"
parameter: vhat = "simple"
parameter: job_size = 1
# Residual correlatoin (vhat) file
parameter: resid_cor = file_target(f"{mixture_components_dir:a}/{name}.{effect_model}.V_{vhat}.rds")
parameter: container = 'gaow/twas'
fail_if(not (resid_cor.is_file() or resid_cor == path('.')), msg = f'Cannot find ``{resid_cor}``')

## Fit mixture model

In [None]:
# Installed commit d6d4c0e
[ud]
# Method is `ed` or `ted`
parameter: ud_method = "ed"
# A typical choice is to estimate scales only for canonical components
parameter: scale_only = []
# Tolerance for change in likelihood
parameter: ud_tol_lik = 1e-3
input: [f'{data}'] + [f"{mixture_components_dir}/{name}.{m}.rds" for m in mixture_components]
output: f'{cwd}/{name}.ud.V_{vhat}.rds'
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '10G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout", container = container
    library(stringr)
    rds_files = c(${_input:r,})
    dat = readRDS(rds_files[1])
    U = list(XtX = dat$XtX)
    U_scaled = list()
    mixture_components =  c(${paths(mixture_components):r,})
    scale_only =  c(${paths(scale_only):r,})
    scale_idx = which(mixture_components %in% scale_only )
    for (f in 2:length(rds_files) ) {
        if ((f - 1) %in% scale_idx ) {
          U_scaled = c(U_scaled, readRDS(rds_files[f]))
        } else {
          U = c(U, readRDS(rds_files[f]))
        }
    }
    #
    if (${"TRUE" if resid_cor.is_file() else "FALSE"}) {
      V = readRDS(${resid_cor:r})
    } else {
      V = cor(dat$null.z)
    }
    # Fit mixture model using udr package
    library(udr)
    message(paste("Running ${ud_method.upper()} via udr package for", length(U), "mixture components"))
    f0 = ud_init(X = as.matrix(dat$strong.z), V = V, U_scaled = U_scaled, U_unconstrained = U, n_rank1=0)
    res = ud_fit(f0, X = na.omit(f0$X), control = list(unconstrained.update = "ed", resid.update = 'none', scaled.update = "fa", maxiter=5000, tol.lik = ${ud_tol_lik}), verbose=TRUE)
    res_ted =  ud_fit(f0, X = na.omit(f0$X), control = list(unconstrained.update = "ted ", resid.update = 'none', scaled.update = "fa", maxiter=5000, tol.lik = ${ud_tol_lik}), verbose=TRUE)

    saveRDS(list(U=res$U, w=res$w, loglik=res$loglik), ${_output:r})

In [None]:
[ted]
# Method is `ed` or `ted`
parameter: ud_method = "ed"
# A typical choice is to estimate scales only for canonical components
parameter: scale_only = []
# Tolerance for change in likelihood
parameter: ud_tol_lik = 1e-3
input: [f'{data}'] + [f"{mixture_components_dir}/{name}.{m}.rds" for m in mixture_components]
output: f'{cwd}/{name}.ted.V_{vhat}.rds'
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '10G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout", container = container
    library(stringr)
    rds_files = c(${_input:r,})
    dat = readRDS(rds_files[1])
    U = list(XtX = dat$XtX)
    U_scaled = list()
    mixture_components =  c(${paths(mixture_components):r,})
    scale_only =  c(${paths(scale_only):r,})
    scale_idx = which(mixture_components %in% scale_only )
    for (f in 2:length(rds_files) ) {
        if ((f - 1) %in% scale_idx ) {
          U_scaled = c(U_scaled, readRDS(rds_files[f]))
        } else {
          U = c(U, readRDS(rds_files[f]))
        }
    }
    #
    if (${"TRUE" if resid_cor.is_file() else "FALSE"}) {
      V = readRDS(${resid_cor:r})
    } else {
      V = cor(dat$null.z)
    }
    # Fit mixture model using udr package
    library(udr)
    message(paste("Running ${ud_method.upper()} via udr package for", length(U), "mixture components"))
    f0 = ud_init(X = as.matrix(dat$strong.z), V = V, U_scaled = U_scaled, U_unconstrained = U, n_rank1=0)
    res = ud_fit(f0, X = na.omit(f0$X), control = list(unconstrained.update = "ed", resid.update = 'none', scaled.update = "fa", maxiter=5000, tol.lik = ${ud_tol_lik}), verbose=TRUE)
    res_ted =  ud_fit(f0, X = na.omit(f0$X), control = list(unconstrained.update = "ted ", resid.update = 'none', scaled.update = "fa", maxiter=5000, tol.lik = ${ud_tol_lik}), verbose=TRUE)

    saveRDS(list(U=res$U, w=res$w, loglik=res$loglik), ${_output:r})

In [None]:
[ed_bovy]
parameter: ed_tol = 1e-6
input: [f'{data}'] + [f"{mixture_components_dir}/{name}.{m}.rds" for m in mixture_components]
output: f'{cwd}/{name}.ed_bovy.V_{vhat}.rds'
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '40G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout", container = container
    rds_files = c(${_input:r,})
    dat = readRDS(rds_files[1])
    U = list(XtX = dat$XtX)
    for (f in rds_files[2:length(rds_files)]) U = c(U, readRDS(f))
    if (${"TRUE" if resid_cor.is_file() else "FALSE"}) {
      V = readRDS(${resid_cor:r})
    } else {
      V = cor(dat$null.z)
    }
    # Fit mixture model using ED code by J. Bovy
    mash_data = mashr::mash_set_data(dat$strong.z, V=V)
    message(paste("Running ED via J. Bovy's code for", length(U), "mixture components"))
    res = mashr:::bovy_wrapper(mash_data, U, logfile=${_output:nr}, tol = ${ed_tol})
    saveRDS(list(U=res$Ulist, w=res$pi, loglik=scan("${_output:n}_loglike.log")), ${_output:r})

## Plot patterns of sharing

This is a simple utility function that takes the output from the pipeline above and make some heatmap to show major patterns of multivariate effects. The plots will be ordered by their mixture weights.

In [None]:
[plot_U]
parameter: model_data = path
# number of components to show
parameter: max_comp = -1
# whether or not to convert to correlation
parameter: to_cor = False
parameter: tol = "1E-6"
parameter: remove_label = False
parameter: name = ""
input: model_data
output: f'{cwd:a}/{_input:bn}{("_" + name.replace("$", "_")) if name != "" else ""}.pdf'
R: expand = "${ }", stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout', container = container
    library(reshape2)
    library(ggplot2)
    plot_sharing = function(X, col = 'black', to_cor=FALSE, title="", remove_names=F) {
        clrs <- colorRampPalette(rev(c("#D73027","#FC8D59","#FEE090","#FFFFBF",
            "#E0F3F8","#91BFDB","#4575B4")))(128)
        if (to_cor) lat <- cov2cor(X)
        else lat = X/max(diag(X))
        lat[lower.tri(lat)] <- NA
        n <- nrow(lat)
        if (remove_names) {
            colnames(lat) = paste('t',1:n, sep = '')
            rownames(lat) = paste('t',1:n, sep = '')
        }
        melted_cormat <- melt(lat[n:1,], na.rm = TRUE)
        p = ggplot(data = melted_cormat, aes(Var2, Var1, fill = value))+
            geom_tile(color = "white")+ggtitle(title) + 
            scale_fill_gradientn(colors = clrs, limit = c(-1,1), space = "Lab") +
            theme_minimal()+ 
            coord_fixed() +
            theme(axis.title.x = element_blank(),
                  axis.title.y = element_blank(),
                  axis.text.x = element_text(color=col, size=8,angle=45,hjust=1),
                  axis.text.y = element_text(color=rev(col), size=8),
                  title =element_text(size=10),
                  # panel.grid.major = element_blank(),
                  panel.border = element_blank(),
                  panel.background = element_blank(),
                  axis.ticks = element_blank(),
                  legend.justification = c(1, 0),
                  legend.position = c(0.6, 0),
                  legend.direction = "horizontal")+
            guides(fill = guide_colorbar(title="", barwidth = 7, barheight = 1,
                   title.position = "top", title.hjust = 0.5))
        if(remove_names){
            p = p + scale_x_discrete(labels= 1:n) + scale_y_discrete(labels= n:1)
        }
        return(p)
    }
  
    dat = readRDS(${_input:r})
    name = "${name}"
    if (name != "") {
      if (is.null(dat[[name]])) stop("Cannot find data ${name} in ${_input}")
        dat = dat[[name]]
    }
    if (is.null(names(dat$U))) names(dat$U) = paste0("Comp_", 1:length(dat$U))
    meta = data.frame(names(dat$U), dat$w, stringsAsFactors=F)
    colnames(meta) = c("U", "w")
    tol = ${tol}
    n_comp = length(meta$U[which(meta$w>tol)])
    meta = head(meta[order(meta[,2], decreasing = T),], ${max_comp if max_comp > 1 else "nrow(meta)"})
    message(paste(n_comp, "components out of", length(dat$w), "total components have weight greater than", tol))
    res = list()
    for (i in 1:n_comp) {
        title = paste(meta$U[i], "w =", round(meta$w[i], 6))
        ##Handle updated udr data structure
        if(is.list(dat$U[[meta$U[i]]])){
          res[[i]] = plot_sharing(dat$U[[meta$U[i]]]$mat, to_cor = ${"T" if to_cor else "F"}, title=title, remove_names = ${"TRUE" if remove_label else "FALSE"})
        } else if(is.matrix(dat$U[[meta$U[i]]])){
          res[[i]] = plot_sharing(dat$U[[meta$U[i]]], to_cor = ${"T" if to_cor else "F"}, title=title, remove_names = ${"TRUE" if remove_label else "FALSE"})
        }
    }
    unit = 4
    n_col = 5
    n_row = ceiling(n_comp / n_col)
    pdf(${_output:r}, width = unit * n_col, height = unit * n_row)
    do.call(gridExtra::grid.arrange, c(res, list(ncol = n_col, nrow = n_row, bottom = "Data source: readRDS(${_input:br})${('$'+name) if name else ''}")))
    dev.off()