From c8d51578adc4fa23d827c9b8ec4e87d0b3aa4b80 Mon Sep 17 00:00:00 2001 From: Qiusheng Wu Date: Thu, 10 Nov 2022 14:53:25 -0500 Subject: [PATCH] Added add_mask_to_image() function (#306) --- leafmap/common.py | 45 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/leafmap/common.py b/leafmap/common.py index d3d1cdf64..97305e312 100644 --- a/leafmap/common.py +++ b/leafmap/common.py @@ -4696,7 +4696,7 @@ def numpy_to_cog( """Converts a numpy array to a COG file. Args: - np_array (np.array): A numpy array representing the image. + np_array (np.array): A numpy array representing an image or an HTTP URL to an image. out_cog (str): The output COG file path. bounds (tuple, optional): The bounds of the image in the format of (minx, miny, maxx, maxy). Defaults to None. profile (str | dict, optional): File path to an existing COG file or a dictionary representing the profile. Defaults to None. @@ -4714,6 +4714,12 @@ def numpy_to_cog( from rio_cogeo.profiles import cog_profiles warnings.filterwarnings("ignore") + + if isinstance(np_array, str): + + with rasterio.open(np_array, "r") as ds: + np_array = ds.read() + if not isinstance(np_array, np.ndarray): raise TypeError("The input array must be a numpy array.") @@ -7127,3 +7133,40 @@ def get_overlap(img1, img2, overlap, out_img1=None, out_img2=None, to_cog=True): clip_image(img2, overlap, out_img2, to_cog=to_cog) return overlap + + +def add_mask_to_image(image, mask, output, color="red"): + """Overlay a binary mask (e.g., roads, building footprints, etc) on an image. Credits to Xingjian Shi for the sample code. + + Args: + image (str): A local path or HTTP URL to an image. + mask (str): A local path or HTTP URL to a binary mask. + output (str): A local path to the output image. + color (str, optional): Color of the mask. Defaults to 'red'. + + Raises: + ImportError: If rasterio and detectron2 are not installed. + """ + try: + import rasterio + from detectron2.utils.visualizer import Visualizer + from PIL import Image + except ImportError: + raise ImportError( + "Please install rasterio and detectron2 to use this function. See https://detectron2.readthedocs.io/en/latest/tutorials/install.html" + ) + + ds = rasterio.open(image) + image_arr = ds.read() + + mask_arr = rasterio.open(mask).read() + + vis = Visualizer(image_arr.transpose((1, 2, 0))) + vis.draw_binary_mask(mask_arr[0] > 0, color=color) + + out_arr = Image.fromarray(vis.get_output().get_image()) + + out_arr.save(output) + + if ds.crs is not None: + numpy_to_cog(output, output, profile=image)