# 4-4A - Cloud and cloud-shadow masks pipeline for CBERS4A with four spectral bands
Python notebook with the proposed pipeline for generating cloud and cloud shadow masks with the bands Blue, Green, Red and Nir. <br>
Developed for Geospatial Data Science course, INPE (Instituto Nacional de Pesquisas Espaciais) <br>
Author: JÃºlia Ascencio Cansado

---

## INTRODUCTION
The notebook works by creating a local database, storing the informations / metadata needed to fetch the important bits!
If you process multiple scenes, I would recommend filtering the database afterwards - that way you can get the images with the least amount of cloud / cloud shadow coverage, instead of all masks.

---

### 1.DEFINE PARAMETERS
This notebook processes images from CB4A-WPM-L4-DN-1 collection using BDC's (INPE) STAC service. <br>
You can change these parameters, if you wish, but other modifications may be necessary.

In [16]:
# CHANGE FREELY
minx, miny, maxx, maxy = -53.248562,-19.498364,-45.907155,-12.394973
bbox = f"{minx}, {miny}, {maxx}, {maxy}"
datetime = '2025-01-01/2025-01-31'

# THINK BEFORE YOU CHANGE
stac = "https://data.inpe.br/bdc/stac/v1/"
collection_id = "CB4A-WPM-L4-DN-1"

---

### 2. AUXILIARY FUNCTIONS
These functions are called inside the main function, but you can edit it! <br>
Remove and add indexes and operations as desired.

In [2]:
# Load libraries
import os
import io
from io import BytesIO
import ast
import zlib
import json
import requests
import pystac_client
import sqlite3
import rasterio 
import warnings
import itertools
import numpy as np
import xml.etree.ElementTree as ET
from rasterio.transform import Affine
from concurrent.futures import ProcessPoolExecutor

In [3]:
# PRE-PROCESSING FUNCTIONS 
def get_coefficient(assets):
    NS = '{http://www.gisplan.com.br/xmlsat}'
    response = requests.get(assets['BAND1_xml'].href)
    response.raise_for_status()
    root = ET.fromstring(response.content)

    image_element = root.find(f"{NS}image")
    if image_element is not None:
        calibration_element = image_element.find(f"{NS}absoluteCalibrationCoefficient")
        if calibration_element is not None:
            calibration_data = {}
            for band_data in calibration_element:
                clean_tag = band_data.tag.replace(NS, '')
                
                band_id = band_data.attrib.get('name', 'N/A')
                band_value = band_data.text.strip() if band_data.text else 'N/A'
                
                key = f"{clean_tag} {band_id}".strip() 
                
                calibration_data[key] = float(band_value)
        else:
            calibration_data = {'band 0': 0.184471, 'band 1': 0.29107, 'band 2': 0.297832, 'band 3': 0.232504, 'band 4': 0.178993}

    return calibration_data

def normalize(array):
    """Normalizes numpy arrays into scale 0.0 - 1.0"""
    array_min, array_max = array.min(), array.max()
    return ((array - array_min)/(array_max - array_min))
# --------------------------------------------------------------------------------------

# STATISTICS AND INDEXES FUNCTIONS
def estatisticas(array):
    dados_flat = array.flatten()
    dados_validos = dados_flat[dados_flat != 0]
    dados_validos = dados_validos[~np.isnan(dados_validos)]
    media = np.mean(dados_validos)
    mediana = np.median(dados_validos)
    minimo = np.min(dados_validos)
    maximo = np.max(dados_validos)
    desvio_padrao = np.std(dados_validos)
    
    q25 = np.percentile(dados_validos, 25)
    q75 = np.percentile(dados_validos, 75)
    
    stats = {'media': media,
             'mediana': mediana,
             'minimo': minimo,
             'maximo': maximo,
             'std': desvio_padrao,
             'q25': q25,
             'q75': q75}
    return stats

