# Open-sourced dataset and model snapshot for precipitation nowcasting, accompanying the paper *Skillful Precipitation Nowcasting using Deep Generative Models of Radar, Ravuri et al. 2021.*

This colab contains:
* Code to read the dataset using [Tensorflow 2](https://www.tensorflow.org/), with documentation of the available splits, variants and fields
* Example plots and animations of the data using [matplotlib](https://matplotlib.org/) and [cartopy](https://scitools.org.uk/cartopy/docs/latest/)
* A [TF-Hub](https://www.tensorflow.org/hub) snapshot of the model from the paper
* Example code to load this model and use it to make predictions.

It has been tested in a public Google colab kernel.

## How to run this notebook

All sections with the exception of 'Making predictions on a row from the full-frame test set (1536x1280)' can be evaluated on a free public Colab kernel. The final section requires more RAM than is available with a free kernel. To evaluate these cells you can either run your own local kernel (with >= 24GB of RAM), or upgrade to Colab Pro.

To launch a local colab kernel, please follow these [instructions](https://research.google.com/colaboratory/local-runtimes.html).

## License and attribution

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

[https://www.apache.org/licenses/LICENSE-2.0](https://www.apache.org/licenses/LICENSE-2.0)

Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

The datasets and the model snapshots associated with this colab are made available for use under the terms of the
[Creative Commons Attribution 4.0 International License](http://creativecommons.org/licenses/by/4.0/).

This colab and the associated model snapshots are Copyright 2021 DeepMind Technologies Limited.

The associated datasets contain public sector information licensed by the [Met Office](https://www.metoffice.gov.uk/) under the
[UK Open Government Licence v3.0](http://www.nationalarchives.gov.uk/doc/open-government-licence/version/3).


## Library dependency installs and imports

The following libraries are required. You can skip these `pip install` cells if your kernel already has them installed.

In [None]:
!pip -q install tensorflow~=2.5.0 numpy>=1.21 matplotlib~=3.2.2 tensorflow_hub~=0.12.0 cartopy~=0.19.0 folium==0.2.1 imgaug===0.2.5

In [None]:
# Workaround for cartopy crashes due to the shapely installed by default in
# google colab kernel (https://github.com/anitagraser/movingpandas/issues/81):
!pip uninstall -y shapely
!pip install shapely --no-binary shapely

## Imports:

In [None]:
import datetime
import os

import cartopy
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import shapely.geometry as sgeom
import tensorflow as tf
import tensorflow_hub

from google.colab import auth

## Dataset location

In [None]:
# This Google Cloud Storage (GCS) bucket is free to access and contains an
# example subset of the full dataset (just the first shard of each
# split/variant):
EXAMPLE_DATASET_BUCKET_PATH = "gs://dm-nowcasting-example-data/datasets/nowcasting_open_source_osgb/nimrod_osgb_1000m_yearly_splits/radar/20200718"

# This bucket is requester-pays and will require authentication. It contains the
# full dataset. We recommend downloading a local copy first and updating
# ROOT_DATASET_DIR below to the local path. This should save on transfer costs
# and speed up training.
FULL_DATASET_BUCKET_PATH = "gs://dm-nowcasting/datasets/nowcasting_open_source_osgb/nimrod_osgb_1000m_yearly_splits/radar/20200718"

# Update this as required:
DATASET_ROOT_DIR = EXAMPLE_DATASET_BUCKET_PATH

Use this to authenticate as required for access to GCS buckets:

In [None]:
auth.authenticate_user()

## Dataset reader code



In [None]:
_FEATURES = {name: tf.io.FixedLenFeature([], dtype)
             for name, dtype in [
               ("radar", tf.string), ("sample_prob", tf.float32),
               ("osgb_extent_top", tf.int64), ("osgb_extent_left", tf.int64),
               ("osgb_extent_bottom", tf.int64), ("osgb_extent_right", tf.int64),
               ("end_time_timestamp", tf.int64),
             ]}

_SHAPE_BY_SPLIT_VARIANT = {
    ("train", "random_crops_256"): (24, 256, 256, 1),
    ("valid", "subsampled_tiles_256_20min_stride"): (24, 256, 256, 1),
    ("test", "full_frame_20min_stride"): (24, 1536, 1280, 1),
    ("test", "subsampled_overlapping_padded_tiles_512_20min_stride"): (24, 512, 512, 1),
}

_MM_PER_HOUR_INCREMENT = 1/32.
_MAX_MM_PER_HOUR = 128.
_INT16_MASK_VALUE = -1


def parse_and_preprocess_row(row, split, variant):
  result = tf.io.parse_example(row, _FEATURES)
  shape = _SHAPE_BY_SPLIT_VARIANT[(split, variant)]
  radar_bytes = result.pop("radar")
  radar_int16 = tf.reshape(tf.io.decode_raw(radar_bytes, tf.int16), shape)
  mask = tf.not_equal(radar_int16, _INT16_MASK_VALUE)
  radar = tf.cast(radar_int16, tf.float32) * _MM_PER_HOUR_INCREMENT
  radar = tf.clip_by_value(
      radar, _INT16_MASK_VALUE * _MM_PER_HOUR_INCREMENT, _MAX_MM_PER_HOUR)
  result["radar_frames"] = radar
  result["radar_mask"] = mask
  return result


def reader(split="train", variant="random_crops_256", shuffle_files=False):
  """Reader for open-source nowcasting datasets.
  
  Args:
    split: Which yearly split of the dataset to use:
      "train": Data from 2016 - 2018, excluding the first day of each month.
      "valid": Data from 2016 - 2018, only the first day of the month.
      "test": Data from 2019.
    variant: Which variant to use. The available variants depend on the split:
      "random_crops_256": Available for the training split. 24x256x256 pixel
        crops, sampled with a bias towards crops containing rainfall. Crops at
        all spatial and temporal offsets were able to be sampled, some crops may
        overlap.
      "subsampled_tiles_256_20min_stride": Available for the validation set.
        Non-spatially-overlapping 24x256x256 pixel crops, subsampled from a
        regular spatial grid with stride 256x256 pixels, and a temporal stride
        of 20mins (4 timesteps at 5 minute resolution). Sampling favours crops
        containing rainfall.
      "subsampled_overlapping_padded_tiles_512_20min_stride": Available for the
        test set. Overlapping 24x512x512 pixel crops, subsampled from a
        regular spatial grid with stride 64x64 pixels, and a temporal stride
        of 20mins (4 timesteps at 5 minute resolution). Subsampling favours
        crops containing rainfall.
        These crops include extra spatial context for a fairer evaluation of
        the PySTEPS baseline, which benefits from this extra context. Our other
        models only use the central 256x256 pixels of these crops.
      "full_frame_20min_stride": Available for the test set. Includes full
        frames at 24x1536x1280 pixels, every 20 minutes with no additional
        subsampling.
    shuffle_files: Whether to shuffle the shard files of the dataset
      non-deterministically before interleaving them. Recommended for the
      training set to improve mixing and read performance (since
      non-deterministic parallel interleave is then enabled).

  Returns:
    A tf.data.Dataset whose rows are dicts with the following keys:

    "radar_frames": Shape TxHxWx1, float32. Radar-based estimates of
      ground-level precipitation, in units of mm/hr. Pixels which are masked
      will take on a value of -1/32 and should be excluded from use as
      evaluation targets. The coordinate reference system used is OSGB36, with
      a spatial resolution of 1000 OSGB36 coordinate units (approximately equal
      to 1km). The temporal resolution is 5 minutes.
    "radar_mask": Shape TxHxWx1, bool. A binary mask which is False
      for pixels that are unobserved / unable to be inferred from radar
      measurements (e.g. due to being too far from a radar site). This mask
      is usually static over time, but occasionally a whole radar site will
      drop in or out resulting in large changes to the mask, and more localised
      changes can happen too. 
    "sample_prob": Scalar float. The probability with which the row was
      sampled from the overall pool available for sampling, as described above
      under 'variants'. We use importance weights proportional to 1/sample_prob
      when computing metrics on the validation and test set, to reduce bias due
      to the subsampling.
    "end_time_timestamp": Scalar int64. A timestamp for the final frame in
      the example, in seconds since the UNIX epoch (1970-01-01 00:00:00 UTC).
    "osgb_extent_left", "osgb_extent_right", "osgb_extent_top",
    "osgb_extent_bottom":
      Scalar int64s. Spatial extent for the crop in the OSGB36 coordinate
      reference system.
  """
  shards_glob = os.path.join(DATASET_ROOT_DIR, split, variant, "*.tfrecord.gz")
  shard_paths = tf.io.gfile.glob(shards_glob)
  shards_dataset = tf.data.Dataset.from_tensor_slices(shard_paths)
  if shuffle_files:
    shards_dataset = shards_dataset.shuffle(buffer_size=len(shard_paths))
  return (
      shards_dataset
      .interleave(lambda x: tf.data.TFRecordDataset(x, compression_type="GZIP"),
                  num_parallel_calls=tf.data.AUTOTUNE,
                  deterministic=not shuffle_files)
      .map(lambda row: parse_and_preprocess_row(row, split, variant),
           num_parallel_calls=tf.data.AUTOTUNE)
      # Do your own subsequent repeat, shuffle, batch, prefetch etc as required.
  )

## Dataset reader documentation

In [None]:
help(reader)

Help on function reader in module __main__:

reader(split='train', variant='random_crops_256', shuffle_files=False)
    Reader for open-source nowcasting datasets.
    
    Args:
      split: Which yearly split of the dataset to use:
        "train": Data from 2016 - 2018, excluding the first day of each month.
        "valid": Data from 2016 - 2018, only the first day of the month.
        "test": Data from 2019.
      variant: Which variant to use. The available variants depend on the split:
        "random_crops_256": Available for the training split. 24x256x256 pixel
          crops, sampled with a bias towards crops containing rainfall. Crops at
          all spatial and temporal offsets were able to be sampled, some crops may
          overlap.
        "subsampled_tiles_256_20min_stride": Available for the validation set.
          Non-spatially-overlapping 24x256x256 pixel crops, subsampled from a
          regular spatial grid with stride 256x256 pixels, and a temporal stride
 

## Reading a row from the training set and inspecting types/shapes/values

In [None]:
row = next(iter(reader(split="train", variant="random_crops_256")))

In [None]:
{k: (v.dtype, v.shape) for k, v in row.items()}

Values for scalar features:

In [None]:
{k: v.numpy() for k, v in row.items() if v.shape.ndims == 0}

{'end_time_timestamp': 1514725200,
 'osgb_extent_bottom': 555000,
 'osgb_extent_left': -9000,
 'osgb_extent_right': 247000,
 'osgb_extent_top': 811000,
 'sample_prob': 9.889281e-06}

Decoding the end_time_timestamp:

In [None]:
datetime.datetime.utcfromtimestamp(row["end_time_timestamp"]).isoformat()

'2017-12-31T13:00:00'

## Visualization helpers

In [None]:
matplotlib.rc('animation', html='jshtml')


def plot_animation(field, figsize=None,
                   vmin=0, vmax=10, cmap="jet", **imshow_args):
  fig = plt.figure(figsize=figsize)
  ax = plt.axes()
  ax.set_axis_off()
  plt.close() # Prevents extra axes being plotted below animation
  img = ax.imshow(field[0, ..., 0], vmin=vmin, vmax=vmax, cmap=cmap, **imshow_args)

  def animate(i):
    img.set_data(field[i, ..., 0])
    return (img,)

  return animation.FuncAnimation(
      fig, animate, frames=field.shape[0], interval=24, blit=False)


class ExtendedOSGB(cartopy.crs.OSGB):
  """MET office radar data uses OSGB36 with an extended bounding box."""

  def __init__(self):
    super().__init__(approx=False)

  @property
  def x_limits(self):
    return (-405000, 1320000)

  @property
  def y_limits(self):
    return (-625000, 1550000)

  @property
  def boundary(self):
    x0, x1 = self.x_limits
    y0, y1 = self.y_limits
    return sgeom.LinearRing([(x0, y0), (x0, y1), (x1, y1), (x1, y0), (x0, y0)])


def plot_rows_on_map(rows, field_name="radar_frames", timestep=0, num_rows=None,
                     cbar_label=None, **imshow_kwargs):
  fig = plt.figure(figsize=(10, 10))
  axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())
  if num_rows is None:
    num_rows = next(iter(rows.values())).shape[0]
  for b in range(num_rows):
    extent = (rows["osgb_extent_left"][b].numpy(),
              rows["osgb_extent_right"][b].numpy(),
              rows["osgb_extent_bottom"][b].numpy(),
              rows["osgb_extent_top"][b].numpy())
    im = axes.imshow(rows[field_name][b, timestep, ..., 0].numpy(),
                extent=extent, **imshow_kwargs)

  axes.set_xlim(*axes.projection.x_limits)
  axes.set_ylim(*axes.projection.y_limits)
  axes.set_facecolor("black")
  axes.gridlines(alpha=0.5)
  axes.coastlines(resolution="50m", color="white")
  if cbar_label:
    cbar = fig.colorbar(im)
    cbar.set_label(cbar_label)
  return fig


def plot_animation_on_map(row):
  fig = plt.figure(figsize=(10, 10))
  axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())
  plt.close() # Prevents extra axes being plotted below animation

  axes.gridlines(alpha=0.5)
  axes.coastlines(resolution="50m", color="white")

  extent = (row["osgb_extent_left"].numpy(),
            row["osgb_extent_right"].numpy(),
            row["osgb_extent_bottom"].numpy(),
            row["osgb_extent_top"].numpy())

  img = axes.imshow(
      row["radar_frames"][0, ..., 0].numpy(),
      extent=extent, vmin=0, vmax=15., cmap="jet")

  cbar = fig.colorbar(img)
  cbar.set_label("Precipitation, mm/hr")

  def animate(i):
    return img.set_data(row["radar_frames"][i, ..., 0].numpy()),

  return animation.FuncAnimation(
      fig, animate, frames=row["radar_frames"].shape[0],
      interval=24, blit=False)


def plot_mask_on_map(row):
  fig = plt.figure(figsize=(10, 10))
  axes = fig.add_subplot(1, 1, 1, projection=ExtendedOSGB())
  axes.gridlines(alpha=0.5)
  axes.coastlines(resolution="50m", color="black")

  extent = (row["osgb_extent_left"].numpy(),
            row["osgb_extent_right"].numpy(),
            row["osgb_extent_bottom"].numpy(),
            row["osgb_extent_top"].numpy())

  img = axes.imshow(
      row["radar_mask"][0, ..., 0].numpy(),
      extent=extent, vmin=0, vmax=1, cmap="viridis")

## Visualizing rows

Animation of a single row from the random_crops_256 training set (sequence of 24 frames at 256x256)

In [None]:
plot_animation(row["radar_frames"].numpy())

And its mask. This may not always be interesting, sometimes it will be all ones. I only plot the first frame as this is usually static over time.

In [None]:
plt.imshow(row["radar_mask"][0, ..., 0].numpy(), vmin=0, vmax=1);

Plotting an animation of a row from the full-frame test set

In [None]:
dataset = reader(split="test", variant="full_frame_20min_stride")
full_frame_test_set_row = next(iter(dataset))

In [None]:
plot_animation_on_map(full_frame_test_set_row)

And just its mask:

In [None]:
plot_mask_on_map(full_frame_test_set_row)

Plotting a few different crops from the training set on the same map, using their OSGB extents. Note these will have been sampled at different timestamps so won't be consistent with each other. 

In [None]:
BATCH_SIZE = 60
dataset = reader(split="train", variant="random_crops_256")
rows = next(iter(dataset.batch(BATCH_SIZE)))

In [None]:
plot_rows_on_map(rows, field_name="radar_frames", num_rows=10, vmin=0, vmax=15.,
                 cmap="jet", cbar_label="Precipitation, mm/hr");

And plotting their masks, which will be more consistent with each other since they change less frequently.

In [None]:
plot_rows_on_map(rows, field_name="radar_mask", vmin=0, vmax=1, alpha=0.5, cmap="spring");

## Making predictions using model loaded from TF-Hub snapshots

Location of snapshots:

In [None]:
TFHUB_BASE_PATH = "gs://dm-nowcasting-example-data/tfhub_snapshots"

### Helper code for loading snapshots and making predictions with them

In [None]:
def load_module(input_height, input_width):
  """Load a TF-Hub snapshot of the 'Generative Method' model."""
  hub_module = tensorflow_hub.load(
      os.path.join(TFHUB_BASE_PATH, f"{input_height}x{input_width}"))
  # Note this has loaded a legacy TF1 model for running under TF2 eager mode.
  # This means we need to access the module via the "signatures" attribute. See
  # https://github.com/tensorflow/hub/blob/master/docs/migration_tf2.md#using-lower-level-apis
  # for more information.
  return hub_module.signatures['default']


def predict(module, input_frames, num_samples=1,
            include_input_frames_in_result=False):
  """Make predictions from a TF-Hub snapshot of the 'Generative Method' model.

  Args:
    module: One of the raw TF-Hub modules returned by load_module above.
    input_frames: Shape (T_in,H,W,C), where T_in = 4. Input frames to condition
      the predictions on.
    num_samples: The number of different samples to draw.
    include_input_frames_in_result: If True, will return a total of 22 frames
      along the time axis, the 4 input frames followed by 18 predicted frames.
      Otherwise will only return the 18 predicted frames.

  Returns:
    A tensor of shape (num_samples,T_out,H,W,C), where T_out is either 18 or 22
    as described above.
  """
  input_frames = tf.math.maximum(input_frames, 0.)
  # Add a batch dimension and tile along it to create a copy of the input for
  # each sample:
  input_frames = tf.expand_dims(input_frames, 0)
  input_frames = tf.tile(input_frames, multiples=[num_samples, 1, 1, 1, 1])

  # Sample the latent vector z for each sample:
  _, input_signature = module.structured_input_signature
  z_size = input_signature['z'].shape[1]
  z_samples = tf.random.normal(shape=(num_samples, z_size))

  inputs = {
      "z": z_samples,
      "labels$onehot" : tf.ones(shape=(num_samples, 1)),
      "labels$cond_frames" : input_frames
  }
  samples = module(**inputs)['default']
  if not include_input_frames_in_result:
    # The module returns the input frames alongside its sampled predictions, we
    # slice out just the predictions:
    samples = samples[:, NUM_INPUT_FRAMES:, ...]

  # Take positive values of rainfall only.
  samples = tf.math.maximum(samples, 0.)
  return samples


# Fixed values supported by the snapshotted model.
NUM_INPUT_FRAMES = 4
NUM_TARGET_FRAMES = 18


def extract_input_and_target_frames(radar_frames):
  """Extract input and target frames from a dataset row's radar_frames."""
  # We align our targets to the end of the window, and inputs precede targets.
  input_frames = radar_frames[-NUM_TARGET_FRAMES-NUM_INPUT_FRAMES : -NUM_TARGET_FRAMES]
  target_frames = radar_frames[-NUM_TARGET_FRAMES : ]
  return input_frames, target_frames


def horizontally_concatenate_batch(samples):
  n, t, h, w, c = samples.shape
  # N,T,H,W,C => T,H,N,W,C => T,H,N*W,C
  return tf.reshape(tf.transpose(samples, [1, 2, 0, 3, 4]), [t, h, n*w, c])

### Making predictions for a row from the validation set (256x256 crops)

In [None]:
module = load_module(256, 256)
row = next(iter(reader(split="valid", variant="subsampled_tiles_256_20min_stride")))

In [None]:
num_samples = 5
input_frames, target_frames = extract_input_and_target_frames(row["radar_frames"])
samples = predict(module, input_frames,
                  num_samples=num_samples, include_input_frames_in_result=True)

We will plot an animation of 5 different samples, including the input frames first (so all 5 will start the same). You can see they end up in different places.

In [None]:
plot_animation(horizontally_concatenate_batch(samples), figsize=(4*num_samples, 4))

### Making predictions on a row from the full-frame test set (1536x1280)

Warning: this will require more RAM than is available in a free public colab kernel, even if you reduce num_samples to 1.

In [None]:
# This is the same model with same parameters as above; we have had to export
# separate copies of the graph for each input size as the input size is
# unfortunately hardcoded into the graph as static shapes.
module = load_module(1536, 1280)

full_frame_test_set_row = next(iter(
    reader(split="test", variant="full_frame_20min_stride")))

In [None]:
num_samples = 2
input_frames, target_frames = extract_input_and_target_frames(
    full_frame_test_set_row["radar_frames"])
samples = predict(module, input_frames,
                  num_samples=num_samples, include_input_frames_in_result=True)

Plotting two different predicted samples following on from the input frames. The first sample:

In [None]:
row_with_predictions = full_frame_test_set_row.copy()
row_with_predictions["radar_frames"] = samples[0]
plot_animation_on_map(row_with_predictions)

And the second sample:

In [None]:
row_with_predictions["radar_frames"] = samples[1]
plot_animation_on_map(row_with_predictions)

The ground truth, for comparison, was plotted earlier as an example row from the test set.