# SusieR benchmark plot


This is a continuation of `20180515_SusieR_Benchmark.ipynb`. Here I write a workflow to generate all plots from that benchmark, and make a table out of it to browser them easily.

In [1]:
%cd ~/GIT/github/mvarbvs/dsc

/home/gaow/GIT/github/mvarbvs/dsc

In [None]:
[global]
parameter: outdir = path('./benchmark')
[1]
target = "liter_data.dataset simple_lm.pve simple_lm.n_signal fit_susie.estimate_residual_variance fit_susie.prior_var fit_susie eval_susie"
output: f'{outdir}/result.RDS'
R: expand = '${ }'
    out = dscrutils::dscquery(${outdir:br}, target = "${target}")
    saveRDS(out, ${_output:r})
  
[2]
pve = [0.05, 0.1, 0.2, 0.4]
n = [1,2,3,4,5]
est_res = ['TRUE', 'FALSE']
prior = [0.05, 0.1, 0.2, 0.4]
ld_col = 1 # LD_Min
lfsr_cutoff = 0.05
combos = len(pve) * len(n) * len(est_res) * len(prior)
output_files = [f'{outdir}/{x+1}.rds' for x in range(combos)]
input: for_each = ['pve', 'n', 'est_res', 'prior'], concurrent = True
output: output_files[_index]
R: expand = '${ }', stdout = f'{_output:n}.log'

    get_combined = function(sub, dirname, ld_col) {
        out_files = sub[,c("fit_susie.output.file", "eval_susie.output.file")]
        combined = list(purity = NULL, lfsr = NULL, size = NULL, 
                        captures = NULL, total_captures = NULL)
        for (i in 1:nrow(out_files)) {
            fit = readRDS(paste0(dirname, out_files[i,1], '.rds'))$posterior
            purity = readRDS(paste0(dirname, out_files[i,2], '.rds'))
            #
            if (is.null(combined$purity)) combined$purity = purity$purity$V1[,ld_col]
            else combined$purity = cbind(combined$purity, purity$purity$V1[,ld_col])
            #
            if (is.null(combined$size)) combined$size = fit$n_in_CI[,1]
            else combined$size = cbind(combined$size, fit$n_in_CI[,1])
            #
            if (is.null(combined$lfsr)) combined$lfsr = fit$set_lfsr[,1]
            else combined$lfsr = cbind(combined$lfsr, fit$set_lfsr[,1])
            #
            if (is.null(combined$captures)) combined$captures = rowSums(purity$signal$V1)
            else combined$captures = cbind(combined$captures, rowSums(purity$signal$V1))
            #
            detected = apply(t(purity$signal$V1[which(fit$set_lfsr[,1] < ${lfsr_cutoff}),,drop=FALSE]), 1, sum)
            if (is.null(combined$total_captures)) combined$total_captures = detected
            else combined$total_captures = combined$total_captures + detected
        }
        return(combined)
    }
    out = readRDS(${_input:r})
    sub = out[which(out$simple_lm.pve == ${_pve} & out$simple_lm.n_signal == ${_n} & out$fit_susie.estimate_residual_variance == ${_est_res} & out$fit_susie.prior_var == ${_prior}),]
    combined = get_combined(sub, "${outdir}/", ${ld_col})
    write(paste(${_pve}, ${_n}, ${_prior}, ${_est_res}, "${_output:n}.png"), stdout())
    saveRDS(combined, ${_output:r})
  
[3]
input: group_by = 1, concurrent = True
output: f"{_input:n}.pkl"
python: expand = True
  import os
  os.system('dsc-io {_input} {_output}')
  
[4]
input: group_by = 1, concurrent = True
output: f"{_input:n}.png"
python: expand = '${ }'
    import numpy as np
    import matplotlib.pyplot as plt
    COLORS = ['#348ABD', '#7A68A6', '#A60628', '#467821', '#FF0000', '#188487', '#E2A233',
                  '#A9A9A9', '#000000', '#FF00FF', '#FFD700', '#ADFF2F', '#00FFFF']
    color_mapper = np.vectorize(lambda x: dict([(i,j) for i,j in enumerate(COLORS)]).get(x))

    def plot_purity(data, output, lfsr_cutoff = 0.05):
        purity = np.array(data['purity'])
        lfsr = np.array(data['lfsr'])
        size = np.array(data['size'])
        capture = np.array(data['captures'])
        capture_summary = [f"Signal {idx+1} captured {item}/{purity.shape[1]} times" for idx, item in enumerate([data['total_captures']] if isinstance(data['total_captures'], np.int64) else data['total_captures'])]
        idx = 0
        plt.figure(figsize=(12, 8))
        L = purity.shape[0]
        cols = 3
        rows = L // cols + L % cols
        position = range(1,L + 1)
        for x, y, z, c in zip(size, purity, lfsr, capture):
            z_sig = [i for i, zz in enumerate(z) if zz <= lfsr_cutoff]
            z_nsig = [i for i, zz in enumerate(z) if zz > lfsr_cutoff]
            colors = [4 if i == 0 else 0 for i in c]
            plt.subplot(rows,cols,position[idx])
            idx += 1
            if len(z_sig):
                label = f'{idx}: lfsr<={lfsr_cutoff}'
                plt.scatter(np.take(x, z_sig),
                            np.take(y, z_sig),
                            c = color_mapper(np.take(colors, z_sig)), 
                            label = label, marker = '*')
            if len(z_nsig):
                label = f'{idx}: lfsr>{lfsr_cutoff}'
                plt.scatter(np.take(x, z_nsig),
                            np.take(y, z_nsig),
                            c = color_mapper(np.take(colors, z_nsig)), 
                            label = label, marker = 'x')   
            plt.legend(bbox_to_anchor=(0,1.02,1,0.2), loc="lower left",
                    mode="expand", borderaxespad=0, ncol=2, handletextpad=0.1)
        plt.subplots_adjust(hspace=0.3, wspace = 0.3)
        plt.suptitle(f"95% CI set sizes vs min(abs(LD))\n{'; '.join(capture_summary)}")
        plt.savefig(output, dpi=500, bbox_inches='tight')                    
        plt.gca()
        
    import pickle
    data = pickle.load(open('${_input}', 'rb'))
    plot_purity(data, '${_output}')