def get_NDVI(band3, band4, footprint):
    ndvi = (((band4 - band3)/(band4 + band3)) * (footprint))
    ndvi_stats = estatisticas(ndvi)
    ndvi_binary = (((ndvi_stats['minimo'] + ndvi_stats['std']) < ndvi) & (ndvi < (ndvi_stats['mediana'] - ndvi_stats['std']))) * footprint
    return ndvi_binary, ndvi_stats, ndvi

def get_WI(band1, band2, band3, footprint):
    m = (0.25 * band1) + (0.375 * band2) + (0.375 * band3)
    wi = abs((band1 - m) / m) + abs((band2 - m) / m) + abs((band3 - m) / m)
    wi_stats = estatisticas(wi)
    wi_binary = (wi < wi_stats['q25']) * footprint
    return wi_binary, wi_stats

def get_HOT(band1, band3, footprint):
    hot = band1 - (0.45 * band3) - 0.08
    hot_stats = estatisticas(hot)
    hot_binary = (hot > hot_stats['q75']) * footprint
    return hot_binary, hot_stats

def get_CI(band1, band2, band3, band4, footprint):
    # cloud index
    ci = ((3 * band4) / (band1 + band2 + band3)) * footprint
    ci_stats = estatisticas(ci)
    ci_binary = ci < ci_stats['q25'] 
    return ci_binary, ci_stats

def get_NSCD(band1, band3, band4, footprint):
    nscd = band1 - (0.45 * band3) - (0.16 * band4) * footprint
    nscd_stats = estatisticas(nscd)
    nscd_binary = nscd > nscd_stats['q75'] 
    return nscd_binary, nscd_stats

def get_D(band2, band4, footprint):
    band2_stats = estatisticas(band2)
    band4_stats = estatisticas(band4)
    D_binary = ((band2 < band2_stats['q25'] + band2_stats['std']) & (band4 < band4_stats['minimo'] + band4_stats['std'])) * footprint
    return D_binary, band2_stats, band4_stats

def get_W(ndvi, band4, footprint):
    ndvi_stats = estatisticas(ndvi)
    NDVI_clean =  ndvi_stats['minimo'] + (1.5*ndvi_stats['std'])
    band4_clean = ndvi_stats['q75'] + ndvi_stats['std'] 
    NDVI_turbid =   ndvi_stats['minimo'] + (2 * ndvi_stats['std'])
    band4_turbid = ndvi_stats['q25'] + ndvi_stats['std'] 
    w_clean = (ndvi < NDVI_clean) & (band4 < band4_clean)
    w_turbid = (ndvi < NDVI_turbid) & (band4 < band4_turbid)
    W_binary = (w_clean | w_turbid) * footprint
    return W_binary
#--------------------------------------------------------------------------------------

# METADATA RELATED FUNCTIONS
def compress_array_to_blob(arr):
    buffer = io.BytesIO()
    np.save(buffer, arr)
    uncompressed_blob = buffer.getvalue()

    compressed_blob = zlib.compress(uncompressed_blob)
    return compressed_blob
#--------------------------------------------------------------------------------------

---

### 3. MAIN PROCESSING FUNCTION
Here is the main function, it reads the necessary bands, calculates indexes, image footprint, cloud cover and etc., updating the image's metadata.

