# Finemapping benchmark

Methods evaluated:

- Variational methods:
    - spike-slab, mixture normal, sum of single effects, m&m
- Popular fine-mapping methods:
    - DAP, FINEMAP, CAVIAR
    
[PAINTOR](https://github.com/gkichaev/PAINTOR_V3.0) is not included because [FINEMAP is recommanded over PAINTOR when used without annotation](https://github.com/gkichaev/PAINTOR_V3.0/issues/11#issuecomment-303135031).

## DSC run

### `mnm.dsc`

Master DSC script.

* `debug_mnm_2`: this shows increase on ELBO.

In [1]:
%save -f mnm.dsc
#!/usr/bin/env dsc

%include modules/setup
%include modules/simulate
%include modules/fit
%include modules/evaluate

DSC:
  define:
    get_data: full_data, lite_data, liter_data, two_effect
    get_Y: original_Y
    fit: (init_mnm * fit_mnm * plot_sse), 
        fit_susie, fit_varbvs, 
        (fit_finemap * plot_finemap), 
        (fit_dap * plot_dap)
  run:
    setup: liter_data * summarize_ld
    benchmark: full_data * summarize_ld * get_Y * get_sumstats * fit
    debug_mnm_1: lite_data * summarize_ld * get_Y * get_sumstats * init_mnm * fit_mnm_debug
    debug_mnm_2: liter_data * summarize_ld * get_Y * get_sumstats * init_mnm * fit_mnm_debug   
    caviar: lite_data * summarize_ld * get_Y * get_sumstats * (fit_caviar * plot_caviar)
  output: benchmark
  exec_path: modules
  global:
    data_file: ~/Documents/GTExV8/Thyroid.Lung.FMO2.filled.rds

## DSC modules

### `setup.dsc`

Data generators. See `20171103_MNMASH_Data.ipynb` for GTEx multitissue data preparation, if more real data are needed.

In [39]:
%save -f modules/setup.dsc

# Modules to provide data
# Real or simulated

# Module output
# =============
# $data: full data
# $sumstats: summary statistics

full_data: R(data =readRDS(${data_file});
            if (end>start) data$X = as.matrix(data$X[,start:end]);
            r2 = cor(data$X);
            saveRDS(r2 ^ 2 * sign(r2), ld_mat);
            write.table(r2,ld_file,quote=F,col.names=F,row.names=F))
  tag: full
  start, end: (0, 0)
  $data: data
  $top_idx: raw(NULL)
  $ld_file: file(ld)
  $ld_mat: file(rds)
        
lite_data(full_data):
  tag: 2k
  start, end: (2500, 4500)
             
liter_data(full_data):
  tag: 1k
  start, end: (3000, 4000)           
            
two_effect(full_data):
  tag: two
  start, end: (3500, 3501)

get_sumstats: regression.R + R(res = mm_regression(as.matrix(data$X), 
                                                   as.matrix(data$Y)))
  @CONF: R_libs = abind
  data: $data
  $sumstats: res
                                                   
summarize_ld: lib_regression_simulator.py + \
                regression_simulator.py + \
                Python(res = summarize_LD(data['X'], ld_mat, ld_plot))
  data: $data
  ld_mat: $ld_mat
  $ld_plot: file(png)
  $top_idx: res

### `simulate.dsc`

In [41]:
%save -f modules/simulate.dsc

# base_sim:
# - A base simulator of 2 independent multivariate effects
# - using MultivariateMixture
# original_Y：
# - do not simulate data, just use original

base_sim: lib_regression_simulator.py + \
                regression_simulator.py + \
                Python(data = simulate_main(data, conf, conf['cache']))
  @CONF: python_modules = (seaborn, matplotlib, pprint)
  data: $data
  top_idx: $top_idx
  n_signal: 3
  n_traits: 2
  eff_mode: mash_low_het
  residual_mode: identity
  swap_eff: raw(True)
  keep_ld: raw(True)
  center_data: raw(True)
  cache: file(sim)
  tag: sim1
  @ALIAS: conf = Dict(!data, !eff_mode)
  $data: data
  $V: np.cov(data['Y'], rowvar = False)
  $N: data['Y'].shape[0]

original_Y(base_sim):
  eff_mode: original

### `fit.dsc`

Fine mapping methods.

In [4]:
%save -f modules/fit.dsc
# workhorse(s) for finemapping

# Module input
# ============
# $data: full data; or
# $sumstats: summary statistics; or / and
# $ld: LD information

# Module output
# =============
# $fitted: for diagnostics
# $posterior: for inference

init_mnm: init_mnm.R
  # mashr comes from `dev` branch on github
  @CONF: R_libs = mashr
  V: $V
  reg: $sumstats
  # FIXME: these quantities are to be computed seperately and globally using mashr procedure
  # See http://stephenslab.github.io/gtex-eqtls/analysis/20171002_MASH_V8.html
  Sigma: empirical
  (U, grid, p): (auto, (0.9,0.01,0.01,0.01,0.01,0.01,0.01,0.02,0.02), auto)
  $model: model
  $V: V

fit_mnm_debug: regression.R + elbo_mnm.R + fit_mnm.R
  @CONF: R_libs = mashr
  maxL: 5
  maxI: 20
  get_elbo: raw(T)
  data: $data
  model: $model
  V: $V
  $fitted: fitted_track
  $posterior: posterior

fit_mnm(fit_mnm_debug):
  maxI: 10
  get_elbo: raw(F)

fit_susie: fit_susie.R
  # Prior variance of nonzero effects.
  @CONF: R_libs = susieR@stephenslab/susieR
  maxL: 5
  maxI: 50
  data: $data
  $posterior: posterior
  $fitted: fitted

fit_varbvs(fit_susie): setup_varbvs.R + fit_varbvs.R
  @CONF: R_libs = varbvs@pcarbo/varbvs/varbvs-R
  sa: 1

fit_caviar: fit_caviar.R + \
             R(posterior = finemap_mcaviar(sumstats[1,,]/sumstats[2,,], 
                                            ld, args, prefix=cache))
  @CONF: R_libs = (dplyr, magrittr)
  sumstats: $sumstats
  ld: $ld_file
  args: -c 1, -c 2
  cache: file(CAVIAR)
  $posterior: posterior

fit_finemap(fit_caviar): fit_finemap.R + \
             R(posterior = finemap_mvar(sumstats[1,,]/sumstats[2,,], 
                                        ld, N, k,
                                        args, prefix=cache))
  N: $N
  k: R(rep(1/5,5)), (0.6,0.25,0.1,0.05)
  args: --regions 1 --prior-std 0.4 --n-causal-max 5
  cache: file(FM)

fit_dap: fit_dap.py + Python(posterior = dap_batch(data['X'], data['Y'], cache, args))
  data: $data
  args: -ld_control 0.25
  cache: file(DAP)
  $posterior: posterior

# fit_dap_mv(fit_dap): fit_dap.py + Python(res = dap_mv())

# fit_dap_ss(fit_dap): fit_dap.py + Python(res = dap_batch_ss())
#   data: $sumstats

# fit_dap_mv_ss(fit_dap): fit_dap.py + Python(res = dap_mv_ss())

### `evaluate.dsc`

Methods evaluation / visualization.

In [43]:
%save -f modules/evaluate.dsc
# Modules to evaluate various methods
# for finemapping-m

# Module input
# ============
# $fit: see fit.dsc
# $result: see fit.dsc

# Module output
# =============
# ? an object or plot for diagnosis

plot_finemap: plot_finemap.R
  @CONF: R_libs = (dplyr, ggplot2, cowplot)
  result: $posterior
  top_rank: 10
  $plot_file: file(pdf)

plot_caviar(plot_finemap): plot_caviar.R
plot_dap(plot_finemap): plot_dap.R

plot_sse: lib_regression_simulator.py + \
            plot_sse.py + \
            Python(plot_sse(result['PosteriorMean'], data['true_coef'],
                            result['in_CI'], ld_mat, plot_file))
  @CONF: python_modules = seaborn
  data: $data
  result: $posterior
  ld_mat: $ld_mat
  $plot_file: file(SSE)

## Workhorses

### `regression.R`

In [6]:
%save -f modules/regression.R
## Perform univariate regression for each column of Y on each column of X
univariate_regression = function(X, y, Z=NULL, return_residue=FALSE) {
  if (!is.null(Z)) {
    y = .lm.fit(Z, y)$residuals
  }
  calc_stderr = function(X, residuals) {
    # S = (X'X)^-1 \Sigma
    sqrt(diag(sum(residuals^2) / (nrow(X) - 2) * chol2inv(chol(t(X) %*% X))))
  }
  output = do.call(rbind,
                   lapply(1:ncol(X), function(i) {
                     g = .lm.fit(cbind(1, X[,i]), y)
                     return(c(coef(g)[2], calc_stderr(cbind(1, X[,i]), g$residuals)[2]))
                   })
                   )
  if (return_residue) {
    return(list(betahat = output[,1], sebetahat = output[,2],
                residuals = y))
  } else {
    return(list(betahat = output[,1], sebetahat = output[,2]))
  }
}

library(abind)
mm_regression = function(X, Y, Z=NULL) {
  reg = lapply(seq_len(ncol(Y)), function (i) simplify2array(univariate_regression(X, Y[,i])))
  reg = do.call(abind, c(reg, list(along=0)))
  # return array: out[1,,] is betahat, out[2,,] is shat
  return(aperm(reg, c(3,2,1)))
}

### `setup_varbvs.R`

In [7]:
%save -f modules/setup_varbvs.R

X <- data$X
storage.mode(X) <- "double"
n <- nrow(X)
p <- ncol(X)
X <- scale(X,center = TRUE,scale = FALSE)
alpha0  <- runif(p)
alpha0  <- alpha0/sum(alpha0)
mu0     <- rnorm(p)
pp      <- rep(maxL/p, p)
logodds <- varbvs:::logit(pp)
Y <- data$Y
for (r in 1:ncol(Y)) {
  Y[,r] <- Y[,r] - mean(Y[,r])
}
storage.mode(Y) <- "double"

### `fit_varbvs.R`

In [8]:
%save -f modules/fit_varbvs.R
fitted <- list()
for (r in 1:ncol(Y)) {
  sigma <- var(Y[,r])
  fitted[[r]] <- varbvs::varbvsnorm(X,Y[,r],sigma,sa,logodds,alpha0,mu0,update.order = 1:p,
                                    update.sigma = FALSE,update.sa = FALSE,tol = 1e-6,
                                    verbose = FALSE, maxiter=maxI)
}

post_mean <- do.call(cbind, lapply(1:length(fitted), function(i) fitted[[i]]$alpha * fitted[[i]]$mu))
lfdr <- do.call(cbind, lapply(1:length(fitted), function(i) 1 - fitted[[i]]$alpha))
posterior <- list(PosteriorMean=post_mean, lfdr=lfdr)

### `init_mnm.R`

In [9]:
%save -f modules/init_mnm.R
# Initialize model data: priors and init values

if (Sigma != 'empirical') {
  # FIXME data$V has to be changed
  V = diag(nrow(V))
}
mash_data = mashr::mash_set_data(reg[1,,], Shat = reg[2,,], V = as.matrix(V))
if (U == 'auto') {
  U = mashr::cov_canonical(mash_data)
} else {
  ## FIXME: add other methods to get U
  U = mashr::cov_canonical(mash_data)
}
model = list()
if (p == 'auto') {
  model$fitted_g = mashr::mash(mash_data, Ulist=U, outputlevel=1, usepointmass=TRUE)$fitted_g
} else {
  ## FIXME: need to use pre-fitted pi on larger data from mash procedure
  model$fitted_g = list(pi=p, Ulist=U, grid=grid, usepointmass=TRUE)
}

### `elbo_mnm.R`

Here we compute ELBO along the lines of the FLASH paper: justify that it is the multivariate normal-mean problem (MASH), then computer ELBO mostly using MASH updates.

In [10]:
%save -f modules/elbo_mnm.R
#' @title Residual covariance for a M&M fit
#' B is J X R matrix of M&M output
#' SM is J X R X R matrix of M&M output with 2nd moment option on
#' XtX is just precomputed t(X) %*% X
compute_mnm_residual_covariance = function(X, Y, XtX, B, SM) {
    # out = t(Y) %*% Y - 2 * t(B) %*% t(X) %*% Y # + E[B^TX^TXB]
    # E[B^TX^TXB] is not easy to compute properly
    # use MLE for now
    return(t(Y - X%*%B) %*% (Y - X%*%B) / nrow(X))
} 
                                            
#' @title expected loglikelihood for a M&M fit
# https://gaow.github.io/mvarbvs/writeup/20171215_MNMModel_Finemap.html
#' S is simply XtX pre-computed
#' Sigma is current estimate of residual variance
#' B is J X R matrix of M&M output
#' SM is J X R X R matrix of M&M output with 2nd moment option on
compute_mnm_Eloglik = function(X,Y,S,Sigma,B,SM){
    inv_Sigma = solve(Sigma)
    det_Sigma = det(Sigma)
    N = nrow(Y)
    R = ncol(Y)
    t0 = vector()
    for (j in 1:ncol(X)) {
        t0[j] = S[j,j] * sum(inv_Sigma * SM[,,j])
    }
    t1 = sum(diag(inv_Sigma %*% t(B) %*% S %*% B)) + sum(t0) -
            2 * sum(diag(Y %*% inv_Sigma %*% t(B) %*% t(X))) +
            sum(diag(Y %*% inv_Sigma %*% t(Y)))
    out = -0.5 * N * R * log(2 * pi) - 0.5 * N * log(det_Sigma) - 0.5 * t1
    return(out)
}

#' @title posterior expected loglikelihood for a MASH problem
## E[log(\hat{B}|B, Shat)]
## Need posterior mean and posterior second moment from MASH
## do not use any computational trick here because this is 
## just for sanity check
compute_mash_Eloglik = function(betahat, Shat, b, b2) {
    inv_Shat = solve(Shat)
    det_Shat = det(Shat)
    res = nrow(b) * log(2*pi) + log(det_Shat) +
            (t(betahat) %*% inv_Shat %*% betahat -
            2 * t(betahat) %*% inv_Shat %*% b +
            sum(diag(inv_Shat %*% b2)))
    return(-0.5 * res)
}

#' @title sum of MASH posterior expected loglikelihood
#' Bhat is J x R matrix of MASH input
#' SDhat is J X R matrix to be expanded with V, turning into J X R X R
#' V is R X R matrix of MASH input
#' Sigma is residual variance
#' alpah is a J vector of weights
#' B is J X R matrix of MASH output
#' SM is J X R X R matrix of MASH output with 2nd moment option on
compute_sse_Eloglik = function(Bhat, SDhat, V, Sigma, alpha, B, SM) {
    ## FIXME: I think it is wrong here because it is not single effect model
    ## where J effects should NOT be factorized.
    ## But otherwise isn't it a matrix normal density with both row and column covariances?
    res = vector()
    for (j in 1:nrow(Bhat)) {
        ## Is R X R
        Shat = SDhat[j,] * t(V * SDhat[j,]) # faster than diag(SDhat[j,]) %*% V %*% diag(SDhat[j,])
        ## Is R X 1
        B_j = B[j,] * alpha[j]
        ## 2nd moment, R X R
        B2_j = (B[j,] %*% t(B[j,]) + SM[,,j]) * alpha[j]
        res[j] = compute_mash_Eloglik(Bhat[j,], Shat, B_j, B2_j)
    }
    return(sum(res))
}

### `fit_mnm.R`

In [42]:
%save -f modules/fit_mnm.R
## M&M ash module core update
mnm_update_model <- function(X, Y, V, fitted_g, fitted, get_kl = FALSE) {
  ## "fitted" include p_alpha, alpha, mu and Xr
  maxL = ncol(fitted$alpha)
  for (l in 1:maxL) {
    ## remove the lth effect
    fitted$Xr <- fitted$Xr - X %*% (fitted$alpha[,l] * fitted$mu[[l]])
    ## update mash model
    reg <- mm_regression(X, Y - fitted$Xr)
    mash_data <- mashr::mash_set_data(reg[1,,], Shat = reg[2,,], V = V)
    mout <- mashr::mash(mash_data, g = fitted_g, fixg = TRUE, outputlevel=3)
    ## update fitted values
    fitted$mu[[l]] <- mout$result$PosteriorMean
    fitted$s[[l]] <- mout$result$PosteriorCov
    fitted$lfsr[[l]] <- mout$result$lfsr
    fitted$neg[[l]] <- mout$result$NegativeProb
    l10bf <- mashr::get_log10bf(mout)
    ## FIXME: mashr issue 35
    l10bf[is.infinite(l10bf)] <- range(l10bf, finite=TRUE)[2] * 100
    alpha_post <- exp((l10bf - max(l10bf)) * log(10)) * fitted$p_alpha
    fitted$alpha[,l] <- alpha_post / sum(alpha_post)
    ## add back the updated lth effect
    fitted$Xr <- fitted$Xr + X %*% (fitted$alpha[,l] * fitted$mu[[l]])
    if (get_kl) {
        # Justified by A.46 of FLASH paper
        # Here KL is denoted as (13.28) of BDA 3
        fitted$kl[l] <- -1 * mout$loglik + compute_sse_Eloglik(reg[1,,], reg[2,,], V,
                                                               fitted$Sigma, 
                                                               fitted$alpha[,l],
                                                               mout$result$PosteriorMean,
                                                               mout$result$PosteriorCov)
    }
  }
  return(fitted)
}

## Compute posterior mean and covariances
mnm_compute_posterior_matrices = function(fitted, J, R, L) {
    post_mean <- matrix(0, J, R)
    for (l in 1:L) {
      post_mean <- post_mean + fitted$mu[[l]] * fitted$alpha[,l]
    }
    post_cov <- array(0, dim=c(R, R, J))
    for (j in 1:J) {
      for (l in 1:L) {
        post_cov[,,j] <- post_cov[,,j] + (fitted$mu[[l]][j,] %*% t(fitted$mu[[l]][j,]) + fitted$s[[l]][,,j]) * fitted$alpha[j,l]
      }
      post_cov[,,j] <- post_cov[,,j] - post_mean[j,] %*% t(post_mean[j,])
    }
    return(list(PosteriorMean = post_mean, PosteriorCov = post_cov))
}

## Initialize storage for results
data$X <- as.matrix(data$X)
data$Y <- as.matrix(data$Y)
maxL <- min(maxL, ncol(data$X))
p_alpha <- rep(1, ncol(data$X)) / ncol(data$X)
alpha <- matrix(0, ncol(data$X), maxL)
mu <- lapply(1:maxL, function(i) matrix(0, ncol(data$X), ncol(data$Y)))
Xr <- matrix(0, nrow(data$Y), ncol(data$Y))
fitted <- list(p_alpha=p_alpha, alpha=alpha, mu=mu, s=list(), Xr=Xr, kl=vector(), lfsr=list(), neg=list(), Sigma=V)
fitted_track <- list()
Vcorr <- cov2cor(V)
## For ELBO
XtX <- t(data$X) %*% data$X
## Fit m&m model
for (i in 1:maxI) {
  fitted <- mnm_update_model(data$X, data$Y, Vcorr, model$fitted_g, fitted, get_elbo)
  if (get_elbo) {
      post_mat = mnm_compute_posterior_matrices(fitted, ncol(data$X), ncol(data$Y), maxL)
      fitted$Sigma = compute_mnm_residual_covariance(data$X, data$Y, XtX,
                                                     post_mat$PosteriorMean, 
                                                     post_mat$PosteriorCov)
      fitted$post_loglik = compute_mnm_Eloglik(data$X, data$Y, 
                                          XtX, fitted$Sigma,
                                          post_mat$PosteriorMean, 
                                          post_mat$PosteriorCov)
      fitted$elbo = fitted$post_loglik - sum(fitted$kl)
  }
  fitted_track[[i]] <- fitted
}

post_mat = mnm_compute_posterior_matrices(fitted, ncol(data$X), ncol(data$Y), maxL)

## Compute lfsr
lfsr <- do.call(rbind, lapply(1:maxL, function(l) colSums(fitted$alpha[,l] * fitted$lfsr[[l]])))
posterior <- list(PosteriorMean=post_mat$PosteriorMean,
                  PosteriorCov=post_mat$PosteriorCov,
                  alpha = fitted$alpha,
                  lfsr=lfsr,
                  n_in_CI=susieR:::n_in_CI(t(fitted$alpha)),
                  in_CI=susieR:::in_CI(t(fitted$alpha))
                  )

### `plot_sse.py`

In [None]:
%save -f modules/plot_sse.py

def plot_sse(coef, true_coef, in_set, ld, plot_prefix):
    reg = RegressionData()
    reg.set_xcorr(ld)
    in_set = np.sum(np.array(in_set), axis = 0)
    coef = np.array(coef)
    if true_coef is not None:
        true_coef = np.array(true_coef)
    for j in range(coef.shape[1]):
        plot_file = f'{plot_prefix}.{j+1}.pdf'
        reg.plot_property_vector(coef[:,j], None,
                                 xz_cutoff = (0, 0.8), out = plot_file,
                                 conf = {'title': f'Response {j+1}', 
                                    'ylabel': 'effect size estimate', 
                                    'zlabel': 'In 95 CI set'})

### `fit_susie.R`

In [12]:
%save -f modules/fit_susie.R
fitted <- list()
for (r in 1:ncol(data$Y)) {
  fitted[[r]] <- susieR::susie(data$X,data$Y[,r],L=maxL,max_iter=maxI)
  fitted[[r]]$lfsr <- susieR:::lfsr_fromfit(fitted[[r]])
  fitted[[r]]$n_in_CI <- susieR:::n_in_CI(fitted[[r]])
  fitted[[r]]$in_CI <- susieR:::in_CI(fitted[[r]])
}

posterior <- list(PosteriorMean=do.call(cbind, lapply(1:length(fitted), function(i) susieR:::coef.susie(fitted[[i]]))),
                  lfsr=do.call(cbind, lapply(1:length(fitted), function(i) fitted[[i]]$lfsr)),
                  alpha=do.call(cbind, lapply(1:length(fitted), function(i) fitted[[i]]$alpha)),
                  n_in_CI=do.call(cbind, lapply(1:length(fitted), function(i) fitted[[i]]$n_in_CI)),
                  in_CI= do.call(cbind, lapply(1:length(fitted), function(i) fitted[[i]]$in_CI))
                  )

### `fit_dap.py`

DAP version 1 was published as Wen et al 2016 AJHG. Here William has polished the software `dap-g` with another manuscript that describes improved algorithm and working with summary statistics. This benchmark uses DAP version 2. Below is an example output that I parse and save.

```
Posterior expected model size: 0.500 (sd = 0.500)
LogNC = -0.30685 ( Log10NC = -0.133 )
Posterior inclusion probability

((1))              7492 6.68581e-05       0.000 1
((2))              7490 6.68581e-05       0.000 1
((3))              7484 6.68581e-05       0.000 1
((4))              7486 6.68581e-05       0.000 1
((5))              7481 6.68581e-05       0.000 1
((6))              7476 6.68581e-05       0.000 1
((7))              7479 6.68581e-05       0.000 1
((8))              7491 6.68046e-05       0.000 2
((9))              7483 6.68046e-05       0.000 2
((10))             7485 6.68046e-05       0.000 2
((11))             7488 6.68046e-05       0.000 2
((12))             7474 6.68046e-05       0.000 2
((13))             7475 6.68046e-05       0.000 2
((14))             7478 6.68046e-05       0.000 2
((15))             7465 6.68046e-05       0.000 2
((16))             7473 6.68046e-05       0.000 2
((17))             7470 6.68046e-05       0.000 2
((18))             7467 6.68046e-05       0.000 2
((19))             7461 6.68046e-05       0.000 2
((20))             7459 6.68046e-05       0.000 2
((21))             7482 6.67422e-05       0.000 -1
((22))             7489 6.67422e-05       0.000 -1
((23))             7487 6.67422e-05       0.000 -1
((24))             7477 6.67422e-05       0.000 -1
((25))             7480 6.67422e-05       0.000 -1
((26))             7463 6.67422e-05       0.000 -1
...
Independent association signal clusters

     cluster         member_snp      cluster_pip      average_r2
       {1}              7            4.680e-04          0.951                 0.951   0.037
       {2}             13            8.685e-04          0.623                 0.037   0.623

```

In [13]:
%save -f modules/fit_dap.py
import subprocess
import pandas as pd
import numpy as np

def dap_single(x, y, prefix, r, args):
    names = np.array([('geno', i+1, str(r)) for i in range(x.shape[1])])
    with open(f'{prefix}.data', 'w') as f:
        print(*(['pheno', 'pheno', str(r)] + list(y.ravel())), file=f)
        np.savetxt(f, np.hstack((names, x.T)), fmt = '%s', delimiter = ' ')
    grid = '''         
        0.0000  0.1000
        0.0000  0.2000
        0.0000  0.4000
        0.0000  0.8000
        0.0000  1.6000
        '''
    grid = '\n'.join([x.strip() for x in grid.strip().split('\n')])
    with open(f'{prefix}.grid', 'w') as f:
        print(grid, file=f)
    cmd = ['dap-g', '-d', f'{prefix}.data', '-g', f'{prefix}.grid', '-o', f'{prefix}.result', '--all'] + ' '.join(args).split()
    subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE).communicate()
    out = [x.strip().split() for x in open(f'{prefix}.result').readlines()]
    pips = []
    clusters = []
    still_pip = True
    for line in out:
        if len(line) == 0:
            continue
        if len(line) > 2 and line[2] == 'cluster_pip':
            still_pip = False
            continue
        if still_pip and (not line[0].startswith('((') or int(line[-1]) < 0):
            continue
        if still_pip:
            pips.append([line[1], float(line[2]), float(line[3]), int(line[4])])
        else:
            clusters.append([len(clusters) + 1, float(line[2]), float(line[3])])
    pips = pd.DataFrame(pips, columns = ['snp', 'snp_prob', 'snp_log10bf', 'cluster'])
    clusters = pd.DataFrame(clusters, columns = ['cluster', 'cluster_prob', 'cluster_avg_r2'])
    clusters = pd.merge(clusters, pips.groupby(['cluster'])['snp'].apply(','.join).reset_index(), on = 'cluster')
    return {'snp': pips, 'set': clusters}

