In [32]:
import importlib
import os
import numpy as np
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import matplotlib.pyplot as plt
from matplotlib.ticker import ScalarFormatter
import pandas as pd
from tqdm import tqdm
tqdm.pandas()
import tensorflow as tf
import tensorflow_probability as tfp # TFP is a necessary import for the output distribution layer of the pre-trained BNNs
# Imports specific to this repository:
import performance_metrics as perf_met
import uncertainty_calibration_metrics as unc_met
from data_pre_processing import owt_flagging
from data_pre_processing import normalise_input
from model_functions import NLL
from model_functions import estimate_chla
from model_functions import calculate_uncertainty
import xarray as xr
from typing import Union
import math
from math import pi
import pandas as pd
import numpy as np
import pickle
import warnings
import sensor_meta_info as meta
import xarray as xr


In [61]:
#####
# 2. Load model and data
#####
# If you run into loading problems: make sure the correct Python version is used. 
# Python version that corresponds to the .h5 Keras models: 3.8.15

# Use os.path.join() to combine the different components of the path into a single string
# using forward slashes as the separator/
# takes your working directory, assuming the same folder structure as in the repository.
cwd = os.getcwd()
cwd_system_wide = os.path.join(*cwd.split("/"))
print(cwd_system_wide)

# define sensor to apply the BNNs to (case sensitive):
sensor_name = 'OLCI_all' # any of: 'OLCI_all', 'OLCI_polymer, 'MSI_s2a', 'MSI_s2b'.

# 'OLCI_all' bands, see sensor_meta_info.py:
# '413', '443', '490','510', '560', '620', '665', '673', '681', '708', '753','778'

# Load the model

bnn_sensor_model = tf.keras.models.load_model(cwd_system_wide+'/bnns/BNN_'+sensor_name+'.h5', custom_objects = {'NLL': NLL})
print(sensor_name + ' BNN model loaded.')

c:\github_repos\eawag\BNN_2022_satellite
OLCI_all BNN model loaded.


In [62]:
# Load example IN SITU data:
try:
    df_input = pd.read_csv(cwd_system_wide+'/data/example_data.csv')
    print('Example data loaded.')
except:
    print('Error: Failed to load the example data csv file!')


Example data loaded.


In [63]:
# Load optical water types of Spyrakos et al. (2018):
try:
    owts_iw = pd.read_csv(cwd_system_wide+'/data/spyrakos_owts_inland_waters_standardised.csv')
    print('OWTs inland water loaded.')
except:
    print('Error: Failed to load the OWT inland water dataset!')

####
# 2. OWT flagging - generates OWT flag (0 = inside application scope, 1 = outside application scope)
####

import data_pre_processing
import sensor_meta_info
importlib.reload(data_pre_processing)
importlib.reload(sensor_meta_info)

OWTs inland water loaded.


<module 'sensor_meta_info' from 'c:\\github_repos\\eawag\\BNN_2022_satellite\\sensor_meta_info.py'>

In [7]:
sensor_name = 'OLCI_all'
input_data_owts = data_pre_processing.owt_flagging(owt_rrs = owts_iw, input_dataset=df_input, sensor=sensor_name, data_type='in_situ')

OWT flagging complete. New OWT columns "owt_class" and "owt_flag" added to input dataframe.


In [74]:
'''
Sensor band configurations. Currently the BNNs are available for OLCI, MSI S2A/B and at some point will be Landsat-8 OLI.

'''

sensor_bands = {
'OLCI_all': ['413', '443', '490','510', '560', '620', '665', '673', '681', '708', '753','778'],
'OLCI_c2rcc': ['rhow_2', 'rhow_3', 'rhow_4', 'rhow_5','rhow_6','rhow_7','rhow_8', 'rhow_9', 'rhow_10', 'rhow_11', 'rhow_12', 'rhow_16'],
'OLCI_c2rcc_rrs': ['Rrs_2', 'Rrs_3', 'Rrs_4', 'Rrs_5','Rrs_6','Rrs_7','Rrs_8', 'Rrs_9', 'Rrs_10', 'Rrs_11', 'Rrs_12','Rrs_16'],
'OLCI_polymer_insitu': ['413', '443', '490','510', '560', '620', '665', '681', '708', '753','778'], # does not include 673 nm
'OLCI_polymer_rw': ['Rw413', 'Rw443', 'Rw490','Rw510', 'Rw560', 'Rw620', 'Rw665', 'Rw681', 'Rw708', 'Rw753','Rw778'], # does not include 673 nm
'MSI_s2a': ['443', '492', '560','665', '704', '740', '783'],
'MSI_s2b': ['443', '492', '560','665', '704', '739', '780'],
'L8_oli': ['443','482','561','655']
}


