Skip to content

Commit

Permalink
updated code to use pmap on multi-device TPU
Browse files Browse the repository at this point in the history
  • Loading branch information
ericjang committed Nov 29, 2019
1 parent cbcb62a commit 626c668
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 136 deletions.
69 changes: 66 additions & 3 deletions jaxpt.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,22 @@
"rd = normalize(p[:,0,None]*u + p[:,1,None]*v + d*w)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"render_fn = lambda rng_key, ro, rd : trace(rng_key, ro, rd, 0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## non-pmap version"
]
},
{
"cell_type": "code",
"execution_count": 0,
Expand All @@ -529,16 +545,63 @@
"outputs": [],
"source": [
"def render_multisample(rng_key, num_samples):\n",
" img = trace(rng_key, eye, rd, 0)\n",
" img = render_fn(rng_key, eye, rd)\n",
" for i in range(2, num_samples+1):\n",
" rng_key, _ = random.split(rng_key)\n",
" sample = trace(rng_key, eye, rd, 0)\n",
" sample = render_fn(rng_key, eye, rd)\n",
" if i % 10 == 0:\n",
" print('Sample %d' % i)\n",
" img = (img + sample)\n",
" return np.fliplr(np.flipud(img.reshape((N,N,3))))/num_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# pmap version"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# in order to pmap this successfully, we divide reshape inputs to have leading axis 8.\n",
"B = eye.shape[0]\n",
"M = jax.local_device_count()\n",
"eye = np.reshape(eye, (M, B//M, 3))\n",
"rd = np.reshape(rd, (M, B//M, 3))\n",
"RNG_KEY = random.split(RNG_KEY, M)\n",
"render_fn = jax.soft_pmap(render_fn)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PMAP version\n",
"def render_multisample(rng_key, num_samples):\n",
" img = render_fn(rng_key, eye, rd)\n",
" for i in range(2, num_samples+1):\n",
" rng_key = random.split(rng_key[0], M)\n",
" sample = render_fn(rng_key, eye, rd)\n",
" if i % 10 == 0:\n",
" print('Sample %d' % i)\n",
" img = (img + sample)\n",
" return np.fliplr(np.flipud(img.reshape((N,N,3))))/num_samples"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Render"
]
},
{
"cell_type": "code",
"execution_count": 17,
Expand Down Expand Up @@ -574,7 +637,7 @@
],
"source": [
"%%time\n",
"img = trace(RNG_KEY, eye, rd, 0)\n",
"img = render_fn(RNG_KEY, eye, rd)\n",
"print('done')"
]
},
Expand Down
Loading

0 comments on commit 626c668

Please sign in to comment.