-
Notifications
You must be signed in to change notification settings - Fork 429
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
291 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,291 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%matplotlib inline\n", | ||
"%config InlineBackend.figure_format = \"retina\"\n", | ||
"\n", | ||
"from __future__ import print_function\n", | ||
"\n", | ||
"from matplotlib import rcParams\n", | ||
"rcParams[\"savefig.dpi\"] = 100\n", | ||
"rcParams[\"figure.dpi\"] = 100\n", | ||
"rcParams[\"font.size\"] = 20" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Parallelization\n", | ||
"\n", | ||
"With emcee, it's easy to make use of multiple CPUs to speed up slow sampling.\n", | ||
"There will always be some computational overhead introduced by parallelization so it will only be beneficial in the case where the model is expensive, but this is often true for real research problems.\n", | ||
"All parallelization techniques are accessed using the `pool` keyword argument in the :class:`EnsembleSampler` class but, depending on your system and your model, there are a few pool options that you can choose from.\n", | ||
"In general, a `pool` is any Python object with a `map` method that can be used to apply a function to a list of numpy arrays.\n", | ||
"Below, we will discuss a few options.\n", | ||
"\n", | ||
"This tutorial was executed with the following version of emcee:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"3.0.0.dev0\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import emcee\n", | ||
"print(emcee.__version__)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In all of the following examples, we'll test the code with the following convoluted model:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"import time\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"def log_prob(theta):\n", | ||
" time.sleep(np.random.uniform(0.005, 0.008))\n", | ||
" return -0.5*np.sum(theta**2)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"This " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 100/100 [00:24<00:00, 4.14it/s]" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Serial took 24.8 seconds\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"np.random.seed(42)\n", | ||
"initial = np.random.randn(32, 5)\n", | ||
"nwalkers, ndim = initial.shape\n", | ||
"nsteps = 100\n", | ||
"\n", | ||
"sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob)\n", | ||
"start = time.time()\n", | ||
"sampler.run_mcmc(initial, nsteps, progress=True)\n", | ||
"end = time.time()\n", | ||
"serial_time = end - start\n", | ||
"print(\"Serial took {0:.1f} seconds\".format(serial_time))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Multiprocessing\n", | ||
"\n", | ||
"The simplest method of parallelizing emcee is to use the [multiprocessing module from the standard library](https://docs.python.org/3/library/multiprocessing.html).\n", | ||
"To parallelize the above sampling, " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"100%|██████████| 100/100 [00:06<00:00, 15.71it/s]\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Multiprocessing took 6.4 seconds\n", | ||
"3.9 times faster than serial\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from multiprocessing import Pool\n", | ||
"\n", | ||
"with Pool() as pool:\n", | ||
" sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)\n", | ||
" start = time.time()\n", | ||
" sampler.run_mcmc(initial, nsteps, progress=True)\n", | ||
" end = time.time()\n", | ||
" multi_time = end - start\n", | ||
" print(\"Multiprocessing took {0:.1f} seconds\".format(multi_time))\n", | ||
" print(\"{0:.1f} times faster than serial\".format(serial_time / multi_time))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"4 CPUs\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from multiprocessing import cpu_count\n", | ||
"ncpu = cpu_count()\n", | ||
"print(\"{0} CPUs\".format(ncpu))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## MPI\n", | ||
"\n", | ||
"Multiprocessing can only be used for distributing calculations across processors on one machine.\n", | ||
"If you want to take advantage of a bigger cluster, you'll need to use MPI.\n", | ||
"In that case, you need to execute the code using the `mpiexec` executable, so " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 21, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"with open(\"script.py\", \"w\") as f:\n", | ||
" f.write(\"\"\"\n", | ||
"import sys\n", | ||
"import time\n", | ||
"import emcee\n", | ||
"import numpy as np\n", | ||
"from schwimmbad import MPIPool\n", | ||
"\n", | ||
"def log_prob(theta):\n", | ||
" time.sleep(np.random.uniform(0.005, 0.008))\n", | ||
" return -0.5*np.sum(theta**2)\n", | ||
"\n", | ||
"with MPIPool() as pool:\n", | ||
" if not pool.is_master():\n", | ||
" pool.wait()\n", | ||
" sys.exit(0)\n", | ||
" \n", | ||
" np.random.seed(42)\n", | ||
" initial = np.random.randn(32, 5)\n", | ||
" nwalkers, ndim = initial.shape\n", | ||
" nsteps = 100\n", | ||
"\n", | ||
" sampler = emcee.EnsembleSampler(nwalkers, ndim, log_prob, pool=pool)\n", | ||
" start = time.time()\n", | ||
" sampler.run_mcmc(initial, nsteps)\n", | ||
" end = time.time()\n", | ||
" print(end - start)\n", | ||
"\"\"\")\n", | ||
"\n", | ||
"mpi_time = !mpiexec -n {ncpu} python script.py\n", | ||
"mpi_time = float(mpi_time[0])\n", | ||
"print(\"MPI took {0:.1f} seconds\".format(mpi_time))\n", | ||
"print(\"{0:.1f} times faster than serial\".format(serial_time / mpi_time))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 22, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"MPI took 9.9 seconds\n", | ||
"2.5 times faster than serial\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"MPI took {0:.1f} seconds\".format(mpi_time))\n", | ||
"print(\"{0:.1f} times faster than serial\".format(serial_time / mpi_time))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"collapsed": true | ||
}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.1" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |