# Prediction score evaluations

This notebook implements prediction score evaluations. The scripe is customized for a particular project with a number of hard-coded structure. Other users are advised to read and customize the script for their own applications.

## Input

RDS format of model fits as well as the original input data, listed in a text file (see `--analysis-units` in the example analysis).

## Output

Prediction score metric.

## Analysis examples

```
sos run prediction_score.ipynb \
    --analysis-units ../data/genes.txt \
    --data-dir ../data/cis_eqtl_analysis_ready  \
    --mrmash-model ../output/gtex_mr_mash_analysis/prediction \
    --analysis-stage first_pass \
    --nfolds 5 \
    --data-suffix GTEx_V8 \
    --wd ../output/gtex_pred_score \
    -c midway2.yml -q midway2
```

To include `mt_lasso` results, add `--mtlasso-model /path/to/mtlasso/results` to the command above.

In [1]:
[global]
# single column file each line is a gene name
parameter: analysis_units = path
# Path to data directory
parameter: data_dir = path
# data file suffix
parameter: data_suffix = str
# Path to work directory where output locates
parameter: wd = path("./output")
genes = [x.strip() for x in open(analysis_units).readlines() if x.strip() and (not x.strip().startswith('#')) and path(f"{data_dir:a}/{x.strip()}.{data_suffix}.rds")]

