# Multivariate prediction workflow

This notebook applies mr.mash on data analysis.

## Input

RDS format of a list of objects, in which case you can specify the names of objects corresponding to the quantities `X`, `Y`, etc. Optionally univariate summary statistics information are provided to compute scaling factor of the prior (summary statistics will be computed on the fly if not provided).

**FIXME: need to document the input data structure**
**Also for prior files should they be stored similarly with data for each fold as `fold_??` tables**?

## Output

For each analysis unit we output:

1. Analysis results in RDS format
2. Default visualization plots

**FIXME: at this point we dont have output figure yet**

## Analysis examples

```
sos run /project/mstephens/fmorgante/bioworkflows/multivariate-prediction/mrmash.ipynb mr_mash \
    --analysis-units ../data/gtex-v8-manifest-2ormore-tissues-nopath-nosuffix-test.txt \
    --data-dir ../data/cis_eqtl_analysis_ready  \
    --data-suffix GTEx_V8.rds \
    --name fold_1 \
    --wd ../output/gtex_mr_mash_analysis \
    --prior-grid ../output/gtex_mr_mash_analysis/grid/fold_1_grid.rds \
    --prior-matrices ../output/gtex_mr_mash_analysis/data_driven_matrices/output/fold_1.ted_unconstrained.rds \
    --sample-partition ../data/gtex-v8-ids-folds.txt \
    --fold 1 \
    -c midway2.yml -q midway2

sos run /project/mstephens/fmorgante/bioworkflows/multivariate-prediction/mrmash.ipynb joint_weights_update \
    --data-dir ../output/gtex_mr_mash_analysis/prediction/fold_1  \
    --data-suffix GTEx_V8_fold_1_mrmash.first_pass.rds \
    --name fold_1 \
    --wd ../output/gtex_mr_mash_analysis \
    -c midway2.yml -q midway2

sos run /project/mstephens/fmorgante/bioworkflows/multivariate-prediction/mrmash.ipynb mr_mash \
    --analysis-units ../data/gtex-v8-manifest-2ormore-tissues-nopath-nosuffix-test.txt \
    --data-dir ../data/cis_eqtl_analysis_ready  \
    --data-suffix GTEx_V8.rds \
    --name fold_1 \
    --wd ../output/gtex_mr_mash_analysis \
    --prior-grid ../output/gtex_mr_mash_analysis/grid/fold_1_grid.rds \
    --prior-matrices ../output/gtex_mr_mash_analysis/data_driven_matrices/output/fold_1.ted_unconstrained.rds \
    --prior-weights ../output/gtex_mr_mash_analysis/weights/fold_1_updated_weights.rds \
    --sample-partition ../data/gtex-v8-ids-folds.txt \
    --fold 1 \
    --save-model TRUE \
    -c midway2.yml -q midway2
```

In [1]:
[global]
# Path to data directory
parameter: data_dir = path
# data file suffix
parameter: data_suffix = str
# Path to work directory where output locates
parameter: wd = path("./output")
# An identifier for your run of analysis
parameter: name = str

