# Uncertainty Calibration of Passive Microwave Brightness Temperatures Predicted by Bayesian Deep Learning Models

### Pedro Ortiz<sup>a</sup>, Eleanor Casas<sup>a,*</sup>, Marko Orescanin<sup>a</sup>, Scott W. Powell<sup>a</sup>, Veljko Petkovic<sup>b</sup>, Micky Hall<sup>a</sup>

<sup>a</sup> Naval Postgraduate School, Monterey, CA, 93943 USA5

<sup>b</sup> ESSIC, CISESS, University of Maryland, College Park, MD 20740 USA


## Journal Article Summary:

This notebook is a companion to the study ["Uncertainty Calibration of Passive Microwave Brightness Temperatures Predicted by Bayesian Deep Learning Models" (Ortiz et al. 2023)]() (note: add DOI link when available), which is published in the journal of [Artificial Intelligence for the Earth Systems (AIES)](https://www.ametsoc.org/index.cfm/ams/publications/journals/artificial-intelligence-for-the-earth-systems/) and has a corresponding poster that was presented at both the [Dec 2022 AGU](https://agu.confex.com/agu/fm22/meetingapp.cgi/Paper/1155040) and [Jan 2023 AMS](https://ams.confex.com/ams/103ANNUAL/meetingapp.cgi/Paper/411304) Annual Meetings by Eleanor Casas. This study sought to use artificial intelligence/machine learning (AI/ML) to produce synthetic, ocean-only full-disk Global Precipitation Mission (GMI) microwave brightness temperatures and uncertainties from GOES-16 Advanced Baseline Imager (ABI) infrared brightness temperatures to "fill in the gaps" resulting from GMI's low-Earth orbit, and it had the following research objectives and findings:

1. **Quantify errors in predicted synthetic passive microwave brightness temperatures using a deterministic model trained on a dataset of limited size**
    1. With a minimal dataset trained on just approximately 10% of the month of Jan. 2020, results indicated that synthetic brightness temperatures at higher GMI frequencies generally have lower mean error than those of lower GMI frequencies, with deterministic models indicating mean absolute error (MAE) as low as 1.72 K for predictions of the highest-frequency GMI channel centered at 183±3 GHz.

2. **Ascertain whether predictive skill is sacrificed compared to a determinstic model when using Bayesian Deep Learning to quantify variance (a metric of uncertainty)**
    1. This study found that model predictive skill differences between [Deterministic](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2D), [Bayesian Monte Carlo (MC) Dropout](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Dropout), [Bayesian Reparameterization](https://www.tensorflow.org/probability/api_docs/python/tfp/layers/Convolution2DReparameterization), and [Bayesian Flipout](https://www.tensorflow.org/probability/api_docs/python/tfp/layers/Convolution2DFlipout), models were negligible, thus indicating that for this regression task, model skill is not sacrificed when increasing complexity from only predicting microwave brightness temperature to predicting both brightness temperature and total variance at each pixel. 

3. **Explore how the choice of Bayesian architecture impacts predictive skill and interpretation by focusing on the calibration of predictive error and uncertainty**
    1. While predictive skill between model architectures in terms of error metrics were negligible, this study finds that the "calibration" between model-predicted variance and error is different between Bayesian architectures. "Well-calibrated uncertainty" is defined as a positive, monotonic relationship between the mean absolute error and percent of predictions retained, which is a very important relationship to have when using predicted uncertainty as a proxy for error in downsteam remote sensing applications. This study finds that the Flipout model architecture has the most robust calibration between error and uncertainty for predictions across all GMI channels. 


<sup>*</sup> Corresponding notebook author address: eleanor.casas@millersville.edu

## Notebook Objectives:

This notebook demonstrates how to:

1. Load and use a pre-trained MC Dropout model architecture,
2. Generate predictions of synthetic 183±3 GHz microwave brightness temperatures and their corresponding uncertainties, 
3. Plot the spatial output  as in the first row of Fig. 7 in the manuscript, and 
4. Plot a sample calibration curve from the output. 

Before beginning, we suggest that you download the provided model checkpoint and data, and then place the paths to your downloads in the lines that contain the phrase "#Input Path Here"


### 1. Package settings and user specifications

In [None]:
#import all packages/colab settings
import sys
import os
import numpy as np

import tensorflow as tf
from tensorflow.keras.models import load_model
import tensorflow_probability as tfp #required for Flipout and Reparameterization models

import matplotlib as mpl
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid
import mpl_toolkits.axes_grid1.axes_size as Size
from mpl_toolkits.axes_grid1.axes_divider import HBoxDivider

import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

from sklearn.metrics import *


In [None]:
#Confirm that settings are adequate
#The tensorflow version should be at least version 2.4
#GPUs are technically not required to reproduce results, but will greatly speed up predictions

print("Tensorflow Version: " + tf.__version__)
print("GPUs available: " + str(len(tf.config.list_physical_devices('GPU'))))
print("CPUs available: " + str(os.cpu_count()))

# grab number of gpus automatically
gpus = len(tf.config.list_physical_devices('GPU'))

In [None]:
#User specifications of prediction hyperparameters

#change batch size in accordance with your computer's memory capacity; 
#larger batches should process a little faster
batch_size=1024

#Number of predictions used to compute variance (see note in section 4); 
#not suggested to go below 30 so that you have a sufficiently large sample size
num_predictions=30

### 2. Importing Data

The provided data is the pre-colocated ABI/GMI data used to produce Fig. 7 in the manuscript. The important variables in each file are:

1. "patchlo": 39x39 patches of ABI near-infrared and infrared brightness temperatures centered on the corresponding GMI pixel
2. "gmi_pixels": GMI pixels (at each GMI frequency) that are collocated with the central pixel of ABI patches
3. Additional metadata that are not used for generating predictions but are used for making plots (lat, lon, other unused fields)

In [None]:
#definitions used for making the dataset

AUTOTUNE = tf.data.experimental.AUTOTUNE

class Data(object):
    def __init__(self, batch_size, epochs, ch_name=""):

    def _parse_batch_predictions(self, record_batch):
            # Create a description of the features
            feature_description = {
                'lon': tf.io.FixedLenFeature([], tf.float32),
                'lat': tf.io.FixedLenFeature([], tf.float32),
                'gmi_pixel': tf.io.FixedLenFeature([], tf.float32),
                'abi_patch': tf.io.FixedLenFeature([], tf.string)
            }

            example = tf.io.parse_single_example(record_batch, feature_description)

            #Note: These values are normally stored within tfrecords so that patches of arbitrary (or forgotten) dimensions can be used, but storage limits were limiting
            depth = 10
            height = 39
            width = 39

            label = example['gmi_pixel']

            features = example['abi_patch']
            features = tf.io.parse_tensor(features, out_type=tf.float32)
            features = tf.reshape(features, shape=[height, width, depth])

            latlo = example['lat']

            lonlo = example['lon']

            return (features, label, latlo, lonlo)


    def make_data(self, pattern, shuffle=True):
            files_ds = tf.data.Dataset.list_files(pattern, shuffle=shuffle)
            ds = tf.data.TFRecordDataset(files_ds, num_parallel_reads=AUTOTUNE)
            ds = ds.map(self._parse_batch_predictions, num_parallel_calls=AUTOTUNE)
            ds = ds.batch(batch_size)

            return ds.prefetch(buffer_size=AUTOTUNE)


In [None]:
#Use the above functions to make the dataset

#Provided data contains sample 39x39 patches from approximately 1440 UTC 01 Feb 2020
#Note that you should not specify "*.tfrecord" here
data_path = '' #Input Path Here

data = Data(batch_size, 512, channel_name)
test_data = data.make_data(os.path.join(data_path, "*.tfrecord"),shuffle=False)
print("Created all datasets")

### 3. Loading the model

In [None]:
#Note that you should point to the unzipped folder "/.../29-15.75/"
model_path = '' #Input Path Here
model = load_model(model_path, custom_objects={"tf": tf})
print("Model loaded", flush=True)

### 4. Generating predictions from the data

#### General note regarding variance predictions:

- In Determinisitic models, the same set of input features will always produce the same GMI brightness temperature prediction value, no matter how many predictions the user requests (e.g. there is zero variance of predictions). 
- In this study's Bayesian probabilistic models (like the example MC Dropout model presented herein), the user specifies a certain number of predictions, and variance is computed over the resulting distribution of predicted brightness temperatures for each pixel. 
- In ongoing and future work, more complex Bayesian models are being created that will directly learn and predict the variance at each pixel, which will allow for uncertainty decomposition in future studies.
    - For those interested, preliminary results of these models have been presented so far at [NeurIPS Climate Change AI 2022](https://www.climatechange.ai/papers/neurips2022/22), [IGARSS 2023 (Ortiz et al.: CS focus)](https://2023.ieeeigarss.org/view_paper.php?PaperNum=1881), and [IGARSS 2023 (Casas et al.: MET/Tropical Cyclone Application focus)](https://2023.ieeeigarss.org/view_paper.php?PaperNum=4460). Stay tuned for additional upcoming AMS/AGU conferences and publications!

In [None]:
# build the results array
results = None
batches = 1
for batch in test_data:
    # record the results for a single batch
    
    #parse the tfrecords for relevant analysis variables
    y_true = batch[1].numpy()
    lat = batch[2].numpy()
    lon = batch[3].numpy()

    batch_results = np.column_stack((lat, lon, y_true))

    # need to make N predictions (set at beginning of notebook)
    for i in range(num_predictions):
        prediction = model(batch[0])

        if isinstance(prediction, tfp.distributions.Distribution):
            batch_results = np.column_stack((batch_results, prediction.mean()))
            batch_results = np.column_stack((batch_results, prediction.variance()))
        else:
            batch_results = np.column_stack((batch_results, prediction))

    # add batch results to existing results
    if results is None:
        results = batch_results
    else:
        results = np.row_stack((results, batch_results))
    print(f"Batch {batches}: ", batch_results.shape, flush=True)
    batches += 1

print(f"All Predictions: ", results.shape)


In [None]:
#Optional Step:
#write results as csv files with headers

#file_path = #Input Path Here
#np.savetxt(file_path, results, delimiter=',', header='lat, lon, label, prediction')
#print(f"Predictions saved to {file_path}", flush=True)

### 5. Creating plots from predictions

In [None]:
#If reading from a .csv file of model output:

#results = np.genfromtxt(file_path, comments='#', delimiter=",", skip_header=0)

In [None]:
#data fields needed for plots

lat = results[:, 0]
lon = results[:, 1]
label = results[:, 2]
mean = np.mean(results[:, 3:], axis=1)
variance = np.var(results[:, 3:], axis=1)
error = mean - label

#### 5a. Example Spatial Plot (Exactly Equivalent to Top Row of Fig. 7)

In [None]:
#general plot settings
marker_size = .1
marker_style = ","
all_fonts_size = 8
plt.rcParams.update({'font.size': all_fonts_size,
                     'axes.titlesize': all_fonts_size,
                     'xtick.labelsize': all_fonts_size,
                     'ytick.labelsize': all_fonts_size,
                     'figure.dpi': 300})

projection = ccrs.PlateCarree()
axes_class = (GeoAxes, dict(map_projection=projection))

#Specifying Longitude labels/graph bounds
x_labels = np.linspace(-105, -15, 7)
x_bounds = (-95, -45)

#Specifying Latitude labels/graph bounds
y_labels = np.linspace(-75, 75, 11)
y_bounds = (0, 62.5)

lon_formatter = LongitudeFormatter(zero_direction_label=True)
lat_formatter = LatitudeFormatter()

#Specifying GMI TB min/max for colorbars
temp_min = 250
temp_max = 280
temp_labels = [i for i in range(temp_min, temp_max + 1, 5)]


In [None]:
#Plotting the 4-panel spatial plots as in Fig. 7 

fig = plt.figure(figsize=(8.5, 4))
grid = AxesGrid(fig, 111,  # as in plt.subplot(111)
                nrows_ncols=(1, 4),
                axes_pad=(.5, .05),
                share_all=False,
                cbar_mode="each",
                cbar_location="right",
                cbar_pad=.05,
                axes_class=axes_class,
                label_mode='')  # note the empty label_mode

#Plotting the observed 183 +/- 3 GHz GMI Brightness Temperatures
cmap = mpl.cm.RdYlBu_r
bounds = [i for i in np.arange(temp_min, temp_max + 1, 2.5)]
norm = mpl.colors.BoundaryNorm(bounds, cmap.N, extend='both')
grid[0].set_title(f'a) GMI\nObservations')
scat = grid[0].scatter(lon, lat, c=label, cmap=cmap, norm=norm, marker=marker_style, s=marker_size, linewidths=0) 
#Note: alteratives to scatter are tricontourf or interpolating predictions to a standard grid and using pcolormesh
grid[0].coastlines()
gl = grid[0].gridlines(crs=ccrs.PlateCarree(), xlocs=x_labels, ylocs=y_labels, color='gray', alpha=0.5)
grid[0].set_rasterized(True)
grid[0].set_xticks(x_labels, crs=ccrs.PlateCarree())
grid[0].tick_params(axis="x", labelrotation=45)
grid[0].xaxis.set_major_formatter(lon_formatter)
grid[0].set_yticks(y_labels, crs=ccrs.PlateCarree())
grid[0].yaxis.set_major_formatter(lat_formatter)
grid.cbar_axes[0].remove()  # remove the first colorbar since it is shared with the next plot
grid[0].set_xlim(x_bounds[0], x_bounds[1])
grid[0].set_ylim(y_bounds[0], y_bounds[1])

#Plotting the model-predicted 183pm3 GHz GMI Brightness Temperatures
grid[1].set_title(f'b) Predictions')
scat = grid[1].scatter(lon, lat, c=mean, cmap=cmap, norm=norm, marker=marker_style, s=marker_size, linewidths=0)
grid[1].coastlines()
gl = grid[1].gridlines(crs=ccrs.PlateCarree(), xlocs=x_labels, ylocs=y_labels, color='gray', alpha=0.5)
grid[1].set_rasterized(True)
grid[1].set_xticks(x_labels, crs=ccrs.PlateCarree())
grid[1].tick_params(axis="x", labelrotation=45)
grid[1].xaxis.set_major_formatter(lon_formatter)
grid.cbar_axes[1].colorbar(scat, format='%d', ticks=temp_labels, label="T$^{mw}_b$ [K]", extend="both")
grid[1].set_xlim(x_bounds[0], x_bounds[1])
grid[1].set_ylim(y_bounds[0], y_bounds[1])

#Plotting the error of predictions (prediction - label)
error_max = 5  # 183 +/- 3 GHz
cmap = mpl.cm.seismic
bounds = [i for i in range(-error_max, error_max + 1)]
norm = mpl.colors.BoundaryNorm(bounds, cmap.N, extend='both')
grid[2].set_title(r'c) Prediction Error')
scat = grid[2].scatter(lon, lat, c=error, cmap=cmap, norm=norm, marker=marker_style, s=marker_size, linewidths=0)
grid[2].coastlines()
gl = grid[2].gridlines(crs=ccrs.PlateCarree(), xlocs=x_labels, ylocs=y_labels, color='gray', alpha=0.5)
grid[2].set_rasterized(True)
grid[2].set_xticks(x_labels, crs=ccrs.PlateCarree())
grid[2].tick_params(axis="x", labelrotation=45)
grid[2].xaxis.set_major_formatter(lon_formatter)
grid.cbar_axes[2].colorbar(scat, format='%d', ticks=bounds,  label="Error [K]", extend="both")
grid[2].set_xlim(x_bounds[0], x_bounds[1])
grid[2].set_ylim(y_bounds[0], y_bounds[1]) 

#Plotting the variance of predictions
variance_max  = 1.5 # 183 +/- 3 GHz
cmap = mpl.cm.plasma
bounds = [i for i in np.arange(0.0, variance_max + 0.1, .125)]
norm = mpl.colors.BoundaryNorm(bounds, cmap.N, extend='max')
grid[3].set_title(r'd) Uncertainty')
scat = grid[3].scatter(lon, lat, c=np.sqrt(variance), cmap=cmap, norm=norm, marker=marker_style, s=marker_size, linewidths=0)
grid[3].coastlines()
gl = grid[3].gridlines(crs=ccrs.PlateCarree(), xlocs=x_labels, ylocs=y_labels, color='gray', alpha=0.5)
grid[3].set_rasterized(True)
grid[3].set_xticks(x_labels, crs=ccrs.PlateCarree())
grid[3].tick_params(axis="x", labelrotation=45)
grid[3].xaxis.set_major_formatter(lon_formatter)
grid.cbar_axes[3].colorbar(scat, format='%.2f', ticks=np.arange(0, variance_max+.01, 0.25),  label="Standard Deviation [K]", extend="max")
grid[3].set_xlim(x_bounds[0], x_bounds[1])
grid[3].set_ylim(y_bounds[0], y_bounds[1])

#Adjust axes to account for the extra space resulting from the shared colorbar between (a) and (b)
ax1 = grid[0]
ax2 = grid[1]
cb = grid.cbar_axes[1]
divider = HBoxDivider(
    fig, 121,
    horizontal=[Size.Fixed(.24), Size.AxesX(ax1), Size.Fixed(.2), Size.AxesX(ax2), Size.Fixed(.05), Size.Fixed(0.08), Size.Fixed(0.07)],
    vertical=[Size.Scaled(1), Size.AxesY(ax1), Size.Scaled(1), Size.AxesY(ax2), Size.Scaled(1), Size.Scaled(1), Size.Scaled(1)]
    , anchor="E"
)
grid[0].set_axes_locator(divider.new_locator(1))
grid[1].set_axes_locator(divider.new_locator(3))
grid.cbar_axes[1].set_axes_locator(divider.new_locator(5))


plt.show()

#### 5b. Example Calibration Curve (Computed over different data than what Fig. 6a shows)

This code computes "calibration" by starting with computing the mean absolute error (MAE) of 100% of predictions, then omitting the top 1% of highest-uncertainty predictions and recomputing the MAE over the remaining 99% of predictions, etc. 

In [None]:
#Computing the calibration curve

scale = 100 #percent
xs = []
ys = []
sd = np.sqrt(variance) #standard deviation
for i in range(scale, 0, -1):
    mask = sd <= np.quantile(sd, i / float(scale))

    mean_mask = mean[mask]
    label_mask = label[mask]

    MAE = mean_absolute_error(label_mask, mean_mask)
    # print(i / float(scale), np.quantile(sd, i / float(scale)), MAE)
    xs += [int(100 * (i / float(scale)))]
    ys += [MAE]


In [None]:
#Plotting the calibration curve

all_fonts_size = 10
plt.figure(figsize=(3, 3), constrained_layout=True)
plt.rcParams.update({'font.size': all_fonts_size,
                     'axes.titlesize': all_fonts_size,
                     'xtick.labelsize': all_fonts_size,
                     'ytick.labelsize': all_fonts_size,
                     'figure.dpi': 300})

#calibration curve
plt.plot(xs,ys, c='m',label=r'183$\pm$3 GHz MC Dropout')

#arrow with standard deviation label
vert_offset = .125
plt.annotate(
    f"SD = {np.quantile(sd, .8):.2f}",
    xy=(80, ys[19]),
    arrowprops=dict(facecolor='black', width=1, headwidth=7.5),
    textcoords='data',
    xytext=(60, ys[19] + vert_offset),
    horizontalalignment='right',
    verticalalignment='center'
)

#vertical reference line at 80% of predictions retained
plt.axvline(80, color="black", linestyle=":")

plt.legend(loc="best")

plt.xlim(1, 100)
plt.xticks(ticks=[1, 20, 40, 60, 80, 100])

plt.yticks(ticks=np.arange(0.5, 2.5, .5))

plt.xlabel('Percent of Predictions')
plt.ylabel('Mean Absolute Error')
plt.title(r'Channel 183$\pm$3 GHz V')

plt.show()

### 6. Conclusions

If the code is working correctly, you should see that higher error predictions are associated with higher predicted uncertainty in two ways. In the spatial plot, you should be able to reproduce Fig. 7 (top row) and see that regions with clouds/precipitation are consistently associated with higher error and uncertainty than regions with clear air. In the calibration curve, you should see that the MAE of 100% of model predictions is highest, and the MAE of the 1% of lowest uncertainty predictions is approximately the lowest. The curve should also be mostly monotonically increasing from left to right, and this calibration curve is an example of a fairly well-calibrated model. Again, note that the corresponding journal article figure (Fig. 6a) is computed over January test data from multiple days, whereas the example curve produced here is computed only over the data contained within the spatial plot.