def dap_batch(X, Y, prefix, *args):
    return dict([(r, dap_single(X, Y[:,r], f'{prefix}_condition_{r+1}', r+1, args)) for r in range(Y.shape[1])])

### `fit_finemap.R`

In [14]:
%save -f modules/fit_finemap.R
#' FINEMAP I/O
write_finemap_sumstats <- function(z, LD_file, n, k, prefix) {
  cfg = list(z=paste0(prefix,".z"),
             ld=LD_file,
             snp=paste0(prefix,".snp"),
             config=paste0(prefix,".config"),
             k=paste0(prefix,".k"),
             log=paste0(prefix,".log"),
             meta=paste0(prefix,".master"))
  write.table(z,cfg$z,quote=F,col.names=F)
  write.table(t(k),cfg$k,quote=F,col.names=F,row.names=F)
  write("z;ld;snp;config;k;log;n-ind",file=cfg$meta)
  write(paste(cfg$z, cfg$ld, cfg$snp, cfg$config, cfg$k, cfg$log, n, sep=";"),
        file=cfg$meta,append=TRUE)
  return(cfg)
}

#' Run FINEMAP.
#' http://www.christianbenner.com
## FIXME: read the finemapr implementation for data sanity check.
## Can be useful as a general data sanity checker (in previous modules)