In [None]:
[mr_mash]
# single column file each line is the data filename
parameter: analysis_units = path
# Path to prior data file: an RDS file with `U` for prior matrices
parameter: prior_matrices = path('.')
# Path to prior grid data file: an RDS file with scaling factors
parameter: prior_grid = path('.')
# Path to prior weights data file: an RDS file with prior weights
parameter: prior_weights = path('.')
# Path to residual cor/cov data file
parameter: resid_cor = path('.')
# Path to summary statistics directory
parameter: sumstats_dir = path('.')
# Path to summary statistics directory
parameter: sample_partition = path
parameter: fold = 1
parameter: imiss = 0.05
parameter: maf = 0.05
parameter: var_cutoff = 0.05
parameter: nthreads = 1
parameter: n_nonmiss_Y = 100
parameter: canonical_mats = "FALSE"
parameter: standardize = "TRUE"
parameter: update_w0 = "TRUE"
parameter: w0_threshold = 0.0
parameter: update_V = "TRUE"
parameter: update_V_method = "full"
parameter: B_init_method = "enet"
parameter: max_iter = 5000
parameter: tol = 1e-2
parameter: verbose = "FALSE"
parameter: save_model = "FALSE"
parameter: glmnet_pred = "FALSE"
# Only analyze `cis` variants -- cis = N means using N variants around the center column of X matrix  
parameter: cis = 'NULL'
parameter: seed = 999
regions = [x.strip() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
genes = [f"{data_dir:a}/{x}.{data_suffix}" for x in regions if path(f"{data_dir:a}/{x}.{data_suffix}").exists()]
analysis_stage = "first_pass" if not prior_weights.is_file() else "second_pass"
input: genes, group_by = 1
output: f'{wd:a}/prediction/fold_{fold}/{_input:bn}_{name}_mrmash.{analysis_stage}.rds'
task: trunk_workers = 2, trunk_size = 36, walltime = '2h', mem = '10G', cores = nthreads, tags = f'{step_name}_{_output[0]:bn}'
R: expand = '${ }', stdout = f"{_output[0]:n}.stdout", stderr = f"{_output[0]:n}.stderr"
    options(stringsAsFactors = FALSE)
    set.seed(${seed})

    ###Set some parameter variables (These should be set in the SoS script)
    fold <- ${fold}
    nthreads <- ${nthreads}
    missing_rate_cutoff <- ${imiss}
    maf_cutoff <- ${maf}
    var_cutoff <- ${var_cutoff}
    n_nonmiss_Y <- ${n_nonmiss_Y}
    canonical_mats <- ${canonical_mats}
    standardize <- ${standardize}
    update_w0 <- ${update_w0}
    w0_threshold <- ${w0_threshold}
    update_V <- ${update_V}
    update_V_method <- "${update_V_method}"
    B_init_method <- "${B_init_method}"
    max_iter <- ${max_iter}
    tol <- ${tol}
    verbose <- ${verbose}
    save_model <- ${save_model}
    glmnet_pred <- ${glmnet_pred}
    analysis_stage <- "${analysis_stage}"
  
    ###
    # Utility functions
    ###

    ###Functions to compute MAF, missing genotype rate, impute missing, and filter X accordingly 
    compute_maf <- function(geno){
      f <- mean(geno,na.rm = TRUE)/2
      return(min(f, 1-f))
    }

    compute_missing <- function(geno){
      miss <- sum(is.na(geno))/length(geno)
      return(miss)
    }

    compute_non_missing_y <- function(y){
      nonmiss <- sum(!is.na(y))
      return(nonmiss)
    }
  
    compute_all_missing_y <- function(y){
      allmiss <- all(is.na(y))
      return(allmiss)
    }

    mean_impute <- function(geno){
      f <- apply(geno, 2, function(x) mean(x,na.rm = TRUE))
      for (i in 1:length(f)) geno[,i][which(is.na(geno[,i]))] <- f[i]
      return(geno)
    }

    filter_X <- function(X, missing_rate_thresh, maf_thresh, var_thresh) {
      rm_col <- which(apply(X, 2, compute_missing) > missing_rate_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      rm_col <- which(apply(X, 2, compute_maf) < maf_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      X <- mean_impute(X)
      rm_col <- which(matrixStats::colVars(X) < var_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      return(X)
    }

    filter_Y <- function(Y, n_nonmiss){
      rm_col <- which(apply(Y, 2, compute_non_missing_y) < n_nonmiss)
      if (length(rm_col)) Y <- Y[, -rm_col]
      if(is.matrix(Y)){
        rm_rows <- which(apply(Y, 1, compute_all_missing_y))
        if (length(rm_rows)) Y <- Y[-rm_rows, ]  
      } else {
        Y <- Y[which(!is.na(Y))]
      }
      return(Y)
    }

    ###Function to compute the grid
    autoselect_mixsd <- function(data, mult=2){
      include <- !(data$Shat==0 | !is.finite(data$Shat) | is.na(data$Bhat))
      gmax <- grid_max(data$Bhat[include], data$Shat[include])
      gmin <- grid_min(data$Bhat[include], data$Shat[include])
      if (mult == 0) {
        return(c(0, gmax/2))
      }
      else {
        npoint = ceiling(log2(gmax/gmin)/log2(mult))
        return(mult^((-npoint):0) * gmax)
      }
    }

    ###Compute the minimum value for the grid
    grid_min = function(Bhat,Shat){
      min(Shat)
    }

    ###Compute the maximum value for the grid
    grid_max = function(Bhat,Shat){
      if (all(Bhat^2 <= Shat^2)) {
        8 * grid_min(Bhat,Shat) # the unusual case where we don't need much grid
      }  else {
        2 * sqrt(max(Bhat^2 - Shat^2))
      }
    }

    ###Function to compute initial estimates of the coefficients from group-lasso
    compute_coefficients_glasso <- function(X, Y, standardize, nthreads, Xnew=NULL, version=c("Rcpp", "R")){

      version <- match.arg(version)

      n <- nrow(X)
      p <- ncol(X)
      r <- ncol(Y)
      Y_has_missing <- any(is.na(Y))
      tissue_names <- colnames(Y)

      if(Y_has_missing){
        ###Extract per-individual Y missingness patterns
        Y_miss_patterns <- mr.mash.alpha:::extract_missing_Y_pattern(Y)

        ###Compute V and its inverse
        V <- mr.mash.alpha:::compute_V_init(X, Y, matrix(0, p, r), method="flash")
        Vinv <- chol2inv(chol(V))

        ###Initialize missing Ys
        muy <- colMeans(Y, na.rm=TRUE)
        for(l in 1:r){
          Y[is.na(Y[, l]), l] <- muy[l]
        }

        ###Compute expected Y (assuming B=0)
        mu <- matrix(rep(muy, each=n), n, r)

        ###Impute missing Ys 
        Y <- mr.mash.alpha:::impute_missing_Y(Y=Y, mu=mu, Vinv=Vinv, miss=Y_miss_patterns$miss, non_miss=Y_miss_patterns$non_miss, 
                                              version=version)$Y
      }

      ##Fit group-lasso
      if(nthreads>1){
        doMC::registerDoMC(nthreads)
        paral <- TRUE
      } else {
        paral <- FALSE
      }

      cvfit_glmnet <- glmnet::cv.glmnet(x=X, y=Y, family="mgaussian", alpha=1, standardize=standardize, parallel=paral)
      coeff_glmnet <- coef(cvfit_glmnet, s="lambda.min")

      ##Build matrix of initial estimates for mr.mash
      B <- matrix(as.numeric(NA), nrow=p, ncol=r)

      for(i in 1:length(coeff_glmnet)){
        B[, i] <- as.vector(coeff_glmnet[[i]])[-1]
      }

      ##Make predictions if requested
      if(!is.null(Xnew)){
        Yhat_glmnet <- drop(predict(cvfit_glmnet, newx=Xnew, s="lambda.min"))
        colnames(Yhat_glmnet) <- tissue_names
        res <- list(Bhat=B, Ytrain=Y, Yhat_new=Yhat_glmnet)
      } else {
        res <- list(Bhat=B, Ytrain=Y)
      }

      return(res)
    }

    compute_coefficients_univ_glmnet <- function(X, Y, alpha, standardize, nthreads, Xnew=NULL){

      r <- ncol(Y)

      linreg <- function(i, X, Y, alpha, standardize, nthreads, Xnew){
        if(nthreads>1){
          doMC::registerDoMC(nthreads)
          paral <- TRUE
        } else {
          paral <- FALSE
        }

        samples_kept <- which(!is.na(Y[, i]))
        Ynomiss <- Y[samples_kept, i]
        Xnomiss <- X[samples_kept, ]

        cvfit <- glmnet::cv.glmnet(x=Xnomiss, y=Ynomiss, family="gaussian", alpha=alpha, standardize=standardize, parallel=paral)
        coeffic <- as.vector(coef(cvfit, s="lambda.min"))
        lambda_seq <- cvfit$lambda

        ##Make predictions if requested
        if(!is.null(Xnew)){
          yhat_glmnet <- drop(predict(cvfit, newx=Xnew, s="lambda.min"))
          res <- list(bhat=coeffic, lambda_seq=lambda_seq, yhat_new=yhat_glmnet)
        } else {
          res <- list(bhat=coeffic, lambda_seq=lambda_seq)
        }

        return(res)
      }

      out <- lapply(1:r, linreg, X, Y, alpha, standardize, nthreads, Xnew)

      Bhat <- sapply(out,"[[","bhat")

      if(!is.null(Xnew)){
        Yhat_new <- sapply(out,"[[","yhat_new")
        colnames(Yhat_new) <- colnames(Y)
        results <- list(Bhat=Bhat[-1, ], intercept=Bhat[1, ], Yhat_new=Yhat_new)
      } else {
        results <- list(Bhat=Bhat[-1, ], intercept=Bhat[1, ])
      }

      return(results)
    }

    load_prior_grid = function(prior_grid, sumstats=NULL) {
      res <- tryCatch(readRDS(prior_grid), 
                       error = function(e) {
                         return(NULL)
                       },
                       warning = function(w) {
                         return(NULL)
                     }
        )
       ###Compute prior covariance
      if(is.null(res) && !is.null(sumstats)){
        res <- autoselect_mixsd(sumstats, mult=sqrt(2))^2
      }
       return(res)
    } 
  
    ###Filter S0 and w0Drop mixture components with weight equal to 0
    filter_S0_w0 <- function(S0, w0, thresh=.Machine$double.eps){
      comps_to_keep <- which(w0 > thresh)
      S0 <- S0[comps_to_keep]
      w0 <- w0[comps_to_keep]
      
      return(list(S0=S0, w0=w0))
    }

    ###Filter data-driven matrices and summary stats based on tissues used
    filter_datadriven_mats_and_sumstats <- function(Y, datadriven_mats, sumstats){
      tissues_to_keep <- colnames(Y)
      #Handle different data structure between udr and Bovy's ed
      if(!is.list(datadriven_mats$U[[1]])){
        datadriven_mats_filt <- lapply(datadriven_mats$U, function(x, to_keep){x[to_keep, to_keep]}, tissues_to_keep)
      } else {
        datadriven_mats_filt <- lapply(datadriven_mats$U, function(x, to_keep){x$mat[to_keep, to_keep]}, tissues_to_keep)
      }
      if(!is.null(sumstats)){
        sumstats_filt <- lapply(sumstats[[1]], function(x, to_keep){x[, to_keep]}, tissues_to_keep)
      } else {
        sumstats_filt <- sumstats
      }
      
      return(list(datadriven_mats_filt=datadriven_mats_filt, sumstats_filt=sumstats_filt))
    }
  
    ###Split the data in training and test
    split_data <- function(X, Y, gtex_ids_folds, fold){
      test_ids <- gtex_ids_folds[which(gtex_ids_folds$fold == fold), "id"]
      Xtrain <- X[!(rownames(X) %in% test_ids), ]
      Ytrain <- Y[!(rownames(Y) %in% test_ids), ]
      Xtest <- X[rownames(X) %in% test_ids, ]
      Ytest <- Y[rownames(Y) %in% test_ids, ]
      
      return(list(Xtrain=Xtrain, Ytrain=Ytrain, Xtest=Xtest, Ytest=Ytest))
    }

    ###Compute prior weights from coefficients estimates
    compute_w0 <- function(Bhat, ncomps){
      prop_nonzero <- sum(rowSums(abs(Bhat))>0)/nrow(Bhat)
      w0 <- c((1-prop_nonzero), rep(prop_nonzero/(ncomps-1), (ncomps-1)))
      
      if(sum(w0 != 0)<2){
        w0 <- rep(1/ncomps, ncomps)
      }
  
      return(w0)
    }

    ###Compute column sums of the posterior assignment probabilities
    compute_posterior_weight_colsum <- function(w, S0_labels){
      w1_colsums <- colSums(w)
      if(!is.null(S0_labels)){
        tmp <- rep(0, length(S0_labels))
        names(tmp) <- S0_labels
        tmp[which(S0_labels %in% names(w1_colsums))] <- w1_colsums
        w1_colsums <- tmp
      }
      return(w1_colsums)
    }
  
    ###
    # mr.mash code
    ###
  
    ###Read in the data
    dat <- readRDS(${_input:r})   
    sumstats <- tryCatch(readRDS("${sumstats_dir}/${_input:bn}_sumstats_cv.rds"), 
                     error = function(e) {
                       return(NULL)
                     },
                     warning = function(w) {
                       return(NULL)
                     }
    )
    tryCatch({
        datadriven_mats <- readRDS(${prior_matrices:r})
      }, error = function(e) {
        # FIXME: we can implement it and provide a warning instead
        stop("Default prior is not yet implemented. Please provide a prior to use.")
    })
    w0_init <- tryCatch(readRDS(${prior_weights:r}), 
                    error = function(e) {
                      message("Prior weights not provided. Computing them from initial estimates of the coefficients.")
                      return(NULL)
                    },
                    warning=function(w){
                      message("Prior weights not provided. Computing them from initial estimates of the coefficients.")
                      return(NULL)
                    }
    )
    gtex_ids_folds <- read.table(${sample_partition:r}, header=TRUE, sep="\t")
  
    if(analysis_stage == "second_pass"){
      first_pass_out <- readRDS("${_output:nn}.first_pass.rds")
      if(!is.null(first_pass_out)){
        B_init <- first_pass_out$Bhat
      } else {
        B_init <- NULL
      }
    }
  
    ###Extract sumstats and only for specified fold
    fold_name <- paste0("fold_", fold)
    sumstats <- sumstats[fold_name]

    ###Extract and filter Y and X. NB the workflow will stop if filtered Y has < 2 tissues.
    Y <- filter_Y(dat$y_res, n_nonmiss_Y)
    if(is.matrix(Y)){
      X <- filter_X(dat$X, missing_rate_cutoff, maf_cutoff, var_cutoff)
      X <- X[rownames(Y), ]

      ###Drop tissues with < n_nonmiss_Y in data-driven matrices and sumstats
      datadriven_mats_sumstats_filt <- filter_datadriven_mats_and_sumstats(Y, datadriven_mats, sumstats)
      S0_data <- datadriven_mats_sumstats_filt$datadriven_mats_filt
      sumstats <- datadriven_mats_sumstats_filt$sumstats_filt
      rm(datadriven_mats_sumstats_filt)
  
      prior_grid = load_prior_grid(${prior_grid:r}, sumstats)
      if(is.null(sumstats) && is.null(prior_grid)){
        # FIXME: we can implement it and provide a warning instead
        stop("Computing summary stats and grid on the fly is not yet implemented. Please provide either proper summary stats path or prior grid file.")
      }
  
      ###Split the data in training and test sets
      dat_split <- split_data(X, Y, gtex_ids_folds, fold)
      Xtrain <- dat_split$Xtrain
      Ytrain <- dat_split$Ytrain
      Xtest <- dat_split$Xtest
      Ytest <- dat_split$Ytest
      rm(dat_split)
  
      ###Compute canonical matrices, if requested
      if(canonical_mats){
        S0_can <- mr.mash.alpha::compute_canonical_covs(ncol(Ytrain), singletons=TRUE, hetgrid=c(0, 0.25, 0.5, 0.75, 1))
        S0_raw <- c(S0_can, S0_data)
      } else {
        S0_raw <- S0_data
      }
    
      ###Compute prior covariance
      S0 <- mr.mash.alpha::expand_covs(S0_raw, prior_grid, zeromat=TRUE)
    
      time1 <- proc.time()

      ###Compute initial estimates of regression coefficients and prior weights (if not provided)
      if(glmnet_pred){
        Xnew <- Xtest
      } else {
        Xnew <- NULL
      }
    
      if(analysis_stage == "first_pass" || is.null(B_init)){
        if(B_init_method == "enet"){
          out <- compute_coefficients_univ_glmnet(Xtrain, Ytrain, alpha=0.5, standardize=standardize, nthreads=nthreads, Xnew=Xnew)
        } else if(B_init_method == "glasso"){
          out <- compute_coefficients_glasso(Xtrain, Ytrain, standardize=standardize, nthreads=nthreads, Xnew=Xnew)
        }
      
        B_init <- out$Bhat
      }

      if(is.null(w0_init)){
        external_w0 <- FALSE
        w0 <- compute_w0(B_init, length(S0))
      } else {
        external_w0 <- TRUE
        w0 <- w0_init
      }
  
      ###Filter prior components based on weights
      comps_filtered <- filter_S0_w0(S0=S0, w0=w0)
      S0 <- comps_filtered$S0
      w0 <- comps_filtered$w0
      rm(comps_filtered)

      ###Fit mr.mash
      fit_mrmash <- tryCatch({mr.mash.alpha::mr.mash(X=Xtrain, Y=Ytrain, S0=S0, w0=w0, update_w0=update_w0, tol=tol,
                                                     max_iter=max_iter, convergence_criterion="ELBO", compute_ELBO=TRUE,  
                                                     standardize=standardize, verbose=verbose, update_V=update_V, update_V_method=update_V_method,
                                                     w0_threshold=w0_threshold, nthreads=nthreads, mu1_init=B_init)
                              },
                             error=function(e) {
                                  message("Original mr.mash error message:")
                                  message(e)
                                  return(NULL)
                              })

      if(!is.null(fit_mrmash)){
        time2 <- proc.time()
        elapsed_time <- time2["elapsed"] - time1["elapsed"]
        ###Make predictions
        Yhat_test <- predict(fit_mrmash, Xtest)
        ###Save results
        if(external_w0) {
          resu <- list(Ytest=Ytest, Yhat_test=Yhat_test, elapsed_time=elapsed_time)
        } else {
          if(w0_threshold > 0){
            S0_labels = names(S0)
          } else {
            S0_labels = NULL
          } 
          w1_colsums <- compute_posterior_weight_colsum(fit_mrmash$w1, S0_labels)     
          resu <- list(w1_colsums=w1_colsums, Bhat=fit_mrmash$mu1, Ytest=Ytest, Yhat_test=Yhat_test, Ybar_train=colMeans(Ytrain, na.rm=TRUE), elapsed_time=elapsed_time)
        }
  
        if(save_model){
          resu$model <-  fit_mrmash
        }

        if(glmnet_pred){
          resu$Yhat_test_glmnet <- out$Yhat_new
        }
        saveRDS(resu, ${_output[0]:r})
      } else {
        saveRDS(NULL, ${_output[0]:r})
      }
    } else {
      saveRDS(NULL, ${_output[0]:r})
      message("Filtered Y has fewer than 2 tissues. Gene will not be analyzed.")
    }

In [None]:
[joint_weights_update]
from glob import glob
input: glob(f"{data_dir:a}/*.{data_suffix}"), group_by = "all"
output: f"{wd:a}/weights/{name}_updated_weights.rds"
task: trunk_workers = 1, walltime = '6h', trunk_size = 1, mem = '2G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    options(stringsAsFactors=FALSE)
 
    i = 0

    for (f in c(${_input:r,})) {
      i = i+1    
      
      dat = readRDS(f)$w1_colsums
      if (is.null(dat)) {
          message(paste("Dataset", f, "has no valid w1_colsums quantity"))
          next
      }
      if(i > 1){
        weights = weights + dat
      } else {
        weights = dat
      }

    }
   
    weights = weights/sum(weights)  

    saveRDS(weights, ${_output:r})