# Multivariate prediction workflow

This notebook applies mrmash on data analysis.

## Input

RDS format of a list of objects, in which case you can specify the names of objects corresponding to the quantities `X`, `Y`, etc. Optionally univariate summary statistics information are provided to compute scaling factor of the prior (summary statistics will be computed on the fly if not provided).

**FIXME: need to document the input data structure**
**Also for prior files should they be stored similarly with data for each fold as `fold_??` tables**?

## Output

For each analysis unit we output:

1. Analysis results in RDS format
2. Default visualization plots

**FIXME: at this point we dont have output figure yet**

## Analysis examples

**FIXME: please adjust the command arguments to use mr-mash data**

```
sos run mrmash.ipynb complete_data_analysis \
    --analysis-units data/27_brain_non_brain_genes_v8.txt \
    --data-dir /project2/compbio/GTEx_eQTL/cis_eqtl_analysis_ready \
    --data-suffix GTEx_V8.rds \
    --name 20210409 \
    --wd /project2/compbio/GTEx_eQTL/mvSuSiE_output/cis_results \
    --prior /project2/compbio/GTEx_eQTL/mvSuSiE_output/GTEx_V8_strong_z.teem.rds \
    --sample-partition /project/mstephens/fmorgante/mr_mash_test/data/gtex-v8-ids-folds.txt \
    -c midway2.yml -q midway2
```

