# Multivariate EBNM based prior for M&M

Here for the simulation benchmark we prepare mixture prior based on a mulrivariate Emperical Bayes Normal Mean model (previously we use Extreme Deconvolution for the task).

## Approach

Here is the analysis plan:

1. Identify up to 20K genes where there is complete phenotype data to make a good / realistic residual variance estimate via FLASH
2. Simulate 20K data under my phenotypic models (the latest DSC benchmark setting) and generate sumstats for them ; bhat and sbhat
3. For each data-set, take the strongest gene-snp pair as the strong set
4. Also select from each data-set perhaps 4 "random" gene-snp pair.
5. then try to run your estimate of Vhat to get Vhat first, and run Yunqi / Peter's ED

In GTEx we have >35K genes. The reason we want to try using 20K is that 20K seems to have enough information learning about the pattern of sharing between conditions. 

But we "cheat" a bit by simulating under identity residual variance for all genes, and fit EBNM assuming residual variance is identity, too; or just estimating a global residual variance. This makes the problem easier. Because in practice residual can be different (though maybe similar!) for different genes.

So the simplified plan is to only do 2~5 with 2 using just identity matrix for residual variance.

## Workflow

In [None]:
[global]
parameter: cwd = path('/project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats')
parameter: model = 'artificial_mixture_identity' # 'gtex_mixture_identity'
parameter: per_chunk = 200
import glob

In [1]:
%cd /project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats

/project2/mstephens/gaow/mvarbvs/dsc/mnm_prototype/mnm_sumstats

### Get top gene-SNP and random gene-SNP pairs per gene

In [31]:
# extract data for MAHS from summary stats
[1]
parameter: seed = 999
parameter: n_random = 4
input: glob.glob(f'{cwd}/{model}/*.rds'), group_by = per_chunk
output: f"{cwd}/{model}/cache/{model}_{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    set.seed(${seed})
    matxMax <- function(mtx) {
        max_idx <- which.max(mtx)
        colmn <- max_idx %/% nrow(mtx) + 1
        row <- max_idx %% nrow(mtx)
        return( matrix(c(row, colmn), 1))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    extract_one_data = function(infile, n_random) {
        # If cannot read the input for some reason then let it go. I dont care losing one.
        dat = tryCatch(readRDS(infile)$sumstats, error = function(e) return(NULL))
        if (is.null(dat)) return(NULL)
        z = abs(dat$bhat/dat$sbhat)
        max_idx = matxMax(z)
        strong = list(bhat = dat$bhat[max_idx[1],,drop=F], sbhat = dat$sbhat[max_idx[1],,drop=F])
        if (max_idx[1] == 1) {
            sample_idx = 2:nrow(z)
        } else if (max_idx[1] == nrow(z)) {
            sample_idx = 1:(max_idx[1]-1)
        } else {
            sample_idx = c(1:(max_idx[1]-1), (max_idx[1]+1):nrow(z))
        }
        random_idx = sample(sample_idx, n_random, replace = T)
        random = list(bhat = dat$bhat[random_idx,,drop=F], sbhat = dat$sbhat[random_idx,,drop=F])
        return(list(random = remove_rownames(random),  strong = remove_rownames(strong)))
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
              for (s in names(one_data[[d]])) {
                  res[[d]][[s]] = rbind(res[[d]][[s]], one_data[[d]][[s]])
              }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      res = merge_data(res, extract_one_data(f, ${n_random}))
    }
    saveRDS(res, ${_output:r})
  
[2]
input: group_by = "all"
output: f"{cwd}/{model}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
              for (s in names(one_data[[d]])) {
                  res[[d]][[s]] = rbind(res[[d]][[s]], one_data[[d]][[s]])
              }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      res = merge_data(res, readRDS(f))
    }
    saveRDS(res, ${_output:r})

To run it:

```
for m in artificial_mixture_identity gtex_mixture_identity; do 
    sos run analysis/20200502_Prepare_ED_prior.ipynb --model $m -c midway2.yml -q midway2
done
```