# Multivariate prediction workflow

This notebook applies mtlasso on data analysis.

## Input

RDS format of a list of objects, in which case you can specify the names of objects corresponding to the quantities `X`, `Y`, etc.

**FIXME: need to document the input data structure**
**Also for prior files should they be stored similarly with data for each fold as `fold_??` tables**?

## Output

For each analysis unit we output:

1. Analysis results in RDS format
2. Default visualization plots

**FIXME: at this point we dont have output figure yet**

## Analysis examples

```
sos run /project/mstephens/fmorgante/bioworkflows/multivariate-prediction/mtlasso.ipynb mtlasso \
    --analysis-units ../data/gtex-v8-manifest-2ormore-tissues-nopath-nosuffix-test.txt \
    --data-dir ../data/cis_eqtl_analysis_ready  \
    --data-suffix GTEx_V8.rds \
    --name fold_1 \
    --wd ../output/gtex_mr_mash_analysis \
    --sample-partition ../data/gtex-v8-ids-folds.txt \
    --fold 1 \
    --mtlasso-script /project/mstephens/fmorgante/mr_mash_test/code/fit_mtlasso_missing_Y.py \
    --conda-env py38dsc \
    -c midway2.yml -q midway2
```

In [1]:
[global]
# Path to data directory
parameter: data_dir = path
# data file suffix
parameter: data_suffix = str
# Path to work directory where output locates
parameter: wd = path("./output")
# An identifier for your run of analysis
parameter: name = str

