# Computation of a grid of scaling factors for *mr.mash*

This notebook computes a grid of scaling factors (i.e., variances) to scale the covariance matrices for *mr.mash*.

## Inputs

A list of genes and a directory of corresponding RDS files in list format with two elements (matrices Bhat and Shat from univariate regression), the desired K fold, as well as parameters to perform some filtering.

## Outputs

A vector of scaling factors (i.e., variances).

## Minimal working example
```
 sos run prior_grid.ipynb compute_grid \
        --analysis-units analysis_units.txt \
        --datadir ../summary_stats --name fold_1 --table_name fold_1 \
        --bhat Bhat --sbhat Shat --expected-ncondition 49 \
        -c midway2.yml -q midway2
```

## Global parameters

In [None]:
[global]
import os
# Work directory & output directory
parameter: cwd = path('.')
# The filename prefix for output data
parameter: name = str
parameter: job_size = 1

## Compute grid from univariate summary statistics

In [None]:
[compute_grid]
parameter: table_name = ""
parameter: bhat = "bhat"
parameter: sbhat = "sbhat"
parameter: expected_ncondition = 0
parameter: datadir = path
parameter: seed = 999
# Analysis units file. For RDS files it can be generated by `ls *.rds | sed 's/\.rds//g' > analysis_units.txt`
parameter: analysis_units = path
regions = [x.strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
input: [f'{datadir}/{x[0]}.rds' for x in regions]
output: f"{cwd}/{name}_grid.rds"
task: trunk_workers = 1, walltime = '6h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }"
    options(stringsAsFactors=FALSE)
    set.seed(${seed})
  
    ###Function to compute the grid
    autoselect_mixsd <- function(gmin, gmax, mult=2){
      if (mult == 0) {
        return(c(0, gmax/2))
      }
      else {
        npoint = ceiling(log2(gmax/gmin)/log2(mult))
        return(mult^((-npoint):0) * gmax)
      }
    }
  
    ###Compute endpoints
    compute_grid_endpoints = function(data){
      include = !(data$Shat==0 | !is.finite(data$Shat) | is.na(data$Bhat))
      gmax = grid_max(data$Bhat[include], data$Shat[include])
      gmin = grid_min(data$Bhat[include], data$Shat[include])
  
      return(list(gmin=gmin, gmax=gmax))
    }


    ###Compute the minimum value for the grid
    grid_min = function(Bhat,Shat){
      min(Shat)
    }

    ###Compute the maximum value for the grid
    grid_max = function(Bhat,Shat){
      if (all(Bhat^2 <= Shat^2)) {
        8 * grid_min(Bhat,Shat) # the unusual case where we don't need much grid
      } else {
        2 * sqrt(max(Bhat^2 - Shat^2))
      }
    }

    grid_mins = c()
    grid_maxs = c()

    for (f in c(${_input:r,})) {
      # If cannot read the input for some reason then we just skip it, assuming we have other enough data-sets to use.
      dat = tryCatch(readRDS(f), error = function(e) return(NULL))${("$"+table_name) if table_name != "" else ""}
      if (is.null(dat)) {
          message(paste("Skip loading file", f, "due to load failure."))
          next
      }
      if (${expected_ncondition} > 0 && (ncol(dat$${bhat}) != ${expected_ncondition} || ncol(dat$${sbhat}) != ${expected_ncondition})) {
          message(paste("Skip loading file", f, "because it has", ncol(dat$${bhat}), "columns different from required", ${expected_ncondition}))
          next
      }
      endpoints = compute_grid_endpoints(dat)
      grid_mins = c(grid_mins, endpoints$gmin)
      grid_maxs = c(grid_maxs, endpoints$gmax)

    }
   
    gmin_tot = min(grid_mins)
    gmax_tot = max(grid_maxs)
    grid = autoselect_mixsd(gmin_tot, gmax_tot, mult=sqrt(2))^2  

    saveRDS(grid, ${_output:r})