Licensed under the Apache License, Version 2.0

# Example code for reprojecting GOES-16 images
This notebook demonstrates how to reproduce the reprojected images in the OpenContrails dataset.

In [None]:
!pip install pyresample
!pip install gcsfs
!pip install xarray

In [None]:
import datetime
import pprint
import sys

import gcsfs
import matplotlib.pyplot as plt
import numpy as np
from osgeo import osr
import pyresample
import pyresample.bilinear
import tensorflow as tf
import xarray

if 'google.colab' in sys.modules:
  from google.colab import auth
  auth.authenticate_user()

In [None]:
# Load a single record from the TFRecords and display the 12um band.

def parse_example(serialized_example: bytes) -> dict[str, tf.Tensor]:
  features = tf.io.parse_single_example(serialized_example, {
      'cloud_top':  tf.io.FixedLenFeature([], tf.string),
      'data_10um':  tf.io.FixedLenFeature([], tf.string),
      'data_11um':  tf.io.FixedLenFeature([], tf.string),
      'data_12um':  tf.io.FixedLenFeature([], tf.string),
      'human_pixel_masks': tf.io.FixedLenFeature([], tf.string),
      'n_times_before': tf.io.FixedLenFeature([], tf.int64),
      'n_times_after': tf.io.FixedLenFeature([], tf.int64),
      # Projection params
      'projection_wkt': tf.io.FixedLenFeature([], tf.string),
      'col_min': tf.io.FixedLenFeature([], tf.float32),
      'row_min': tf.io.FixedLenFeature([], tf.float32),
      'col_size': tf.io.FixedLenFeature([], tf.float32),
      'row_size': tf.io.FixedLenFeature([], tf.float32),
      # Timestamp
      'timestamp': tf.io.FixedLenFeature([], tf.int64),  # approximate timestamp
      'satellite_scan_starts': tf.io.VarLenFeature(tf.int64),  # timestamp from original file
  })
  for key in ['cloud_top', 'data_10um', 'data_11um', 'data_12um']:
    features[key] = tf.io.parse_tensor(features[key], tf.double)
  features['human_pixel_masks'] = tf.io.parse_tensor(features['human_pixel_masks'], tf.int32)
  return features

dataset = tf.data.TFRecordDataset(tf.io.gfile.glob('gs://goes_contrails_dataset/20230419/tfrecords/train.tfrecords-*'))
dataset = dataset.map(parse_example)
features = dataset.take(1).get_single_element()

n_times_before = features['n_times_before']
plt.figure(figsize=(12, 6))
plt.imshow(features['data_12um'][:, :, n_times_before])
plt.show()

Here we load the original NetCDF files that is publicly available from Google Cloud Storage. We convert the raw radiance to brightness temperature, and then compute the `AreaDefinition` of the original GOES-16 full-disk image from the parameters in the NetCDF files. 

See more in the [example  notebook](https://github.com/google-research/google-research/blob/master/contrails/demos/load_goes_data.ipynb) for loading and visualizing GOES images.

In [None]:
satellite_scan_starts = tf.sparse.to_dense(features['satellite_scan_starts']).numpy()
timestamp = satellite_scan_starts[features['n_times_before']]
dt = datetime.datetime.fromtimestamp(timestamp, tz=datetime.timezone.utc)

fs = gcsfs.GCSFileSystem(project='gcp-public-data-goes-16')

# data_12um corresponds to channel 15 (i.e. M6C15)
paths = fs.glob(f'gcp-public-data-goes-16/ABI-L1b-RadF/{dt.year}/{dt.timetuple().tm_yday:03d}/{dt.hour:02d}/OR_ABI-L1b-RadF-M6C15_G16_s{dt.year}{dt.timetuple().tm_yday:03d}{dt.hour:02d}{dt.minute:02d}*')
assert len(paths) == 1, 'There should be exactly one NetCDF file for a band at the timestamp.'

with fs.open(paths[0], 'rb') as f:
  dataset = xarray.open_dataset(f)
  dataset.load()

# Convert the raw radiance to brightness temperature.
radiance = dataset.Rad.data
brightness_temperature = (dataset.planck_fk2.data / np.log((dataset.planck_fk1.data / radiance) + 1) - dataset.planck_bc1.data) / dataset.planck_bc2.data

In [None]:

h0 = dataset.goes_imager_projection.perspective_point_height
goes_area_def = pyresample.geometry.AreaDefinition(
  area_id='all_goes_16',  # Used only for pyresample logging
  proj_id='deprecated',  # Deprecated but required by pyresample
  description='all_goes_16',  # Used only for pyresample logging
  projection={  # proj4 dict
      'proj': 'geos',  # Stands for 'geostationary'
      'units': 'm',
      'h': str(h0),
      'lon_0': str(
          dataset.goes_imager_projection.longitude_of_projection_origin
      ),
      'a': str(dataset.goes_imager_projection.semi_major_axis),
      'b': str(dataset.goes_imager_projection.semi_minor_axis),
      'sweep': dataset.goes_imager_projection.sweep_angle_axis,
  },
  width=dataset['x'].shape[0],
  height=dataset['y'].shape[0],
  area_extent=[
      dataset['x_image_bounds'].data[0] * h0,
      dataset['y_image_bounds'].data[1] * h0,
      dataset['x_image_bounds'].data[1] * h0,
      dataset['y_image_bounds'].data[0] * h0,
  ],
)


Here we construct the `AreaDefinition` for the GOES scene from the projection parameters provided in the TFRecords.

In [None]:
rows, cols = features['data_12um'].shape[:2]
area_extent = [
    features['col_min'],
    features['row_min'] + features['row_size'] * rows,
    features['col_min'] + features['col_size'] * cols,
    features['row_min']
]
target_area_def = pyresample.AreaDefinition(
    area_id='n/a',
    description='n/a',
    proj_id='n/a',
    projection=osr.SpatialReference(wkt=features['projection_wkt'].numpy().decode()).ExportToProj4(),
    width=cols,
    height=rows,
    area_extent=area_extent,
)

Once we have the `AreaDefinition` of the original GOES full disk image and the target scene, we can use bilinear resampling to obtain the image that corresponds to the local scene.

In [None]:
t_params, s_params, input_idxs, idx_ref = pyresample.bilinear.get_bil_info(
      goes_area_def, target_area_def
)
resampled = pyresample.bilinear.get_sample_from_bil_info(
    brightness_temperature.flatten(),
    t_params,
    s_params,
    input_idxs,
    idx_ref,
    output_shape=target_area_def.shape,
)

The resampled image closely reproduces the one provided with the dataset.

In [None]:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(resampled)
plt.subplot(1, 2, 2)
plt.imshow(features['data_12um'][..., features['n_times_before']])