Machine learning instrumental variable estimators for Mendelian randomization
The goal of ml-mr is to provide a toolkit to facilitate the use of machine learning estimators developed for instrumental variable causal effect estimation in Mendelian randomization studies.
There are 4 core modules:
- Simulation Utilities to simulate data and pre-implemented simulation models from published articles.
- Estimation Unified interface (programmatic and command-line) to different estimators. This module makes it easy to compare algorithms, and to train the neural networks on arbitrary hardware (e.g. GPUs) while not having to worry about logging, early stopping, model checkpointing, etc.
- Evaluation Facilitates plotting, computing metrics when the true causal function is known (e.g. in simulation models). It also makes it easy to extract stored model statistics or metadata.
- Sweep Many of the implemented algorithms rely on hyperparameters. For example, we need to pre-specify the number of quantiles to use in Quantile IV, or the learned feature's dimensionality in the Deep Feature IV algorithm. Furthermore, neural network optimization also relies on hyperparameters such as the learning rate. To make it easy to conduct grid or random hyperparameter searches, we implemented a sweep module that uses multiprocessing to fit models in parallel.
To help understand the different estimands available in ml-mr, we give a short clarifying description. Throughout the table,
MREstimator method or standard terminology | Mathematical object | Definition | Comments |
---|---|---|---|
iv_reg_function |
Estimates the structural function between the exposure and the outcome conditional on specified covariate values. It is often the case that this estimate will biased. | ||
avg_iv_reg_function |
In practice, we marginalize the covariates from the iv_reg_function by using the observed covariates values. In other words, we use the sample estimate of |
||
ate (Average treatment effect) |
Scalar |
|
This can be expressed as the difference of avg_iv_reg_function calls at reference values (e.g. |
cate (Conditional Average Treatment Effect) |
Scalar at fixed value of C or |
|
This quantity is also defined with respect to a pair of X values that are compared and fixed levels of the observed covariates. |
We presented ml-mr at the Great Lakes Bioinformatics 2023 conference. The slides are available here.
# Clone repository.
git clone git@github.com:legaultmarc/ml-mr.git
pip install ml-mr
We recommend setting up Weights and Biases. It is integrated with our estimator module and will help track training metrics to ensure convergence and diagnose problems.
A simple simulation model from Burgess S, et al. Epidemiology (2014) can be found here.
A more complex example can be found here. In this model, the (linear) causal effect of the exposure on the outcome varies nonlinearly as a function of an effect modifier variable. It is a good example of where machine learning models may perform better than conventional instrumental variable estimators.
Assuming a TSV file with columns for the instruments, covariables, exposure and outcome, we can easily fit machine learning instrumental variable estimators using the command-line interface.
Here's an example command that would train a Quantile IV model.
ml-mr estimation \
quantile_iv \
--n-quantiles 10 \
--data filename.tsv.gz \
--exposure x --outcome y --instruments v{1..20} \
--wandb-project my_fit
Algorithm | Reference | Comments |
---|---|---|
Deep IV | Hartford J, et al. ICLR (2017) | We implement three different density estimators for the 1st stage: ridge regression, mixture density network (as proposed in the original paper) and a gaussian network. We recommend using the gaussian network which performed best in simulation studies. This can be done using the --exposure-network-type gaussian_net option. |
Quantile IV | Legault MA, et al. medRxiv (2024) | We derived a new estimator we call Quantile IV. It is similar to Deep IV but uses quantile regression to learn quantiles of the conditional exposure distribution. Instead of sampling, it then uses quantile midpoints to train the outcome network allowing for simple averaging. |
Deep feature IV | Xu L, et al. ICLR (2021) | This method also has two stages, but it is quite different from Deep IV and others. It learns feature mappings for the instruments and for the exposure and estimates the causal effect using penalized two stage least squares regression on these learned features. |
In the future, we hope to include algorithms that rely on deep generalized method of moments and on kernel instrumental variable methods. Other baseline methods (e.g. LACE estimator, DeLIVR, 2SLS) are implemented, but they are meant to be used for simple benchmarking and comparisons.
The command line interface integrated help message details the full array of option for every estimator. For example, the help message from Deep IV is shown below:
Display command-line help page
usage: ml-mr estimation deep_iv [-h] [--n-gaussians N_GAUSSIANS] [--exposure-network-type {mixture_density_net,gaussian_net,ridge}]
[--output-dir OUTPUT_DIR] [--no-plot] [--alpha ALPHA] [--outcome-type {continuous,binary}]
[--validation-proportion VALIDATION_PROPORTION] [--accelerator ACCELERATOR]
[--wandb-project WANDB_PROJECT] [--exposure-hidden [EXPOSURE_HIDDEN ...]]
[--exposure-max-epochs EXPOSURE_MAX_EPOCHS] [--exposure-batch-size EXPOSURE_BATCH_SIZE]
[--exposure-optimizer EXPOSURE_OPTIMIZER] [--exposure-learning-rate EXPOSURE_LEARNING_RATE]
[--exposure-weight-decay EXPOSURE_WEIGHT_DECAY] [--exposure-add-input-batchnorm]
[--outcome-hidden [OUTCOME_HIDDEN ...]] [--outcome-max-epochs OUTCOME_MAX_EPOCHS]
[--outcome-batch-size OUTCOME_BATCH_SIZE] [--outcome-optimizer OUTCOME_OPTIMIZER]
[--outcome-learning-rate OUTCOME_LEARNING_RATE] [--outcome-weight-decay OUTCOME_WEIGHT_DECAY]
[--outcome-add-input-batchnorm] --data DATA [--sep SEP] [--instruments [INSTRUMENTS ...]]
[--covariables [COVARIABLES ...]] --exposure EXPOSURE --outcome OUTCOME
[--genotypes-backend GENOTYPES_BACKEND] [--genotypes-backend-type GENOTYPES_BACKEND_TYPE]
[--sample-id-col SAMPLE_ID_COL]
optional arguments:
-h, --help show this help message and exit
--n-gaussians N_GAUSSIANS
Number of gaussians used for the mixture density network.
--exposure-network-type {mixture_density_net,gaussian_net,ridge}
Density model for the exposure network.
--output-dir OUTPUT_DIR
--no-plot Disable plotting of diagnostics.
--alpha ALPHA Miscoverage level for the prediction interval.
--outcome-type {continuous,binary}
Variable type for the outcome (binary vs continuous).
--validation-proportion VALIDATION_PROPORTION
--accelerator ACCELERATOR
Accelerator (e.g. gpu, cpu, mps) use to train the model. This will be passed to Pytorch Lightning.
--wandb-project WANDB_PROJECT
Activates the Weights and Biases logger using the provided project name. Patterns such as project:run_name
are also allowed.
--data DATA, -d DATA Path to a data file.
--sep SEP Separator (column delimiter) for the data file.
--instruments [INSTRUMENTS ...], -z [INSTRUMENTS ...]
The instrument (Z or G) in the case where we're not using genotypes provided through --genotypes. Multiple
values can be provided for multiple instruments. This should be column(s) in the data file.
--covariables [COVARIABLES ...]
Variables which will be included in both stages.This should be column(s) in the data file.
--exposure EXPOSURE, -x EXPOSURE
The exposure (X). This should be a column name in the data file.
--outcome OUTCOME, -y OUTCOME
The outcome (Y). This should be a column name in the data file.
--genotypes-backend GENOTYPES_BACKEND
Pickle containing a pytorch-genotypes backend. This can be created from various genetic data formats using
the 'pt-geno-create-backend' command line utility provided by pytorch genotypes.
--genotypes-backend-type GENOTYPES_BACKEND_TYPE
Pickle containing a pytorch-genotypes backend. This can be created from various genetic data formats using
the 'pt-geno-create-backend' command line utility provided by pytorch genotypes.
--sample-id-col SAMPLE_ID_COL
Column that contains the sample id. This is mandatory if genotypes are provided to enable joining.
Exposure Model Parameters:
--exposure-hidden [EXPOSURE_HIDDEN ...]
--exposure-max-epochs EXPOSURE_MAX_EPOCHS
--exposure-batch-size EXPOSURE_BATCH_SIZE
--exposure-optimizer EXPOSURE_OPTIMIZER
--exposure-learning-rate EXPOSURE_LEARNING_RATE
--exposure-weight-decay EXPOSURE_WEIGHT_DECAY
--exposure-add-input-batchnorm
Outcome Model Parameters:
--outcome-hidden [OUTCOME_HIDDEN ...]
--outcome-max-epochs OUTCOME_MAX_EPOCHS
--outcome-batch-size OUTCOME_BATCH_SIZE
--outcome-optimizer OUTCOME_OPTIMIZER
--outcome-learning-rate OUTCOME_LEARNING_RATE
--outcome-weight-decay OUTCOME_WEIGHT_DECAY
--outcome-add-input-batchnorm
To estimate confidence intervals, we often rely to bagging. Here we describe how to achieve this for the Quantile IV estimator. The other estimators may not support boostrapping yet. We rely on GNU Parallel to fit the bootstrap resamples on an arbitrary number of devices.
First, we fit 50 Quantile IV models with bootstrap resampling (i.e. resampling the
# Use -j to control the number of parallel jobs.
# Here, we use 20 jobs and consequentially 20 CPUs
#
# We recommend using --halt-on-error 2 so that parallel stops if one of the
# fit crashes with an error. This most often happens when there is a file not
# found error or other problems.
#
# For Quantile IV, the --fast flag avoids saving plots of the causal effect
# and other statistics that are irrelevant when we want to do bagging.
#
# We use the 'seq' command to generate an increasing list of integers
# corresponding to the bootstrap iteration index.
parallel --halt-on-error 2 -j 20 ml-mr estimation quantile_iv \
--data $filename \
--sep ',' \
--exposure exposure \
--outcome outcome \
--instruments v1 v2 v3 v4 v5 v6 v7 v8 v9 v10 \
--fast \
--resample \
--output "bs_"{} ::: $(seq 1 50)
# After fitting, the ensembling can be done using the API or, for example
# using the ml-mr evaluation module which will make plotting easy.
ml-mr evaluation \
--true-function='lambda x: 0' \
--plot --plot-filename my_ensemble_estimator.png \
--input bs_* \
--ensemble
To get a ensembling estimator instance using the API, the procedure looks like:
from ml_mr.estimation.core import EnsembleMREstimator
from ml_mr.estimation.quantile_iv import QuantileIVEstimator
def get_ensemble_estimator(template, n, cls=QuantileIVEstimator):
estimators = []
for i in range(n):
try:
filename = template.format(i=i)
estimator = cls.from_results(filename)
estimators.append(estimator)
except FileNotFoundError:
pass
return EnsembleMREstimator(*estimators)
# The returned object will support all of the methods (e.g. ate, cate,
# iv_reg_function) from MREstimator described above.
There is also a script in the ml-mr/scripts
folder to facilitate plotting of ensemble estimators.
When the true causal function is known, the evaluation module can be used to compare different instrumental variable estimates.
The format of the command is:
ml-mr \
evaluation \
--input dfiv_estimate \
--true-function 'lambda x: (0.3 * (x + 2) * (x - 1) + 1)' \
--plot
Instead of a lambda, the --true-function
argument can be a function in a python file:
ml-mr \
evaluation \
--input dfiv_estimate \
--true-function 'my_file.py:my_causal_function' \
--plot
It is also possible to glob multiple estimates (which is especially useful in a sweep):
ml-mr \
evaluation \
--input my_sweep/*/estimate* \
--true-function 'my_file.py:my_causal_function' \
--plot
A common practice is to focus on the central 95% of the exposure range. This can be done automatically using the --domain-95
flag. If the user wishes to ensemble the provided estimators (e.g. if they were fit using bootstrap resamples and the user wants a simple bagging estimator) the --ensemble
flag can be used.
Calling the evaluation module also computes metrics that are printed as a CSV file. The mean squared error is computed over a linear space over the observed data domain by default (i.e. min(exposure), max(exposure)). This means that we take evenly spaced exposure values between the minimum and observed maximum values, compute the true and estimated functions and take the mean of their squared difference.
The prediction interval width is computed similarly, but taking the mean of the width of the prediction interval over the domain.
If one is interested in the performance on a different space (e.g. 1st to 99th percentiles), it is currently necessary to use the MREstimator
API directly.
Estimators store many information in JSON format in a file called meta.json
. This file is useful to allow reloading the estimator from fitted models, but it also keeps track of validation metrics and other hyperparameters.
For example, here's a meta.json
file from a fitted Quantile IV Model:
Display the contents of a meta.json file.
{
"q": 20,
"output_dir": "estimate_run_46",
"validation_proportion": 0.2,
"fast": 1,
"sqr": true,
"exposure_hidden": [
128,
64
],
"exposure_learning_rate": 0.0009258416040756616,
"exposure_weight_decay": 0.0001,
"exposure_batch_size": 10000,
"exposure_max_epochs": 1000,
"exposure_add_input_batchnorm": false,
"outcome_hidden": [
64,
32
],
"outcome_learning_rate": 0.000875839535511659,
"outcome_weight_decay": 0.0001,
"outcome_batch_size": 10000,
"outcome_max_epochs": 1000,
"outcome_add_input_batchnorm": false,
"accelerator": "cpu",
"wandb_project": null,
"model": "quantile_iv",
"domain": [
-4.534079074859619,
4.357717037200928
],
"exposure_95_percentile": [
-1.9370369940996168,
1.9681992024183272
],
"exposure_val_loss": 4.551474571228027,
"outcome_val_loss": 0.7311554551124573,
"q_hat": 7.812013626098633
}
To avoid having to parse this manually, we can specify fields to extract in the ml-mr evaluation
command. For example, if you want the CSV file to include the q
and outcome_val_loss
, we will use the --meta-keys q outcome_val_loss
argument.
Hyperparameter sweeps work by specifying a JSON configuration file and calling
ml-mr sweep my_config.json
The configuration should define how to read a dataset, what algorithm to fit as well as configuration parameter samplers.
To control the number of workers fitting models in parallel, use the --n-workers
argument.
Simple examples can be found in this repository's "examples" directory.
To ensure that sampling is done as expected, it's possible to use the --create-db-only
option. This will create a ml_mr_sweep_runs.db
sqlite3 database without starting worker processes. You can then use sqlite3 to list all sampled parameter values and runs:
-- Connect using sqlite3 ml_mr_sweep_runs.db
-- You may also want to use .header on and .mode columns to make it easier
-- to read.
select * from run_parameters;
You can then start or resume the sweep using ml-mr sweep ml_mr_sweep_runs.db --n-workers N
.