run_finemap <- function(z, LD_file, n, k, args = "", prefix="data")
{
  cfg = write_finemap_sumstats(z, LD_file, n, k, prefix)
  cmd = paste("finemap --sss --log", "--in-files", cfg$meta, args)
  dscrutils::run_cmd(cmd)

  # read output tables
  snp = read.table(cfg$snp,header=TRUE,sep=" ")
  snp$snp = as.character(snp$snp)

  snp = rank_snp(snp)
  config = read.table(cfg$config,header=TRUE,sep=" ")

  # extract number of causal
  ncausal = finemap_extract_ncausal(cfg$log)
  return(list(snp=snp, set=config, ncausal=ncausal))
}

rank_snp <- function(snp) {
  snp <- arrange(snp, -snp_prob) %>%
    mutate(
        rank = seq(1, n()),
        snp_prob_cumsum = cumsum(snp_prob) / sum(snp_prob)) %>%
    select(rank, snp, snp_prob, snp_prob_cumsum, snp_log10bf)
  return(snp)    
}

finemap_extract_ncausal <- function(logfile)
{
  lines <- grep("->", readLines(logfile), value = TRUE)
  lines <- gsub("\\(|\\)|>", "", lines)
  splits <- strsplit(lines, "\\s+")
  tab <- data.frame(
    ncausal_num = sapply(splits, function(x) as.integer(x[2])),
    ncausal_prob = sapply(splits, function(x) as.double(x[4])))
  tab <- mutate(tab, type = ifelse(duplicated(ncausal_num), "post", "prior"))
  return(tab)
}

