Masks are usually binary images where each pixel corresponds to one of two classes: 1 is for the object we are interested in and 0 is for background (including other objects we don't need). For multiclass segmentation masks with more values are used.

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [14]:
! pip install gdal

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [16]:
import numpy as np

from osgeo import ogr, gdal_array
import gdal

%matplotlib inline
import matplotlib.pyplot as plt

In [17]:
np.random.seed(100)

# Image mask

In [18]:
## Open shape file
dataset = ogr.Open('Mask_river.shp')

layer = dataset.GetLayerByIndex(0)

## Open the image
raster_ds = gdal.Open('South_coast.tif', gdal.GA_ReadOnly)

## Fetch number of rows and columns & projection and extent
ncol = raster_ds.RasterXSize
nrow = raster_ds.RasterYSize
proj = raster_ds.GetProjectionRef()
ext = raster_ds.GetGeoTransform()

raster_ds = None

## Create the raster mask
driver = gdal.GetDriverByName('GTiff')

out_raster_ds = driver.Create('river_mask.gtif', ncol, nrow, 1, gdal.GDT_Byte)


## Set the set projection and extents
out_raster_ds.SetProjection(proj)
out_raster_ds.SetGeoTransform(ext)

## Fill our output band with the 0 
nodata = out_raster_ds.GetRasterBand(1)
nodata.Fill(0)

gdal.RasterizeLayer(out_raster_ds,[1],layer,None, None, [0], ['ALL_TOUCHED=TRUE', 'ATTRIBUTE=ID'])

## Close dataset
out_raster_ds = None

AttributeError: ignored

In [None]:
# Tell GDAL to throw Python exceptions, and register all drivers
gdal.UseExceptions()
gdal.AllRegister()

## Let's open the created mask
mask_ds = gdal.Open(directory + 'river_mask.gtif', gdal.GA_ReadOnly)

mask = mask_ds.GetRasterBand(1).ReadAsArray().astype(np.uint8)

## show the data
plt.imshow(mask)
plt.show()

In [None]:
## How does the river look?
img_ds = gdal.Open(directory + 'South_coast.tif', gdal.GA_ReadOnly)
img = img_ds.GetRasterBand(1).ReadAsArray().astype(np.uint8)

mask_ds = gdal.Open(directory + 'river_mask.gtif', gdal.GA_ReadOnly)
mask = mask_ds.GetRasterBand(1).ReadAsArray().astype(np.uint8)

## multiply image and mask to combine them
new_data = mask * img

## show the data
plt.imshow(new_data)
plt.show()

# Clip image

What if we need not only get the object but also crop the image to the object boundary?

In [None]:
!pip install rasterio
!pip install geopandas

In [None]:
import rasterio
import rasterio.mask as msk
from rasterio import plot
import fiona

In [None]:
def cliptoshp(Rast, shape_file):
    """
    This function will clip a given raster to the shape file
    
    Parameters:
    Rast is the input raster that you want to clip
    shape_file store the shape for clipping
    
    Returns the clipped raster
    """

    # open up the shape file
    with fiona.open(shape_file,'r') as shapefile:
        shapes = [feature["geometry"] for feature in shapefile]
    # open up our new raster and crop it to the shape file
    with rasterio.open(Rast) as src:
        out_image, out_transform = msk.mask(src, shapes, crop=True)
        out_meta = src.meta
    out_meta.update({"driver": "GTiff",
                 "height": out_image.shape[1],
                 "width": out_image.shape[2],
                 "transform": out_transform},
                  nodata=float('nan'))
    # save our clipped raster as clipped
    clippedfile = str(Rast[:-4] + '_clipped.tif')
    with rasterio.open(clippedfile, "w", **out_meta) as dest:
        dest.write(out_image)


In [None]:
raster = directory + 'South_coast.tif'
shape_file = directory + 'Mask_river.shp'

## clip the image
cliptoshp(raster, shape_file)

## show the result
clipped = rasterio.open(directory + 'South_coast_clipped.tif')
plot.show(clipped)