PyTorch UNet implementation using IceNet library for data download and post-processing of sea ice forecasting.

This notebook has been designed to be independent of other notebooks.

### Highlights
The key features of this notebook are:
* [1. Download](#1.-Download) 
* [2. Data Processing](#2.-Data-Processing)
* [3. Train](#3.-Train)
* [4. Prediction](#4.-Prediction)
* [5. Outputs and Plotting](#5.-Outputs-and-Plotting)

It currently uses a dev version of IceNet library (v0.2.8_dev) to run.

To install, can use the conda `icenet-notebooks/pytorch/environment.yml` environment file on a Linux system to be able to set-up the necessary pytorch + tensorflow + cuda + other modules which could be a tricky mix to get working manually:

```bash
conda env create -f environment.yml
```

### Contributions
#### PyTorch implementation of IceNet

Andrew McDonald ([icenet-gan](https://github.com/ampersandmcd/icenet-gan))

Bryn Noel Ubald (Refactor, updates for daily predictions and matching icenet library)

#### Notebook
Bryn Noel Ubald (author)

#### PyTorch Integration
Bryn Noel Ubald

Ryan Chan

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import shutil
import numpy as np
import pandas as pd

import torch
from torch.utils.data import DataLoader

# We also set the logging level so that we get some feedback from the API
import logging
logging.basicConfig(level=logging.INFO)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.set_float32_matmul_precision('medium')

JAX/FLAX imports

In [None]:
from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint

## 1. Download

In [None]:
from icenet.data.sic.mask import Masks
from icenet.data.interfaces.cds import ERA5Downloader
from icenet.data.sic.osisaf import SICDownloader

### Mask data

Create masks for masking data.

In [None]:
masks = Masks(north=False, south=True)
masks.generate(save_polarhole_masks=False)

### Climate and Sea Ice data

Download climate variables from ERA5 and sea ice concentration from OSI-SAF.

In [None]:
era5 = ERA5Downloader(
    var_names=["tas", "zg", "uas", "vas"],
    levels=[None, [250, 500], None, None],
    dates=[pd.to_datetime(date).date() for date in
           pd.date_range("2020-01-01", "2020-04-30", freq="D")],
    delete_tempfiles=False,
    max_threads=64,
    north=False,
    south=True,
    # NOTE: there appears to be a bug with the toolbox API at present (icenet#54)
    use_toolbox=False
)

era5.download()

In [None]:
sic = SICDownloader(
    dates=[pd.to_datetime(date).date() for date in
           pd.date_range("2020-01-01", "2020-04-30", freq="D")],
    delete_tempfiles=False,
    north=False,
    south=True,
    parallel_opens=False,
)

sic.download()

Re-grid ERA5 reanalysis data, and rotate wind vector data from ERA5 to align with EASE2 projection.

In [None]:
era5.regrid()
era5.rotate_wind_data()

## 2. Data Processing

Process downloaded datasets.

To make life easier, setting up train, val, test dates.

In [None]:
processing_dates = dict(
    train=[pd.to_datetime(el) for el in pd.date_range("2020-01-01", "2020-03-31")],
    val=[pd.to_datetime(el) for el in pd.date_range("2020-04-03", "2020-04-23")],
    test=[pd.to_datetime(el) for el in pd.date_range("2020-04-01", "2020-04-02")],
)
processed_name = "notebook_api_jax_data"

In [None]:
processing_dates["test"]

Next, we create the data producer and configure them for the dataset we want to create.

In [None]:
from icenet.data.processors.era5 import IceNetERA5PreProcessor
from icenet.data.processors.meta import IceNetMetaPreProcessor
from icenet.data.processors.osi import IceNetOSIPreProcessor

pp = IceNetERA5PreProcessor(
    ["uas", "vas"],
    ["tas", "zg500", "zg250"],
    processed_name,
    processing_dates["train"],
    processing_dates["val"],
    processing_dates["test"],
    linear_trends=tuple(),
    north=False,
    south=True
)

osi = IceNetOSIPreProcessor(
    ["siconca"],
    [],
    processed_name,
    processing_dates["train"],
    processing_dates["val"],
    processing_dates["test"],
    linear_trends=tuple(),
    north=False,
    south=True
)

meta = IceNetMetaPreProcessor(
    processed_name,
    north=False,
    south=True
)

Next, we initialise the data processors using `init_source_data` which scans the data source directories to understand what data is available for processing based on the parameters. Since we named the processed data `"notebook_api_data"` above, it will create a data loader config file, `loader.notebook_api_data.json`, in the current directory.

In [None]:
pp.init_source_data(
    lag_days=1,
)
pp.process()

osi.init_source_data(
    lag_days=1,
)
osi.process()

meta.process()

At this point the preprocessed data is ready to convert or create a configuration for the network dataset.

### Dataset creation

As with the `icenet_dataset_create` command we can create a dataset configuration for training the network. As before this can include cached data for the network in the format of a TFRecordDataset compatible set of tfrecords. To achieve this we create the `IceNetDataLoader`, which can both generate `IceNetDataSet` configurations (which easily provide the necessary functionality for training and prediction) as well as individual data samples for direct usage.

In [None]:
from icenet.data.loaders import IceNetDataLoaderFactory

implementation = "dask"
loader_config = "loader.notebook_api_jax_data.json"
dataset_name = "notebook_api_jax_data"
lag = 1

dl = IceNetDataLoaderFactory().create_data_loader(
    implementation,
    loader_config,
    dataset_name,
    lag,
    n_forecast_days=7,
    north=False,
    south=True,
    output_batch_size=1,
    generate_workers=4)

At this point we can either use `generate` or `write_dataset_config_only` to produce a ready-to-go `IceNetDataSet` configuration. Both of these will generate a dataset config, `dataset_config.notebook_api_pytorch_data.json` (recall we set the dataset name as `notebook_api_pytorch_data` above).

In this case, for pytorch, will read data in directly, rather than using cached tfrecords inputs.

In [None]:
dl.write_dataset_config_only()

We can now create the IceNetDataSet object:

In [None]:
from icenet.data.dataset import IceNetDataSetPyTorch
dataset_config = f"dataset_config.{dataset_name}.json"

In [None]:
batch_size = 4
shuffle = False
persistent_workers=True
num_workers = 4

Create a dataloader interface for Jax from Pytorch.

In [None]:
import jax.numpy as jnp
from jax.tree_util import tree_map
from torch.utils import data

def numpy_collate(batch):
  return tree_map(np.asarray, data.default_collate(batch))

class NumpyDataLoader(data.DataLoader):
  def __init__(self, dataset, batch_size=1,
                shuffle=False, sampler=None,
                batch_sampler=None, num_workers=0,
                pin_memory=False, drop_last=False,
                timeout=0, worker_init_fn=None):
    super(self.__class__, self).__init__(dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        sampler=sampler,
        batch_sampler=batch_sampler,
        num_workers=num_workers,
        collate_fn=numpy_collate,
        pin_memory=pin_memory,
        drop_last=drop_last,
        timeout=timeout,
        worker_init_fn=worker_init_fn)

class FlattenAndCast(object):
  def __call__(self, pic):
    return np.ravel(np.array(pic, dtype=jnp.float32))

## 3. Train

We implement a custom PyTorch class for training.

## Create dataloader

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [None]:
def get_datasets(batch_size, configuration_path, n_workers=1):
    # configure datasets and dataloaders
    train_dataset = IceNetDataSetPyTorch(configuration_path, mode="train")
    val_dataset = IceNetDataSetPyTorch(configuration_path, mode="val")
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers,
                                  persistent_workers=persistent_workers, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers,
                                persistent_workers=persistent_workers, shuffle=False)
    
    test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="test")
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers,
                                persistent_workers=persistent_workers, shuffle=False)
    return train_dataset, train_dataloader, val_dataset, val_dataloader, test_dataset, test_dataloader

# IceNet UNet model

In [None]:
import jax
from flax import linen as nn
from jax.image import resize

class UNet(nn.Module):
    """UNet implementation for binary classification for IceNet"""

    padding = "SAME"
    filter_size = (3, 3)
    n_filters_factor = 0.1
    n_forecast_days = 7
    # kernel_init = nn.initializers.he_normal()
    # kernel_init(jax.random.key(42), (3, 3), jnp.float32)

    @nn.compact
    def __call__(self, x, train=True):
        """Forward pass"""
        filter_size = self.filter_size
        n_filters_factor = self.n_filters_factor
        n_forecast_days = self.n_forecast_days

        start_out_channels = 64
        reduced_channels = int(start_out_channels * n_filters_factor)
        channels = {
            start_out_channels * 2**pow: reduced_channels * 2**pow
            for pow in range(4)
        }
        
        conv1 = nn.Conv(channels[64], kernel_size=filter_size, padding=self.padding)(x)
        conv1 = nn.relu(conv1)
        conv1 = nn.Conv(channels[64], kernel_size=filter_size, padding=self.padding)(conv1)
        conv1 = nn.relu(conv1)
        bn1 = nn.BatchNorm(use_running_average=not train)(conv1)
        pool1 = nn.max_pool(bn1, window_shape=(2, 2), strides=(2, 2))
        
        conv2 = nn.Conv(channels[128], kernel_size=filter_size, padding=self.padding,)(pool1)
        conv2 = nn.relu(conv2)
        conv2 = nn.Conv(channels[128], kernel_size=filter_size, padding=self.padding,)(conv2)
        conv2 = nn.relu(conv2)
        bn2 = nn.BatchNorm(use_running_average=not train)(conv2)
        pool2 = nn.max_pool(bn2, window_shape=(2, 2), strides=(2, 2))

        conv3 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(pool2)
        conv3 = nn.relu(conv3)
        conv3 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(conv3)
        conv3 = nn.relu(conv3)
        bn3 = nn.BatchNorm(use_running_average=not train)(conv3)
        pool3 = nn.max_pool(bn3, window_shape=(2, 2), strides=(2, 2))

        conv4 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(pool3)
        conv4 = nn.relu(conv4)
        conv4 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(conv4)
        conv4 = nn.relu(conv4)
        bn4 = nn.BatchNorm(use_running_average=not train)(conv4)
        pool4 = nn.max_pool(bn4, window_shape=(2, 2), strides=(2, 2))

        conv5 = nn.Conv(channels[512], kernel_size=filter_size, padding=self.padding,)(pool4)
        conv5 = nn.relu(conv5)
        conv5 = nn.Conv(channels[512], kernel_size=filter_size, padding=self.padding,)(conv5)
        conv5 = nn.relu(conv5)
        bn5 = nn.BatchNorm(use_running_average=not train)(conv5)

        bn5 = resize(bn5, shape=(bn5.shape[0], bn5.shape[1]*2, bn5.shape[2]*2, bn5.shape[3]), method="nearest")
        up6 = nn.Conv(channels[256], kernel_size=(2, 2), padding=self.padding,)(bn5)
        merge6 = jnp.concatenate([bn4, up6], axis=-1)
        conv6 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(merge6)
        conv6 = nn.relu(conv6)
        conv6 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(conv6)
        conv6 = nn.relu(conv6)
        bn6 = nn.BatchNorm(use_running_average=not train)(conv6)

        bn6 = resize(bn6, shape=(bn6.shape[0], bn6.shape[1]*2, bn6.shape[2]*2, bn6.shape[3]), method="nearest")
        up7 = nn.Conv(channels[256], kernel_size=(2, 2), padding=self.padding,)(bn6)
        merge7 = jnp.concatenate([bn3, up7], axis=-1)
        conv7 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(merge7)
        conv7 = nn.relu(conv7)
        conv7 = nn.Conv(channels[256], kernel_size=filter_size, padding=self.padding,)(conv7)
        conv7 = nn.relu(conv7)
        bn7 = nn.BatchNorm(use_running_average=not train)(conv7)

        bn7 = resize(bn7, shape=(bn7.shape[0], bn7.shape[1]*2, bn7.shape[2]*2, bn7.shape[3]), method="nearest")
        up8 = nn.Conv(channels[128], kernel_size=(2, 2), padding=self.padding,)(bn7)
        merge8 = jnp.concatenate([bn2, up8], axis=-1)
        conv8 = nn.Conv(channels[128], kernel_size=filter_size, padding=self.padding,)(merge8)
        conv8 = nn.relu(conv8)
        conv8 = nn.Conv(channels[128], kernel_size=filter_size, padding=self.padding,)(conv8)
        conv8 = nn.relu(conv8)
        bn8 = nn.BatchNorm(use_running_average=not train)(conv8)

        bn8 = resize(bn8, shape=(bn8.shape[0], bn8.shape[1]*2, bn8.shape[2]*2, bn8.shape[3]), method="nearest")
        up9 = nn.Conv(channels[64], kernel_size=(2, 2), padding=self.padding,)(bn8)
        merge9 = jnp.concatenate([bn1, up9], axis=-1)

        conv9 = nn.Conv(channels[64], kernel_size=filter_size, padding=self.padding,)(merge9)
        conv9 = nn.Conv(channels[64], kernel_size=filter_size, padding=self.padding,)(conv9)
        conv9 = nn.Conv(channels[64], kernel_size=filter_size, padding=self.padding,)(conv9)

        final_layer = nn.Conv(n_forecast_days, kernel_size=(1, 1),)(conv9)

        # if not train:
        #     return nn.sigmoid(final_layer)
        return final_layer

### View model layer

In [None]:
import jax
import jax.numpy as jnp  # JAX NumPy

unet = UNet()
print(unet.tabulate(jax.random.key(0), jnp.ones((1, 432, 432, 9)),
                   train=True))

# Newer flax version (>0.8.0)
# print(unet.tabulate(jax.random.key(0), jnp.ones((1, 432, 432, 9)),
#                    train=True,
#                    compute_flops=True, compute_vjp_flops=True))

# 3. Train

In [None]:
from clu import metrics
from flax.training import train_state  # Useful dataclass to keep train state
from flax import struct                # Flax dataclasses
import optax                           # Common loss functions and optimizers

### Define custom metrics

In [None]:
@struct.dataclass
class BinaryAccuracy(metrics.Metric):
    weighted_score: jnp.ndarray
    possible_score: jnp.ndarray

    @classmethod
    def empty(cls):
        return cls(weighted_score=jnp.array(0, jnp.float32), possible_score=jnp.array(0, jnp.int32))

    @classmethod
    def from_model_output(cls, *, predictions, targets, sample_weights, **kwargs):
        predictions = predictions > 0.15
        targets = targets > 0.15
        base_score = predictions == targets
        return cls(
            weighted_score = jnp.sum(base_score*sample_weights),
            possible_score = jnp.sum(sample_weights)
        )

    def compute(self):
        binary_accuracy = self.weighted_score / self.possible_score * 100
        return binary_accuracy

@struct.dataclass
class RMSE(metrics.Metric):
    rmse: float

    @classmethod
    def empty(cls):
        return cls(rmse=0.)

    @classmethod
    def from_model_output(cls, *, predictions, targets, sample_weights, **kwargs):
        predictions = 100*(predictions > 0.15)*sample_weights
        targets = 100*(targets > 0.15)*sample_weights
        rmse = jnp.sqrt(jnp.mean(jnp.square(predictions - targets)))
        return cls(
            rmse = rmse
        )

    def compute(self):
        return self.rmse

In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
  # accuracy: metrics.Accuracy
  accuracy: BinaryAccuracy
  rmse : RMSE
  loss: metrics.Average.from_output('loss')

## Define training state

In [None]:
class TrainState(train_state.TrainState):
  # Ref: https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html#training-and-evaluation
  batch_stats: any
  metrics: Metrics

def create_train_state(module, rng, learning_rate, momentum):
  """Creates an initial `TrainState`."""
  variables = module.init(rng, jnp.ones([1, 432, 432, 9]), train=False) # initialize parameters by passing a template image
  params = variables['params'] # initialize parameters by passing a template image
  batch_stats = variables['batch_stats']
  tx = optax.sgd(learning_rate, momentum)
  return TrainState.create(
      apply_fn=module.apply,
      params=params,
      batch_stats=batch_stats,
      tx=tx,
      metrics=Metrics.empty()
      )

## Define training step

In [None]:
@jax.jit
def train_step(state, batch):
  """Train for a single step."""
  inputs, targets, sample_weights = batch
  def loss_fn(params):
    # Ref: https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html#training-and-evaluation
    logits, updates = state.apply_fn(
      {'params': params, 'batch_stats': state.batch_stats},
      x=inputs,
      train=True,
      mutable=['batch_stats']
      )
    batch_stats = updates['batch_stats']
    # loss = optax.softmax_cross_entropy_with_integer_labels(
    #     logits=logits, labels=batch['label']).mean()
    # print("Checking shapes:", logits.shape, targets.shape, sample_weights.shape)
    loss = optax.l2_loss(100*nn.sigmoid(logits)*sample_weights, 100*targets*sample_weights).mean()
    return loss, (logits, updates)
  grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
  (loss, (logits, updates)), grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  state = state.replace(batch_stats=updates['batch_stats'])

  # preds = nn.sigmoid(logits) > 0.15
  # targets = targets > 0.15
  # base_score = preds == targets
  # weighted_score = jnp.sum(base_score*sample_weights)
  # possible_score = jnp.sum(sample_weights)
  # metrics = {
  #   'loss': loss,
  #   'binary_accuracy': weighted_score / possible_score * 100
  # }
  return state#, metrics

## Metric computaton

In [None]:
@jax.jit
def compute_metrics(*, state, batch):
  inputs, targets, sample_weights = batch
  # Ref: https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html#training-and-evaluation
  logits, updates = state.apply_fn(
    {'params': state.params, 'batch_stats': state.batch_stats},
    x=inputs,
    train=True,
    mutable=['batch_stats']
    )
  batch_stats = updates['batch_stats']
#   loss = optax.softmax_cross_entropy_with_integer_labels(
#         logits=logits, labels=batch['label']).mean()
  loss = optax.l2_loss(
    predictions=100*nn.sigmoid(logits)*sample_weights,
    targets=100*targets*sample_weights
    ).mean()
  
  metric_updates = state.metrics.single_from_model_output(
    predictions=nn.sigmoid(logits)*sample_weights, targets=targets*sample_weights, sample_weights=sample_weights, loss=loss)
  state = state.replace(metrics=metric_updates)
  # metrics = state.metrics.merge(metric_updates)
  # state = state.replace(metrics=metrics)


  # preds = nn.sigmoid(logits) > 0.15
  # targets = targets > 0.15
  # base_score = preds == targets
  # weighted_score = jnp.sum(base_score*sample_weights)
  # possible_score = jnp.sum(sample_weights)
  # metrics = {
  #   'loss': loss,
  #   'accuracy': weighted_score / possible_score * 100
  # }
  # state = state.replace(metrics=metrics)
  return state

## Define data

In [None]:
num_epochs = 10
# batch_size = 4 # Defined earlier...

train_dataset, train_ds, val_dataset, val_ds, test_dataset, test_ds = get_datasets(configuration_path=dataset_config, batch_size=batch_size)

## Set seed randomness

In [None]:
import tensorflow as tf

tf.random.set_seed(0)
init_rng = jax.random.key(0)

## Initialize the `TrainState`

In [None]:
learning_rate = 1e-4
momentum = 0.9

In [None]:
state = create_train_state(unet, init_rng, learning_rate, momentum)
del init_rng  # Must not be used anymore.

# Train and evaluate

In [None]:
import math
num_steps_per_epoch = math.ceil(len(processing_dates["train"]) / batch_size)
num_steps_per_epoch

Initialise metrics history and checkpointing

In [None]:
metrics_history = {'train_loss': [],
                   'train_accuracy': [],
                   'train_rmse': [],
                   'val_loss': [],
                   'val_accuracy': [],
                   'val_rmse': []
                   }

In [None]:
ckpt_dir = '/tmp/flax_ckpt'

if os.path.exists(ckpt_dir):
    shutil.rmtree(ckpt_dir)  # Remove any existing checkpoints from the last notebook run.

In [None]:
from flax.training import orbax_utils

# orbax_checkpointer = orbax.checkpoint.AsyncCheckpointer(orbax.checkpoint.PyTreeCheckpointHandler())
# # options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True, best_fn=lambda metrics: metrics["val_rmse"], best_mode="min")
# options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=4, create=True)
# out_path = os.path.abspath('managed-checkpoint')
# checkpoint_manager = orbax.checkpoint.CheckpointManager(out_path, options=options)

options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=1, create=True, best_fn=lambda metrics: metrics, best_mode='min')
checkpoint_manager = orbax.checkpoint.CheckpointManager('/tmp/flax_ckpt/orbax/managed', options=options)


for epoch in range(1, num_epochs+1):
  for step, batch in enumerate(train_ds):

    # Convert to numpy
    batch = [element.numpy()[..., 0] if element.shape[-1] == 1 else element.numpy() for element in batch]
    # batch[0] = jnp.expand_dims(batch[0], axis=-1)
    # print("Init shapes", batch[0].shape, batch[1].shape, batch[2].shape)

    # Run optimization steps over training batches and compute batch metrics
    state = train_step(state, batch) # get updated train state (which contains the updated parameters)
    state = compute_metrics(state=state, batch=batch) # aggregate batch metrics

    # print(f"Step: {step+1}, num_steps_per_epoch: {num_steps_per_epoch}, check: {(step+1) % num_steps_per_epoch}")
    if (step+1) % num_steps_per_epoch == 0: # one training epoch has passed
      print("\tRunning validation set")
      for metric,value in state.metrics.compute().items(): # compute metrics
        metrics_history[f'train_{metric}'].append(value) # record metrics
      # state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch

      # Compute metrics on the validation set after each training epoch
      val_state = state
      for val_batch in val_ds:
        # print("Val shapes", val_batch[0].shape, val_batch[1].shape, val_batch[2].shape)
        val_batch = [element.numpy()[..., 0] if element.shape[-1] == 1 else element.numpy() for element in val_batch]
        # print("Val shapes", val_batch[0].shape, val_batch[1].shape, val_batch[2].shape)
        val_state = compute_metrics(state=val_state, batch=val_batch)

      for metric,value in val_state.metrics.compute().items():
        metrics_history[f'val_{metric}'].append(value)

      print(f"train epoch: {(epoch)}, "
            f"loss: {metrics_history['train_loss'][-1]}, "
            f"accuracy: {metrics_history['train_accuracy'][-1]}, "
            f"rmse: {metrics_history['train_rmse'][-1]}"
            )
      print(f"val epoch: {(epoch)}, "
            f"loss: {metrics_history['val_loss'][-1]}, "
            f"accuracy: {metrics_history['val_accuracy'][-1]}, "
            f"rmse: {metrics_history['val_rmse'][-1]}"
            )

      print("Checkpointing...")
      # Bundle everything together.
      ckpt = {'model': state}
      # save_args = orbax_utils.save_args_from_target(ckpt)
      save_args = orbax.checkpoint.args.StandardSave(state)
      # orbax_checkpointer.save('/tmp/flax_ckpt/orbax/single_save', ckpt, save_args=save_args, force=True)
      val_rmse = state.metrics.rmse.rmse.item()
      # checkpoint_manager.save(epoch, save_kwargs={'save_args': save_args}, metrics=val_rmse)
      checkpoint_manager.save(epoch, args=save_args, metrics=val_rmse)
      
      state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch








#       # Bundle everything together.
#       ckpt = {'model': state}
#       save_args = orbax_utils.save_args_from_target(ckpt)
#       checkpoint_manager.save(step, ckpt, save_kwargs={'save_args': save_args}, force=True)






#       # print(type(state))
#       # # state['step'] = step
#       # state_save_args = jax.tree_map(lambda _: orbax.checkpoint.SaveArgs(), state)
#       # # print(state)
#       # # checkpoint_manager.save(step, state, metrics=state.metrics)
#       # checkpoint_manager.save(
#       #   step,
#       #   # # {
#       #   # #   "model_state": state,
#       #   # # }
#       #   # # metrics=state.metrics
#       #   # items={
#       #   #     'state': state,
#       #   # },
#       #   # # save_kwargs must be a dict with the same keys as items.
#       #   # # not all keys in items have to be provided, in which case default kwargs
#       #   # # are used each value must be a dict with keyword args passed to the
#       #   # # underlying CheckpointHandler for that item (see CheckpointManager
#       #   # # object construction)
#       #   # save_kwargs={'state': {
#       #   #     'save_args': state_save_args
#       #   # }},
#       #   )

#       state = state.replace(metrics=state.metrics.empty()) # reset train_metrics for next training epoch

checkpoint_manager.wait_until_finished()
print(f'Checkpointed epochs: {checkpoint_manager.all_steps()}')

In [None]:
os.listdir('/tmp/flax_ckpt/orbax/managed')  # Because max_to_keep=2, only step 3 and 4 are retained

## 4. Prediction

Load latest saved checkpoint

In [None]:
epoch = checkpoint_manager.latest_step()
state_restored = checkpoint_manager.restore(epoch, items=ckpt)["model"]

In [None]:
@jax.jit
def pred_step(state, batch):
  inputs, targets, sample_weights = batch
  logits, updates = state.apply_fn(
    {'params': state.params, 'batch_stats': state.batch_stats},
    x=inputs,
    train=False,
    mutable=['batch_stats']
    )
  return nn.sigmoid(logits)

predictions = []
for step, test_batch in enumerate(test_ds):
  print(f"Batch: {step}")
  # Convert to numpy
  test_batch = [element.numpy()[..., 0] if element.shape[-1] == 1 else element.numpy() for element in test_batch]
  pred = pred_step(state_restored, test_batch)
  predictions.append(pred)

In [None]:
predictions[0].shape

## 5. Outputs and Plotting

Create prediction output directory

In [None]:
# dataset = "pytorch_notebook"
network_name = "api_jax_dataset"
output_name = "example_jax_forecast"
output_folder = os.path.join(".", "results", "predict", output_name,
                                "{}.{}".format(network_name, 42))
os.makedirs(output_folder, exist_ok=output_folder)

Convert and output predictions to numpy files

In [None]:
idx = 0
for workers, prediction in enumerate(predictions):
    for batch in range(prediction.shape[0]):
        date = pd.Timestamp(test_dataset.dates[idx].replace('_', '-'))
        output_path = os.path.join(output_folder, date.strftime("%Y_%m_%d.npy"))
        forecast = prediction[batch, :, :, :]
        # # forecast_np = forecast.detach().cpu().numpy()
        np.save(output_path, forecast)
        idx += 1

Create a csv file with all the test dates we have predicted for, and to use in generating the final netCDF output using `icenet_output`.

In [None]:
!printf "2020-04-01\n2020-04-02" | tee testdates.csv

In [None]:
!icenet_output -m -o results/predict example_jax_forecast notebook_api_jax_data testdates.csv

Plotting the forecast

In [None]:
import xarray as xr
import datetime as dt
from IPython.display import HTML

In [None]:
from icenet.plotting.video import xarray_to_video as xvid
from icenet.data.sic.mask import Masks

ds = xr.open_dataset("results/predict/example_jax_forecast.nc")
land_mask = Masks(south=True, north=False).get_land_mask()
ds.info()

In [None]:
forecast_date = ds.time.values[0]
fc = ds.sic_mean.isel(time=0).drop_vars("time").rename(dict(leadtime="time"))
fc['time'] = [pd.to_datetime(forecast_date) \
              + dt.timedelta(days=int(e)) for e in fc.time.values]

anim = xvid(fc, 15, figsize=4, mask=land_mask)
HTML(anim.to_jshtml())

___

#### Load original input dataset

This is the original input dataset (pre-normalisation) for comparison.

In [None]:
# Load original input dataset (domain not normalised)
xr.plot.contourf(xr.open_dataset("data/osisaf/south/siconca/2020.nc").isel(time=92).ice_conc, levels=50)

## Version
- IceNet Codebase: v0.2.8_dev