finemap_mvar <- function(zscore, LD_file, n, k, args, prefix) {
  return(parallel::mclapply(1:ncol(zscore), function(r) 
          run_finemap(zscore[,r], LD_file, n, k, args, 
                      paste0(prefix, '_condition_', r)),
                            mc.cores = min(8, ncol(zscore))))
}

### `fit_caviar.R`

`CAVIAR` output file (`*_post`): 
- column #1 is the variant name;
- column #2 is the [posterior prob. that the variant is causal](https://github.com/fhormoz/caviar/issues/1#issuecomment-286521771);
- column #3 is the amount that this variant contributes to 95%-causal credible set.

In [15]:
%save -f modules/fit_caviar.R
#' CAVIAR I/O
write_caviar_sumstats <- function(z, prefix) {
  cfg = list(z=paste0(prefix,".z"),
             set=paste0(prefix,"_set"),
             post=paste0(prefix,"_post"),
             log=paste0(prefix,".log"))
  write.table(z,cfg$z,quote=F,col.names=F)
  return(cfg)
}

#' Run CAVIAR
#' https://github.com/fhormoz/caviar

run_caviar <- function(z, LD_file, args = "", prefix="data")
{
  cfg = write_caviar_sumstats(z, prefix)
  cmd = paste("CAVIAR", "-z", cfg$z, "-l", LD_file, "-o", prefix, args)
  dscrutils::run_cmd(cmd)
  if(!all(file.exists(cfg$post, cfg$set, cfg$log))) {
      stop("Cannot find one of the post, set, and log files")
  }
  
  log <- readLines(cfg$log)

  # read output tables
  snp <- read.delim(cfg$post)  
  stopifnot(ncol(snp) == 3)
  names(snp) <- c("snp", "snp_prob_set", "snp_prob")
  snp$snp <- as.character(snp$snp)
  snp <- rank_snp(snp)

  # `set` of snps
  set <- readLines(cfg$set)
  set_ordered <- left_join(data_frame(snp = set), snp, by = "snp") %>% 
    arrange(rank) %$% snp
  return(list(snp=snp, set=set_ordered))
}