In [None]:
[pred_score]
# Path to mr.mash model fit directory
parameter: mrmash_model = path
parameter: analysis_stage = "first_pass"
# Path to mtlasso model fit directory, default to a placeholder path called "NULL"
parameter: mtlasso_model = path('NULL')
parameter: thresh = 100
parameter: nfolds = 5
input: for_each = "genes"
output: f"{wd:a}/{_genes}.{data_suffix}_score.rds"
task:  trunk_workers = 2, trunk_size = 100, walltime = '30m', mem = '2G', cores = 1, tags = f'{step_name}_{_output[0]:bn}'
R: expand = '${ }', stdout = f"{_output[0]:n}.stdout", stderr = f"{_output[0]:n}.stderr"

    ###Function to compute accuracy
    compute_accuracy <- function(Y, Yhat) {
      bias <- rep(as.numeric(NA), ncol(Y))
      names(bias) <- colnames(Y)
      r2 <- rep(as.numeric(NA), ncol(Y))
      names(r2) <- colnames(Y)
      mse <- rep(as.numeric(NA), ncol(Y))
      names(mse) <- colnames(Y)
      rmse <- rep(as.numeric(NA), ncol(Y))
      names(rmse) <- colnames(Y)

      for(i in 1:ncol(Y)){ 
        dat <- na.omit(data.frame(Y[, i], Yhat[, i]))
        colnames(dat) <- c("Y", "Yhat")

        fit  <- lm(Y ~ Yhat, data=dat)
        bias[i] <- coef(fit)[2] 
        r2[i] <- summary(fit)$r.squared
        mse[i] <- mean((dat$Y - dat$Yhat)^2)
        rmse[i] <- sqrt(mse[i])
      }

      return(list(bias=bias, r2=r2, mse=mse, rmse=rmse))
    }
  
    ###Function to load the data
    load_data <- function(path, nfolds, gene, data_suffix, model_suffix){
      dat_list <- list()
      for(i in 1:nfolds){
        dat <- readRDS(paste0(path, "fold_", i, "/", gene, ".", data_suffix, "_fold_", i, "_", model_suffix, ".rds"))
        if(is.null(dat)){
          dat_list[[i]] <- NA
        } else {
          dat_list[[i]] <- dat
        }
      }
      return(dat_list)
    }
  
    ###Function to compute accuracy of glmnet for all the folds
    compute_accuracy_glmnet <- function(dat, sample_size, thresh){
      r2 <- vector("list", length(dat))
      scaled_mse <- vector("list", length(dat))
      scaled_rmse <- vector("list", length(dat))
  
      for(i in 1:length(dat)){
        if(length(dat[[i]]) == 1 && is.na(dat[[i]])){
          r2[[i]] <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
          scaled_mse[[i]] <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
          scaled_rmse[[i]] <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        } else {
          acc_Yhat_test <- compute_accuracy(dat[[i]]$Ytest, dat[[i]]$Yhat_test_glmnet)
          acc_Ybar_train <- compute_accuracy(dat[[i]]$Ytest, matrix(dat[[i]]$Ybar_train, nrow=nrow(dat[[i]]$Ytest), ncol=ncol(dat[[i]]$Ytest), byrow=TRUE))
          r2[[i]] <- acc_Yhat_test$r2
          scaled_mse[[i]] <- acc_Yhat_test$mse/acc_Ybar_train$mse
          scaled_rmse[[i]] <- acc_Yhat_test$rmse/acc_Ybar_train$rmse
        }
      }
      return(list(r2=r2, scaled_mse=scaled_mse, scaled_rmse=scaled_rmse))
    }

    ###Function to compute accuracy of mr.mash/mtlasso for all the folds
    compute_accuracy_general <- function(dat, sample_size, thresh){
      r2 <- vector("list", length(dat))
      scaled_mse <- vector("list", length(dat))
      scaled_rmse <- vector("list", length(dat))
  
      for(i in 1:length(dat)){
        if(length(dat[[i]]) == 1 && is.na(dat[[i]])){
          r2[[i]] <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
          scaled_mse[[i]] <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
          scaled_rmse[[i]] <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        } else {
          acc_Yhat_test <- compute_accuracy(dat[[i]]$Ytest, dat[[i]]$Yhat_test)
          acc_Ybar_train <- compute_accuracy(dat[[i]]$Ytest, matrix(dat[[i]]$Ybar_train, nrow=nrow(dat[[i]]$Ytest), ncol=ncol(dat[[i]]$Ytest), byrow=TRUE))
          r2[[i]] <- acc_Yhat_test$r2
          scaled_mse[[i]] <- acc_Yhat_test$mse/acc_Ybar_train$mse
          scaled_rmse[[i]] <- acc_Yhat_test$rmse/acc_Ybar_train$rmse
        }
      }
      return(list(r2=r2, scaled_mse=scaled_mse, scaled_rmse=scaled_rmse))
    }

    thresh <- ${thresh}
    ###Load the data
    dat_first <- load_data("${mrmash_model:a}/", ${nfolds}, "${_genes}", "${data_suffix}", "batch_1_mrmash.first_pass")
    
    if("${analysis_stage}" == "second_pass"){
      dat_sec <- load_data("${mrmash_model:a}/", ${nfolds}, "${_genes}", "${data_suffix}", "batch_1_mrmash.second_pass")
    }

    if(${mtlasso_model:r} != "NULL"){
        dat_mtlasso <- load_data("${mtlasso_model:a}/", ${nfolds}, "${_genes}", "${data_suffix}", "mtlasso")
    }

    dat_input <- readRDS("${data_dir:a}/${_genes}.${data_suffix}.rds")

    ###Extract sample size
    sample_size <- apply(dat_input$y_res, 2, function(x){sum(!is.na(x))})
    sample_size <- data.frame(tissue=names(sample_size), sample_size)

    ###enet accuracy
    accuracy_enet <- compute_accuracy_glmnet(dat_first, sample_size, thresh)
    r2_enet <- do.call(cbind, accuracy_enet$r2)
    scaled_mse_enet <- do.call(cbind, accuracy_enet$scaled_mse)
    scaled_rmse_enet <- do.call(cbind, accuracy_enet$scaled_rmse)

    ###mr.mash accuracy
    ##First pass
    accuracy_mrmash_first <- compute_accuracy_general(dat_first, sample_size, thresh)
    r2_mrmash_first <- do.call(cbind, accuracy_mrmash_first$r2)
    scaled_mse_mrmash_first <- do.call(cbind, accuracy_mrmash_first$scaled_mse)
    scaled_rmse_mrmash_first <- do.call(cbind, accuracy_mrmash_first$scaled_rmse)
    ##Second pass
    if("${analysis_stage}" == "second_pass"){
      accuracy_mrmash_sec <- compute_accuracy_general(dat_sec, sample_size, thresh)
      r2_mrmash_sec <- do.call(cbind, accuracy_mrmash_sec$r2)
      scaled_mse_mrmash_first <- do.call(cbind, accuracy_mrmash_first$scaled_mse)
      scaled_rmse_mrmash_sec <- do.call(cbind, accuracy_mrmash_sec$scaled_rmse)
    }

    ###mtlasso accuracy
    if(${mtlasso_model:r} != "NULL"){
      ##Add Ybar_train to mtlasso object
      for(i in 1:length(dat_mtlasso)){
        dat_mtlasso[[i]]$Ybar_train <- dat_first[[i]]$Ybar_train
      }

      accuracy_mtlasso <- compute_accuracy_general(dat_mtlasso, sample_size, thresh)
      r2_mtlasso <- do.call(cbind, accuracy_mtlasso$r2)
      scaled_mse_mtlasso <- do.call(cbind, accuracy_mtlasso$scaled_mse)
      scaled_rmse_mtlasso <- do.call(cbind, accuracy_mtlasso$scaled_rmse)
    }

    ###Combined accuracy
    mean_r2 <- cbind(enet=rowMeans(r2_enet, na.rm=TRUE), mrmash_first=rowMeans(r2_mrmash_first, na.rm=TRUE))
    se_r2 <- cbind(enet=matrixStats::rowSds(r2_enet, na.rm=TRUE)/apply(r2_enet, 1, function(x){sum(is.finite(x))}), 
                   mrmash_first=matrixStats::rowSds(r2_mrmash_first, na.rm=TRUE)/apply(r2_mrmash_first, 1, function(x){sum(is.finite(x))}))

    mean_scaled_rmse <- cbind(enet=rowMeans(scaled_rmse_enet, na.rm=TRUE), mrmash_first=rowMeans(scaled_rmse_mrmash_first, na.rm=TRUE))
    se_scaled_rmse <- cbind(enet=matrixStats::rowSds(scaled_rmse_enet, na.rm=TRUE)/apply(scaled_rmse_enet, 1, function(x){sum(is.finite(x))}), 
                            mrmash_first=matrixStats::rowSds(scaled_rmse_mrmash_first, na.rm=TRUE)/apply(scaled_rmse_mrmash_first, 1, function(x){sum(is.finite(x))}))

    mean_scaled_mse <- cbind(enet=rowMeans(scaled_mse_enet, na.rm=TRUE), mrmash_first=rowMeans(scaled_mse_mrmash_first, na.rm=TRUE))
    se_scaled_mse <- cbind(enet=matrixStats::rowSds(scaled_mse_enet, na.rm=TRUE)/apply(scaled_mse_enet, 1, function(x){sum(is.finite(x))}), 
                            mrmash_first=matrixStats::rowSds(scaled_mse_mrmash_first, na.rm=TRUE)/apply(scaled_mse_mrmash_first, 1, function(x){sum(is.finite(x))}))
  
    if("${analysis_stage}" == "second_pass"){
      mean_r2 <- cbind(mean_r2, mrmash_second=rowMeans(r2_mrmash_sec, na.rm=TRUE))
      se_r2 <- cbind(se_r2, mrmash_second=matrixStats::rowSds(r2_mrmash_sec, na.rm=TRUE)/apply(r2_mrmash_sec, 1, function(x){sum(is.finite(x))}))

      mean_scaled_rmse <- cbind(mean_scaled_rmse, mrmash_second=rowMeans(scaled_rmse_mrmash_sec, na.rm=TRUE))
      se_scaled_rmse <- cbind(se_scaled_rmse, mrmash_second=matrixStats::rowSds(scaled_rmse_mrmash_sec, na.rm=TRUE)/apply(scaled_rmse_mrmash_sec, 1, function(x){sum(is.finite(x))}))
  
      mean_scaled_mse <- cbind(mean_scaled_mse, mrmash_second=rowMeans(scaled_mse_mrmash_sec, na.rm=TRUE))
      se_scaled_mse <- cbind(se_scaled_mse, mrmash_second=matrixStats::rowSds(scaled_mse_mrmash_sec, na.rm=TRUE)/apply(scaled_mse_mrmash_sec, 1, function(x){sum(is.finite(x))}))
    }

    if(${mtlasso_model:r} != "NULL"){
      mean_r2 <- cbind(mean_r2, mtlasso=rowMeans(r2_mtlasso, na.rm=TRUE))
      se_r2 <- cbind(se_r2, mtlasso=matrixStats::rowSds(r2_mtlasso, na.rm=TRUE)/apply(r2_mtlasso, 1, function(x){sum(is.finite(x))}))
  
      mean_scaled_rmse <- cbind(mean_scaled_rmse, mtlasso=rowMeans(scaled_rmse_mtlasso, na.rm=TRUE))
      se_scaled_rmse <- cbind(se_scaled_rmse, mtlasso=matrixStats::rowSds(scaled_rmse_mtlasso, na.rm=TRUE)/apply(scaled_rmse_mtlasso, 1, function(x){sum(is.finite(x))}))
  
      mean_scaled_mse <- cbind(mean_scaled_mse, mtlasso=rowMeans(scaled_mse_mtlasso, na.rm=TRUE))
      se_scaled_mse <- cbind(se_scaled_mse, mtlasso=matrixStats::rowSds(scaled_mse_mtlasso, na.rm=TRUE)/apply(scaled_mse_mtlasso, 1, function(x){sum(is.finite(x))}))
    }

    if(!all(is.nan(mean_r2))){
       mean_r2 <- data.frame(tissue=rownames(mean_r2), mean_r2)
       mean_r2_sample_size <- merge(mean_r2, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

       se_r2 <- data.frame(tissue=rownames(se_r2), se_r2)
       se_r2_sample_size <- merge(se_r2, sample_size, by="tissue", sort=FALSE, all.x=TRUE)
  
       ci_r2_enet <- cbind(lower=mean_r2_sample_size$enet-2*se_r2_sample_size$enet,
                           upper=mean_r2_sample_size$enet+2*se_r2_sample_size$enet)
                       
       ci_r2_mrmash_first <- cbind(lower=mean_r2_sample_size$mrmash_first-2*se_r2_sample_size$mrmash_first,
                                   upper=mean_r2_sample_size$mrmash_first+2*se_r2_sample_size$mrmash_first)

       res <- list(mean_r2=mean_r2_sample_size, se_r2=se_r2_sample_size, ci_mean_r2_enet=ci_r2_enet, ci_mean_r2_mrmash_first=ci_r2_mrmash_first)

       if("${analysis_stage}" == "second_pass"){
         ci_r2_mrmash_sec <- cbind(lower=mean_r2_sample_size$mrmash_sec-2*se_r2_sample_size$mrmash_sec,
                                   upper=mean_r2_sample_size$mrmash_sec+2*se_r2_sample_size$mrmash_sec)

         res$ci_mean_r2_mrmash_sec <- ci_r2_mrmash_sec
       }
  
       if(${mtlasso_model:r} != "NULL"){
         ci_r2_mtlasso <- cbind(lower=mean_r2_sample_size$mtlasso-2*se_r2_sample_size$mtlasso,
                                upper=mean_r2_sample_size$mtlasso+2*se_r2_sample_size$mtlasso)
  
         res$ci_mean_r2_mtlasso <- ci_r2_mtlasso
       } 
    } else {
      res <- NULL
    }

    if(!all(is.nan(mean_scaled_rmse))){
      mean_scaled_rmse <- data.frame(tissue=rownames(mean_scaled_rmse), mean_scaled_rmse)
      mean_scaled_rmse_sample_size <- merge(mean_scaled_rmse, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

      se_scaled_rmse <- data.frame(tissue=rownames(se_scaled_rmse), se_scaled_rmse)
      se_scaled_rmse_sample_size <- merge(se_scaled_rmse, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

      ci_scaled_rmse_enet <- cbind(lower=mean_scaled_rmse_sample_size$enet-2*se_scaled_rmse_sample_size$enet,
                                   upper=mean_scaled_rmse_sample_size$enet+2*se_scaled_rmse_sample_size$enet)

      ci_scaled_rmse_mrmash_first <- cbind(lower=mean_scaled_rmse_sample_size$mrmash_first-2*se_scaled_rmse_sample_size$mrmash_first,
                                           upper=mean_scaled_rmse_sample_size$mrmash_first+2*se_scaled_rmse_sample_size$mrmash_first)

      res$mean_scaled_rmse <- mean_scaled_rmse_sample_size
      res$se_scaled_rmse <- se_scaled_rmse_sample_size
      res$ci_mean_scaled_rmse_enet <- ci_scaled_rmse_enet
      res$ci_mean_scaled_rmse_mrmash_first <- ci_scaled_rmse_mrmash_first

      if("${analysis_stage}" == "second_pass"){
        ci_scaled_rmse_mrmash_sec <- cbind(lower=mean_scaled_rmse_sample_size$mrmash_sec-2*se_scaled_rmse_sample_size$mrmash_sec,
                                           upper=mean_scaled_rmse_sample_size$mrmash_sec+2*se_scaled_rmse_sample_size$mrmash_sec)

        res$ci_mean_scaled_rmse_mrmash_sec <- ci_scaled_rmse_mrmash_sec
      }
                     
      if(${mtlasso_model:r} != "NULL"){
        ci_scaled_rmse_mtlasso <- cbind(lower=mean_scaled_rmse_sample_size$mtlasso-2*se_scaled_rmse_sample_size$mtlasso,
                                        upper=mean_scaled_rmse_sample_size$mtlasso+2*se_scaled_rmse_sample_size$mtlasso)

        res$ci_mean_scaled_rmse_mtlasso <- ci_scaled_rmse_mtlasso
      }
    }

    if(!all(is.nan(mean_scaled_mse))){
      mean_scaled_mse <- data.frame(tissue=rownames(mean_scaled_mse), mean_scaled_mse)
      mean_scaled_mse_sample_size <- merge(mean_scaled_mse, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

      se_scaled_mse <- data.frame(tissue=rownames(se_scaled_mse), se_scaled_mse)
      se_scaled_mse_sample_size <- merge(se_scaled_mse, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

      ci_scaled_mse_enet <- cbind(lower=mean_scaled_mse_sample_size$enet-2*se_scaled_mse_sample_size$enet,
                                   upper=mean_scaled_mse_sample_size$enet+2*se_scaled_mse_sample_size$enet)

      ci_scaled_mse_mrmash_first <- cbind(lower=mean_scaled_mse_sample_size$mrmash_first-2*se_scaled_mse_sample_size$mrmash_first,
                                           upper=mean_scaled_mse_sample_size$mrmash_first+2*se_scaled_mse_sample_size$mrmash_first)

      res$mean_scaled_mse <- mean_scaled_mse_sample_size
      res$se_scaled_mse <- se_scaled_mse_sample_size
      res$ci_mean_scaled_mse_enet <- ci_scaled_mse_enet
      res$ci_mean_scaled_mse_mrmash_first <- ci_scaled_mse_mrmash_first

      if("${analysis_stage}" == "second_pass"){
        ci_scaled_mse_mrmash_sec <- cbind(lower=mean_scaled_mse_sample_size$mrmash_sec-2*se_scaled_mse_sample_size$mrmash_sec,
                                           upper=mean_scaled_mse_sample_size$mrmash_sec+2*se_scaled_mse_sample_size$mrmash_sec)

        res$ci_mean_scaled_mse_mrmash_sec <- ci_scaled_mse_mrmash_sec
      }
                     
      if(${mtlasso_model:r} != "NULL"){
        ci_scaled_mse_mtlasso <- cbind(lower=mean_scaled_mse_sample_size$mtlasso-2*se_scaled_mse_sample_size$mtlasso,
                                        upper=mean_scaled_mse_sample_size$mtlasso+2*se_scaled_mse_sample_size$mtlasso)

        res$ci_mean_scaled_mse_mtlasso <- ci_scaled_mse_mtlasso
      }
    }

    saveRDS(res, ${_output:ar})