Skip to content

Commit

Permalink
Try mapping label and convert to masks
Browse files Browse the repository at this point in the history
Also make the chunks concrete (as they are undefined) to allow storage.
  • Loading branch information
jakirkham committed Apr 3, 2018
1 parent 192d030 commit 8ce5237
Showing 1 changed file with 85 additions and 21 deletions.
106 changes: 85 additions & 21 deletions nanshe_ipython.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1162,6 +1162,60 @@
"* `noise_threshold` (`float`): number of units of \"noise\" above which something needs to be to be significant"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def np_label(a):\n",
" return np.array(\n",
" scipy.ndimage.label(a),\n",
" dtype=[\n",
" (\"label\", np.int32, da_imgs_thrd.shape),\n",
" (\"num\", int, ()),\n",
" ]\n",
" )\n",
"\n",
"def label_chunk(a):\n",
" return np.stack([np_label(e) for e in a])\n",
"\n",
"def label(d):\n",
" d_lbld = d.map_blocks(\n",
" label_chunk,\n",
" dtype=[\n",
" (\"label\", np.int32, d.shape[1:]),\n",
" (\"num\", int, ()),\n",
" ],\n",
" drop_axis=tuple(irange(1, d.ndim))\n",
" )\n",
"\n",
" return d_lbld[\"label\"], d_lbld[\"num\"]\n",
"\n",
"def np_labels_to_masks_chunk(a, num):\n",
" r = np.empty((0,) + a.shape[1:], dtype=bool)\n",
" if num:\n",
" r = np.concatenate([a == i for i in irange(1, 1 + num)])\n",
" return r\n",
"\n",
"def labels_to_masks_chunk(a, num):\n",
" r = np.empty((0,) + a.shape[2:], dtype=bool)\n",
" if len(num):\n",
" r = np.concatenate([np_labels_to_masks_chunk(e0, e1) for e0, e1 in zip(a, num)])\n",
" return r\n",
"\n",
"def labels_to_masks(d, nums):\n",
" out = da.atop(\n",
" labels_to_masks_chunk, tuple(irange(d.ndim)),\n",
" d, tuple(irange(d.ndim)),\n",
" nums, tuple(irange(nums.ndim)),\n",
" dtype=bool\n",
" )\n",
" out._chunks = (len(out.chunks[0]) * (np.nan,),) + out.chunks[1:]\n",
"\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -1180,43 +1234,53 @@
"da_imgs = dask_store[subgroup_dict]\n",
"da_imgs = da_imgs.rechunk(((1,) + da_imgs.shape[1:]))\n",
"\n",
"da_imgs = da_imgs[0]\n",
"\n",
"da_imgs_thrd = (da_imgs - noise_threshold * (da_imgs - significance_threshold * da_imgs.std()).std()) > 0\n",
"\n",
"da_lbl_img, da_num_lbls = dask_ndmeasure.label(da_imgs_thrd)\n",
"da_lbl_img, da_num_lbls = client.persist([da_lbl_img, da_num_lbls])\n",
"da_lbl_img, da_num_lbls = label(da_imgs_thrd)\n",
"\n",
"da_result = []\n",
"for i in irange(1, 1 + int(da_num_lbls)):\n",
" da_result.append(da_lbl_img == i)\n",
"da_result = da.stack(da_result)\n",
"da_result = labels_to_masks(da_lbl_img, da_num_lbls)\n",
"da_result = da_result.astype(np.uint8)\n",
"\n",
"dask_store[subgroup_post_mask] = da_result\n",
"da_result = client.persist(da_result)\n",
"\n",
"dask.distributed.progress(dask_store[subgroup_post_mask], notebook=False)\n",
"print(\"\")\n",
"dask.distributed.progress(da_result, notebook=False)\n",
"\n",
"\n",
"# View results\n",
"imgs_min, imgs_max = 0, 100\n",
"# Make chunks concrete\n",
"\n",
"da_imgs = dask_store[subgroup_post_mask]\n",
"da_imgs = da_imgs.astype(np.uint8)\n",
"da_result_chunks_0 = tuple(\n",
" da_result[:, 0, 0].map_blocks(lambda e: np.atleast_1d(np.ones_like(e).astype(int).sum())).compute()\n",
")\n",
"\n",
"da_imgs_min, da_imgs_max = da_imgs.min(), da_imgs.max()\n",
"da_result_keys_0, da_result_chunks_0 = list(zip(*[[k, c] for k, c in zip(dask.core.flatten(da_result.__dask_keys__()), da_result_chunks_0) if c]))\n",
"\n",
"status = client.compute([da_imgs_min, da_imgs_max])\n",
"dask.distributed.progress(status, notebook=False)\n",
"da_result_chunks = (\n",
" (da_result_chunks_0,) + da_result.chunks[1:]\n",
")\n",
"\n",
"# (cls, dask, name, chunks, dtype, shape=None):\n",
"da_result_2 = da.Array(\n",
" dask.sharedict.merge(dask.optimization.cull(da_result.__dask_graph__(), list(da_result_keys_0))[0]),\n",
" da_result.name,\n",
" da_result_chunks,\n",
" da_result.dtype\n",
")\n",
"\n",
"dask_store[subgroup_post_mask] = da_result_2\n",
"\n",
"dask.distributed.progress(dask_store[subgroup_post_mask], notebook=False)\n",
"print(\"\")\n",
"\n",
"imgs_min, imgs_max = [s.result() for s in status]\n",
"\n",
"# View results\n",
"da_imgs = dask_store[subgroup_post_mask]\n",
"da_imgs = da_imgs.astype(np.uint8)\n",
"\n",
"mplsv = plt.figure(FigureClass=MPLViewer)\n",
"mplsv.set_images(\n",
" da_imgs,\n",
" vmin=imgs_min,\n",
" vmax=imgs_max\n",
" vmin=0,\n",
" vmax=1\n",
")"
]
},
Expand Down

0 comments on commit 8ce5237

Please sign in to comment.