rank_snp <- function(snp) {
  snp <- arrange(snp, -snp_prob) %>%
    mutate(
        rank = seq(1, n()),
        snp_prob_cumsum = cumsum(snp_prob) / sum(snp_prob)) %>%
    select(rank, snp, snp_prob, snp_prob_cumsum, snp_prob_set)
  return(snp)    
}

finemap_mcaviar <- function(zscore, LD_file, args, prefix) {
  return(parallel::mclapply(1:ncol(zscore), function(r)
          run_caviar(zscore[,r], LD_file, args, 
                     paste0(prefix, '_condition_', r)), 
                            mc.cores = min(8, ncol(zscore))))
}

## Visualization

### `plot_finemap.R`

In [16]:
%save -f modules/plot_finemap.R

plot_finemap <- function(x,
                         grid_nrow = NULL, 
                         grid_ncol = NULL, 
                         label_size = 2,
                         top_rank = 5,
                         lim_prob = c(0, 1.2),
                         ...)
{
  label_size_config = label_size
  label_size_snp = label_size
  top_rank_config = top_rank
  top_rank_snp = top_rank
  lim_prob_config = lim_prob
  lim_prob_snp = lim_prob
  lim_prob_ncausal = lim_prob   
    
  p1 <- plot_ncausal(x, 
    lim_prob = lim_prob_ncausal, ...)
  p2 <- plot_set(x,  
    top_rank = top_rank_config, 
    label_size = label_size_config, 
    lim_prob = lim_prob_config, ...)
  p3 <- plot_snp(x, 
    top_rank = top_rank_snp,
    label_size = label_size_snp, 
    lim_prob = lim_prob_snp, ...)
  
  plot_grid(p1, p2, p3,  labels = "AUTO", nrow = grid_nrow, ncol = grid_ncol)
}


plot_ncausal <- function(x, lim_prob, ...)
{
  ptab <- x$ncausal
  
  sum_prop_zero <- filter(ptab, ncausal_num == 0)[["prob"]]  %>% sum
  if(sum_prop_zero == 0) {
    ptab <- filter(ptab, ncausal_num != 0)
  }
  
  ptab <- mutate(ptab, 
    ncausal_num = factor(ncausal_num, levels = sort(unique(ncausal_num), 
                                                    decreasing = TRUE)),
    type = factor(type, levels = c("prior", "post")))
    
  p <- ggplot(ptab, aes(ncausal_num, ncausal_prob, fill = type)) + 
    geom_hline(yintercept = 1, linetype = 3) + 
    geom_bar(stat = "identity", position = "dodge") + 
    coord_flip() + theme(legend.position = "top") + 
    scale_fill_manual(values = c("grey50", "orange")) +
    ylim(lim_prob)
    
  return(p)
}

plot_set <- function(x, lim_prob, label_size, top_rank, ...)
{
  ptab <- x$set

  ptab <- head(ptab, top_rank)

  ptab <- mutate(ptab,
    label = paste0(config, "\n", 
      "P = ", round(config_prob, 2),
      "; ", "log10(BF) = ", round(config_log10bf, 2)))

  ggplot(ptab, aes(config_prob, rank)) + 
    geom_vline(xintercept = 1, linetype = 3) + 
    geom_point() + 
    geom_segment(aes(xend = config_prob, yend = rank, x = 0)) + 
    geom_text(aes(label = label), hjust = 0, nudge_x = 0.025, size = label_size) + 
    xlim(lim_prob) + 
    scale_y_continuous(limits  = c(top_rank + 0.5, 0.5), trans = "reverse")
}


plot_snp <- function(x, lim_prob, label_size, top_rank, ...)
{
  ptab <- x$snp
  
  ptab <- head(ptab, top_rank)

  ptab <- mutate(ptab,
    rank = seq(1, n()), 
    label = paste0(snp, "\n", 
      "P = ", round(snp_prob, 2),
      "; ", "log10(BF) = ", round(snp_log10bf, 2)))

  ggplot(ptab, aes(snp_prob, rank)) +
    geom_vline(xintercept = 1, linetype = 3) + 
    geom_point() + 
    geom_segment(aes(xend = snp_prob, yend = rank, x = 0)) + 
    geom_text(aes(label = label), hjust = 0, nudge_x = 0.025, size = label_size) + 
    xlim(lim_prob) + 
    scale_y_continuous(limits  = c(top_rank + 0.5, 0.5), trans = "reverse")
}

pdf(plot_file)
for (r in 1:length(result)) {
    print(plot_finemap(result[[r]], top_rank = top_rank))
}
dev.off()

### `plot_caviar.R`

In [17]:
%save -f modules/plot_caviar.R
plot_caviar <- function(x,
                        grid_nrow = NULL, 
                        grid_ncol = NULL, 
                        label_size = 2,
                        top_rank = 5,
                        lim_prob = c(0, 1.5),
                        ...)
{
  plot_snp(x, label_size, top_rank, lim_prob, ...)
}

