# Multivariate EBNM based prior for M&M

Here for the simulation benchmark we prepare mixture prior based on a mulrivariate Emperical Bayes Normal Mean model (previously we use Extreme Deconvolution for the task).

## Approach

Here is the analysis plan:

1. Identify up to 20K genes where there is complete phenotype data to make a good / realistic residual variance estimate via FLASH
2. Simulate 20K data under my phenotypic models (the latest DSC benchmark setting) and generate sumstats for them ; bhat and sbhat
3. For each data-set, take the strongest gene-snp pair as the strong set
4. Also select from each data-set perhaps 4 "random" gene-snp pair.
5. then try to run your estimate of Vhat to get Vhat first, and run Yunqi / Peter's ED

In GTEx we have >35K genes. The reason we want to try using 20K is that 20K seems to have enough information learning about the pattern of sharing between conditions. 

But we "cheat" a bit by simulating under identity residual variance for all genes, and fit EBNM assuming residual variance is identity, too; or just estimating a global residual variance. This makes the problem easier. Because in practice residual can be different (though maybe similar!) for different genes.

So the simplified plan is to only do 2~5 with 2 using just identity matrix for residual variance.

## Workflow

In [None]:
[global]
parameter: cwd = path('/project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats')
parameter: model = 'artificial_mixture_identity' # 'gtex_mixture_identity'
parameter: per_chunk = 200
import glob

In [1]:
%cd /project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats

/project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats

### Get top gene-SNP and random gene-SNP pairs per gene

In [31]:
# extract data for MAHS from summary stats
[extract_1]
parameter: seed = 999
parameter: n_random = 4
input: glob.glob(f'{cwd}/{model}/*.rds'), group_by = per_chunk
output: f"{cwd}/{model}/cache/{model}_{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    set.seed(${seed})
    matxMax <- function(mtx) {
        max_idx <- which.max(mtx)
        colmn <- max_idx %/% nrow(mtx) + 1
        row <- max_idx %% nrow(mtx)
        return( matrix(c(row, colmn), 1))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    extract_one_data = function(infile, n_random) {
        # If cannot read the input for some reason then let it go. I dont care losing one.
        dat = tryCatch(readRDS(infile)$sumstats, error = function(e) return(NULL))
        if (is.null(dat)) return(NULL)
        z = abs(dat$bhat/dat$sbhat)
        max_idx = matxMax(z)
        strong = list(bhat = dat$bhat[max_idx[1],,drop=F], sbhat = dat$sbhat[max_idx[1],,drop=F])
        if (max_idx[1] == 1) {
            sample_idx = 2:nrow(z)
        } else if (max_idx[1] == nrow(z)) {
            sample_idx = 1:(max_idx[1]-1)
        } else {
            sample_idx = c(1:(max_idx[1]-1), (max_idx[1]+1):nrow(z))
        }
        random_idx = sample(sample_idx, n_random, replace = T)
        random = list(bhat = dat$bhat[random_idx,,drop=F], sbhat = dat$sbhat[random_idx,,drop=F])
        return(list(random = remove_rownames(random),  strong = remove_rownames(strong)))
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
              for (s in names(one_data[[d]])) {
                  res[[d]][[s]] = rbind(res[[d]][[s]], one_data[[d]][[s]])
              }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      res = merge_data(res, extract_one_data(f, ${n_random}))
    }
    saveRDS(res, ${_output:r})
  
[extract_2]
input: group_by = "all"
output: f"{cwd}/{model}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
              for (s in names(one_data[[d]])) {
                  res[[d]][[s]] = rbind(res[[d]][[s]], one_data[[d]][[s]])
              }
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    # make output consistent in format with 
    # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb
    saveRDS(
          list(random.z = dat$random$bhat/dat$random$sbhat,
           strong.z = dat$strong$bhat/dat$strong$sbhat, 
           random.b = dat$random$bhat,
           strong.b = dat$strong$bhat,
           random.s = dat$random$sbhat,
           strong.s = dat$strong$sbhat),
          ${_output:r})

To run it:

```
for m in artificial_mixture_identity gtex_mixture_identity; do 
    sos run analysis/20200502_Prepare_ED_prior.ipynb extract --model $m -c midway2.yml -q midway2
done
```

## Run extreme deconvolution using `mashr`

Before this, we need to run the following to generate FLASH mixture,

```
for m in artificial_mixture_identity gtex_mixture_identity; do
sos run ~/GIT/gtexresults/workflows/mashr_flashr_workflow.ipynb flash \
    --cwd /project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats/ \
    --data /project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats/$m.rds \
    --effect-model EE
done
```

We will use `simple` method to compute the residual variance, as implemented in the pipeline below.

In [None]:
[prior]
depends: R_library("mashr")
parameter: npc = 3
input: f"{cwd}/{model}.rds", f"{cwd}/{model}.EE.flash.rds"
output: f"{cwd}/{model}.FLASH_PC{npc}_ED.rds"
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '4G', cores = 4, tags = f'{_output:bn}'
R: expand = "${ }", workdir = cwd, stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout"
    library(mashr)
    dat = readRDS(${_input[0]:r})
    mash_data = mash_set_data(dat$strong.b, Shat=dat$strong.s, alpha=0)
    vhat = estimate_null_correlation_simple(mash_set_data(dat$random.b, Shat=dat$random.s))
    # FLASH matrices
    U.flash = readRDS(${_input[1]:r})
    # SVD matrices
    U.pca = ${"cov_pca(mash_data, %s)" % npc if npc > 0 else "list()"}
    # Emperical cov matrix
    X.center = apply(mash_data$Bhat, 2, function(x) x - mean(x))
    # Denoised data-driven matrices
    res = bovy_wrapper(mash_data, c(U.flash, U.pca, list("XX" = t(X.center) %*% X.center / nrow(X.center))), logfile=${_output:nr})
    saveRDS(res, ${_output:r})

In [None]:
```
sos run analysis/20200502_Prepare_ED_prior.ipynb prior --model artificial_mixture_identity -c midway2.yml -q midway2
sos run analysis/20200502_Prepare_ED_prior.ipynb prior --model gtex_mixture_identity -c midway2.yml -q midway2
```