scaler_bands = {
'OLCI_all': ['413_res_s3b','443_res_s3b','490_res_s3b','510_res_s3b','560_res_s3b','620_res_s3b', '665_res_s3b','673_res_s3b','681_res_s3b','708_res_s3b','753_res_s3b','778_res_s3b'],
'OLCI_polymer':['413_res_s3b','443_res_s3b','490_res_s3b','510_res_s3b','560_res_s3b','620_res_s3b', '665_res_s3b','681_res_s3b','708_res_s3b','753_res_s3b','778_res_s3b'], # does not include 673 nm
'MSI_s2a':['443_res_s2a', '492_res_s2a', '560_res_s2a','665_res_s2a', '704_res_s2a','740_res_s2a', '783_res_s2a'],
'MSI_s2b':['443_res_s2b', '492_res_s2b','560_res_s2b','665_res_s2b', '704_res_s2b', '739_res_s2b','780_res_s2b']
}

def get_sensor_config_ipynb(sensor_name):
    """
    Get sensor configuration for the functions and the BNNs and modify the dataset if necessary.
    
    Args:
        sensor_name (str): A sensor string, one of ['OLCI_all', 'OLCI_c2rcc', 'OLCI_polymer_rw', 'OLCI_polymer_insitu', 'MSI_s2a', 'MSI_s2b'].
        dataset (xr.Dataset or pd.DataFrame): The dataset to modify if necessary.
    
    Returns:
        xr.Dataset or pd.DataFrame: The possibly modified dataset.
    """
    sensor_config = sensor_bands[sensor_name]
        
    return sensor_config

def calculate_rrs_ipynb(dataset, sensor_config):
    """
    Calculate remote sensing reflectance if necessary, e.g. for C2RCC or POLYMER.
    
    Args:
        dataset (xr.Dataset or pd.DataFrame): The dataset to modify.
        sensor_config (list): The configuration of the sensor bands.
        
    Returns:
        xr.Dataset or pd.DataFrame: The modified dataset.
    """
    modified_dataset = dataset.copy()
    for band in sensor_config:
        new_band_name = 'Rrs_' + band.split('_')[1]
        modified_dataset[new_band_name] = dataset[band] / np.pi
        print('Band ' + band + ' was divided by pi and stored as ' + new_band_name + '.')
    
    return modified_dataset

In [126]:
def standardise_rrs_ipynb(rrs_spectrum, data_type: str):
    """
    Standardisation as in Spyrakos et al. (2018) used in OWT flagging. 
    
    Args:
        rrs_spectrum (pandas.Series or xarray.DataArray): A one-dimensional pandas Series or a multi-dimensional xarray DataArray containing the Rrs spectrum to standardise.
        data_type (str): A string specifying the type of rrs_spectrum. Either 'in_situ' or 'satellite'. 
    
    Returns:
        pandas.Series or xarray.DataArray: The standardised Rrs spectrum.
    
    Raises:
        ValueError: If data_type is not one of 'in_situ' or 'satellite'.
    """

    if data_type not in ['in_situ', 'satellite']:
        raise ValueError("Invalid data_type. Expected one of: 'in_situ', 'satellite'")
    
    if np.isnan(rrs_spectrum).any():
        return rrs_spectrum

    if data_type == 'in_situ':
        # Calculate the area of the Rrs spectrum
        rrs_area = sum(rrs_spectrum)
        
        # Standardise the Rrs spectrum by dividing each element by the Rrs area
        standardised_rrs = rrs_spectrum / rrs_area
        
    elif data_type == 'satellite':
        # Calculate the area of the Rrs spectrum along the 'pixel' dimension, which is a stacked ('lat', 'lon') dimension 
        rrs_area = rrs_spectrum.sum(dim=('pixel'))
        
        # Standardise the Rrs spectrum by dividing each element by the Rrs area
        standardised_rrs = rrs_spectrum / rrs_area
    
    return standardised_rrs