plot_snp <- function(x, label_size, top_rank, lim_prob, ...)
{
  ptab <- x$snp

  ptab <- head(ptab, top_rank)

  ptab <- mutate(ptab,
    label = paste0(snp, "\n", 
      "P = ", round(snp_prob, 2),
      "; ", "P(set) = ", round(snp_prob_set, 2)))

  ggplot(ptab, aes(snp_prob, rank)) +
    geom_vline(xintercept = 1, linetype = 3) + 
    geom_point() + 
    geom_segment(aes(xend = snp_prob, yend = rank, x = 0)) + 
    geom_text(aes(label = label), hjust = 0, nudge_x = 0.025, size = label_size) + 
    xlim(lim_prob) + 
    scale_y_continuous(limits  = c(top_rank + 0.5, 0.5), trans = "reverse")
}

pdf(plot_file)
for (r in 1:length(result)) {
    print(plot_caviar(result[[r]], top_rank = top_rank))
}
dev.off()

### `plot_dap.R`

In [18]:
%save -f modules/plot_dap.R


plot_dap <- function(x,
                     grid_nrow = 2, 
                     grid_ncol = 1, 
                     label_size = 2,
                     top_rank = 5,
                     lim_prob = c(0, 1.2),
                     ...)
{
  label_size_config = label_size
  label_size_snp = label_size
  top_rank_config = top_rank
  top_rank_snp = top_rank
  lim_prob_config = lim_prob
  lim_prob_snp = lim_prob
    
  p2 <- plot_set(x,  
    top_rank = top_rank_config, 
    label_size = label_size_config, 
    lim_prob = lim_prob_config, ...)
  p3 <- plot_snp(x, 
    top_rank = top_rank_snp,
    label_size = label_size_snp, 
    lim_prob = lim_prob_snp, ...)
  
  plot_grid(p2, p3,  labels = "AUTO", nrow = grid_nrow, ncol = grid_ncol)
}


plot_set <- function(x, lim_prob, label_size, top_rank, ...)
{
  ptab <- x$set

  ptab <- head(ptab, top_rank)

  ptab <- mutate(ptab,
    label = paste0(snp, "\n", 
      "P = ", round(cluster_prob, 2),
      "; ", "avg(r^2) = ", round(cluster_avg_r2, 2)))

  ggplot(ptab, aes(cluster_prob, cluster)) + 
    geom_vline(xintercept = 1, linetype = 3) + 
    geom_point() + 
    geom_segment(aes(xend = cluster_prob, yend = cluster, x = 0)) + 
    geom_text(aes(label = label), hjust = 0, nudge_x = 0.025, size = label_size) + 
    xlim(lim_prob) + 
    scale_y_continuous(limits  = c(min(top_rank, nrow(ptab)) + 0.5, 0.5), trans = "reverse")
}


plot_snp <- function(x, lim_prob, label_size, top_rank, ...)
{
  ptab <- x$snp
  
  ptab <- head(ptab, top_rank)

  ptab <- mutate(ptab,
    rank = seq(1, n()), 
    label = paste0(snp, "\n", 
      "P = ", round(snp_prob, 2),
      "; ", "log10(BF) = ", round(snp_log10bf, 2)))

  ggplot(ptab, aes(snp_prob, rank)) +
    geom_vline(xintercept = 1, linetype = 3) + 
    geom_point() + 
    geom_segment(aes(xend = snp_prob, yend = rank, x = 0)) + 
    geom_text(aes(label = label), hjust = 0, nudge_x = 0.025, size = label_size) + 
    xlim(lim_prob) + 
    scale_y_continuous(limits  = c(top_rank + 0.5, 0.5), trans = "reverse")
}

pdf(plot_file)
for (r in 1:length(result)) {
    print(plot_dap(result[[r]], top_rank = top_rank))
}
dev.off()

## Simulation under regression model

### `lib_regression_simulator.py`

- `RegressionData`: Stores multivariate $Y$ and multiple feature $X$ data.
- `UnivariateMixture`: Simulating univariate effects with mixture distribution of effects: $\beta$ are sampled from normal mixtures as described in Stephens 2017 the ASH paper.
- `MultivariateMixture`: Multivariate mixture of Urbut 2017 the MASH paper.

In [45]:
%save -f modules/lib_regression_simulator.py
import numpy as np
import os, copy
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
from pprint import pformat
from collections import OrderedDict

class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

    def __getattr__(self, item):
        try:
            return self[item]
        except KeyError:
            raise AttributeError(item)

    def __deepcopy__(self, memo):
        return dotdict(copy.deepcopy(dict(self)))
    