In [4]:
def process_item(id):
    warnings.filterwarnings('ignore', category=RuntimeWarning)
    print(f"Processing item: {id}")
    try:
        item = next(service.get_items(id))
        assets = item.assets
    except Exception as e:
        print(f"ERROR: Service access failed for item {id}: {e}")
        return {'id': id, 'corrupt': True, 'error_type': 'ServiceAccessError'}

    bands_to_read = ['BAND1', 'BAND2', 'BAND3', 'BAND4']
    item_bands = {}
    original_crs = None
    original_transform = None
    
    # READ BANDS
    print(f"Opening {id} bands")
    try:
        for name, asset in assets.items():
            if name in bands_to_read:
                with rasterio.open(assets[name].href) as src:
                    item_bands[name] = src.read(1)

                    if original_crs is None:
                        original_crs = src.crs.to_wkt()                
                        original_transform = src.transform.to_gdal()
        
        if len(item_bands) != len(bands_to_read):
            raise IOError(f"Missing one or more required bands: {bands_to_read}")

    except (rasterio.RasterioIOError, IOError, ValueError) as e:
        print(f"FATAL ERROR: Band reading failed for item {id}. Details: {e}")
        return {
            'id': id, 
            'collection_id': item.collection_id, 
            'corrupt': True, 
            'error_type': 'BandReadError',
            'error_detail': str(e)
        }
        
    print(f"Finished opening {id} bands")

    # CALIBRATE / NORMALIZE BANDS
    calibration_data = get_coefficient(assets)

    band1 = normalize(item_bands['BAND1'] * calibration_data['band 1'])
    band2 = normalize(item_bands['BAND2'] * calibration_data['band 2'])
    band3 = normalize(item_bands['BAND3'] * calibration_data['band 3'])
    band4 = normalize(item_bands['BAND4'] * calibration_data['band 4'])
    
    # GET IMAGE BORDERS / PIXELS
    band_mismatch = np.sum(((band1 != 0)^ (band2 != 0)) | ((band2 != 0) ^ (band3 != 0)) | ((band1 != 0) ^ (band3 != 0))) > 100
    border = (band1 == 0) | (band2 == 0) | (band3 == 0) | (band4 == 0) 
    footprint = ~ border
    total_pixels = np.sum(~ border)
    
    # CLOUD MASK
    print(f"Start {id} cloud mask")

    # ndvi
    ndvi_binary, ndvi_stats, ndvi = get_NDVI(band3, band4, footprint)
    #whiteness index
    wi_binary, wi_stats = get_WI(band1, band2, band3, footprint)
    # hot
    hot_binary, hot_stats = get_HOT(band1, band3, footprint)
    # cloud index
    ci_binary, ci_stats = get_CI(band1, band2, band3, band4, footprint)
    # nscd
    nscd_binary, nscd_stats = get_NSCD(band1, band3, band4, footprint)

    cloud =  ndvi_binary & wi_binary & hot_binary & ci_binary & nscd_binary
    cloud_percentage = (np.sum(cloud) / total_pixels) * 100

    # CLOUD SHADOW MASK
    print(f"Start {id} cloud shadow mask")
    # D
    D_binary, band2_stats, band4_stats = get_D(band2, band4, footprint)
    # W
    W_binary = get_W(ndvi, band4, footprint)
    W_percentage = (np.sum(W_binary) / total_pixels) * 100

    cloud_shadow = D_binary & (~ W_binary)
    cloud_shadow_percentage = (np.sum(cloud_shadow) / total_pixels) * 100

    # MERGE MASKS AND SAVE TO BLOB
    print(f"Merging masks and saving to blob")
    merged_mask = np.zeros_like(cloud, dtype=np.uint16)
    merged_mask = np.where(cloud == 1, 200, merged_mask)
    merged_mask = np.where(cloud_shadow == 1, 100, merged_mask)

    raster_blob = compress_array_to_blob(merged_mask)

    dictionary = item.properties
    dictionary['raster_blob'] = raster_blob

    # METADATA DICTIONARY UPDATE
    print(f"Updating {id} dictionary")
    dictionary['id'] = item.id
    dictionary['collection_id'] = item.collection_id
    del dictionary['eo:cloud_cover']
    dictionary['path'] = 'path'
    del dictionary['path']
    dictionary['row'] = 'row'
    del dictionary['row']
    dictionary['band_mismatch'] = band_mismatch.item()
    dictionary['cloud_percentage'] = cloud_percentage.item()
    dictionary['water_percentage'] = W_percentage.item()
    dictionary['cloud_shadow_percentage'] = cloud_shadow_percentage.item()
    transform_string = json.dumps(original_transform)
    dictionary['crs_wkt'] =  original_crs
    dictionary['transform_gdal'] = transform_string
    dictionary['corrupt'] = False

    print(f"Finished processing item: {id}")

    return dictionary