In [1]:
[global]
import glob
# single column file each line is the data filename
parameter: analysis_units = path
# Path to data directory
parameter: data_dir = path
# data file suffix
parameter: data_suffix = str
# Path to work directory where output locates
parameter: wd = path("./output")
# Path to prior data file: an RDS file with `U` and `w` for prior matrices and weights
parameter: prior = path('.')
# Path to residual cor/cov data file
parameter: resid_cor = path('.')
# Path to summary statistics directory
parameter: sumstats_dir = path
# An identifier for your run of analysis
parameter: name = str
# Only analyze `cis` variants -- cis = N means using N variants around the center column of X matrix  
parameter: cis = 'NULL'
regions = [x.strip() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
genes = [f"{data_dir:a}/{x}.{data_suffix}" for x in regions if path(f"{data_dir:a}/{x}.{data_suffix}").exists()]

In [None]:
[complete_data_analysis_1]
parameter: sample_partition = path
parameter: fold = 1
# remove a variant if it has more than imiss missing individual data
parameter: imiss = 0.05
parameter: maf = 0.05
parameter: var_cutoff = 0.05
parameter: nthreads = 1
parameter: n_nonmiss_Y = 100
parameter: canonical_mats = "FALSE"
parameter: standardize = "TRUE"
parameter: w0_init = "NULL"
parameter: update_w0 = "TRUE"
parameter: w0_threshold = 0.0
parameter: update_V = "TRUE"
parameter: update_V_method = "full"
parameter: B_init_method = "enet"
parameter: max_iter = 5000
parameter: tol = 1e-2
parameter: verbose = "FALSE"
parameter: save_model = "TRUE"
parameter: glmnet_pred = "TRUE"
parameter: sumstats = path(".")
input: genes, group_by = 1
output: f'{wd:a}/fold_{fold}/{_input:bn}_{name}.rds'
task: trunk_workers = 1, trunk_size = 18, walltime = '2h', mem = '10G', cores = nthreads, tags = f'{step_name}_{_output[0]:bn}'
R: expand = '${ }', stdout = f"{_output[0]:nn}.stdout", stderr = f"{_output[0]:nn}.stderr"

    options(stringsAsFactors = FALSE)

    set.seed(1)

    ###Set some parameter variables (These should be set in the SoS script)
    fold <- ${fold}
    nthreads <- ${nthreads}
    missing_rate_cutoff <- ${imiss}
    maf_cutoff <- ${maf}
    var_cutoff <- ${var_cutoff}
    n_nonmiss_Y <- ${n_nonmiss_Y}
    canonical_mats <- ${canonical_mats}
    standardize <- ${standardize}
    w0_init <- ${w0_init}
    update_w0 <- ${update_w0}
    w0_threshold <- ${w0_threshold}
    update_V <- ${update_V}
    update_V_method <- "${update_V_method}"
    B_init_method <- "${B_init_method}"
    max_iter <- ${max_iter}
    tol <- ${tol}
    verbose <- ${verbose}
    save_model <- ${save_model}
    glmnet_pred <- ${glmnet_pred}

    ###Functions to compute MAF, missing genotype rate, impute missing, and filter X accordingly 
    compute_maf <- function(geno){
      f <- mean(geno,na.rm = TRUE)/2
      return(min(f, 1-f))
    }

    compute_missing <- function(geno){
      miss <- sum(is.na(geno))/length(geno)
      return(miss)
    }

    compute_non_missing_y <- function(y){
      nonmiss <- sum(!is.na(y))
      return(nonmiss)
    }

    mean_impute <- function(geno){
      f <- apply(geno, 2, function(x) mean(x,na.rm = TRUE))
      for (i in 1:length(f)) geno[,i][which(is.na(geno[,i]))] <- f[i]
      return(geno)
    }

    filter_X <- function(X, missing_rate_thresh, maf_thresh, var_thresh) {
      rm_col <- which(apply(X, 2, compute_missing) > missing_rate_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      rm_col <- which(apply(X, 2, compute_maf) < maf_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      X <- mean_impute(X)
      rm_col <- which(matrixStats::colVars(X) < var_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      return(X)
    }

    filter_Y <- function(Y, n_nonmiss){
      rm_col <- which(apply(Y, 2, compute_non_missing_y) < n_nonmiss)
      if (length(rm_col)) Y <- Y[, -rm_col]
      return(Y)
    }

    ###Function to compute the grid
    autoselect_mixsd <- function(data, mult=2){
      include <- !(data$Shat==0 | !is.finite(data$Shat) | is.na(data$Bhat))
      gmax <- grid_max(data$Bhat[include], data$Shat[include])
      gmin <- grid_min(data$Bhat[include], data$Shat[include])
      if (mult == 0) {
        return(c(0, gmax/2))
      }
      else {
        npoint = ceiling(log2(gmax/gmin)/log2(mult))
        return(mult^((-npoint):0) * gmax)
      }
    }

    ###Compute the minimum value for the grid
    grid_min = function(Bhat,Shat){
      min(Shat)
    }

    ###Compute the maximum value for the grid
    grid_max = function(Bhat,Shat){
      if (all(Bhat^2 <= Shat^2)) {
        8 * grid_min(Bhat,Shat) # the unusual case where we don't need much grid
      }  else {
        2 * sqrt(max(Bhat^2 - Shat^2))
      }
    }

    ###Function to compute initial estimates of the coefficients from group-lasso
    compute_coefficients_glasso <- function(X, Y, standardize, nthreads, Xnew=NULL, version=c("Rcpp", "R")){

      version <- match.arg(version)

      n <- nrow(X)
      p <- ncol(X)
      r <- ncol(Y)
      Y_has_missing <- any(is.na(Y))
      tissue_names <- colnames(Y)

      if(Y_has_missing){
        ###Extract per-individual Y missingness patterns
        Y_miss_patterns <- mr.mash.alpha:::extract_missing_Y_pattern(Y)

        ###Compute V and its inverse
        V <- mr.mash.alpha:::compute_V_init(X, Y, matrix(0, p, r), method="flash")
        Vinv <- chol2inv(chol(V))

        ###Initialize missing Ys
        muy <- colMeans(Y, na.rm=TRUE)
        for(l in 1:r){
          Y[is.na(Y[, l]), l] <- muy[l]
        }

        ###Compute expected Y (assuming B=0)
        mu <- matrix(rep(muy, each=n), n, r)

        ###Impute missing Ys 
        Y <- mr.mash.alpha:::impute_missing_Y(Y=Y, mu=mu, Vinv=Vinv, miss=Y_miss_patterns$miss, non_miss=Y_miss_patterns$non_miss, 
                                              version=version)$Y
      }

      ##Fit group-lasso
      if(nthreads>1){
        doMC::registerDoMC(nthreads)
        paral <- TRUE
      } else {
        paral <- FALSE
      }

      cvfit_glmnet <- glmnet::cv.glmnet(x=X, y=Y, family="mgaussian", alpha=1, standardize=standardize, parallel=paral)
      coeff_glmnet <- coef(cvfit_glmnet, s="lambda.min")

      ##Build matrix of initial estimates for mr.mash
      B <- matrix(as.numeric(NA), nrow=p, ncol=r)

      for(i in 1:length(coeff_glmnet)){
        B[, i] <- as.vector(coeff_glmnet[[i]])[-1]
      }

      ##Make predictions if requested
      if(!is.null(Xnew)){
        Yhat_glmnet <- drop(predict(cvfit_glmnet, newx=Xnew, s="lambda.min"))
        colnames(Yhat_glmnet) <- tissue_names
        res <- list(Bhat=B, Ytrain=Y, Yhat_new=Yhat_glmnet)
      } else {
        res <- list(Bhat=B, Ytrain=Y)
      }

      return(res)
    }

    compute_coefficients_univ_glmnet <- function(X, Y, alpha, standardize, nthreads, Xnew=NULL){

      r <- ncol(Y)

      linreg <- function(i, X, Y, alpha, standardize, nthreads, Xnew){
        if(nthreads>1){
          doMC::registerDoMC(nthreads)
          paral <- TRUE
        } else {
          paral <- FALSE
        }

        samples_kept <- which(!is.na(Y[, i]))
        Ynomiss <- Y[samples_kept, i]
        Xnomiss <- X[samples_kept, ]

        cvfit <- glmnet::cv.glmnet(x=Xnomiss, y=Ynomiss, family="gaussian", alpha=alpha, standardize=standardize, parallel=paral)
        coeffic <- as.vector(coef(cvfit, s="lambda.min"))
        lambda_seq <- cvfit$lambda

        ##Make predictions if requested
        if(!is.null(Xnew)){
          yhat_glmnet <- drop(predict(cvfit, newx=Xnew, s="lambda.min"))
          res <- list(bhat=coeffic, lambda_seq=lambda_seq, yhat_new=yhat_glmnet)
        } else {
          res <- list(bhat=coeffic, lambda_seq=lambda_seq)
        }

        return(res)
      }

      out <- lapply(1:r, linreg, X, Y, alpha, standardize, nthreads, Xnew)

      Bhat <- sapply(out,"[[","bhat")

      if(!is.null(Xnew)){
        Yhat_new <- sapply(out,"[[","yhat_new")
        colnames(Yhat_new) <- colnames(Y)
        results <- list(Bhat=Bhat[-1, ], intercept=Bhat[1, ], Yhat_new=Yhat_new)
      } else {
        results <- list(Bhat=Bhat[-1, ], intercept=Bhat[1, ])
      }

      return(results)
    }



    ###Read in the data
    dat <- readRDS(${_input:r})
    tryCatch({
    sumstats <- readRDS("${sumstats_dir}/${_input:bn}_sumstats_cv.rds")
    }, error = function(e) {
      # FIXME: we can implement it and provide a warning instead
      stop("Computing summary stats on the fly is not yet implemented. Please provide proper summary stats path")
    })
    tryCatch({
    datadriven_mats <- readRDS(${prior:r})
      }, error = function(e) {
      # FIXME: we can implement it and provide a warning instead
      stop("Default prior is not yet implemented. Please provide a prior to use")
    })
    gtex_ids_folds <- read.table(${sample_partition:r}, header=TRUE, sep="\t")

    ###Extract sumstats and only for specified fold
    fold_name <- paste0("fold_", fold)
    sumstats <- sumstats[fold_name]

    ###Extract and filter Y and X
    Y <- filter_Y(dat$y_res, n_nonmiss_Y)
    X <- filter_X(dat$X, missing_rate_cutoff, maf_cutoff, var_cutoff)

    ###Drop tissues with < n_nonmiss_Y in data-driven matrices and sumstats
    tissues_to_keep <- colnames(Y)
    #Handle different data structure between udr and Bovy's ed
    if(!is.list(datadriven_mats$U[[1]])){
      S0_data <- lapply(datadriven_mats$U, function(x, to_keep){x[to_keep, to_keep]}, tissues_to_keep)
    } else {
      S0_data <- lapply(datadriven_mats$U, function(x, to_keep){x$mat[to_keep, to_keep]}, tissues_to_keep)
    }
    sumstats <- lapply(sumstats[[1]], function(x, to_keep){x[, to_keep]}, tissues_to_keep)

    ###Split the data in training and test sets
    test_ids <- gtex_ids_folds[which(gtex_ids_folds$fold == fold), "id"]
    Xtrain <- X[!(rownames(X) %in% test_ids), ]
    Ytrain <- Y[!(rownames(Y) %in% test_ids), ]
    Xtest <- X[test_ids, ]
    Ytest <- Y[test_ids, ]

    ###Compute canonical matrices, if requested
    if(canonical_mats){
      S0_can <- mr.mash.alpha::compute_canonical_covs(ncol(Ytrain), singletons=TRUE, hetgrid=c(0, 0.25, 0.5, 0.75, 1))
      S0_raw <- c(S0_can, S0_data)
    } else {
      S0_raw <- S0_data
    }

    ###Compute prior covariance
    grid <- autoselect_mixsd(sumstats, mult=sqrt(2))^2
    S0 <- mr.mash.alpha::expand_covs(S0_raw, grid, zeromat=TRUE)

    time1 <- proc.time()

    ###Compute initial estimates of regression coefficients and prior weights (if not provided)
    if(glmnet_pred){
      Xnew <- Xtest
    } else {
      Xnew <- NULL
    }

    if(B_init_method == "enet"){
      out <- compute_coefficients_univ_glmnet(Xtrain, Ytrain, alpha=0.5, standardize=standardize, nthreads=nthreads, Xnew=Xnew)
    } else if(B_init_method == "glasso"){
      out <- compute_coefficients_glasso(Xtrain, Ytrain, standardize=standardize, nthreads=nthreads, Xnew=Xnew)
    }

    if(is.null(w0_init)){
      prop_nonzero_glmnet <- sum(rowSums(abs(out$Bhat))>0)/ncol(Xtrain)
      w0 <- c((1-prop_nonzero_glmnet), rep(prop_nonzero_glmnet/(length(S0)-1), (length(S0)-1)))
    } else {
      w0 <- w0_init
    }

    ###Fit mr.mash
    fit_mrmash <- mr.mash.alpha::mr.mash(X=Xtrain, Y=Ytrain, S0=S0, w0=w0, update_w0=update_w0, tol=tol,
                          max_iter=max_iter, convergence_criterion="ELBO", compute_ELBO=TRUE,  
                          standardize=standardize, verbose=verbose, update_V=update_V, update_V_method=update_V_method,
                          w0_threshold=w0_threshold, nthreads=nthreads, mu1_init=out$Bhat)

    time2 <- proc.time()

    elapsed_time <- time2["elapsed"] - time1["elapsed"]

    ###Compute column sums of the posterior assignment probabilities
    w1_colsums <- colSums(fit_mrmash$w1)
    if(w0_threshold > 0){
      tmp <- rep(0, length(S0))
      names(tmp) <- names(S0)
      tmp[which(names(tmp) %in% names(w1_colsums))] <- w1_colsums
      w1_colsums <- tmp
    }

    ###Make predictions
    Yhat_test <- predict(fit_mrmash, Xtest)

    ###Save results
    resu <- list(w1_colsums=w1_colsums, Ytest=Ytest, Yhat_test=Yhat_test, elapsed_time=elapsed_time)
    if(save_model){
     resu$model <-  fit_mrmash
    }

    if(glmnet_pred){
      resu$Yhat_test_glmnet <- out$Yhat_new
    }

    saveRDS(resu, ${_output[0]:r})