# Data munggling for multi-variant summary stats

## Minimal working example

To see the input requirements and output data formats, please [download a minimal working example here](https://drive.google.com/file/d/1838xUOQuWTszQ0WJGXNiJMszY05cw3RS/view?usp=sharing), and run the following:

### Merge univariate results

```
sos run mixture_prior.ipynb merge \
    --analysis-units <FIXME> \
    --plink-sumstats <FIXME> \
    --name gtex_mixture
```

### Select and merge univariate effects

```
m=/path/to/data
cd $m && ls *.rds | sed 's/\.rds//g' > analysis_units.txt && cd -
sos run mixture_prior.ipynb extract_effects \
        --analysis-units $m/analysis_units.txt \
        --datadir $m --name `basename $m`
```

Notice that for production use, each `sos run` command should be submitted to the cluster as a job.

## Global parameters

In [None]:
[global]
import os
# Work directory & output directory
parameter: cwd = path('./output')
# The filename prefix for output data
parameter: name = str
parameter: mixture_components = ['flash', 'flash_nonneg', 'pca', 'canonical']
parameter: job_size = 1# Residual correlatoin file
parameter: resid_cor = path(".")
fail_if(not (resid_cor.is_file() or resid_cor == path('.')), msg = f'Cannot find ``{resid_cor}``')

## Merge PLINK univariate association summary statistic to RDS format

In [None]:
[merge]
parameter: molecular_pheno = path
# Analysis units file. For RDS files it can be generated by `ls *.rds | sed 's/\.rds//g' > analysis_units.txt`
parameter: analysis_units = path
regions = [x.strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
input:  molecular_pheno, for_each = "regions"
output: f'{cwd:a}/RDS/{_regions[0]}'

task: trunk_workers = 1, trunk_size = job_size, walltime = '4h',  mem = '6G', tags = f'{step_name}_{_output:bn}'  

R: expand = "$[ ]", stderr = f'{_output}.stderr', stdout = f'{_output}.stdout'
    library("dplyr")
    library("tibble")
    library("purrr")
    library("readr")
    molecular_pheno = read_delim($[molecular_pheno:r], delim = "\t")
    molecular_pheno = molecular_pheno%>%mutate(dir = map_chr(`#molc_pheno`,~paste(c(`.x`,"$[_regions[0]]"),collapse = "")))
    n = nrow(molecular_pheno)
    # For every condition read rds and extract the bhat and sbhat.
    genos = tibble( i = 1:n)
    genos = genos%>%mutate(bhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$bhat%>%as.data.frame%>%rownames_to_column),
                           sbhat = map(i, ~readRDS(molecular_pheno[[.x,2]])$sbhat%>%as.data.frame%>%rownames_to_column))
                      
    # Join first two conditions
    genos_join_bhat = full_join((genos%>%pull(bhat))[[1]],(genos%>%pull(bhat))[[2]],by = "rowname")
    genos_join_sbhat = full_join((genos%>%pull(sbhat))[[1]],(genos%>%pull(sbhat))[[2]],by = "rowname")
    
    # If there are more conditions, join the rest
    if(n > 2){
        for(j in 3:n){
            genos_join_bhat = full_join(genos_join_bhat,(genos%>%pull(bhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
            genos_join_sbhat = full_join(genos_join_sbhat,(genos%>%pull(sbhat))[[j]],by = "rowname")%>%select(-rowname)%>%as.matrix
        }
    }
    
    name = molecular_pheno%>%mutate(name = map(`#molc_pheno`, ~read.table(text = .x,sep = "/")),
                                    name = map_chr(name, ~.x[,ncol(.x)-2]%>%as.character) )%>%pull(name)
    colnames(genos_join_bhat) = name
    colnames(genos_join_sbhat) = name
    
    
    # save the rds file
    saveRDS(file = "$[_output]", list(bhat=genos_join_bhat, sbhat=genos_join_sbhat))

## Get top, random and null effects per analysis unit

In [None]:
# extract data for MASH from summary stats
[extract_effects_1]
parameter: table_name = ""
parameter: bhat = "bhat"
parameter: sbhat = "sbhat"
parameter: expected_ncondition = 0
parameter: datadir = path
parameter: seed = 999
parameter: n_random = 4
parameter: n_null = 4
parameter: z_only = True
# Analysis units file. For RDS files it can be generated by `ls *.rds | sed 's/\.rds//g' > analysis_units.txt`
parameter: analysis_units = path
# handle N = per_chunk data-set in one job
parameter: per_chunk = 1000
regions = [x.strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
input: [f'{datadir}/{x[0]}.rds' for x in regions], group_by = per_chunk
output: f"{cwd}/{name}/cache/{name}_{_index+1}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    set.seed(${seed})
    matxMax <- function(mtx) {
      return(arrayInd(which.max(mtx), dim(mtx)))
    }
    remove_rownames = function(x) {
        for (name in names(x)) rownames(x[[name]]) = NULL
        return(x)
    }
    handle_nan_etc = function(x) {
      x$bhat[which(is.nan(x$bhat))] = 0
      x$sbhat[which(is.nan(x$sbhat) | is.infinite(x$sbhat))] = 1E3
      return(x)
    }
    extract_one_data = function(dat, n_random, n_null, infile) {
        if (is.null(dat)) return(NULL)
        z = abs(dat$${bhat}/dat$${sbhat})
        max_idx = matxMax(z)
        # strong effect samples
        strong = list(bhat = dat$${bhat}[max_idx[1],,drop=F], sbhat = dat$${sbhat}[max_idx[1],,drop=F])
        # random samples excluding the top one
        if (max_idx[1] == 1) {
            sample_idx = 2:nrow(z)
        } else if (max_idx[1] == nrow(z)) {
            sample_idx = 1:(max_idx[1]-1)
        } else {
            sample_idx = c(1:(max_idx[1]-1), (max_idx[1]+1):nrow(z))
        }
        random_idx = sample(sample_idx, min(n_random, length(sample_idx)), replace = F)
        random = list(bhat = dat$${bhat}[random_idx,,drop=F], sbhat = dat$${sbhat}[random_idx,,drop=F])
        # null samples defined as |z| < 2
        null.id = which(apply(abs(z), 1, max) < 2)
        if (length(null.id) == 0) {
          warning(paste("Null data is empty for input file", infile))
          null = list()
        } else {
          null_idx = sample(null.id, min(n_null, length(null.id)), replace = F)
          null = list(bhat = dat$${bhat}[null_idx,,drop=F], sbhat = dat$${sbhat}[null_idx,,drop=F])
        }
        dat = (list(random = remove_rownames(random), null = remove_rownames(null), strong = remove_rownames(strong)))
        dat$random = handle_nan_etc(dat$random)
        dat$null = handle_nan_etc(dat$null)
        dat$strong = handle_nan_etc(dat$strong)
        return(dat)
    }
    reformat_data = function(dat, z_only = TRUE) {
        # make output consistent in format with 
        # https://github.com/stephenslab/gtexresults/blob/master/workflows/mashr_flashr_workflow.ipynb      
        res = list(random.z = dat$random$bhat/dat$random$sbhat, 
                  strong.z = dat$strong$bhat/dat$strong$sbhat,  
                  null.z = dat$null$bhat/dat$null$sbhat)
        if (!z_only) {
          res = c(res, list(random.b = dat$random$bhat,
           strong.b = dat$strong$bhat,
           null.b = dat$null$bhat,
           null.s = dat$null$sbhat,
           random.s = dat$random$sbhat,
           strong.s = dat$strong$sbhat))
      }
      return(res)
    }
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else if (is.null(one_data)) {
          return(res)
      } else {
          for (d in names(one_data)) {
            if (is.null(one_data[[d]])) {
              next
            } else {
                res[[d]] = rbind(res[[d]], one_data[[d]])
            }
          }
          return(res)
      }
    }
    res = list()
    for (f in c(${_input:r,})) {
      # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
      dat = tryCatch(readRDS(f), error = function(e) return(NULL))${("$"+table_name) if table_name != "" else ""}
      if (is.null(dat)) {
          message(paste("Skip loading file", f, "due to load failure."))
          next
      }
      if (${expected_ncondition} > 0 && (ncol(dat$${bhat}) != ${expected_ncondition} || ncol(dat$${sbhat}) != ${expected_ncondition})) {
          message(paste("Skip loading file", f, "because it has", ncol(dat$${bhat}), "columns different from required", ${expected_ncondition}))
          next
      }
      res = merge_data(res, reformat_data(extract_one_data(dat, ${n_random}, ${n_null}, f), ${"TRUE" if z_only else "FALSE"}))
    }
    saveRDS(res, ${_output:r})

In [None]:
[extract_effects_2]
input: group_by = "all"
output: f"{cwd}/{name}.rds"
task: trunk_workers = 1, walltime = '1h', trunk_size = 1, mem = '16G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    merge_data = function(res, one_data) {
      if (length(res) == 0) {
          return(one_data)
      } else {
          for (d in names(one_data)) {
            res[[d]] = rbind(res[[d]], one_data[[d]])
          }
          return(res)
      }
    }
    dat = list()
    for (f in c(${_input:r,})) {
      dat = merge_data(dat, readRDS(f))
    }
    # compute empirical covariance XtX
    X = dat$strong.z
    X[is.na(X)] = 0
    dat$XtX = t(as.matrix(X)) %*% as.matrix(X) / nrow(X)
    saveRDS(dat, ${_output:r})