class RegressionData(dotdict):
    def __init__(self, X = None, Y = None, Z = None):
        # FIXME: check if inputs are indeed numpy arrays
        self.debug = dotdict()
        self.x_centered = self.y_centered = self.z_centered = False
        self.X = None
        self.Y = None
        self.Z = None
        self.xcorr = None

    def get_summary_stats(self):
        '''
        Computer univariate regression for every X_j (N by 1) and Y_r (N by 1)
        Bhat: J by R matrix of estimated effects
        Shat: J by R matrix of SE of Bhat
        '''
        if self.Z is not None:
            self.remove_covariates()
        # Compute betahat
        XtX_vec = np.einsum('ji,ji->i', self.X, self.X)
        self.Bhat = (self.X.T @ self.Y) / XtX_vec[:,np.newaxis]
        # Compute se(betahat)
        Xr = self.Y - np.einsum('ij,jk->jik', self.X, self.B)
        Re = np.einsum('ijk,ijk->ik', Xr, Xr)
        self.Shat = np.sqrt(Re / XtX_vec[:,np.newaxis] / (self.X.shape[0] - 2))

    def remove_covariates(self):
        if self.Z is not None:
            self.Y -= self.Z @ (np.linalg.inv(self.Z.T @ self.Z) @ self.Z.T @ self.Y)
            self.Z = None

    def center_data(self):
        # for np.array: np.mean(Z, axis=0, keepdims=True)
        # for np.matrix, no keepdims argument
        if self.X is not None and not self.x_centered:
            self.X -= np.mean(self.X, axis=0)
            self.x_centered = True
        if self.Y is not None and not self.y_centered:
            self.Y -= np.mean(self.Y, axis=0)
            self.y_centered = True
        if self.Z is not None and not self.z_centered:
            self.Z -= np.mean(self.Z, axis=0)
            self.z_centered = True

    def set_xcorr(self, xcorr):
        if xcorr is not None:
            self.xcorr = np.array(xcorr)
        else:
            self.xcorr = np.corrcoef(self.X, rowvar = False)
            self.xcorr = (np.square(self.xcorr) * np.sign(self.xcorr)).astype(np.float16)

    def plot_xcorr(self, out, limit = 5000):
        use_abs = np.sum(self.xcorr < 0) == 0
        fig, ax = plt.subplots()
        limit = min(self.xcorr.shape[0], limit)
        if out.endswith('pdf'):
            raise ValueError('Please use png extension for output file.')
        print(f'Plotting figure {out} for {limit} markers (default limit set to 5000) ...')
        cmap = sns.cubehelix_palette(50, hue=0.05, rot=0, light=1, dark=0, as_cmap=True)
        sns.heatmap(self.xcorr[1:limit,1:limit], ax = ax, cmap = cmap, vmin=-1 if not use_abs else 0,
                    vmax=1, square=True, xticklabels = False, yticklabels = False)
        ax = plt.gca()
        print(f'Saving figure {out} ...')        
        plt.savefig(out, dpi = 500)
        
    def permute_X_columns(self):
        '''
        Permute X columns, i.e. break blocked correlation structure
        '''
        np.random.shuffle(self.X) 
        
    def plot_property_vector(self, yaxis, zaxis, xz_cutoff = None, out = '/tmp/1.pdf',
                            conf = {'title': '', 'ylabel': '', 'zlabel': ''}):
        '''
        - yaxis can be eg $\beta$ or log10BF or -log10Prob
        - zaxis can be some other quantity whose value will be 
        reflected by color shade
        - xz_cutoff: (c1, c2). c1 is correlation cutoff to highlight
        when c2 is satisfied by a given position on x-axis
        '''
        xaxis = [x+1 for x in range(len(yaxis))]
        cmap = sns.cubehelix_palette(start=2.8, rot=.1, as_cmap=True)
        f, ax = plt.subplots(figsize=(18,5))
        if zaxis is not None:
            points = ax.scatter(xaxis, yaxis, c=zaxis, cmap=cmap)
            f.colorbar(points, label=conf['zlabel'])
        else:
            points = ax.scatter(xaxis, yaxis, cmap=cmap)
        if xz_cutoff is not None and zaxis is not None:
            c1, c2 = xz_cutoff
            if len([i for i in zaxis if i > c2]) > 100:
                print('Too many to highlight!')
            else:
                for idx, item in enumerate(zaxis):
                    if item > c2:
                        ax.scatter(xaxis[idx], yaxis[idx], s=80, 
                                   facecolors='none', edgecolors='r')
                        for ii, xx in enumerate(self.xcorr[idx,:]):
                            if xx > c1 and xx < 1.0:
                                ax.scatter(xaxis[ii], yaxis[ii], 
                                           color='y', marker='+')
        ax.set_title(conf['title'])
        ax.set_ylabel(conf['ylabel'])
        plt.gca()
        plt.savefig(out, dpi = 500)
        
    def get_representative_features(self, block_r2 = 0.8, block_size = 10, max_indep_r2 = 0.02):
        '''
        Based on xcorr matrix, select "most representative features". 
        That is, these features are potentially most convoluted by other features (have stronger xcorr)
        yet are independent among each other.
        - block_r2: definition of correlated block -- abs squared correlation have to be > cutoff1
        - block_size: define a large enough block -- block size have to be > block_size
        - max_indep_r2: now select features that are completely independent -- r2 < max_indep_r2
        '''
        if self.xcorr is None:
            self.set_xcorr(None)
        # get r2 summary
        r2 = pd.DataFrame(self.xcorr)
        strong_r2_count = ((np.absolute(r2) > block_r2) * r2).sum(axis = 0).sort_values(ascending = False)
        strong_r2_count = strong_r2_count[strong_r2_count > block_size]
        # filter by r2
        exclude = []
        for x in strong_r2_count.index:
            if x in exclude:
                continue
            for y in strong_r2_count.index:
                if y in exclude or y == x:
                    continue
                if np.absolute(r2[x][y]) > max_indep_r2:
                    exclude.append(y)
        return [x for x in strong_r2_count.index if not x in exclude]

    def __str__(self):
        return pformat(self.__dict__, indent = 4)
    
class ResidualVariance:
    def __init__(self, mode):
        self.mode = mode
        
    def apply(self, eff_obj):
        if self.mode == 'identity':
            return np.identity(eff_obj.R)
    
class UnivariateMixture:
    '''Simulated distributions of Stephens 2017 (ASH paper)'''
    def __init__(self, dim):
        self.size = dim
        self.pi0 = 0
        self.pis = []
        self.mus = []
        self.sigmas = []
        self.coef = []
        
    def set_pi0(self, pi0):
        self.pi0 = pi0
        
    def set_spiky(self):
        self.pis = [0.4,0.2,0.2,0.2]
        self.mus = [0,0,0,0]
        self.sigmas = [0.25,0.5,1,2]
    
    def set_near_normal(self):
        self.pis = [2/3,1/3]
        self.mus = [0,0]
        self.sigmas = [1,2]
        
    def set_flat_top(self):
        self.pis = [1/7] * 7
        self.mus = [-1.5, -1, -.5 , 0, .5, 1, 1.5]
        self.sigmas = [0.5] * 7
        
    def set_skew(self):
        self.pis = [1/4,1/4,1/3,1/6]
        self.mus = [-2,-1,0,1]
        self.sigmas = [2,1.5,1,1]
        
    def set_big_normal(self):
        self.pis = [1]
        self.mus = [0]
        self.sigmas = [4]

    def set_bimodal(self):
        self.pis = [0.5, 0.5]
        self.mus = [-2, 2]
        self.sigmas = [1, 1]
        
    def get_effects(self):
        '''
        beta ~ \pi_0\delta_0 + \sum \pi_i N(mu_i, sigma_i)
        '''
        sigmas = np.diag(self.sigmas)
        assert (len(self.pis), len(self.pis)) == sigmas.shape
        masks = np.random.multinomial(1, self.pis, size = self.size)
        mix = np.random.multivariate_normal(self.mus, sigmas, self.size)
        self.coef = np.sum(mix * masks, axis = 1) * np.random.binomial(1, 1 - self.pi0, self.size)
        
    def swap_top_effects(self, given_index):
        '''Set top effects to given indices
        One can specify index, or use the "top_index"
        generated by RegressionData.get_representative_features()
        '''
        given_index = np.array(given_index, dtype=int)
        nb = np.zeros(len(self.coef))
        beta = sorted(self.coef, key=abs, reverse=True)
        for idx in given_infex:
            nb[idx] = beta.pop(0)
        random.shuffle(beta)
        for idx in range(len(nb)):
            if not idx in given_index:
                nb[idx] = beta.pop(0)
        assert len(beta) == 0
        self.coef = np.array(nb)
        
    def sparsify_effects(self, num_non_zero):
        '''
        only keep top `num_non_zero` effects
        '''
        nb = np.zeros(len(self.coef))
        big_beta_index = [i[0] for i in sorted(enumerate(self.coef), key = lambda x: np.absolute(x[1]), reverse = True)]
        selected_index = big_beta_index[:min(len(big_beta_index), num_non_zero)]
        for j in self.size:
            if j not in selected_index:
                self.coef[j] = 0
                
    def get_y(self, regression_data, sigma):
        y = np.dot(regression_data.X, self.coef.T) + np.random.normal(0, sigma, regression_data.X.shape[0])
        y.reshape(len(y), 1)
        return y
        
    def __str__(self):
        params = ' + '.join(["{} N({}, {}^2)".format(x,y,z) for x, y, z in zip(self.pis, self.mus, self.sigmas)])
        return '{:.3f} \delta_0 + {:.3f} [{}]'.format(self.pi0, 1 - self.pi0, params)
    