In [78]:
def cal_sa_ipynb(owt_rrs: np.ndarray, rrs_spectrum, data_type: str) -> float:
    """
    Calculate the spectral angle (SA) between an Rrs spectrum and an OWT.
    
    Args:
        owt_rrs (np.ndarray): A one-dimensional array containing the Rrs spectra of the OWTs.
        rrs_spectrum (pandas.Series or xarray.DataArray): A one-dimensional pandas Series or a two-dimensional xarray DataArray containing the Rrs spectrum to calculate the SA for.
        data_type (str): A string specifying the type of rrs_spectrum. Either 'in_situ' or 'satellite'. 
    
    Returns:
        float or xarray.DataArray: The spectral angle between the given Rrs spectra. If rrs_spectrum is a pandas Series, this is a single float. If rrs_spectrum is an xarray DataArray, this is an xarray DataArray of floats.
    
    Raises:
        ValueError: If data_type is not one of 'in_situ' or 'satellite'.
    """

    if data_type not in ['in_situ', 'satellite']:
        raise ValueError("Invalid data_type. Expected one of: 'in_situ', 'satellite'")

    # Calculate the spectral angle using the given Rrs spectra
    if data_type == 'in_situ':
        alfa = np.arccos(sum(rrs_spectrum * owt_rrs) / (math.sqrt(sum(rrs_spectrum ** 2)) * math.sqrt(sum(owt_rrs ** 2))))
    #elif data_type == 'satellite':
    #    alfa = (rrs_spectrum * owt_rrs).sum(dim=('lat', 'lon')) / (np.sqrt((rrs_spectrum ** 2).sum(dim=('lat', 'lon'))) * np.sqrt((owt_rrs ** 2).sum(dim=('lat', 'lon'))))
    elif data_type == 'satellite':
        alfa = np.arccos(np.dot(rrs_spectrum, owt_rrs) / (np.linalg.norm(rrs_spectrum) * np.linalg.norm(owt_rrs)))
    
    sa_owt = 1 - alfa / pi
    
    return sa_owt

In [None]:
def calc_max_owt_ipynb(rrs_spectrum, owt_array: np.ndarray, data_type: str) -> int:
    """
    Calculate the maximum OWT class for a given Rrs spectrum.
    
    Args:
        rrs_spectrum (pandas.Series or xarray.DataArray): A one-dimensional pandas Series or a two-dimensional xarray DataArray containing the Rrs spectrum to calculate the maximum OWT class for.
        owt_array (np.ndarray): A two-dimensional array containing the OWTs to use for the calculation.
        data_type (str): A string specifying the type of rrs_spectrum. Either 'in_situ' or 'satellite'. 
    
    Returns:
        int or xarray.DataArray: The maximum OWT class for the given spectrum. If rrs_spectrum is a pandas Series, this is a single int. If rrs_spectrum is an xarray DataArray, this is an xarray DataArray of ints.
    
    Raises:
        ValueError: If data_type is not one of 'in_situ' or 'satellite'.
    """

    if data_type not in ['in_situ', 'satellite']:
        raise ValueError("Invalid data_type. Expected one of: 'in_situ', 'satellite'")

    # Calculate the spectral angle (SA) for each OWT in the OWT array using the given Rrs spectrum
    if data_type == 'in_situ':
        all_sa = np.apply_along_axis(cal_sa_ipynb, 1, owt_array, rrs_spectrum=rrs_spectrum, data_type=data_type)
        # Find the maximum SA and corresponding OWT class, + 1 because python starts counting at 0
        owt_class = np.argmax(all_sa) + 1

    elif data_type == 'satellite':
        all_sa = rrs_spectrum.apply(lambda data_array: np.apply_along_axis(cal_sa_ipynb, 1, owt_array, rrs_spectrum=data_array, data_type=data_type), dim=('band'))
        # Find the maximum SA and corresponding OWT class, + 1 because python starts counting at 0
        owt_class = all_sa.argmax(dim=('band')) + 1
    
    return owt_class