In [None]:
[mtlasso]
# single column file each line is the data filename
parameter: analysis_units = path
# Path to summary statistics directory
parameter: sample_partition = path
parameter: conda_env = str
parameter: mtlasso_script = path
parameter: fold = 1
parameter: imiss = 0.05
parameter: maf = 0.05
parameter: var_cutoff = 0.05
parameter: n_nonmiss_Y = 100
parameter: standardize = "TRUE"
parameter: max_iter = 5000
parameter: verbose = "FALSE"
parameter: nfolds = 5
parameter: B_init = "NULL"
parameter: grid_limits = "NULL"
parameter: grid_length = 10
# Only analyze `cis` variants -- cis = N means using N variants around the center column of X matrix  
parameter: cis = 'NULL'
parameter: seed = 999
regions = [x.strip() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
genes = [f"{data_dir:a}/{x}.{data_suffix}" for x in regions if path(f"{data_dir:a}/{x}.{data_suffix}").exists()]
input: genes, group_by = 1
output: f'{wd:a}/mtlasso/fold_{fold}/{_input:bn}_{name}_mtlasso.rds'
task: trunk_workers = 4, trunk_size = 8, walltime = '16h', mem = '8G', cores = 1, tags = f'{step_name}_{_output[0]:bn}'
R: expand = '${ }', stdout = f"{_output[0]:n}.stdout", stderr = f"{_output[0]:n}.stderr"
    options(stringsAsFactors = FALSE)
    set.seed(${seed})
    library(reticulate)

    ###Set some parameter variables (These should be set in the SoS script)
    fold <- ${fold}
    missing_rate_cutoff <- ${imiss}
    maf_cutoff <- ${maf}
    var_cutoff <- ${var_cutoff}
    n_nonmiss_Y <- ${n_nonmiss_Y}
    standardize <- ${standardize}
    max_iter <- ${max_iter}
    verbose <- ${verbose}
    nfolds <- ${nfolds}
    B_init <- ${B_init}
    grid_limits <- ${grid_limits}
    grid_length <- ${grid_length}


    ###
    # Utility functions
    ###

    ###Functions to compute MAF, missing genotype rate, impute missing, and filter X accordingly 
    compute_maf <- function(geno){
      f <- mean(geno,na.rm = TRUE)/2
      return(min(f, 1-f))
    }

    compute_missing <- function(geno){
      miss <- sum(is.na(geno))/length(geno)
      return(miss)
    }

    compute_non_missing_y <- function(y){
      nonmiss <- sum(!is.na(y))
      return(nonmiss)
    }

    compute_all_missing_y <- function(y){
      allmiss <- all(is.na(y))
      return(allmiss)
    }

    mean_impute <- function(geno){
      f <- apply(geno, 2, function(x) mean(x,na.rm = TRUE))
      for (i in 1:length(f)) geno[,i][which(is.na(geno[,i]))] <- f[i]
      return(geno)
    }

    filter_X <- function(X, missing_rate_thresh, maf_thresh, var_thresh) {
      rm_col <- which(apply(X, 2, compute_missing) > missing_rate_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      rm_col <- which(apply(X, 2, compute_maf) < maf_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      X <- mean_impute(X)
      rm_col <- which(matrixStats::colVars(X) < var_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      return(X)
    }

    filter_Y <- function(Y, n_nonmiss){
      rm_col <- which(apply(Y, 2, compute_non_missing_y) < n_nonmiss)
      if (length(rm_col)) Y <- Y[, -rm_col]
      rm_rows <- which(apply(Y, 1, compute_all_missing_y))
      if (length(rm_rows)) Y <- Y[-rm_rows, ]
      return(Y)
    }
  
    ###Split the data in training and test
    split_data <- function(X, Y, gtex_ids_folds, fold){
      test_ids <- gtex_ids_folds[which(gtex_ids_folds$fold == fold), "id"]
      Xtrain <- X[!(rownames(X) %in% test_ids), ]
      Ytrain <- Y[!(rownames(Y) %in% test_ids), ]
      Xtest <- X[rownames(X) %in% test_ids, ]
      Ytest <- Y[rownames(Y) %in% test_ids, ]
      
      return(list(Xtrain=Xtrain, Ytrain=Ytrain, Xtest=Xtest, Ytest=Ytest))
    }
  
    predict_general <- function(B, intercept, newx){
      if(is.matrix(intercept))
        intercept <- drop(intercept)
      return(addtocols(newx %*% B, intercept))
    }
  
    addtocols <- function (A, b){
      t(t(A) + b)
    }
      
    ###
    # mtlasso code
    ###
  
    ###Read in the data
    dat <- readRDS(${_input:r})   
    gtex_ids_folds <- read.table(${sample_partition:r}, header=TRUE, sep="\t")
  
    ###Extract sumstats and only for specified fold
    fold_name <- paste0("fold_", fold)

    ###Extract and filter Y and X. NB the workflow will stop if filtered Y has < 2 tissues.
    Y <- filter_Y(dat$y_res, n_nonmiss_Y)
    if(is.matrix(Y)){
      X <- filter_X(dat$X, missing_rate_cutoff, maf_cutoff, var_cutoff)
      X <- X[rownames(Y), ]
  
      ###Split the data in training and test sets
      dat_split <- split_data(X, Y, gtex_ids_folds, fold)
      Xtrain <- dat_split$Xtrain
      Ytrain <- dat_split$Ytrain
      Xtest <- dat_split$Xtest
      Ytest <- dat_split$Ytest
      rm(dat_split)
      
      time1 <- proc.time()

      ###Fit mtlasso
      use_condaenv("${conda_env}")
      source_python(${mtlasso_script:r})
  
      fit_mtlasso <- tryCatch({fit_sparse_multi_task_lasso_missing_Y(X=Xtrain, Y=Ytrain, standardize=standardize, 
                                                                     max_iter=as.integer(max_iter), verbose=verbose, 
                                                                     nfolds=as.integer(nfolds), grid_length=as.integer(grid_length),
                                                                     B_init=B_init, grid_limits=grid_limits)
                              },
                             error=function(e) {
                                  message("Original mtlasso error message:")
                                  message(e)
                                  return(NULL)
                              },
                             warning=function(w) {
                                  message("Original mtlasso warning message:")
                                  message(w)
                                  return(NULL)
                              })

    
      if(!is.null(fit_mtlasso)){
        time2 <- proc.time()
        elapsed_time <- time2["elapsed"] - time1["elapsed"]
        ###Make predictions
        Yhat_test <- predict_general(B=fit_mtlasso[[1]], intercept=fit_mtlasso[[2]], newx=Xtest)
  
        ###Save results
        resu <- list(Ytest=Ytest, Yhat_test=Yhat_test, elapsed_time=elapsed_time, model=fit_mtlasso)

        saveRDS(resu, ${_output[0]:r})
      } else {
        saveRDS(NULL, ${_output[0]:r})
      }
    } else {
      saveRDS(NULL, ${_output[0]:r})
      message("Filtered Y has fewer than 2 tissues. Gene will not be analyzed.")
    }