class MultivariateMixture:
    '''FIXME: ideally implement Urbut 2017 simulated covs'''
    def __init__(self, dim):
        self.J, self.R = dim
        self.pis = OrderedDict([('null', 0)])
        self.Us = OrderedDict([('null', np.zeros((self.R, self.R)))])
        self.mus = dict([('zeros', np.zeros(self.R))])
        self.coef = []
        self.grid = [0.1,0.5,1,2]
        self._init_canonical()

    def set_pi0(self, pi0):
        self.pis['null'] = pi0
        
    def set_grid(self, grid):
        self.grid = grid
        
    def _init_canonical(self):
        '''
        U is a dict of 
        - "identity" for the identity (effects are independent among conditions);
        - "singletons" for the set of matrices with just one non-zero entry x_{jj} = 1 (j=1,...,R); (effect specific to condition j);
        - "equal_effects" for the matrix of all 1s (effects are equal among conditions);
        - "simple_het" for a set of matrices with 1s on the diagonal and all off-diagonal elements equal to pho; (effects are correlated among conditions).
        '''
        pho = [0.25, 0.5, 0.75]
        self.Us['identity'] = np.identity(self.R)
        for i in range(self.R):
            self.Us[f'singleton_{i+1}'] = np.diagflat([1 if idx == i else 0 for idx in range(self.R)])
        self.Us['equal_effects'] = np.ones((self.R, self.R))
        for idx, item in enumerate(sorted(pho)):
            self.Us[f'simple_het_{idx+1}'] = np.ones((self.R, self.R)) * item
            np.fill_diagonal(self.Us[f'simple_het_{idx+1}'], 1)
            
    def set_shared(self):
        '''
        All weights are on equal effects
        '''
        self.pis['equal_effects'] = 1 - self.pis['null']
        for k in self.Us:
            if k not in self.pis:
                self.pis[k] = 0
                
    def set_low_het(self):
        '''
        All weights are on small het effects
        '''
        self.pis['simple_het_1'] = 1 - self.pis['null']
        for k in self.Us:
            if k not in self.pis:
                self.pis[k] = 0
                
    def set_indep(self):
        '''
        All weights are on identity effects
        '''
        self.pis['identity'] = 1 - self.pis['null']
        for k in self.Us:
            if k not in self.pis:
                self.pis[k] = 0

    def set_singleton(self, index):
        '''
        All weights evenly set to given index of singleton effects
        '''
        index = [int(x) for x in index if x <= self.R and x > 1]
        weight = (1 - self.pis['null']) / len(index)
        for item in index:
            self.pis[f'singleton_{item}'] = weight
        for k in self.Us:
            if k not in self.pis:
                self.pis[k] = 0        
        
    def apply_grid(self):
        def product(x,y):
            for item in y:
                yield x*item
        self.Us = dict(sum([[(f"{p}_{i+1}", g) for i, g in enumerate(product(self.Us[p], np.square(self.grid)))] for p in self.Us if p != 'null'], []) + \
                      [('null', self.Us['null'])])
        nG = len(self.grid)
        for k in list(self.pis.keys()):
            if k == 'null':
                continue
            for g in range(nG):
                self.pis[f'{k}_{g+1}'] = self.pis[k] / nG
            del self.pis[k]
            
    def get_effects(self):
        '''
        Generate B under multivariate normal mixture
        beta ~ \pi_0\delta_0 + \sum \pi_i N(0, U_i)
        '''
        self.coef = np.zeros((self.J, self.R))
        for j in range(self.J):
            # sample distribution
            dist_index = np.random.multinomial(1, list(self.pis.values()), size = 1).tolist()[0].index(1)
            name = list(self.pis.keys())[dist_index]
            self.coef[j,:] = np.random.multivariate_normal(self.mus['zeros'], self.Us[name], 1)
        
    def sparsify_effects(self, num_non_zero):
        '''
        only keep top `num_non_zero` effects
        '''
        beta_max = np.amax(np.absolute(self.coef), axis = 1)
        big_beta_index = [i[0] for i in sorted(enumerate(beta_max), key = lambda x: x[1], reverse = True)]
        selected_index = big_beta_index[:min(len(big_beta_index), num_non_zero)]
        for j in range(self.J):
            if j not in selected_index:
                self.coef[j,:] = self.mus['zeros']
                
    def swap_top_effects(self, given_index):
        '''Set top effects to given indices
        One can specify index, or use the "top_index"
        generated by RegressionData.get_representative_features()
        '''
        given_index = np.array(given_index, dtype=int)
        nb = np.zeros(self.coef.shape)
        beta_max = np.amax(np.absolute(self.coef), axis = 1)
        big_beta_index = [i[0] for i in sorted(enumerate(beta_max), key = lambda x: x[1], reverse = True)]
        for idx in given_index:
            nb[idx,:] = self.coef[big_beta_index.pop(0),:]
        for idx in range(nb.shape[0]):
            if not idx in given_index:
                nb[idx,:] = self.coef[big_beta_index.pop(0),:]
        self.coef = nb
        
    def get_y(self, regression_data, sigma_mat):
        return regression_data.X @ self.coef + np.random.multivariate_normal(np.zeros(self.R), sigma_mat)

### `regression_simulator.py`

Simulator workhorses.

In [44]:
%save -f modules/regression_simulator.py
def summarize_LD(X, ld_input, ld_plot):
    data = RegressionData()
    data.X = X
    data.set_xcorr(ld_input)
    data.plot_xcorr(ld_plot)
    return data.get_representative_features()

def simulate_main(data, c, plot_prefix):
    '''
    data: $data
    top_idx: $top_eff
    n_signal: 3
    n_traits: 2
    eff_mode: mash_low_het
    swap_eff: raw(True)
    keep_ld: raw(True)
    tag: sim1
    @ALIAS: conf = Dict(!data, !eff_mode)
    $data: data
    '''
    reg = RegressionData()
    reg.X = data['X']
    if c['swap_eff'] and c['top_idx'] is None:
        raise ValueError(f'"top_idx" variable is not set by an upstream module')
    if eff_mode == 'mash_low_het':
        if c['n_traits'] < 2:
            raise ValueError(f'Cannot simulate {c["n_traits"]} under mode {eff_mode}')
        data['true_coef'] = mash_low_het(data, reg, c)
    elif eff_mode == 'original':
        data['true_coef'] = original_y(data, reg, c)
    else:
        raise ValueError(f'Mode {eff_mode} is not implemented.')
    if c['center_data']:
        reg.center_data()
    data['X'] = reg.X
    data['Y'] = reg.Y
    if data['true_coef'] is not None:
        for j in range(data['true_coef'].shape[1]):
            plot_file = f'{plot_prefix}.{j+1}.pdf'
            reg.plot_property_vector(data['true_coef'][:,j], 
                                 [np.absolute(x)>0 for x in data['true_coef'][:,j]], 
                                 xz_cutoff = None, out = plot_file,
                                conf = {'title': f'Response {j+1}', 
                                        'ylabel': 'effect size', 'zlabel': ''})
    return data
        
def original_y(data, reg, c):
    reg.Y = np.vstack(data['Y'].values()).T
    return None
    
def mash_low_het(data, reg, c):
    if not c['keep_ld']:
        reg.permuate_X_columns()
        data['X'] = reg.X
    eff = MultivariateMixture((data['X'].shape[1], c['n_traits']))
    eff.set_low_het()
    eff.apply_grid()
    eff.get_effects()
    if c['swap_eff']:
        eff.swap_top_effects(c['top_idx'])
    eff.sparsify_effects(c['n_signal'])
    reg.Y = eff.get_y(reg, ResidualVariance(c['residual_mode']).apply(eff))
    return eff.coef