# K-fold CV data setup for eQTL / TWAS

This notebook preprocess association analysis data for use in a K-fold cross validation context.

## Input

A list of RDS files with `X` and `Y` for regression, the desired K fold, as well as parameters to perform some initial filtering on the `X` matrix.

## Output

Data for each cross-validation fold as well as the corresponding GWAS summary stats $\hat{\beta}$, $se(\hat{\beta})$.

## Analysis examples

```
sos run cv_preprocessing.ipynb \
    --analysis-units data/27_brain_non_brain_genes_v8.txt \
    --sample-partition data/samples_folds.txt \
    --data-dir /project2/compbio/GTEx_eQTL/cis_eqtl_analysis_ready \
    --data-suffix GTEx_V8.rds \
    --name 20210512_CV \
    --wd mr_mash_analysis_ready \
    -c midway2.yml -q midway2
```

In [None]:
[global]
import glob
# single column file each line is the data filename
parameter: analysis_units = path
# Sample partition file
parameter: sample_partition = path
# Path to data directory
parameter: data_dir = path
# data file suffix
parameter: data_suffix = str
# Path to work directory where output locates
parameter: wd = path("./output")
# An identifier for your run of analysis
parameter: name = str
# Only analyze `cis` variants -- cis = N means using N variants around the center column of X matrix 
# Default to NULL to analyze all available variants
parameter: cis = 'NULL'
regions = [x.strip() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
genes = [f"{data_dir:a}/{x}.{data_suffix}" for x in regions if path(f"{data_dir:a}/{x}.{data_suffix}").exists()]

In [None]:
[process]
# standardize X or not 
parameter: standardize_X = "TRUE"
# standardize Y or not
parameter: standardize_Y = "FALSE"
# number of threads to use
parameter: nthreads = 1
# missing data filter
parameter: missing_rate_cutoff = 0.05
# MAF filter
parameter: maf_cutoff = 0.05
# Column variance filter
parameter: var_cutoff = 0.05
# X data table name
parameter: x_table = 'X'
# Y data table name
parameter: y_table = 'y_res'
input: genes, group_by = 1
output: f'{wd:a}/{_input:bn}{("_cis_%s" % cis) if cis != "NULL" else ""}_{name}.rds'
task: trunk_workers = 3, trunk_size = 150, walltime = '25m', mem = '6G', cores = nthreads, tags = f'{step_name}_{_output:bn}'
R: expand = '${ }', stdout = f"{_output[0]:n}.stdout", stderr = f"{_output[0]:n}.stderr"
    
    options(stringsAsFactors = FALSE)

    ###Set some parameter variables (These should be set in the SoS script)
    standardize <- ${standardize_X}
    standardize_response <- ${standardize_Y}
    nthreads <- ${nthreads}
    missing_rate_cutoff <- ${missing_rate_cutoff}
    maf_cutoff <- ${maf_cutoff}
    var_cutoff <- ${var_cutoff}

    ###Functions to compute MAF, missing genotype rate, impute missing, and filter X accordingly 
    compute_maf <- function(geno){
      f <- mean(geno,na.rm = TRUE)/2
      return(min(f, 1-f))
    }

    compute_missing <- function(geno){
      miss <- sum(is.na(geno))/length(geno)
      return(miss)
    }

    mean_impute <- function(geno){
      f <- apply(geno, 2, function(x) mean(x,na.rm = TRUE))
      for (i in 1:length(f)) geno[,i][which(is.na(geno[,i]))] <- f[i]
      return(geno)
    }

    filter_X <- function(X, missing_rate_thresh, maf_thresh, var_thresh) {
      rm_col <- which(apply(X, 2, compute_missing) > missing_rate_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      rm_col <- which(apply(X, 2, compute_maf) < maf_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      X <- mean_impute(X)
      rm_col <- which(matrixStats::colVars(X) < var_thresh)
      if (length(rm_col)) X <- X[, -rm_col]
      return(X)
    }
  
    get_center <- function(k,n) {
      ## For given number k, get the range k surrounding n/2
      ## but have to make sure it does not go over the bounds
      if (is.null(k)) {
          return(1:n)
      }
      start = floor(n/2 - k/2)
      end = floor(n/2 + k/2)
      if (start<1) start = 1
      if (end>n) end = n
      return(start:end)
    }
  
    dat = readRDS(${_input:r})
 
    ###Read in the data
    gtex_ids_folds <- read.table(${sample_partition:r}, header=TRUE, sep="	")

    ###Get fold names
    folds <- sort(unique(gtex_ids_folds$fold))

    ###List to store results
    res_all_folds <- vector("list", length(folds))
    names(res_all_folds) <- paste0("fold_", folds)

    ###Extract Y and (filter) X
    Y <- dat$${y_table}
    X <- filter_X(dat$${x_table}, missing_rate_cutoff, maf_cutoff, var_cutoff)
    X <- X[,get_center(${cis}, ncol(X))]

    for(i in folds){
      test_ids <- gtex_ids_folds[which(gtex_ids_folds$fold == i), 1]
      Xtrain <- X[!(rownames(X) %in% test_ids), ]
      Ytrain <- Y[!(rownames(Y) %in% test_ids), ]

      univ_sumstats <- mr.mash.alpha::compute_univariate_sumstats(X=Xtrain, Y=Ytrain, standardize=standardize,
                                                                  standardize.response=standardize_response,
                                                                  mc.cores=nthreads)
      res_all_folds[[i]] <- univ_sumstats
    }

    saveRDS(res_all_folds, ${_output:r})