# Prediction score evaluations

This notebook implements prediction score evaluations. The scripe is customized for a particular project with a number of hard-coded structure. Other users are advised to read and customize the script for their own applications.

## Input

RDS format of model fits as well as the original input data, listed in a text file (see `--analysis-units` in the example analysis).

## Output

Prediction score metric.

## Analysis examples

```
sos run prediction_score.ipynb \
    --analysis-units ../data/genes.txt \
    --data-dir ../data/cis_eqtl_analysis_ready  \
    --mrmash-model ../output/gtex_mr_mash_analysis/prediction \
    --data-suffix GTEx_V8 \
    --wd ../output/gtex_pred_score \
    -c midway2.yml -q midway2
```

To include `mt_lasso` results, add `--mtlasso-model /path/to/mtlasso/results` to the command above.

In [1]:
[global]
# single column file each line is a gene name
parameter: analysis_units = path
# Path to data directory
parameter: data_dir = path
# data file suffix
parameter: data_suffix = str
# Path to work directory where output locates
parameter: wd = path("./output")
genes = [x.strip() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#') and path(f"{data_dir:a}/{x}.{data_suffix}.rds").exists()]
fail_if(len(genes) == 0, msg = f"Cannot find valid input data with pattern {data_dir:a}/<name>.{data_suffix}.rds")

In [None]:
[pred_score]
# Path to mr.mash model fit directory
parameter: mrmash_model = path
# Path to mtlasso model fit directory, default to a placeholder path called "NULL"
parameter: mtlasso_model = path('NULL')
parameter: thresh = 100
input: for_each = "genes"
output: f"{wd:a}/{_genes}.{data_suffix}_score.rds"
task:  trunk_workers = 4, trunk_size = 8, walltime = '1h', mem = '5G', cores = 1, tags = f'{step_name}_{_output[0]:bn}'
R: expand = '${ }', stdout = f"{_output[0]:n}.stdout", stderr = f"{_output[0]:n}.stderr"

    ###Function to compute accuracy
    compute_accuracy <- function(Y, Yhat) {
      bias <- rep(as.numeric(NA), ncol(Y))
      names(bias) <- colnames(Y)
      r2 <- rep(as.numeric(NA), ncol(Y))
      names(r2) <- colnames(Y)
      scaled_rmse <- rep(as.numeric(NA), ncol(Y))
      names(scaled_rmse) <- colnames(Y)

      for(i in 1:ncol(Y)){ 
        dat <- na.omit(data.frame(Y[, i], Yhat[, i]))
        colnames(dat) <- c("Y", "Yhat")

        fit  <- lm(Y ~ Yhat, data=dat)
        bias[i] <- coef(fit)[2] 
        r2[i] <- summary(fit)$r.squared
        scaled_rmse[i] <- sqrt(mean((dat$Y - dat$Yhat)^2))/sd(dat$Y)
      }

      return(list(bias=bias, r2=r2, scaled_rmse=scaled_rmse))
    }

    thresh <- ${thresh}
    ###Load the data
    dat1_first <- readRDS("${mrmash_model:a}/fold_1/${_genes}.${data_suffix}_fold_1_batch_1_mrmash.first_pass.rds")
    dat2_first <- readRDS("${mrmash_model:a}/fold_2/${_genes}.${data_suffix}_fold_2_batch_1_mrmash.first_pass.rds")
    dat3_first <- readRDS("${mrmash_model:a}/fold_3/${_genes}.${data_suffix}_fold_3_batch_1_mrmash.first_pass.rds")
    dat4_first <- readRDS("${mrmash_model:a}/fold_4/${_genes}.${data_suffix}_fold_4_batch_1_mrmash.first_pass.rds")
    dat5_first <- readRDS("${mrmash_model:a}/fold_5/${_genes}.${data_suffix}_fold_5_batch_1_mrmash.first_pass.rds")

    dat1_sec <- readRDS("${mrmash_model:a}/fold_1/${_genes}.${data_suffix}_fold_1_batch_1_mrmash.second_pass.rds")
    dat2_sec <- readRDS("${mrmash_model:a}/fold_2/${_genes}.${data_suffix}_fold_2_batch_1_mrmash.second_pass.rds")
    dat3_sec <- readRDS("${mrmash_model:a}/fold_3/${_genes}.${data_suffix}_fold_3_batch_1_mrmash.second_pass.rds")
    dat4_sec <- readRDS("${mrmash_model:a}/fold_4/${_genes}.${data_suffix}_fold_4_batch_1_mrmash.second_pass.rds")
    dat5_sec <- readRDS("${mrmash_model:a}/fold_5/${_genes}.${data_suffix}_fold_5_batch_1_mrmash.second_pass.rds")

    if(${mtlasso_model:r} != "NULL"){
        dat1_mtlasso <- readRDS("${mtlasso_model:a}/fold_1/${_genes}.${data_suffix}_fold_1_batch_1_mtlasso.rds")
        dat2_mtlasso <- readRDS("${mtlasso_model:a}/fold_2/${_genes}.${data_suffix}_fold_2_batch_1_mtlasso.rds")
        dat3_mtlasso <- readRDS("${mtlasso_model:a}/fold_3/${_genes}.${data_suffix}_fold_3_batch_1_mtlasso.rds")
        dat4_mtlasso <- readRDS("${mtlasso_model:a}/fold_4/${_genes}.${data_suffix}_fold_4_batch_1_mtlasso.rds")
        dat5_mtlasso <- readRDS("${mtlasso_model:a}/fold_5/${_genes}.${data_suffix}_fold_5_batch_1_mtlasso.rds")
    }

    dat_input <- readRDS("${data_dir:a}/${_genes}.${data_suffix}.rds")

    sample_size <- apply(dat_input$y_res, 2, function(x){sum(!is.na(x))})
    sample_size <- data.frame(tissue=names(sample_size), sample_size)

    ###enet accuracy
    if(!is.null(dat1_first)){
        acc_enet1 <- compute_accuracy(dat1_first$Ytest, dat1_first$Yhat_test_glmnet)
        r2_enet1 <- acc_enet1$r2
        scaled_rmse_enet1 <- acc_enet1$scaled_rmse
    } else {
        r2_enet1 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_enet1 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat2_first)){
        acc_enet2 <- compute_accuracy(dat2_first$Ytest, dat2_first$Yhat_test_glmnet)
        r2_enet2 <- acc_enet2$r2
        scaled_rmse_enet2 <- acc_enet2$scaled_rmse
    } else {
        r2_enet2 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_enet2 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat3_first)){
        acc_enet3 <- compute_accuracy(dat3_first$Ytest, dat3_first$Yhat_test_glmnet)
        r2_enet3 <- acc_enet3$r2
        scaled_rmse_enet3 <- acc_enet3$scaled_rmse
    } else {
        r2_enet3 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_enet3 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat4_first)){
        acc_enet4 <- compute_accuracy(dat4_first$Ytest, dat4_first$Yhat_test_glmnet)
        r2_enet4 <- acc_enet4$r2
        scaled_rmse_enet4 <- acc_enet4$scaled_rmse
    } else {
        r2_enet4 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_enet4 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat5_first)){
        acc_enet5 <- compute_accuracy(dat5_first$Ytest, dat5_first$Yhat_test_glmnet)
        r2_enet5 <- acc_enet5$r2
        scaled_rmse_enet5 <- acc_enet5$scaled_rmse
    } else {
        r2_enet5 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_enet5 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    r2_enet <- cbind(r2_enet1, r2_enet2, r2_enet3, r2_enet4, r2_enet5)
    scaled_rmse_enet <- cbind(scaled_rmse_enet1, scaled_rmse_enet2, scaled_rmse_enet3, scaled_rmse_enet4, scaled_rmse_enet5)

    ###mr.mash accuracy
    if(!is.null(dat1_first) && !is.null(dat1_sec)){
        acc_mrmash1 <- compute_accuracy(dat1_sec$Ytest, dat1_sec$Yhat_test)
        r2_mrmash1 <- acc_mrmash1$r2
        scaled_rmse_mrmash1 <- acc_mrmash1$scaled_rmse
    } else {
        r2_mrmash1 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_mrmash1 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat2_first) && !is.null(dat2_sec)){
        acc_mrmash2 <- compute_accuracy(dat2_sec$Ytest, dat2_sec$Yhat_test)
        r2_mrmash2 <- acc_mrmash2$r2
        scaled_rmse_mrmash2 <- acc_mrmash2$scaled_rmse
    } else {
        r2_mrmash2 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_mrmash2 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat3_first) && !is.null(dat3_sec)){
        acc_mrmash3 <- compute_accuracy(dat3_sec$Ytest, dat3_sec$Yhat_test)
        r2_mrmash3 <- acc_mrmash3$r2
        scaled_rmse_mrmash3 <- acc_mrmash3$scaled_rmse
    } else {
        r2_mrmash3 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_mrmash3 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat4_first) && !is.null(dat4_sec)){
        acc_mrmash4 <- compute_accuracy(dat4_sec$Ytest, dat4_sec$Yhat_test)
        r2_mrmash4 <- acc_mrmash4$r2
        scaled_rmse_mrmash4 <- acc_mrmash4$scaled_rmse
    } else {
        r2_mrmash4 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_mrmash4 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    if(!is.null(dat5_first) && !is.null(dat5_sec)){
        acc_mrmash5 <- compute_accuracy(dat5_sec$Ytest, dat5_sec$Yhat_test)
        r2_mrmash5 <- acc_mrmash5$r2
        scaled_rmse_mrmash5 <- acc_mrmash5$scaled_rmse
    } else {
        r2_mrmash5 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        scaled_rmse_mrmash5 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
    }

    r2_mrmash <- cbind(r2_mrmash1, r2_mrmash2, r2_mrmash3, r2_mrmash4, r2_mrmash5)
    scaled_rmse_mrmash <- cbind(scaled_rmse_mrmash1, scaled_rmse_mrmash2, scaled_rmse_mrmash3, scaled_rmse_mrmash4, scaled_rmse_mrmash5)

    ###mtlasso accuracy
    if(mtlasso_pred){
        if(!is.null(dat1_first)){
            acc_mtlasso1 <- compute_accuracy(dat1_mtlasso$Ytest, dat1_mtlasso$Yhat_test)
            r2_mtlasso1 <- acc_mtlasso1$r2
            scaled_rmse_mtlasso1 <- acc_mtlasso1$scaled_rmse
        } else {
            r2_mtlasso1 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
            scaled_rmse_mtlasso1 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        }

        if(!is.null(dat2_first)){
            acc_mtlasso2 <- compute_accuracy(dat2_mtlasso$Ytest, dat2_mtlasso$Yhat_test)
            r2_mtlasso2 <- acc_mtlasso2$r2
            scaled_rmse_mtlasso2 <- acc_mtlasso2$scaled_rmse
        } else {
            r2_mtlasso2 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
            scaled_rmse_mtlasso2 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        }

        if(!is.null(dat3_first)){
            acc_mtlasso3 <- compute_accuracy(dat3_mtlasso$Ytest, dat3_mtlasso$Yhat_test)
            r2_mtlasso3 <- acc_mtlasso3$r2
            scaled_rmse_mtlasso3 <- acc_mtlasso3$scaled_rmse
        } else {
            r2_mtlasso3 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
            scaled_rmse_mtlasso3 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        }

        if(!is.null(dat4_first)){
            acc_mtlasso4 <- compute_accuracy(dat4_mtlasso$Ytest, dat4_mtlasso$Yhat_test)
            r2_mtlasso4 <- acc_mtlasso4$r2
            scaled_rmse_mtlasso4 <- acc_mtlasso4$scaled_rmse
        } else {
            r2_mtlasso4 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
            scaled_rmse_mtlasso4 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        }

        if(!is.null(dat5_first)){
            acc_mtlasso5 <- compute_accuracy(dat5_mtlasso$Ytest, dat5_mtlasso$Yhat_test)
            r2_mtlasso5 <- acc_mtlasso5$r2
            scaled_rmse_mtlasso5 <- acc_mtlasso5$scaled_rmse
        } else {
            r2_mtlasso5 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
            scaled_rmse_mtlasso5 <- rep(as.numeric(NA), sum(sample_size[, 2] > thresh))
        }

        r2_mtlasso <- cbind(r2_mtlasso1, r2_mtlasso2, r2_mtlasso3, r2_mtlasso4, r2_mtlasso5)
        scaled_rmse_mtlasso <- cbind(scaled_rmse_mtlasso1, scaled_rmse_mtlasso2, scaled_rmse_mtlasso3, scaled_rmse_mtlasso4, scaled_rmse_mtlasso5)
    }


    ###Combined accuracy
    if(mtlasso_pred){
        mean_r2 <- cbind(enet=rowMeans(r2_enet, na.rm=TRUE), mrmash=rowMeans(r2_mrmash, na.rm=TRUE), mtlasso=rowMeans(r2_mtlasso, na.rm=TRUE))
        se_r2 <- cbind(enet=matrixStats::rowSds(r2_enet, na.rm=TRUE)/apply(r2_enet, 1, function(x){sum(is.finite(x))}), 
                        mrmash=matrixStats::rowSds(r2_mrmash, na.rm=TRUE)/apply(r2_mrmash, 1, function(x){sum(is.finite(x))}),
                        mtlasso=matrixStats::rowSds(r2_mtlasso, na.rm=TRUE)/apply(r2_mtlasso, 1, function(x){sum(is.finite(x))}))

        mean_scaled_rmse <- cbind(enet=rowMeans(scaled_rmse_enet, na.rm=TRUE), mrmash=rowMeans(scaled_rmse_mrmash, na.rm=TRUE), mtlasso=rowMeans(scaled_rmse_mtlasso, na.rm=TRUE))
        se_scaled_rmse <- cbind(enet=matrixStats::rowSds(scaled_rmse_enet, na.rm=TRUE)/apply(scaled_rmse_enet, 1, function(x){sum(is.finite(x))}), 
                        mrmash=matrixStats::rowSds(scaled_rmse_mrmash, na.rm=TRUE)/apply(scaled_rmse_mrmash, 1, function(x){sum(is.finite(x))}),
                        mtlasso=matrixStats::rowSds(scaled_rmse_mtlasso, na.rm=TRUE)/apply(scaled_rmse_mtlasso, 1, function(x){sum(is.finite(x))}))
    } else {
        mean_r2 <- cbind(enet=rowMeans(r2_enet, na.rm=TRUE), mrmash=rowMeans(r2_mrmash, na.rm=TRUE))
        se_r2 <- cbind(enet=matrixStats::rowSds(r2_enet, na.rm=TRUE)/apply(r2_enet, 1, function(x){sum(is.finite(x))}), 
                        mrmash=matrixStats::rowSds(r2_mrmash, na.rm=TRUE)/apply(r2_mrmash, 1, function(x){sum(is.finite(x))}))

        mean_scaled_rmse <- cbind(enet=rowMeans(scaled_rmse_enet, na.rm=TRUE), mrmash=rowMeans(scaled_rmse_mrmash, na.rm=TRUE))
        se_scaled_rmse <- cbind(enet=matrixStats::rowSds(scaled_rmse_enet, na.rm=TRUE)/apply(scaled_rmse_enet, 1, function(x){sum(is.finite(x))}), 
                        mrmash=matrixStats::rowSds(scaled_rmse_mrmash, na.rm=TRUE)/apply(scaled_rmse_mrmash, 1, function(x){sum(is.finite(x))}))
    }

    if(!all(is.nan(mean_r2))){
        mean_r2 <- data.frame(tissue=rownames(mean_r2), mean_r2)
        mean_r2_sample_size <- merge(mean_r2, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

        se_r2 <- data.frame(tissue=rownames(se_r2), se_r2)
        se_r2_sample_size <- merge(se_r2, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

        ci_r2_enet <- cbind(lower=mean_r2_sample_size$enet-2*se_r2_sample_size$enet,
                               upper=mean_r2_sample_size$enet+2*se_r2_sample_size$enet)

        ci_r2_mrmash <- cbind(lower=mean_r2_sample_size$mrmash-2*se_r2_sample_size$mrmash,
                               upper=mean_r2_sample_size$mrmash+2*se_r2_sample_size$mrmash)

        res <- list(mean_r2=mean_r2_sample_size, se_r2=se_r2_sample_size, ci_mean_r2_enet=ci_r2_enet, ci_mean_r2_mrmash=ci_r2_mrmash)

        if(mtlasso_pred){
            ci_r2_mtlasso <- cbind(lower=mean_r2_sample_size$mtlasso-2*se_r2_sample_size$mtlasso,
                                    upper=mean_r2_sample_size$mtlasso+2*se_r2_sample_size$mtlasso)

            res$ci_mean_r2_mtlasso <- ci_r2_mtlasso
        } 
    } else {
        res <- NULL
    }

    if(!all(is.nan(mean_scaled_rmse))){
        mean_scaled_rmse <- data.frame(tissue=rownames(mean_scaled_rmse), mean_scaled_rmse)
        mean_scaled_rmse_sample_size <- merge(mean_scaled_rmse, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

        se_scaled_rmse <- data.frame(tissue=rownames(se_scaled_rmse), se_scaled_rmse)
        se_scaled_rmse_sample_size <- merge(se_scaled_rmse, sample_size, by="tissue", sort=FALSE, all.x=TRUE)

        ci_scaled_rmse_enet <- cbind(lower=mean_scaled_rmse_sample_size$enet-2*se_scaled_rmse_sample_size$enet,
                               upper=mean_scaled_rmse_sample_size$enet+2*se_scaled_rmse_sample_size$enet)

        ci_scaled_rmse_mrmash <- cbind(lower=mean_scaled_rmse_sample_size$mrmash-2*se_scaled_rmse_sample_size$mrmash,
                               upper=mean_scaled_rmse_sample_size$mrmash+2*se_scaled_rmse_sample_size$mrmash)

        res$mean_scaled_rmse <- mean_scaled_rmse_sample_size
        res$se_scaled_rmse <- se_scaled_rmse_sample_size
        res$ci_mean_scaled_rmse_enet <- ci_scaled_rmse_enet
        res$ci_mean_scaled_rmse_mrmash <- ci_scaled_rmse_mrmash

        if(mtlasso_pred){
            ci_scaled_rmse_mtlasso <- cbind(lower=mean_scaled_rmse_sample_size$mtlasso-2*se_scaled_rmse_sample_size$mtlasso,
                                    upper=mean_scaled_rmse_sample_size$mtlasso+2*se_scaled_rmse_sample_size$mtlasso)

            res$ci_mean_scaled_rmse_mtlasso <- ci_scaled_rmse_mtlasso
        }	
    }
    saveRDS(res, ${_output:ar})