## UKB Multivariate fine-mapping workflow

1. Sufficient statistics input XtX, XtY, YtY and n. We assume covariates C have been removed from X and Y. We provide a procedure to implement this.
2. GWAS summary statistics input z and R. We assume z scores have been computed after removal of covariates C.

In [None]:
[global]
import glob
# single column file each line is the data filename
parameter: analysis_units = path
# Path to data directory
parameter: data_dir = path
# data file suffix
parameter: data_suffix = str
# Path to work directory where output locates
parameter: wd = path("./output")
# Path to prior data file: an RDS file with `U` and `w` for prior matrices and weights
parameter: prior = path('.')
# Path to residual cor/cov data file
parameter: resid_cor = path('.')
# Only analyze `cis` variants -- cis = N means using N variants around the center column of X matrix  
parameter: cis = 'NULL'
regions = [x.strip() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
genes = [f"{data_dir:a}/{x}.{data_suffix}" for x in regions if path(f"{data_dir:a}/{x}.{data_suffix}").exists()]

In [None]:
# Convert LD store file to RDS format
[ldstore_to_rds]
# An identifier for your run of analysis
parameter: name = str
parameter: ld_dir = path
ld_files = glob.glob(f"{ld_dir:a}/{name}*.matrix")
input: ld_files, group_by = 1
output: f"{wd:a}/{_input:bn}.ld.rds"
task: trunk_workers = 1, trunk_size = 1, walltime = '12h', mem = '20G', cores = 2, tags = f'{step_name}_{_output:bn}'
R: expand = "${ }"
    ld = as.matrix(data.table::fread(${_input:r}))
    saveRDS(ld, ${_output:r})

In [None]:
[sufficient_summary_stats_preprocessing]
parameter: phenoFile = path
parameter: covarFile = path
# path to z score file
parameter: z_dir = path()
parameter: z_suffix = str
# path to LD file
parameter: ld_dir = path()
parameter: ld_suffix = str
input: genes, group_by = 1
output: suffstats = f"{wd:a}/{_input:bn}.sufficient_stats.rds", 
        sumstats =  f"{wd:a}/{_input:bn}.summary_stats.rds"
task: trunk_workers = 1, trunk_size = 1, walltime = '4h', mem = '200G', cores = 1, tags = f'{step_name}_{_output[0]:bn}'
R: expand = '${ }', stdout = f"{_output[0]:nn}.stdout", stderr = f"{_output[0]:nn}.stderr"
    # FIXME: in practice we might need to 
    geno_file = ${_input:nr}
    z.file = "${z_dir:a}/${_input:bn}.${z_suffix}"
    ld.file = "${ld_dir:a}/${_input:bn}.${ld_suffix}"
    library(data.table)
    library(dplyr)
    
    X <- fread(paste0(geno_file, '.raw.gz'),sep = "\t",header = TRUE,stringsAsFactors = FALSE)
    map <- X[,1:6]
    X = X[, c('FID','IID','PAT','MAT','SEX', 'PHENOTYPE') := NULL]
    X <- as.matrix(X)
    
    X.info = fread(paste0(geno_file, '.pvar'),sep = "\t",header = TRUE,stringsAsFactors = FALSE)
    
    # Read phenotype data
    cat("Reading phenotype data.\n")
    pheno <- suppressMessages(fread(${phenoFile:r}))

    cat("Reading covariate file.\n")
    Z = suppressMessages(fread(${covarFile:r}))

    match.idx = match(map$IID, pheno$IID)
    pheno = pheno[match.idx,]
    match.idx = match(map$IID, Z$IID)
    Z = Z[match.idx,]
  
    Y = pheno %>% select(-FID, -IID) %>% as.matrix
    Z = Z %>% select(-FID, -IID) %>% as.matrix
  
    # centering
    Y = sweep(Y, 2, colMeans(Y), '-')
    Z = sweep(Z, 2, colMeans(Z), '-')
  
    A   <- crossprod(Z) # Z'Z
    # chol decomposition for (Z'Z)^(-1)
    R = chol(solve(A)) # R'R = (Z'Z)^(-1)
    W = R %*% crossprod(Z, X) # RZ'X
    S = R %*% crossprod(Z, Y) # RZ'Y

    SNPnames = colnames(X)
    rm(X)
    rm(Z)

    zscores = readRDS(z.file)

    # Load LD matrix from raw genotype
    ld = readRDS(ld.file)
    XtX = sqrt(zscores$XtXD) * t(ld*sqrt(zscores$XtXD)) - crossprod(W) # W'W = X'ZR'RZ'X = X'Z(Z'Z)^{-1}Z'X
    XtX = as.matrix(XtX)
    rownames(XtX) = colnames(XtX) = SNPnames
    R = cov2cor(XtX)

    # X'Y
    ## flip sign because X flip the REF, ALT
    XtY = -as.matrix(zscores$XtY - crossprod(W, S)) # W'S = X'ZR'RZ'y = X'Z(Z'Z)^{-1}Z'y

    # YtY
    YtY = as.matrix(crossprod(Y) - crossprod(S))

    Z = as.matrix(zscores$Z)
    rownames(Z) = SNPnames
    
    meta = zscores$pos[,1:5]
    if(!all.equal(meta, X.info, check.attributes = FALSE)){
        stop("ALLELE doesn't match.")
    }

    saveRDS(list(XtX = XtX, XtY = XtY, YtY = YtY, N = nrow(Y), meta = zscores$pos), ${_output["suffstats"]:r})
    saveRDS(list(Z = Z, LD = R, meta = zscores$pos, ld.file = ld.file), ${_output["sumstats"]:r})

In [None]:
[univariate_analysis]
parameter: max_L = 10
input: genes, group_by = 1
output: rss_rem_covariates = f"{wd:a}/{_input:bnn}/{_input:bnn}.susierss_rem_covariates.rds", 
        rss_notrem_covariates =  f"{wd:a}/{_input:bnn}/{_input:bnn}.susierss_notrem_covariates.rds"
task: trunk_workers = 1, trunk_size = 1, walltime = '36h', mem = '150G', cores = 1, tags = f'{step_name}_{_output[0]:bnn}'
R: expand = '${ }', stdout = f"{_output[0]:nn}.stdout", stderr = f"{_output[0]:nn}.stderr"
    library(susieR)
    dat_rss = readRDS(${_input:r})
    ld.file = gsub('/project2/mstephens/yuxin/ukb-bloodcells/LD/', 
                   '/project/mstephens/yuxin/ukb-bloodcells/LD/', 
                   dat_rss$ld.file)
    R = readRDS(ld.file)
    rownames(R) = colnames(R) = rownames(dat_rss$LD)
    fitted_rss_rem_covariates = list()
    fitted_rss_notrem_covariates = list()
    for (r in 1:ncol(dat_rss$Z)) {
        ## sufficient stats
        st = proc.time()
        fitted_rss_rem_covariates[[r]] <- susieR::susie_rss(z = dat_rss$Z[,r],
                                                            R = dat_rss$LD,
                                                            n = 248980,
                                                            L=${max_L},
                                                            max_iter=1000,
                                                            estimate_residual_variance=FALSE,
                                                            estimate_prior_variance=TRUE,
                                                            refine=TRUE)
        fitted_rss_rem_covariates[[r]]$time = proc.time() - st
        fitted_rss_rem_covariates[[r]]$cs_corr = susieR:::get_cs_correlation(fitted_rss_rem_covariates[[r]], Xcorr=dat_rss$LD)
        
        ## rss, LD not correct for covariates
        st = proc.time()
        fitted_rss_notrem_covariates[[r]] <- susieR::susie_rss(z = dat_rss$Z[,r],
                                                               R = R,
                                                               n=248980,
                                                               L=${max_L},
                                                               max_iter=1000,
                                                               estimate_prior_variance=TRUE,
                                                               refine=TRUE)
        fitted_rss_notrem_covariates[[r]]$time = proc.time() - st
        fitted_rss_notrem_covariates[[r]]$cs_corr = susieR:::get_cs_correlation(fitted_rss_notrem_covariates[[r]], 
                                                                                Xcorr=R)
    }
    
    names(fitted_rss_rem_covariates) = colnames(dat_rss$Z)
    names(fitted_rss_notrem_covariates) = colnames(dat_rss$Z)
        
    saveRDS(fitted_rss_rem_covariates, ${_output["rss_rem_covariates"]:r})
    saveRDS(fitted_rss_notrem_covariates, ${_output["rss_notrem_covariates"]:r})

In [None]:
[mvsusie_analysis]
parameter: max_L = 10
parameter: ld_type = 'original'
input: genes, group_by = 1
output: f'{wd:a}/{_input:bnn}/{_input:bnn}.LD{ld_type}{resid_cor:bnx}.mvsusierss.rds'
task: trunk_workers = 1, trunk_size = 1, walltime = '12h', mem = '150G', cores = 1, tags = f'{step_name}_{_output:bn}'
R: expand = '${ }', stdout = f"{_output:n}.stdout", stderr = f"{_output:n}.stderr"
    get_prior_indices <- function(Z, U) {
      # make sure the prior col/rows match the colnames of the Y matrix
      z_names = colnames(Z)
      u_names = colnames(U)
      if (is.null(z_names) || is.null(u_names)) {
          return(NULL)
      } else if (identical(z_names, u_names)) {
          return(NULL)
      } else {
          return(match(z_names, u_names))
      }
    }

    library(mvsusieR)
    dat = readRDS(${_input:r})
    V = readRDS(${resid_cor:r})
    prior = readRDS(${prior:r})
    print(paste("Number of components in the mixture prior:", length(prior$U)))
    prior = mvsusieR::create_mixture_prior(mixture_prior=list(weights=prior$w, matrices=prior$U), 
                                        include_indices = get_prior_indices(dat$Z, prior$U[[1]]), 
                                        max_mixture_len=-1)
    if("${ld_type}" == 'original'){
        ld.file = gsub('/project2/mstephens/yuxin/ukb-bloodcells/LD/', 
                   '/project/mstephens/yuxin/ukb-bloodcells/LD/', 
                   dat$ld.file)
        R = readRDS(ld.file)
    }else if("${ld_type}" == 'remove_cov'){
        R = dat$LD
        R = (R+t(R))/2
    }
    st = proc.time()
    mv_res = mvsusieR::mvsusie_rss(dat$Z, R, L=${max_L}, N = 248980,
                                   prior_variance=prior, residual_variance=V, 
                                   precompute_covariances=T, compute_objective=T, 
                                   estimate_prior_variance=T, estimate_prior_method='EM',
                                   max_iter = 1000, n_thread=1)
    mv_res$time = proc.time() - st
    if(mv_res$convergence$converged == FALSE){
        stop('Fail to converge.')
    }
    mv_res$cs_corr = susieR:::get_cs_correlation(mv_res, Xcorr=R)
    saveRDS(mv_res, ${_output:r})
            

In [None]:
[CS_report]
mvsusie_res = [f'{wd:a}/{x}/{x}.{data_suffix}' for x in regions if path(f"{wd:a}/{x}/{x}.{data_suffix}").exists()]
input: mvsusie_res, group_by = 1
output: text_summary = f"{_input:n}.summary.md"
R: expand = '${ }'
    res = readRDS(${_input:r})
    num_cs = length(res$sets$cs)
    regionname = gsub(".${data_suffix}", "",${_input:br})
    write(paste(regionname, num_cs), ${_output["text_summary"]:r})

In [None]:
[analysis_report]
parameter: CS_lfsr = 0.01
parameter: CS_purity = 0.5
sumstat = [f'{data_dir:a}/{x}.summary_stats.rds' for x in regions]
mvsusie_res = [f'{wd:a}/{x}/{x}.{data_suffix}' for x in regions]
input: sumstat, mvsusie_res, group_by = 'pairs'
output: textfile = f"{_input[1]:n}.CS_purity{CS_purity}.CS_lfsr{CS_lfsr}.summary.rds"
R: expand = '${ }', stdout = f"{_output['textfile']:nn}.stdout", stderr = f"{_output['textfile']:nn}.stderr"
    library(mvsusieR)
    library(ggplot2)
    check_overlap = function(cs) {
        if (length(cs) == 0) {
            return(0)
        } else {
            overlaps_cs = matrix(NA, length(cs), length(cs))
            rownames(overlaps_cs) = colnames(overlaps_cs) = names(cs)
            for (i in 1:length(cs)) {
                for (j in 1:i) {
                    if (i == j){
                        overlaps_cs[i,j] = length(cs[[i]])
                    }else{
                        overlap = intersect(cs[[i]], cs[[j]])
                        overlaps_cs[i,j] = length(overlap)
                    }
                }
            }
            overlaps_cs = as.matrix(Matrix::forceSymmetric(overlaps_cs,uplo="L"))
            return(overlaps_cs)
        }
    }

    dat = readRDS(${_input[0]:r})
    res = readRDS(${_input[1]:r})
    regionname = "${regions[_index]}"
    trait_names = res$condition_names

    if(grepl('LDoriginal', ${_input[1]:r})){
        ld.file = gsub('/project2/mstephens/yuxin/ukb-bloodcells/LD/', 
                        '/project/mstephens/yuxin/ukb-bloodcells/LD/', 
                        dat$ld.file)
        ld = readRDS(ld.file)
    }else{
        ld = dat$LD
        ld = (ld+t(ld))/2
    }
    if(${CS_purity} != 0.5){
        res$sets = susieR::susie_get_cs(res, Xcorr = ld, min_abs_corr = ${CS_purity})
    }
    
    snps = sort(union(which(res$pip > 0.05), unlist(res$sets$cs)))
    res$variable_names = paste(dat$meta$CHR, dat$meta$POS, sep = '.')
    if(length(snps) > 0){
        # PIP
        tb = data.frame('Region' = regionname, dat$meta[snps,], 'PIP' = res$pip[snps])
        
        # CS
        snps_cs = unlist(res$sets$cs)
        snps_cs_match = match(snps_cs, rownames(tb))
        snps_cs_match = snps_cs_match[!is.na(snps_cs_match)]
        tb$CS = NA
        tb[snps_cs_match,]$CS = rep(res$sets$cs_index, times = sapply(res$sets$cs, length))
        
        # purity
        tb$purity = NA
        tb[snps_cs_match,]$purity = rep(res$sets$purity[,1], times = sapply(res$sets$cs, length))

        # trait CS
        tb$CS_trait = NA
        tb[snps_cs_match,]$CS_trait = rep(sapply(res$sets$cs_index, 
                                                 function(i) paste(trait_names[which((res$single_effect_lfsr < ${CS_lfsr})[i,])], 
                                                                   collapse = ' | ')), 
                                          times = sapply(res$sets$cs, length))

        write.csv(tb,"${_output['textfile']:n}.csv", row.names = FALSE, quote = FALSE)
        
        # ## plot
        # pdf(${_output['pip_plot']:r}, width=8, height=4)
        # susieR::susie_plot(res,y='PIP', main = 'Cross-condition Posterior Inclusion Probability', 
        #                    xlab = 'SNP positions', add_legend = F)
        # dev.off()
        # p = mvsusieR::mvsusie_plot(res)
        # pdf(${_output['mv_post_plot']:r}, width = p$width, height = p$height)
        # print(p$plot)
        # dev.off()
        # res$z = dat$Z
        # p = mvsusieR::mvsusie_plot(res, plot_z=TRUE)
        # pdf(${_output['mv_z_plot']:r}, width = p$width, height = p$height)
        # print(p$plot)
        # dev.off()
        
        cs_corr = susieR:::get_cs_correlation(res, Xcorr=ld)
        if(all(!is.na(cs_corr))){
            rownames(cs_corr) = colnames(cs_corr) = names(res$sets$cs)
        }
        
        saveRDS(list(total_snps = nrow(dat$Z), summary_tb = tb, 
                     expect_causal = sum(res$pip), 
                     num_pip_not_CS = sum(is.na(tb$CS)),
                     cs_corr = cs_corr,
                     cs_overlap = check_overlap(res$sets$cs)), ${_output['textfile']:r})
    }else{
        system("touch ${_output['textfile']:n}.csv")
        system("touch ${_output['textfile']}")
    }


In [None]:
[analysis_plot]
parameter: CS_lfsr = 0.01
sumstat = [f'{data_dir:a}/{x}.summary_stats.rds' for x in regions]
mvsusie_res = [f'{wd:a}/{x}/{x}.{data_suffix}' for x in regions]
input: sumstat, mvsusie_res, group_by = 'pairs'
output: pip_plot = f"{_input[1]:n}.manhattan.pdf", 
        mv_post_plot = f"{_input[1]:n}.bubble_finemap.pdf", 
        mv_z_plot = f"{_input[1]:n}.bubble_original.pdf"
R: expand = '${ }', stdout = f"{_output['pip_plot']:nn}.stdout", stderr = f"{_output['pip_plot']:nn}.stderr"
    library(mvsusieR)
    library(ggplot2)
    library(readr)
    library(data.table)
    library(dplyr)

    plot_geneName = function(dat, xrange, chr){
        line = 1
        dat$lines = NA
        dat$lines[1] = 1
        gene.end = dat[1, 'end']
        if(nrow(dat) > 1){
            ngene = 2:nrow(dat)
            while(length(ngene) != 0){
                id = which(dat[ngene, 'start'] > gene.end + 0.02)[1]
                if(!is.na(id)){
                    dat$lines[ngene[id]] = line
                    gene.end = dat[ngene[id],'end']
                    ngene = ngene[-id]
                }else{
                    line = line + 1
                    dat$lines[ngene[1]] = line
                    gene.end = dat[ngene[1],'end']
                    ngene = ngene[-1]
                }
            }
        }
  
        dat$start = pmax(dat$start, xrange[1])
        dat$end = pmin(dat$end, xrange[2])
        dat$mean = rowMeans(dat[,c('start', 'end')])
  
        pl = ggplot(dat, aes(xmin = xrange[1], 
            xmax = xrange[2])) + xlim(xrange[1], xrange[2]) + 
            ylim(min(-dat$lines-0.6), -0.8) + 
            geom_rect(aes(xmin = start, xmax = end, ymin = -lines-0.05, ymax = -lines+0.05), fill='blue') +
            geom_text(aes(x = mean, y=-lines-0.4, label=geneName), size=4, fontface = "italic") + 
            xlab(paste0('base-pair position (Mb) on chromosome ', chr)) + ylab('Gene') + 
            theme_bw() + theme(axis.text.x=element_blank(),
                axis.ticks = element_blank(),
                axis.text.y=element_blank(),
                axis.title = element_text(size=15),
                plot.title=element_text(size=11),
                panel.grid.major = element_blank(),
                panel.grid.minor = element_blank())
    }
    mvsusie_pip_plot = function(model, pos, chr, gene.pos.map = NULL,
        title = NULL, title.size = 10, 
        y.lim=NULL, y.susie='PIP', xrange=NULL,
        legend.position = 'top'){
        if(is.null(xrange)){
            xrange = c(min(pos), max(pos))
        }
        pip = model$pip
        tmp = data.frame(POS = pos, PIP = pip)
        pl_susie = ggplot(tmp, aes(x = POS, y = PIP)) + geom_point(show.legend = FALSE, size=3) + 
            xlim(xrange[1], xrange[2]) + 
            theme_bw() + 
            theme(axis.title.x=element_blank(),
                axis.text = element_text(size=12),
                axis.title.y = element_text(size=15))
        pl_susie = pl_susie + ggtitle(title) + theme(plot.title = element_text(size=title.size))
        
        model.cs = model$sets$cs
        if(!is.null(model.cs)){
            colors = c(
                "#FF7F00", # orange
                "skyblue2", 
                "green1",
                "#6A3D9A", # purple
                "#FB9A99", # lt pink
                "dodgerblue2",
                "green4",
                "gold1",
                "palegreen2",
                "#CAB2D6", # lt purple
                "#FDBF6F", # lt orange
                "gray70", "khaki2",
                "maroon", "orchid1", "deeppink1", "blue1", "steelblue4",
                "darkturquoise", "yellow4", "yellow3",
                "darkorange4", "brown", 'cyan',
                'cyan3', 'aliceblue', 'darkolivegreen4', 'darksalmon',
                'tomato2', 'tan'
            )
            cs_include_idx = model$sets$cs_index
            tmp$CS = numeric(length(pip))
            tmp$color = numeric(length(pip))
            tmp.cs = c()
            for(i in 1:length(model.cs)){
                if(model$sets$cs_index[i] %in% cs_include_idx){
                    tmp$CS[model.cs[[i]]] = paste0(names(model.cs)[i], ": C=", length(model.cs[[i]]),
                        "/R=", round(model$sets$purity[i,1],3))
                    tmp$color[model.cs[[i]]] = colors[model$sets$cs_index[i] %% length(colors)]
                    tmp.cs = rbind(tmp.cs, tmp[model.cs[[i]],])
                }
            }
            tmp.cs$CS = factor(tmp.cs$CS)
            pl_susie = pl_susie + geom_point(data=tmp.cs, aes(x=POS, y=PIP, color=CS), 
                size=3, shape=1, stroke = 2) + 
                scale_color_manual(breaks = tmp.cs$CS[order(tmp.cs$CS)], values=tmp.cs$color[order(tmp.cs$CS)]) + 
                guides(colour=guide_legend(override.aes=list(shape=1, size=0.1)))
            if(legend.position == 'top'){
                pl_susie = pl_susie + theme(legend.justification = c(1, 1), legend.position = 'top',
                    legend.title = element_blank(), legend.text = element_text(size=10))
            }else if(legend.position == 'right'){
                pl_susie = pl_susie + theme(legend.title = element_blank(), legend.text = element_text(size=10))
            }
        }
  
        if(!is.null(gene.pos.map)){
            pl_gene = plot_geneName(gene.pos.map, xrange = xrange, chr=chr)
            g = egg::ggarrange(pl_susie, pl_gene, nrow=2, heights = c(5.5,1.5), draw=FALSE)
        }else{
            g = pl_susie
        }
        g
    }

    dat = readRDS(${_input[0]:r})
    res = readRDS(${_input[1]:r})
    regionname = "${regions[_index]}"
    trait_names = res$condition_names
    
    Z = dat$Z
    Z[which(abs(Z)>37.5)] = sign(Z[which(abs(Z)>37.5)]) * 37.5
    res$z = Z
    res$variable_names = paste('chr', dat$meta$CHR, dat$meta$POS, sep = '.')

    ## PIP
    genes <- read_delim("~/GitHub/finemap-uk-biobank/data/seq_gene.md.gz",delim = "\t",quote = "")
    class(genes) <- "data.frame"
    genes <- subset(genes,
        group_label == "GRCh37.p5-Primary Assembly" &
        feature_type == "GENE")
    
    start.pos <- min(dat$meta$POS)
    stop.pos  <- max(dat$meta$POS)
    plot.genes <- subset(genes,
        chromosome == unique(dat$meta$CHR) &
        ((chr_start > start.pos & chr_start < stop.pos) |
        (chr_stop > start.pos & chr_start < stop.pos)) & feature_type == 'GENE')
    gene.pos.map = plot.genes %>% select(feature_name, chr_start, chr_stop)
    colnames(gene.pos.map) = c('geneName', 'start', 'end')
    gene.pos.map = as.data.frame(gene.pos.map)
    gene.pos.map = gene.pos.map %>% mutate(start = start/1e6, end = end/1e6)
    gene.pos.map = gene.pos.map %>% filter(!grepl('LOC', geneName))
    
    if(nrow(gene.pos.map) == 0){
        gene.pos.map = NULL
    }
    pip_plot = mvsusie_pip_plot(model = res, 
        pos = dat$meta$POS/1e6, 
        chr = unique(dat$meta$CHR), 
        gene.pos.map = gene.pos.map)
    pdf(${_output['pip_plot']:r}, width=18, height=14)
    pip_plot
    dev.off()
    
    if(!is.null(res$sets$cs)){
        # posterior
        p = mvsusie_plot(res, cslfsr_threshold = ${CS_lfsr})
        pdf(${_output['mv_post_plot']:r}, width = p$width, height = p$height)
        print(p$plot)
        dev.off()

        # z
        p = mvsusie_plot(res, plot_z = TRUE, cslfsr_threshold = ${CS_lfsr})
        pdf(${_output['mv_z_plot']:r}, width = p$width, height = p$height)
        print(p$plot)
        dev.off()
    }else{
        system("touch ${_output['mv_post_plot']:r}")
        system("touch ${_output['mv_z_plot']:r}")
    }

In [None]:
[univariate_analysis_report]
parameter: CS_purity = 0.5
sumstat = [f'{data_dir:a}/{x}.summary_stats.rds' for x in regions]
susie_res = [f'{wd:a}/{x}/{x}.{data_suffix}' for x in regions]
input: sumstat, susie_res, group_by = 'pairs'
output: pip_plot = f"{_input[1]:n}.manhattan.pdf", 
        textfile = f"{_input[1]:n}.CS_purity{CS_purity}.summary.rds"
task: trunk_workers = 1, trunk_size = 1, walltime = '2h', mem = '55G', cores = 1, tags = f'{step_name}_{_output[1]:bn}'
R: expand = '${ }'
    library(susieR)
    library(ggplot2)
    dat = readRDS(${_input[0]:r})
    res = readRDS(${_input[1]:r})
    regionname = "${regions[_index]}"
    
    if(${CS_purity} != 0.5){
        if(grepl('susierss_notrem_covariates', ${_input[1]:r})){
            ld.file = gsub('/project2/mstephens/yuxin/ukb-bloodcells/LD/', 
                           '/project/mstephens/yuxin/ukb-bloodcells/LD/', 
                           dat$ld.file)
            ld = readRDS(ld.file)
        }else{
            ld = dat$LD
            ld = (ld+t(ld))/2
        }
    }
    
    trait_names = names(res)
    tb.all = c()
    pdf(${_output['pip_plot']:r})
    par(mfcol = c(2,2))
    for (name in trait_names){
        res_trait = res[[name]]
        if(${CS_purity} != 0.5){
            res_trait$sets = susieR::susie_get_cs(res_trait, Xcorr = ld, 
                                                  min_abs_corr = ${CS_purity})
        }
        susieR::susie_plot(res_trait, y='PIP', main = name, 
                           xlab = 'SNP positions', add_legend = F, max_cs = 3000)
        
        snps = sort(union(which(res_trait$pip > 0.05), unlist(res_trait$sets$cs)))
        if(length(snps) > 0){
            # PIP
            tb = data.frame('Region' = regionname, dat$meta[snps,], 'trait' = name, 'PIP' = res_trait$pip[snps])
            # CS
            snps_cs = unlist(res_trait$sets$cs)
            snps_cs_match = match(snps_cs, rownames(tb))
            snps_cs_match = snps_cs_match[!is.na(snps_cs_match)]
            tb$CS = NA
            tb[snps_cs_match,]$CS = rep(res_trait$sets$cs_index, times = sapply(res_trait$sets$cs, length))
            # purity
            tb$purity = NA
            tb[snps_cs_match,]$purity = rep(res_trait$sets$purity[,1], times = sapply(res_trait$sets$cs, length))

            tb.all = rbind(tb.all, tb)
        }
    }
    dev.off()
    if(nrow(tb.all) > 0){
        write.csv(tb.all,"${_output['textfile']:n}.csv", 
                  row.names = FALSE, quote = FALSE)
        saveRDS(list(total_snps = nrow(dat$Z), summary_tb = tb.all), 
                ${_output['textfile']:r})
    }else{
        system("touch ${_output['textfile']:n}.csv")
        system("touch ${_output['textfile']:r}")
    }
    