Copyright 2022 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
#@title License
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# MetNet-2 Model Skeleton

This colab provides the model code for [MetNet-2](https://ai.googleblog.com/2021/11/metnet-2-deep-learning-for-12-hour.html) as well as general code for preprocessing data.

In [None]:
#@title Install Packages
!pip install flax
!pip install ml_collections

In [None]:
#@title Imports

from typing import Optional, Tuple, Callable, Any, Iterable, Union, Dict
Array = Any
PRNGKey = Any
Shape = Iterable[int]
Dtype = Any
ModuleDef = Any

import datetime
import functools

from flax import linen as nn
from flax.linen import initializers
from flax.training import common_utils
import jax
from jax import lax
import jax.numpy as jnp
import numpy as onp
import scipy
import ml_collections

In [None]:
#@title Auxiliary Layers and Functions

def flatten_spatial_dim(image, factor):
  h, w, c = image.shape[-3:]
  rest = image.shape[:-3]
  f = factor
  return image.reshape(rest + (h // f, f, w // f, f, c))


def downsample_nanmean(image, factor):
  return onp.nanmean(flatten_spatial_dim(image, factor), axis=(-4, -2))


def mrms_normalize(a):
  """Normalize according to
     [NaN, inf, -inf, -50, -.5, 0., .2, 1., 2., 10.])
     ->
     [-1, -1, -1, -1, -1, 0, .046, .172, .268, .537]
  """
  a = onp.nan_to_num(onp.where(a < 0, 0, a + 1), posinf=0, neginf=0)
  return onp.tanh(onp.log(a) / 4)

def onehot_range(labels, num_classes, on_value=1.0, off_value=0.0):
  """Onehot but instead of a single 1, multiple 1's in a range is returned.

  Similar to common_utils.onehot but instead of 0...010...0, it returns
  0...01...10...0 where the provided range is inclusive for both beginning and
  end.

  Args:
    labels: ndarray-like, shape=(..., 2)
    num_classes: Number of classes.
    on_value: The value to use in the range.
    off_value: The value to use outside the range.

  Returns:
    ndarray-like, shape=(..., num_classes)
  """
  x0 = (labels[..., 0, None] <= jnp.arange(num_classes)[None])
  x1 = (labels[..., 1, None] >= jnp.arange(num_classes)[None])
  x = x0 & x1
  x = lax.select(x, jnp.full(x.shape, on_value), jnp.full(x.shape, off_value))
  return x.astype(jnp.float32)


DENSE_INIT = initializers.variance_scaling(
    scale=1.0 / 3, mode='fan_out', distribution='uniform')


def cond_func(x, cond_input, name):
  """Condition x on cond_input."""
  if cond_input is None:
    return x
  embedding = nn.Dense(
      features=2 * x.shape[-1],
      use_bias=False,
      kernel_init=DENSE_INIT,
      name=name)(cond_input)
  scale, bias = jnp.split(embedding, 2, axis=-1)
  x += bias
  x *= scale
  return x


class ConvLSTMCell(nn.Module):
  """Convolutional LSTM cell."""
  kernel_size: Tuple[int, ...]
  gate_fn: Callable[[Array], Array] = nn.sigmoid
  activation_fn: Callable[[Array], Array] = nn.activation.tanh
  kernel_init: Callable[[PRNGKey, Shape, Dtype],
                        Array] = initializers.lecun_normal()
  recurrent_kernel_init: Callable[[PRNGKey, Shape, Dtype],
                                  Array] = initializers.orthogonal()
  bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = nn.initializers.zeros
  precomputed_inputs: bool = False
  dtype: Dtype = jnp.float32

  @functools.partial(
      nn.scan,
      variable_broadcast='params',
      in_axes=1,
      out_axes=1,
      split_rngs={'params': False})
  @nn.compact
  def __call__(self, carry, inputs):
    """LSTM cell but with conv projections from the input and hidden state."""
    assert self.kernel_size is not None
    c, h = carry
    hidden_features = h.shape[-1]

    # input and recurrent layers are summed so only one needs a bias.
    dense_h = nn.Conv(
        features=4 * hidden_features,
        use_bias=True,
        kernel_size=self.kernel_size,
        kernel_init=self.recurrent_kernel_init,
        bias_init=self.bias_init,
        padding='SAME',
        dtype=self.dtype,
        name='h_all')

    dense_i = nn.Conv(
        features=4 * hidden_features,
        use_bias=False,
        padding='SAME',
        kernel_init=self.kernel_init,
        kernel_size=self.kernel_size,
        dtype=self.dtype,
        name='i_all')

    res = dense_h(h)
    h_i, h_f, h_g, h_o = jnp.split(res, 4, axis=3)

    if self.precomputed_inputs:
      i_i, i_f, i_g, i_o = jnp.split(inputs, 4, axis=3)
    else:
      res = dense_i(inputs)
      i_i, i_f, i_g, i_o = jnp.split(res, 4, axis=3)

    i = self.gate_fn(i_i + h_i)
    f = self.gate_fn(i_f + h_f)
    g = self.activation_fn(i_g + h_g)
    o = self.gate_fn(i_o + h_o)
    new_c = f * c + i * g
    new_h = o * self.activation_fn(new_c)
    return (new_c, new_h), new_h

  @staticmethod
  def initialize_carry(batch_dims, hidden_size, init_fn=initializers.zeros):
    # use dummy key since default state init fn is just zeros.
    return nn.LSTMCell.initialize_carry(
        jax.random.PRNGKey(0), batch_dims, hidden_size, init_fn=init_fn)


class ResidualBlock(nn.Module):
  """Bottleneck ResNet block."""
  filters: int
  rezero: bool = False
  dtype: Dtype = jnp.float32
  groupnorm: Optional[int] = None
  activation: Callable[[Array], Array] = nn.relu
  kernel_dilation: int = 1
  half_channels: bool = False
  cond_input: Optional[Array] = None
  channel_dropout_rate: Optional[float] = None
  bias_scale: Optional[float] = False
  train: Optional[bool] = None

  @nn.compact
  def __call__(self, x):
    assert not self.channel_dropout_rate
    assert not self.train
    assert not self.bias_scale

    needs_projection = x.shape[-1] != self.filters
    conv = functools.partial(
        nn.Conv,
        use_bias=False,
        dtype=self.dtype,
        kernel_dilation=(self.kernel_dilation, self.kernel_dilation))

    norm_module = functools.partial(
        nn.GroupNorm, num_groups=self.groupnorm, dtype=self.dtype)
    norm = lambda x: norm_module()(x) if self.groupnorm else x

    r = x
    if needs_projection:
      r = conv(self.filters, (1, 1), name='proj_conv')(r)

    if self.half_channels:
      y = conv(self.filters // 2, (3, 3), name='conv1')(x)
    else:
      y = conv(self.filters, (3, 3), name='conv1')(x)
    y = norm(y)
    y = cond_func(y, self.cond_input, name='embed1')
    y = self.activation(y)

    y = conv(self.filters, (3, 3), name='conv2')(y)
    if self.rezero:
      y = y * self.param('alpha', nn.initializers.zeros, (1,), self.dtype)
    y = norm(y)
    y = cond_func(y, self.cond_input, name='embed2')

    # the skip connection is added in such a way to be compatible with
    # trained checkpoints.
    if self.cond_input is None:
      y = self.activation(y + r)
    else:
      y = self.activation(y)
      y += r
    return y


class ResidualStack(nn.Module):
  """Stack of residual modules."""
  num_blocks: int
  filters: int
  groupnorm: Optional[int] = None
  kernel_dilations: Tuple[int, ...] = (1, 2, 4, 8, 16, 32)
  dtype: Dtype = jnp.float32
  # Arguments to residual block constructor.
  half_channels: Optional[bool] = None
  cond_input: Optional[Array] = None
  channel_dropout_rate: Optional[float] = None
  bias_scale: Optional[bool] = False
  train: Optional[bool] = None
  extra_resnet_block_kwargs: Optional[Dict[str, Any]] = None

  @nn.compact
  def __call__(self, x):
    """Applies a stack of residual modules.

    Args:
      x: The input array.

    Returns:
      The result of applying the residual stack and passing it through a relu.
    """
    dense = functools.partial(nn.Dense, features=self.filters, dtype=self.dtype)
    residual_block = functools.partial(
        ResidualBlock,
        filters=self.filters,
        groupnorm=self.groupnorm,
        dtype=self.dtype,
        half_channels=self.half_channels,
        cond_input=self.cond_input,
        channel_dropout_rate=self.channel_dropout_rate,
        bias_scale=self.bias_scale,
        train=self.train,
        **(self.extra_resnet_block_kwargs or {}))
    x_tot = dense(name='skip_init')(x)

    for i in range(self.num_blocks):
      dilation = self.kernel_dilations[i % len(self.kernel_dilations)]
      x = residual_block(name=f'block{i}', kernel_dilation=dilation)(x)
      x_tot += dense(name=f'skip{i}')(x)

    return nn.relu(x_tot)


class DenseStack(nn.Module):
  """A stack of dense layers."""
  features: int = 4096
  output_channels: int = 512
  num_layers: int = 1
  dtype: Dtype = jnp.float32

  @nn.compact
  def __call__(self, inputs):
    out = inputs

    for _ in range(self.num_layers):
      out = nn.Dense(
          features=self.features,
          use_bias=False,
          kernel_init=DENSE_INIT,
          dtype=self.dtype)(out)
      out = nn.relu(out)

    pre_out = out
    out = nn.Dense(
        features=self.output_channels,
        use_bias=False,
        kernel_init=DENSE_INIT,
        dtype=self.dtype)(pre_out)
    return out, pre_out

In [None]:
#@title Main Model

def metnet_encoder(inputs,
                   target_index,
                   model_target_index,
                   hps,
                   is_initializing):
  """MetNet encoder architecture."""
  lstm_channels = hps.lstm_channels
  encoder_channels = hps.encoder_channels
  encoder_num_blocks = hps.encoder_num_blocks
  num_time_classes = len(hps.mrms_target_tds)

  if hps.shift_target_index:
    # For predicting cumulative, we use an alternative one hot encoding which
    # has 30 ones instead of a single one, each representing 2 min. I.e. a full
    # 30 ones is the prediction of a full cumulative hour. By adding 29 to the
    # target index, there will also be 30 ones for predicting 2 min to 58 min.
    # This is especially important for our mixed models so the model can
    # distinquish between predicting rate and cumulative precipitation.
    num_time_classes = num_time_classes + 29
    model_target_index = model_target_index + 29

  # configurations for conditioning target index.
  target_cond = hps.target_cond
  target_features = hps.target_features
  target_n_layers = hps.target_n_layers
  cond_per_layer = hps.cond_per_layer

  groupnorm = hps.groupnorm
  target_size = hps.target_size

  dtype = jnp.bfloat16 if hps.dtype == 'bfloat16' else jnp.float32

  # Remove disabled inputs (constant zeros).
  # Flax (non-Linen) doesn't support None inputs.
  inputs = [input_ for input_ in inputs
            if input_ is not None and input_.ndim != 1]

  batch_size = target_index.shape[0]
  num_steps = max([input_.shape[1] for input_ in inputs])
  input_size = onp.unique([input_.shape[2] for input_ in inputs]).item()

  def target_index_block(model_target_index):
    target_index = onehot_range(
        model_target_index, num_classes=num_time_classes)
    target_index = jnp.reshape(target_index, (-1, 1, 1, 1, num_time_classes))
    return jnp.broadcast_to(
        target_index,
        (batch_size, num_steps, input_size, input_size, num_time_classes))

  def concatenate_block(*inputs):
    padded_inputs = []
    for input_ in inputs:
      pad_width = [(0, 0), (num_steps - input_.shape[1], 0), (0, 0), (0, 0),
                   (0, 0)]
      padded_inputs.append(jnp.pad(input_, pad_width, mode='constant'))
    return jnp.concatenate(padded_inputs, axis=-1)

  assert target_cond in ['concat', 'dense_scale', 'dense_bias']

  # For the paper (see appendix) we tried multiple types of conditioning.
  # `dense_scale` is what ended up being used in the paper.
  # `dense_bias` is "Add, No Mult" in the paper.
  # `concat` is "No Add, No Mult" in the paper.
  if target_cond == 'concat':
    target_index = target_index_block(model_target_index)
    inputs = concatenate_block(*(inputs + [target_index]))
    target_embed = target_index
  elif 'dense' in target_cond:
    inputs = concatenate_block(*inputs)
    target_index = onehot_range(
        model_target_index, num_classes=num_time_classes)
    target_index = jnp.reshape(target_index, (-1, 1, 1, 1, num_time_classes))

    output_channels = inputs.shape[-1]
    if target_cond == 'dense_scale':
      output_channels = 2 * inputs.shape[-1]
    out, target_embed = DenseStack(
        output_channels=output_channels,
        features=target_features,
        num_layers=target_n_layers)(target_index)
    if target_cond == 'dense_scale':
      target_bias, target_scale = jnp.split(out, 2, axis=-1)
      inputs += target_bias
      inputs *= target_scale
    elif target_cond == 'dense_bias':
      inputs += out
    target_embed = jnp.squeeze(target_embed, axis=1)

  # Layers definition
  init_carry = ConvLSTMCell.initialize_carry(
      (batch_size, input_size, input_size), lstm_channels)
  lstm_cell = ConvLSTMCell(
      kernel_size=(3, 3), dtype=dtype, name='conv_lstm0')

  def make_encoder(i):
    return ResidualStack(
        num_blocks=encoder_num_blocks[i],
        filters=encoder_channels[i],
        kernel_dilations=(1, 2, 4, 8, 16, 32, 64, 128),
        half_channels=False,
        groupnorm=groupnorm,
        cond_input=target_embed if cond_per_layer else None,
        name=f'encoder{i}')

  def crop_to_target(x, target_size, downsampling_ratio):
    ds_ts = target_size // downsampling_ratio
    x_start = (x.shape[1] - ds_ts) // 2
    return x[:, x_start:x_start + ds_ts, x_start:x_start + ds_ts, :]

  def upsample_by_repeat(x, times):
    # X = [B, H, W, F]
    b, h, w, f = x.shape
    x = x.reshape((b, h, 1, w, 1, f))
    x = jnp.broadcast_to(x, (b, h, times, w, times, f))
    x = x.reshape((b, h * times, w * times, f))
    return x

  carry, _ = lstm_cell(init_carry, inputs)
  time_encoded_input = jnp.concatenate(carry, axis=-1)

  # First stage
  encoded_input = make_encoder(0)(time_encoded_input)

  # Second stage after first crop
  start_crop = input_size // 4
  size_crop = input_size // 2
  encoded_input = encoded_input[:, start_crop:start_crop + size_crop,
                                start_crop:start_crop + size_crop, :]

  if encoder_num_blocks[1] > 0:
    encoded_input = make_encoder(1)(encoded_input)

  resolution = 4
  encoded_input = crop_to_target(encoded_input, target_size, resolution)
  repeated_encoded_input = upsample_by_repeat(encoded_input, resolution)

  return repeated_encoded_input, target_embed

class MetNet2(nn.Module):
  """MetNet2"""
  input_keys = [
      'mrms',
      'mrms_cumulative',
      'goes',
      'hrrr',
      'target_index',
      'model_target_index',
  ]
  num_output_channels: int = 512
  hps: Optional[ml_collections.ConfigDict] = None

  @nn.compact
  def __call__(self, mrms, mrms_cumulative, goes, hrrr,
               target_index, model_target_index, train):
    hps = self.hps
    # Detect if we're initializing by absence of params.
    is_initializing = not self.has_variable('params', 'upsampler')

    upsampler_channels = hps.upsampler_channels
    upsampler_num_blocks = hps.upsampler_num_blocks
    upsampler_cond_per_layer = hps.upsampler_cond_per_layer
    groupnorm = hps.groupnorm

    dtype = jnp.bfloat16 if hps.dtype == 'bfloat16' else jnp.float32

    inputs = [mrms, mrms_cumulative, goes, hrrr]
    print(f'Dimensions before encoding: {[a.shape for a in inputs]}')
    repeated_encoded_input, target_embed = metnet_encoder(
        inputs, target_index,
        model_target_index, hps,
        is_initializing)
    print(f'Dimensions after encoding: {repeated_encoded_input.shape}')

    def final_block_fn(hidden_size, prefix):
      def compute_outputs(input_):
        pre_final_dense = nn.Dense(
            features=hidden_size, name=f'{prefix}prefinal', dtype=dtype)

        final_dense = nn.Dense(
            features=self.num_output_channels,
            name=f'{prefix}final',
            dtype=jnp.float32)

        logits = final_dense(nn.relu(pre_final_dense(input_)))

        return logits
      return compute_outputs

    # Upsampler stage

    upsampler = ResidualStack(
        num_blocks=upsampler_num_blocks,
        kernel_dilations=(1,),
        filters=upsampler_channels,
        half_channels=False,
        cond_input=target_embed if upsampler_cond_per_layer else None,
        groupnorm=groupnorm,
        name='upsampler')

    def upsampler_block(all_encoded_inputs):
      # Upsample and combine
      return upsampler(all_encoded_inputs)


    out = upsampler_block(repeated_encoded_input)

    outputs = final_block_fn(
        hidden_size=hps.pre_final_size_1km_resolution,
        prefix='')(out)

    print(f'Final output shape: {outputs.shape}')

    return outputs

In [None]:
#@title Hyperparameters/config

hps = ml_collections.ConfigDict()
hps.dtype = 'bfloat16'

hps.mrms_target_tds = list(range(2, 12 * 60 + 1, 2))
hps.bins = onp.linspace(0, 102.4, 513, dtype=onp.float32)

# Model parameters.
hps.lstm_channels = 128
hps.encoder_channels = [384, 384]
hps.encoder_num_blocks = [16, 8]
hps.upsampler_channels = 512
hps.upsampler_num_blocks = 2
hps.remat = False
hps.pre_final_size_1km_resolution = 4096

hps.shift_target_index = True
hps.target_cond = 'dense_scale'
hps.target_n_layers = 2
hps.target_features = 2048
hps.target_size = 512
hps.cond_per_layer = True
hps.upsampler_cond_per_layer = True
hps.groupnorm = None

# Training parameters.
hps.train_steps = 500000
hps.batch_size = 16
hps.sampling_priority_exp = 2
hps.train_mask = 'roi'
hps.learning_rate = 2e-5
hps.optimizer = 'adam'
hps.lr_schedule = 'none'
hps.optimizer_beta = 0.9
hps.optimizer_decay_steps = '100000'
hps.polyak_decay = 0.9999
hps.weight_decay = 1e-1

hps = ml_collections.FrozenConfigDict(hps)

In [None]:
#@title Input and Preprocessing

# All the spatial inputs are expected to be on a evenly spaced 1km grid.

batch_size = 1 # @param {type: 'integer'}
spatial_dim = 512 # @param {type: 'integer'}
#@markdown ##### in km. NOTE: 2048 for the full mode, 512 is for testing

# Time of the sample.
sample_time = [datetime.datetime(2020, 3, 1, 12, 0, 0)] * batch_size

# Timedelta to predict.
td = onp.array([60] * batch_size)  # In minutes

# Elevation.
# NOTE: Should be replaced with real data
elevation = onp.zeros((batch_size, 1, spatial_dim, spatial_dim, 1), jnp.float32)

# Longitude/latitude.
# NOTE: Should be replaced with real data
lon_lat = onp.zeros((batch_size, 1, spatial_dim, spatial_dim, 2), jnp.float32)

# MRMS precipitation rate input with 11 time slices.
# [-90 min, -75 min, -60 min, -45 min, -30 min, -25 min, -20 min, -15 min, -10 min, -5 min, 0 min]
# NOTE: Should be replaced with real data
mrms_rate = onp.zeros((batch_size, 11, spatial_dim, spatial_dim, 1), onp.float32)

# MRMS cumulative precipitation input with 2 time slices.
# [-60 min, 0 min]
# NOTE: Should be replaced with real data
mrms_cumulative = onp.zeros((batch_size, 2, spatial_dim, spatial_dim, 1), jnp.float32)

# GOES satellite input with 3 time slices.
# [-30 min, -15 min, 0 min]
# NOTE: Should be replaced with real data
goes = onp.zeros((batch_size, 3, spatial_dim, spatial_dim, 16), jnp.float32)

# HRRR assimilated input with 2 time slices.
# [-60 min, 0 min]
# NOTE: Should be replaced with real data
hrrr = onp.zeros((batch_size, 3, spatial_dim, spatial_dim, 612), jnp.float32)

# Example of a target with precipitation.
target_precipitation = onp.zeros((batch_size, spatial_dim, spatial_dim, 1), jnp.float32)

print(f'Sample for {td} min prediction from time {sample_time}')
print()
print('Raw sample dimensions:')
print(f'Elevation: {elevation.shape}')
print(f'Longitude/lattitude: {lon_lat.shape}')
print(f'MRMS precipitation rate: {mrms_rate.shape}')
print(f'MRMS cumulative rate: {mrms_cumulative.shape}')
print(f'GOES: {goes.shape}')
print(f'HRRR: {hrrr.shape}')
print(f'Target: {target_precipitation.shape}')


# Preprocessing the data.
elevation = elevation / 2000.

lon_lat = lon_lat / 1000.

sample_time = onp.array([[x.hour, x.day, x.month] for x in sample_time],
                        dtype=jnp.bfloat16)
sample_time = onp.reshape(sample_time, [batch_size, 1, 1, 1, 3])
sample_time = onp.tile(sample_time, [1, 1, spatial_dim, spatial_dim, 1])

rest = onp.concatenate([elevation, lon_lat, sample_time], axis=-1)
rest = downsample_nanmean(rest, 4)

mrms_rate = downsample_nanmean(mrms_rate, 4)
mrms_rate = mrms_normalize(mrms_rate)
mrms_rate = onp.concatenate([
    mrms_rate,
    onp.tile(rest, [1, mrms_rate.shape[1], 1, 1, 1])  # To fit the time dimension.
], axis=-1)
mrms_rate = mrms_rate.astype(jnp.bfloat16)

mrms_cumulative = downsample_nanmean(mrms_cumulative, 4)
mrms_cumulative = mrms_normalize(mrms_cumulative)
mrms_cumulative = onp.concatenate([
    mrms_cumulative,
    onp.tile(rest, [1, mrms_cumulative.shape[1], 1, 1, 1])  # To fit the time dimension.
], axis=-1)
mrms_cumulative = mrms_cumulative.astype(jnp.bfloat16)

goes = downsample_nanmean(goes, 4)
goes = goes # Data should be standardized
goes = onp.concatenate([
    goes,
    onp.tile(rest, [1, goes.shape[1], 1, 1, 1])  # To fit the time dimension.
], axis=-1)
goes = goes.astype(jnp.bfloat16)

hrrr = downsample_nanmean(hrrr, 4)
hrrr = hrrr # Data should be standardized
hrrr = onp.concatenate([
    hrrr,
    onp.tile(rest, [1, hrrr.shape[1], 1, 1, 1])  # To fit the time dimension.
], axis=-1)
hrrr = hrrr.astype(jnp.bfloat16)

# Target needs to be logits based on one hot encoding of the precipitation bins.
target_precipitation = onp.clip(
    onp.digitize(target_precipitation[..., 0], hps.bins) - 1, 0, len(hps.bins) - 2)
target_precipitation = common_utils.onehot(
    target_precipitation, len(hps.bins) - 1)
target_precipitation = onp.log(target_precipitation)

# Target index. The time delta the model has to predict.
is_cumulative = False  # True, if the model should predict 60 min cumulative precipitation.
target_index = td // 2 - 1

if is_cumulative:
  # For 1hr cumulative precipitation the ones gets a range of 30 1's as input
  # instead of a single 1 (onehot encoding) to distinquich between the two.
  model_target_index = onp.stack([target_index - 29, target_index], axis=-1)
else:
  model_target_index = onp.stack([target_index, target_index], axis=-1)

target_index = target_index.astype(jnp.bfloat16)
model_target_index = model_target_index.astype(jnp.bfloat16)

print()
print('Preprocessed sample dimensions:')
print(f'MRMS precipitation rate: {mrms_rate.shape}')
print(f'MRMS cumulative rate: {mrms_cumulative.shape}')
print(f'GOES: {goes.shape}')
print(f'HRRR: {hrrr.shape}')
print(f'Target: {target_precipitation.shape}')

inputs = {
    'mrms': mrms_rate,
    'mrms_cumulative': mrms_cumulative,
    'goes': goes,
    'hrrr': hrrr,
    'target_index': target_index,
    'model_target_index': model_target_index,
}

In [None]:
#@title Inference

num_output_channels = len(hps.bins) - 1

module = MetNet2(hps=hps, num_output_channels=num_output_channels)

jit_init = jax.jit(module.init, static_argnames='train')
rng = jax.random.PRNGKey(0)
variables = jit_init(rng, **inputs, train=False)

params = variables['params']
apply = jax.jit(module.apply, static_argnames='train')

num_params = sum([x.size for x in jax.tree_util.tree_leaves(params)])
print(f'Number of trainable parameters: {num_params}')

result = apply({'params': params}, **inputs, train=False)

In [None]:
#@title Probabilities from output

# Minimum milimeter of precpitation one wants probability of.
mm = 2.

i, = onp.where(hps.bins == mm)
i = i.item()

probs = nn.softmax(result, axis=-1)
mm_probs = jnp.cumsum(probs[..., ::-1], axis=-1)[..., ::-1][..., i]

print(f'Shape of probabilities for at least {mm} mm of precipitation rate/1hr accumulated: {mm_probs.shape}')