In [33]:
def owt_flagging_ipynb(owt_rrs: pd.DataFrame, input_dataset, sensor: str, data_type: str) -> Union[pd.DataFrame, xr.Dataset]:
    """
    Calculates an OWT per observation and the OWT quality flag (inside (0) and outside (1) of the application scope, respectively).
    
    Args:
        owt_rrs (pandas.DataFrame): A data frame containing the mean standardised OWTs of Spyrakos et al. (2018).
        input_dataset (pandas.DataFrame or xr.Dataset): A data frame or xarray dataset to produce the OWTs for.
        sensor (str): A sensor configuration, one of ['OLCI_all', 'OLCI_polymer', 'MSI_s2a', 'MSI_s2b'].
        data_type (str): A string specifying the type of input_dataset. Either 'in_situ' or 'satellite'. 
    
    Returns:
        pandas.DataFrame or xr.Dataset: If 'in_situ' data, returns the input pandas DataFrame with 'owt_class' and 'owt_flag' columns. If 'satellite' data, returns the input xarray Dataset with 'owt_class' and 'owt_flag' variables. The return is not standardised, this is just undertaken for the calculation of the OWTs.
    
    Raises:
        ValueError: If the input dataset contains NaN values in the input sensor columns.
        ValueError: If data_type is not one of 'in_situ' or 'satellite'.
        
    """

    if data_type not in ['in_situ', 'satellite']:
        raise ValueError("Invalid data_type. Expected one of: 'in_situ', 'satellite'")
    
    if data_type == 'in_situ':
        # Execute original function for pandas DataFrame
 
        # Select the sensor configuration for the given sensor
        sensor_bands = get_sensor_config_ipynb(sensor)
        
        # If it does, raise a ValueError - there should be no NaN values in any Rrs band in an in situ data file
        if input_dataset[sensor_bands].isnull().values.any():
            raise ValueError('The input dataset contains NaN values in one or more of the sensor columns specified in sensor_meta_info.py. Please treat/remove the corresponding observation.')

        # Select the specified columns from the owt_rrs and input_dataset data frames
        input_dataset_sel = input_dataset[sensor_bands]

        # Standardise the values in each row of input_dataset_sel
        input_dataset_standardised = input_dataset_sel.apply(lambda row: standardise_rrs_ipynb(row, data_type), axis=1)

        # Calculate the maximum OWT class for each row in the input dataset
        owt_cols = ['wl'] + sensor_bands
        owt_sel = owt_rrs[owt_cols]
        owt_array = owt_sel.iloc[:, 1:].values
        max_owt_classes = np.apply_along_axis(lambda x: calc_max_owt_ipynb(x, owt_array=owt_array, data_type=data_type), 1, input_dataset_standardised.values)

        # These are the valid OWTs of the manuscript the BNNs were designed for:
        bnn_owts = [2, 3, 4, 5, 9]

        # For each maximum OWT membership, determine the corresponding OWT flag
        owt_flag = 1 - np.isin(max_owt_classes, bnn_owts).astype(int)

        # Add the maximum OWT membership and OWT flag as columns to the input dataset
        input_dataset['owt_class'] = max_owt_classes
        input_dataset['owt_flag'] = owt_flag

        print('OWT flagging complete. New OWT columns "owt_class" and "owt_flag" added to input dataframe.')
        return input_dataset
    
    elif data_type == 'satellite':
         # Select the sensor configuration for the given sensor
        sensor_bands = get_sensor_config_ipynb(sensor)
    
        #  Skip those pixels and proceed with the calculations for the remaining pixels.
        #  The structure of the original dataset will be preserved since the NaN pixels are not dropped, but ignored during the computations.

        for band in sensor_bands:
            print(band)
            if np.isnan(input_dataset[band]).any():
                print(f'Skipping band {band} due to NaN values.')
                continue
            
        # Create a list of data arrays for each band in the sensor configuration
        list_of_data_arrays = [input_dataset[band] for band in sensor_bands]
        # Combine the list of data arrays into a new data array
        data_array_combined = xr.concat(list_of_data_arrays, dim='band')
        # Stack 'lat' and 'lon' into a single multi-index
        data_array_combined_stacked = data_array_combined.stack(pixel=('lat', 'lon'))
        # Group by 'pixel' and apply the standardise_rrs function to each group
        input_dataset_standardised_stacked = data_array_combined_stacked.groupby('pixel').apply(standardise_rrs_ipynb, data_type='satellite')
        print('Standardising all valid Rrs spectra.')
        # Unstack the 'pixel' multi-index back to 'lat' and 'lon'
        input_dataset_standardised = input_dataset_standardised_stacked.unstack('pixel')

        # Split the standardised data array back into separate bands and update the original dataset
        for i, band in enumerate(sensor_bands):
            input_dataset[band] = input_dataset_standardised.isel(band=i)

        print('Standardisation complete.')

        # Select OWT columns for calculating the maximum OWT membership
        # The OWT dataframe is already standardised and as such both data sources share the same standardisation

        # OWT band selection is independent of AC, but depends on sensor

        # If 'OLCI' is in 'sensor', change 'sensor' to 'OLCI_all', e.g. if it's OLCI_c2rcc
        if 'OLCI' in sensor:
            sensor = 'OLCI_all'
    
        # Fetch the sensor configuration based on the updated 'sensor'
        sensor_bands = get_sensor_config_ipynb(sensor)
        
        owt_cols = ['wl'] + sensor_bands
        # Select the specified columns from the owt_rrs data frame
        owt_sel = owt_rrs[owt_cols]
        owt_array = owt_sel.iloc[:, 1:].values

        # calculate OWT using the standardised pixels (input_dataset_standardised)
        max_owt_classes = input_dataset_standardised.apply(lambda x: calc_max_owt_ipynb(x, owt_array=owt_array, data_type=data_type), dim=('lat', 'lon'))

        # These are the valid OWTs of the manuscript the BNNs were designed for:
        bnn_owts = [2, 3, 4, 5, 9]

        # For each maximum OWT membership, determine the corresponding OWT flag
        # Using xarray's where method to handle multi-dimensional array
        owt_flag = xr.where(np.isin(max_owt_classes, bnn_owts), 0, 1)

        # Add the maximum OWT membership and OWT flag as variables to the input dataset
        input_dataset['owt_class'] = max_owt_classes
        input_dataset['owt_flag'] = owt_flag

        print('OWT flagging complete. New OWT variables "owt_class" and "owt_flag" added to input dataset.')
        return input_dataset