---

### 4. SETTING UP LOCAL DATABASE
Functions to create table / insert data / check if a scene has already been processed.

In [5]:
def create_table(db_name, table_name):
    conn = sqlite3.connect(db_name)
    cursor = conn.cursor()
    cursor.execute(f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            id TEXT PRIMARY KEY,
            collection_id TEXT,
            created TEXT,
            start_datetime TEXT,
            end_datetime TEXT,
            datetime TEXT,
            updated TEXT,
            crs_wkt TEXT,
            transform_gdal TEXT,
            band_mismatch BOOLEAN,
            cloud_percentage FLOAT,
            cloud_shadow_percentage FLOAT,
            water_percentage FLOAT,
            raster_blob BLOB,
            row INT,
            corrupt BOOLEAN
        );
    """)
    conn.commit()
    conn.close()

def insert_image_data(db_name: str, table_name: str, data_dict: dict):
    conn = None
    try:
        conn = sqlite3.connect(db_name)
        cursor = conn.cursor()

        quoted_columns = ', '.join(f'"{key}"' for key in data_dict.keys())
        
        placeholders = ', '.join(['?'] * len(data_dict))
        values = tuple(data_dict.values())

        sql_insert = f"INSERT INTO {table_name} ({quoted_columns}) VALUES ({placeholders})"
  
        cursor.execute(sql_insert, values)
        conn.commit()
        print(f"Successfully inserted record with ID: {data_dict.get('id')}")
    except sqlite3.Error as e:
        print(f"SQLite error occurred: {e}")
    finally:
        if conn:
            conn.close()


def check_db_for_id(db_name: str, table_name: str, image_id: str):
    """
    Check if id exists in db.
    
    Args:
        db_name: The database file name.
        table_name: The table name.
        image_id: The string ID of the record to retrieve.
    """
    conn = None
    try:
        conn = sqlite3.connect(db_name)
        cursor = conn.cursor()

        cursor.execute(f"""
            SELECT collection_id
            FROM {table_name}
            WHERE id = ?
        """, (image_id,)) # <-- Pass the ID as a single-element tuple for safety
        
        result = cursor.fetchone()

        if not result:
            return False

        return True
    except sqlite3.Error as e:
        print(f"Database error during check: {e}")
        return False # Return False on error
    finally:
        if conn:
            conn.close()

Create table if it does not exist:

In [6]:
DB_FILE = 'cd_geo_github.db'
TABLE_NAME = 'cloud_processing_github'

create_table(DB_FILE, TABLE_NAME)

Code snippet to drop table, if you wish to start from scratch (uncomment it):

In [None]:
# connection = sqlite3.connect(DB_FILE)
# connection.execute(f'DROP TABLE {TABLE_NAME}')

<sqlite3.Cursor at 0x7fb9ad853a40>

---

### 5. RUN CODE
Fetch items that match your parameters.

In [17]:
service = pystac_client.Client.open(stac)
item_search = service.search(bbox= bbox,
                             datetime=datetime,
                             collections=[collection_id])
item_search.matched()

8

If you wish to run one by one (or you do not have a powerful computer), uncomment the code snippet below. <br>
This method also has the advantage of inserting the data after processing each item - if it fails along the way, you will still have data in you database!

In [32]:
MAX_ITEMS_TO_PROCESS = 4
count = 0
processed_data_dicts = []

for item in item_search.items():
    item_id = item.id
    count += 1

    if count >= MAX_ITEMS_TO_PROCESS:
        break 
    if check_db_for_id(DB_FILE, TABLE_NAME, item_id):
        print(f"Item {item_id} already processed. Skipping.")
        print('--------------------------------------------------------------------------')
        continue
    else:
        # print(f"Processing new item: {item_id}")
        data_dict = process_item(item_id) 
        insert_image_data(DB_FILE, TABLE_NAME, data_dict)
        print('--------------------------------------------------------------------------')

Processing item: CBERS_4A_WPM_20250129_215_128_L4
Opening CBERS_4A_WPM_20250129_215_128_L4 bands
Finished opening CBERS_4A_WPM_20250129_215_128_L4 bands
Start CBERS_4A_WPM_20250129_215_128_L4 cloud mask
Start CBERS_4A_WPM_20250129_215_128_L4 cloud shadow mask
Merging masks and saving to blob
Updating CBERS_4A_WPM_20250129_215_128_L4 dictionary
Finished processing item: CBERS_4A_WPM_20250129_215_128_L4
Successfully inserted record with ID: CBERS_4A_WPM_20250129_215_128_L4
--------------------------------------------------------------------------
--------------------------------------------------------------------------
Processing item: CBERS_4A_WPM_20250122_204_136_L4
Opening CBERS_4A_WPM_20250122_204_136_L4 bands
Finished opening CBERS_4A_WPM_20250122_204_136_L4 bands
Start CBERS_4A_WPM_20250122_204_136_L4 cloud mask
Start CBERS_4A_WPM_20250122_204_136_L4 cloud shadow mask
Merging masks and saving to blob
Updating CBERS_4A_WPM_20250122_204_136_L4 dictionary
Finished processing item: CB

If you have a better computer, you can try running multiple images in parallel, as the cell below. <br>
This is much faster, but the insertion happens only after everything is processed - if something happens along the way, you might have to reprocess everything :(

In [18]:
MAX_ITEMS_TO_PROCESS = 8 
MAX_WORKERS = 3 

def process_items(item):
    """Same function structure, but executed concurrently by threads."""
    item_id = item.id
    if check_db_for_id(DB_FILE, TABLE_NAME, item_id):
        print(f"Item {item_id} already processed. Skipping.")
        print('--------------------------------------------------------------------------')
        return None
    else:
        data_dict = process_item(item_id) 
        return data_dict
        print('--------------------------------------------------------------------------')

processed_data_dicts = []
with ProcessPoolExecutor(max_workers=MAX_WORKERS) as executor:
    limited_items = itertools.islice(item_search.items(), MAX_ITEMS_TO_PROCESS)
    results = executor.map(process_items,limited_items)
    processed_data_dicts = [d for d in results if d is not None]

print(f"--- {len(processed_data_dicts)} items to process. ---")

for data_dict in processed_data_dicts:
    insert_image_data(DB_FILE, TABLE_NAME, data_dict)

Item CBERS_4A_WPM_20250129_215_128_L4 already processed. Skipping.Item CBERS_4A_WPM_20250122_204_136_L4 already processed. Skipping.Item CBERS_4A_WPM_20250122_204_135_L4 already processed. Skipping.


------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------


Processing item: CBERS_4A_WPM_20250118_211_128_L4Processing item: CBERS_4A_WPM_20250122_204_134_L4Processing item: CBERS_4A_WPM_20250112_206_137_L4


ERROR: Service access failed for item CBERS_4A_WPM_20250122_204_134_L4: ('Connection aborted.', RemoteDisconnected('Remote end closed connection without response'))
Processing item: CBERS_4A_WPM_20250107_207_130_L4
Opening CBERS_4A_WPM_20250118_211_128_L4 bandsOpening CBERS_4A_WPM_20250107_207_130_L4 bands

Opening CBERS_4A_WPM_20250112_206_137_L4 bands
Finished opening CBERS_4A_WPM_20250112_206_137_L4 bands
Start CBERS_

---

### 6. RETRIEVE RASTER FROM DB
After processing the desired images, you can filter the ids you want by the metadata you generated! And then retrieve and transform the masks into .tiff files to apply it in QGIS or other platforms.

In [32]:
MAX_CLOUD_PERCENTAGE = 15 #Filtering by cloud percentage

conn = sqlite3.connect(DB_FILE)
cursor = conn.cursor()

query = f"SELECT * FROM cloud_processing_github WHERE cloud_percentage < ?"

cursor.execute(query, (MAX_CLOUD_PERCENTAGE,))
rows = cursor.fetchall()
id_list = [row[0] for row in rows]

conn.close()

print(len(id_list))

5


Define function to get raster from the DB

In [33]:
def get_raster_by_id(db_name: str, table_name: str, image_id: str):
    """
    Retrieves the raster array and metadata for a specific image_id.
    
    Args:
        db_name: The database file name.
        table_name: The table name.
        image_id: The string ID of the record to retrieve.
    """
    conn = None
    try:
        conn = sqlite3.connect(db_name)
        cursor = conn.cursor()
        clean_id = image_id.strip()
        cursor.execute(f"""
            SELECT raster_blob, crs_wkt, transform_gdal 
            FROM {table_name}
            WHERE id = ?
        """, (clean_id,)) 
        
        result = cursor.fetchone()

        if not result:
            print(f"Error: No data found for ID: {image_id}.")
            return None, None, None

        raster_blob, crs_wkt, transform_gdal_str = result
        retrieved_array = np.load(BytesIO(zlib.decompress(raster_blob)))
        transform_tuple = ast.literal_eval(transform_gdal_str) 
        retrieved_transform = Affine.from_gdal(*transform_tuple) 

        return retrieved_array, crs_wkt, retrieved_transform

    except sqlite3.Error as e:
        print(f"SQLite error occurred: {e}")
        return None, None, None
    except Exception as e:
        print(f"Error during reconstruction: {e}")
        return None, None, None
    finally:
        if conn:
            conn.close()


Save masks into .tiff

In [36]:
for item in id_list:
    target_id = item
    OUTPUT_FILENAME =  os.path.join('./masks_github',f'{target_id}_mask.tif')
    raster_data, crs_wkt, transform = get_raster_by_id(DB_FILE, TABLE_NAME, target_id)

    if raster_data is not None:
        print(f"Successfully retrieved array for ID: {target_id}")
        height, width = raster_data.shape
        
        profile = {
            'driver': 'GTiff',
            'dtype': raster_data.dtype,
            'count': 1,           
            'height': height,
            'width': width,
            'crs': crs_wkt,           
            'transform': transform,  
            'nodata': 0,         
            'compress': 'lzw'    
        }

        # Open the output file for writing
        with rasterio.open(OUTPUT_FILENAME, 'w', **profile) as dst:
            dst.write(raster_data, 1)

        print(f"\nSaved georeferenced raster to {OUTPUT_FILENAME}")
        print("------------------------------------------------------------")

Successfully retrieved array for ID: CBERS_4A_WPM_20250122_204_136_L4

Saved georeferenced raster to ./masks_github/CBERS_4A_WPM_20250122_204_136_L4_mask.tif
------------------------------------------------------------
Successfully retrieved array for ID: CBERS_4A_WPM_20250122_204_135_L4

Saved georeferenced raster to ./masks_github/CBERS_4A_WPM_20250122_204_135_L4_mask.tif
------------------------------------------------------------
Successfully retrieved array for ID: CBERS_4A_WPM_20250118_211_128_L4

Saved georeferenced raster to ./masks_github/CBERS_4A_WPM_20250118_211_128_L4_mask.tif
------------------------------------------------------------
Successfully retrieved array for ID: CBERS_4A_WPM_20250112_206_137_L4

Saved georeferenced raster to ./masks_github/CBERS_4A_WPM_20250112_206_137_L4_mask.tif
------------------------------------------------------------
Successfully retrieved array for ID: CBERS_4A_WPM_20250107_207_130_L4

Saved georeferenced raster to ./masks_github/CBERS_4A

Thank you for reading!