In [None]:
def coastal_buffer(tileId, land_mask, res=20, original_crs="EPSG:4326", metric_crs="EPSG:32630"):
    buffer_wkt_fp = os.path.join(DATA_DIR, f'coastline_buffer/Buffered_Coastline_{tileId}.wkt')
    with open(buffer_wkt_fp, 'r') as f:
        buffer_wkt = f.read()

    buffer_poly = from_wkt(buffer_wkt)

    # Get transform and shape from template raster
    reproj = pyproj.Transformer.from_crs(original_crs, metric_crs, always_xy=True).transform
    lat_origin = land_mask['lat'].values[0, 0]
    lon_origin = land_mask['lon'].values[0, 0]
    x_origin, y_origin = reproj(lon_origin, lat_origin)
    transform = from_origin(x_origin, y_origin, res, res)
    out_shape = land_mask.shape

    # Rasterize
    buffer_raster = rasterize(
        [shapely_transform(reproj, geom) for geom in buffer_poly.geoms],
        out_shape=out_shape,
        transform=transform,
        fill=0,
        dtype='uint8'
    )

    # Mask land
    masked_raster = buffer_raster & ~land_mask

    # Convert to DataArray
    buffer_da = xr.DataArray(
        masked_raster,
        coords=land_mask.coords,
        dims=land_mask.dims,
        name='coastal_buffer'
    )

    return buffer_da