Code here below, step by step to bugfix

In [129]:
# open dataset using Xarray
df_sat = xr.open_dataset(cwd_system_wide+'/data/subset_0_of_Mosaic_L2C2RCC_NA_S3A_OL_1_EFR_20200806__NT.nc')
# Create a mask where the IDEPIX pixel_classif_flags is equal to -32768 - you can also process the BNN chl for all pixels or select your own pixelset
mask = df_sat['pixel_classif_flags'] == -32768
# Filter the data using the mask
filtered_sat = df_sat.where(mask, drop=True)

# Define original sensor + AC used
sensor_name = 'OLCI_c2rcc'

# Calculate Rrs and generate a modified dataset with new Rrs bands
olci_c2rcc_bands = sensor_meta_info.get_sensor_config(sensor_name)
filtered_sat_rrs = sensor_meta_info.calculate_rrs(filtered_sat, olci_c2rcc_bands)

Band rhow_2 was divided by pi and stored as Rrs_2.
Band rhow_3 was divided by pi and stored as Rrs_3.
Band rhow_4 was divided by pi and stored as Rrs_4.
Band rhow_5 was divided by pi and stored as Rrs_5.
Band rhow_6 was divided by pi and stored as Rrs_6.
Band rhow_7 was divided by pi and stored as Rrs_7.
Band rhow_8 was divided by pi and stored as Rrs_8.
Band rhow_9 was divided by pi and stored as Rrs_9.
Band rhow_10 was divided by pi and stored as Rrs_10.
Band rhow_11 was divided by pi and stored as Rrs_11.
Band rhow_12 was divided by pi and stored as Rrs_12.
Band rhow_16 was divided by pi and stored as Rrs_16.


In [130]:
# Use the Rrs sensor bands
sensor_rrs ='OLCI_c2rcc_rrs'
# Select the sensor configuration for the given sensor
current_sensor_bands = get_sensor_config_ipynb(sensor_rrs)

input_dataset=filtered_sat_rrs

In [132]:
df_sat = xr.open_dataset(cwd_system_wide+'/data/subset_0_of_Mosaic_L2C2RCC_NA_S3A_OL_1_EFR_20200806__NT.nc')
# Create a mask where the IDEPIX pixel_classif_flags is equal to -32768 - you can also process the BNN chl for all pixels or select your own pixelset
mask = df_sat['pixel_classif_flags'] == -32768
# Filter the data using the mask
filtered_sat = df_sat.where(mask, drop=True)


In [138]:
total_pixels = df_sat['pixel_classif_flags'].size
flagged_pixels = (df_sat['pixel_classif_flags'].values == -32768).sum()
percentage_flagged = (flagged_pixels / total_pixels) * 100
print(f"Percentage of pixels with flag -32768: {percentage_flagged}%")


Percentage of pixels with flag -32768: 11.764000659383477%


In [139]:
filtered_pixels = filtered_sat['pixel_classif_flags'].count().values
print(f"Number of pixels in filtered_sat: {filtered_pixels}")


Number of pixels in filtered_sat: 7850


In [136]:
print("Original shape:", df_sat['rhow_10'].shape)
print("Filtered shape:", filtered_sat['rhow_10'].shape)


Original shape: (177, 377)
Filtered shape: (150, 290)


In [68]:
# Create a list of data arrays for each band in the sensor configuration
list_of_data_arrays = [input_dataset[band] for band in current_sensor_bands]
# Combine the list of data arrays into a new data array
data_array_combined = xr.concat(list_of_data_arrays, dim='band')
# Stack 'lat' and 'lon' into a single multi-index
data_array_combined_stacked = data_array_combined.stack(pixel=('lat', 'lon'))
# Group by 'pixel' and apply the standardise_rrs function to each group
input_dataset_standardised_stacked = data_array_combined_stacked.groupby('pixel').apply(standardise_rrs_ipynb, data_type='satellite')
print('Standardising all valid Rrs spectra.')
# Unstack the 'pixel' multi-index back to 'lat' and 'lon'
input_dataset_standardised = input_dataset_standardised_stacked.unstack('pixel')

Standardising all valid Rrs spectra.


In [69]:
input_dataset_standardised.shape

(12, 150, 290)

In [70]:
# it is necessary to update them, because these are later on used in the OWT dataset
''
for i, band in enumerate(current_sensor_bands):
    new_band_name = 'standardised_' + band
    input_dataset[new_band_name] = input_dataset_standardised.isel(band=i)

('Standardisation complete. New standardized bands added.')

'Standardisation complete. New standardized bands added.'

In [71]:
input_dataset

In [72]:
# Load optical water types of Spyrakos et al. (2018):
try:
    owts_rrs = pd.read_csv(cwd_system_wide+'/data/spyrakos_owts_inland_waters_standardised.csv')
    print('OWTs inland water loaded.')
except:
    print('Error: Failed to load the OWT inland water dataset!')

####
# 2. OWT flagging - generates OWT flag (0 = inside application scope, 1 = outside application scope)
####

OWTs inland water loaded.


In [75]:
# If 'OLCI' is in 'sensor', change 'sensor' to 'OLCI_all', e.g. if it's OLCI_c2rcc
if 'OLCI' in sensor_rrs:
    sensor = 'OLCI_all'

# Fetch the sensor configuration based on the updated 'sensor'
current_sensor_bands = get_sensor_config_ipynb(sensor)

owt_cols = ['wl'] + current_sensor_bands
# Select the specified columns from the owt_rrs data frame
owt_sel = owts_rrs[owt_cols]
owt_array = owt_sel.iloc[:, 1:].values

In [80]:
input_dataset

In [92]:
input_dataset_standardised

In [81]:
def calc_max_owt_ufunc(rrs_spectrum, owt_array, data_type):
    if data_type == 'satellite':
        all_sa = np.apply_along_axis(cal_sa_ipynb, 1, owt_array, rrs_spectrum=rrs_spectrum, data_type=data_type)
        owt_class = np.argmax(all_sa) + 1  # +1 because Python starts counting at 0
        return owt_class

result = xr.apply_ufunc(
    calc_max_owt_ufunc,  # The function to apply
    input_dataset_standardised,  # The DataArray to operate on - it has to be input_dataset_standardised, because this one has the 'band' value and contains the 12 Rrs bands
    kwargs={'owt_array': owt_array, 'data_type': 'satellite'},  # Additional keyword arguments to the function
    input_core_dims=[['band']],  # The dimensions along which to apply the function
    vectorize=True,  # If true, will vectorize calc_max_owt_ufunc to work on each 'lat', 'lon' pair
    dask='parallelized',  # Enable Dask for parallel computing, if you're using Dask-backed arrays
    output_dtypes=[int]  # The dtype of the output DataArray
)

In [90]:
# These are the valid OWTs of the manuscript the BNNs were designed for:
bnn_owts = [2, 3, 4, 5, 9]

# For each maximum OWT membership, determine the corresponding OWT flag
# Using xarray's where method to handle multi-dimensional array
owt_flag = xr.where(np.isin(result, bnn_owts), 0, 1)

# Capture dimension names from input_dataset
dims = list(input_dataset.dims.keys())

# Use .data to get the underlying array from 'result' and 'owt_flag'
input_dataset['owt_class'] = (dims, result.data)
input_dataset['owt_flag'] = (dims, owt_flag.data)

print('OWT flagging complete. New OWT variables "owt_class" and "owt_flag" added to input dataset.')



OWT flagging complete. New OWT variables "owt_class" and "owt_flag" added to input dataset.
