In [1]:
# Quick hack to put us in the icenet-pipeline folder,
# assuming it was created as per 01.cli_demonstration.ipynb
import os
if os.path.exists("pytorch_example.ipynb"):
    os.chdir("../../notebook-pipeline")
print("Running in {}".format(os.getcwd()))

%matplotlib inline

Running in /data/hpcdata/users/rychan/notebooks/notebook-pipeline


In [2]:
import numpy as np
import pandas as pd
import torch
import logging

from icenet.data.loaders import IceNetDataLoaderFactory
from icenet.data.dataset import IceNetDataSet
from icenet_pytorch_dataset import IceNetDataSetPyTorch

from train_icenet_unet import train_icenet_unet
from test_icenet_unet import test_icenet_unet

# We also set the logging level so that we get some feedback from the API
import logging
logging.basicConfig(level=logging.INFO)

2023-08-30 17:07:17.448588: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-08-30 17:07:17.511983: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
print('A', torch.__version__)
print('B', torch.cuda.is_available())
print('C', torch.backends.cudnn.enabled)
device = torch.device('cuda')
print('D', torch.cuda.get_device_properties(device))

A 2.0.1+cu117
B True
C True
D _CudaDeviceProperties(name='NVIDIA A2', major=8, minor=6, total_memory=14938MB, multi_processor_count=10)


In [4]:
!nvidia-smi

Wed Aug 30 17:08:13 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A2           On   | 00000000:98:00.0 Off |                    0 |
|  0%   32C    P8     5W /  60W |      2MiB / 15356MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Dataset creation

Assuming we have ran [03.library_usage](03.library_usage.ipynb) `loader.notebook_api_data.json` file existing in the current directory.

In [5]:
implementation = "dask"
loader_config = "loader.notebook_api_data.json"
dataset_name = "pytorch_notebook"
lag = 1

dl = IceNetDataLoaderFactory().create_data_loader(
    implementation,
    loader_config,
    dataset_name,
    lag,
    n_forecast_days=7,
    north=False,
    south=True,
    output_batch_size=4,
    generate_workers=8)

INFO:root:Loading configuration loader.notebook_api_data.json


In [6]:
dl._n_forecast_days

7

In [7]:
dl._config

{'sources': {'era5': {'name': 'notebook_api_data',
   'implementation': 'IceNetERA5PreProcessor',
   'anom': ['tas', 'zg500', 'zg250'],
   'abs': ['uas', 'vas'],
   'dates': {'train': ['2020_01_01',
     '2020_01_02',
     '2020_01_03',
     '2020_01_04',
     '2020_01_05',
     '2020_01_06',
     '2020_01_07',
     '2020_01_08',
     '2020_01_09',
     '2020_01_10',
     '2020_01_11',
     '2020_01_12',
     '2020_01_13',
     '2020_01_14',
     '2020_01_15',
     '2020_01_16',
     '2020_01_17',
     '2020_01_18',
     '2020_01_19',
     '2020_01_20',
     '2020_01_21',
     '2020_01_22',
     '2020_01_23',
     '2020_01_24',
     '2020_01_25',
     '2020_01_26',
     '2020_01_27',
     '2020_01_28',
     '2020_01_29',
     '2020_01_30',
     '2020_01_31',
     '2020_02_01',
     '2020_02_02',
     '2020_02_03',
     '2020_02_04',
     '2020_02_05',
     '2020_02_06',
     '2020_02_07',
     '2020_02_08',
     '2020_02_09',
     '2020_02_10',
     '2020_02_11',
     '2020_02_12',
   

We generate a config only dataset, which will get saved in `dataset_config.pytorch_notebook.json`.

In [8]:
dl.write_dataset_config_only()

INFO:root:Writing dataset configuration without data generation
INFO:root:91 train dates in total, NOT generating cache data.
INFO:root:21 val dates in total, NOT generating cache data.
INFO:root:2 test dates in total, NOT generating cache data.
INFO:root:Writing configuration to ./dataset_config.pytorch_notebook.json


We can now create the IceNetDataSet object:

In [9]:
dataset_config = "dataset_config.pytorch_notebook.json"

In [10]:
dataset = IceNetDataSet(dataset_config, batch_size=4)

INFO:root:Loading configuration dataset_config.pytorch_notebook.json


In [11]:
dataset._config

{'identifier': 'pytorch_notebook',
 'implementation': 'DaskMultiWorkerLoader',
 'channels': ['uas_abs_1',
  'vas_abs_1',
  'siconca_abs_1',
  'tas_anom_1',
  'zg250_anom_1',
  'zg500_anom_1',
  'cos_1',
  'land_1',
  'sin_1'],
 'counts': {'train': 91, 'val': 21, 'test': 2},
 'dtype': 'float32',
 'loader_config': '/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json',
 'missing_dates': [],
 'n_forecast_days': 7,
 'north': False,
 'num_channels': 9,
 'shape': [432, 432],
 'south': True,
 'dataset_path': False,
 'loss_weight_days': True,
 'output_batch_size': 4,
 'var_lag': 1,
 'var_lag_override': {}}

In [12]:
dataset._config["n_forecast_days"]

7

In [13]:
dataset.loader_config

'/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json'

In [14]:
dataloader_from_dataset = dataset.get_data_loader()

INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json


In [15]:
dataloader_from_dataset._config.keys()

dict_keys(['sources', 'dtype', 'shape', 'missing_dates'])

In [16]:
dataloader_from_dataset._n_forecast_days

7

In [17]:
dataloader_from_dataset

<icenet.data.loaders.dask.DaskMultiWorkerLoader at 0x7f4c65ee6f70>

## Custom PyTorch Dataset

In [18]:
ds_torch = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="train")

INFO:root:Loading configuration dataset_config.pytorch_notebook.json
INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json


In [19]:
ds_torch.__len__()

91

In [20]:
ds_torch._dates[0]

'2020-01-01'

In [23]:
first_item = ds_torch.__getitem__(0)

n_forecast_days: 7


In [24]:
print(f"first_item is a {type(first_item)} of length {len(first_item)}")

first_item is a <class 'tuple'> of length 3


In [25]:
for i in range(len(first_item)):
    print(f"first_item[i] is a {type(first_item[i])} of shape {first_item[i].shape}")

first_item[i] is a <class 'numpy.ndarray'> of shape (432, 432, 9)
first_item[i] is a <class 'numpy.ndarray'> of shape (432, 432, 7, 1)
first_item[i] is a <class 'numpy.ndarray'> of shape (432, 432, 7, 1)


In [23]:
ds_torch._ds._config

{'identifier': 'pytorch_notebook',
 'implementation': 'DaskMultiWorkerLoader',
 'channels': ['uas_abs_1',
  'vas_abs_1',
  'siconca_abs_1',
  'tas_anom_1',
  'zg250_anom_1',
  'zg500_anom_1',
  'cos_1',
  'land_1',
  'sin_1'],
 'counts': {'train': 91, 'val': 21, 'test': 2},
 'dtype': 'float32',
 'loader_config': '/data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json',
 'missing_dates': [],
 'n_forecast_days': 7,
 'north': False,
 'num_channels': 9,
 'shape': [432, 432],
 'south': True,
 'dataset_path': False,
 'loss_weight_days': True,
 'output_batch_size': 4,
 'var_lag': 1,
 'var_lag_override': {}}

In [None]:
ds_torch._ds

In [22]:
ds_torch.__getitem__(80)

n_forecast_days: 7


(array([[[ 0.5070619 ,  0.507128  ,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         [ 0.5067984 ,  0.5119969 ,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         [ 0.50709283,  0.5130819 ,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         ...,
         [ 0.49436113,  0.51067567,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         [ 0.49760532,  0.5123943 ,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         [ 0.5012724 ,  0.5121377 ,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ]],
 
        [[ 0.5080182 ,  0.5084777 ,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         [ 0.50926554,  0.51501405,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         [ 0.50827605,  0.5093369 ,  0.        , ..., -0.18561055,
           1.        , -0.9826234 ],
         ...,
         [ 0.4894217 ,  0.5110176

## Generating PyTorch DataLoaders

In [27]:
train_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="train")
val_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="val")
test_dataset = IceNetDataSetPyTorch(configuration_path=dataset_config, mode="test")

INFO:root:Loading configuration dataset_config.pytorch_notebook.json
INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json
INFO:root:Loading configuration dataset_config.pytorch_notebook.json
INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json
INFO:root:Loading configuration dataset_config.pytorch_notebook.json
INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json


In [28]:
from torch.utils.data import DataLoader

batch_size = 4
shuffle = False # set to False for now
persistent_workers = False
num_workers = 0

train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=shuffle,
                              persistent_workers=persistent_workers,
                              num_workers=num_workers)
val_dataloader = DataLoader(val_dataset,
                            batch_size=batch_size,
                            shuffle=shuffle,
                            persistent_workers=persistent_workers,
                            num_workers=num_workers)
test_dataloader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             shuffle=shuffle,
                             persistent_workers=persistent_workers,
                             num_workers=num_workers)

## Iterating through DataLoaders

In [29]:
len(train_dataloader)

23

In [30]:
train_features, train_labels, sample_weights = next(iter(train_dataloader))

n_forecast_days: 7
n_forecast_days: 7
n_forecast_days: 7
n_forecast_days: 7


In [31]:
train_features.shape

torch.Size([4, 432, 432, 9])

In [32]:
train_labels.shape

torch.Size([4, 432, 432, 7, 1])

In [33]:
sample_weights.shape

torch.Size([4, 432, 432, 7, 1])

In [34]:
from icenet_unet_small import UNet

unet = UNet(input_channels=9)

In [35]:
y_hat = unet(train_features)



In [36]:
y_hat.shape

torch.Size([4, 432, 432, 3, 6])

## IceNet UNet model

As a first attempt to implement a PyTorch example, we adapt code from https://github.com/ampersandmcd/icenet-gan/.

Below is a PyTorch implementation of the UNet architecture.

In [41]:
import torch
from torch import nn
import torch.nn.functional as F

class UNet(nn.Module):
    """
    An implementation of a UNet for pixelwise classification.
    """
    
    def __init__(self,
                 input_channels, 
                 filter_size=3, 
                 n_filters_factor=1, 
                 n_forecast_days=6, 
                 n_output_classes=3,
                **kwargs):
        super(UNet, self).__init__()

        self.input_channels = input_channels
        self.filter_size = filter_size
        self.n_filters_factor = n_filters_factor
        self.n_forecast_days = n_forecast_days
        self.n_output_classes = n_output_classes

        self.conv1a = nn.Conv2d(in_channels=input_channels, 
                                out_channels=int(128*n_filters_factor),
                                kernel_size=filter_size,
                                padding="same")
        self.conv1b = nn.Conv2d(in_channels=int(128*n_filters_factor),
                                out_channels=int(128*n_filters_factor),
                                kernel_size=filter_size,
                                padding="same")
        self.bn1 = nn.BatchNorm2d(num_features=int(128*n_filters_factor))

        self.conv2a = nn.Conv2d(in_channels=int(128*n_filters_factor),
                                out_channels=int(256*n_filters_factor),
                                kernel_size=filter_size,
                                padding="same")
        self.conv2b = nn.Conv2d(in_channels=int(256*n_filters_factor),
                                out_channels=int(256*n_filters_factor),
                                kernel_size=filter_size,
                                padding="same")
        self.bn2 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))

        # self.conv3a = nn.Conv2d(in_channels=int(256*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv3b = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.bn3 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))

        # self.conv4a = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv4b = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.bn4 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))

        # self.conv5a = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(1024*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv5b = nn.Conv2d(in_channels=int(1024*n_filters_factor),
        #                         out_channels=int(1024*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.bn5 = nn.BatchNorm2d(num_features=int(1024*n_filters_factor))

        # self.conv6a = nn.Conv2d(in_channels=int(1024*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv6b = nn.Conv2d(in_channels=int(1024*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv6c = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.bn6 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))

        # self.conv7a = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv7b = nn.Conv2d(in_channels=int(1024*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv7c = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(512*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.bn7 = nn.BatchNorm2d(num_features=int(512*n_filters_factor))

        # self.conv8a = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(256*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv8b = nn.Conv2d(in_channels=int(512*n_filters_factor),
        #                         out_channels=int(256*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.conv8c = nn.Conv2d(in_channels=int(256*n_filters_factor),
        #                         out_channels=int(256*n_filters_factor),
        #                         kernel_size=filter_size,
        #                         padding="same")
        # self.bn8 = nn.BatchNorm2d(num_features=int(256*n_filters_factor))

        self.conv9a = nn.Conv2d(in_channels=int(256*n_filters_factor),
                                out_channels=int(128*n_filters_factor),
                                kernel_size=filter_size,
                                padding="same")
        self.conv9b = nn.Conv2d(in_channels=int(256*n_filters_factor),
                                out_channels=int(128*n_filters_factor),
                                kernel_size=filter_size,
                                padding="same")
        self.conv9c = nn.Conv2d(in_channels=int(128*n_filters_factor),
                                out_channels=int(128*n_filters_factor),
                                kernel_size=filter_size,
                                padding="same")  # no batch norm on last layer

        self.final_conv = nn.Conv2d(in_channels=int(128*n_filters_factor),
                                    out_channels=n_output_classes*n_forecast_days,
                                    kernel_size=filter_size,
                                    padding="same")
        
    def forward(self, x):

        # transpose from shape (b, h, w, c) to (b, c, h, w) for pytorch conv2d layers
        x = torch.movedim(x, -1, 1)  # move c from last to second dim

        # run through network
        conv1 = self.conv1a(x)  # input to 128
        conv1 = F.relu(conv1)
        conv1 = self.conv1b(conv1)  # 128 to 128
        conv1 = F.relu(conv1)
        bn1 = self.bn1(conv1)
        pool1 = F.max_pool2d(bn1, kernel_size=(2, 2))

        conv2 = self.conv2a(pool1)  # 128 to 256
        conv2 = F.relu(conv2)
        conv2 = self.conv2b(conv2)  # 256 to 256
        conv2 = F.relu(conv2)
        bn2 = self.bn2(conv2)
        pool2 = F.max_pool2d(bn2, kernel_size=(2, 2))

        # conv3 = self.conv3a(pool2)  # 256 to 512
        # conv3 = F.relu(conv3)
        # conv3 = self.conv3b(conv3)  # 512 to 512
        # conv3 = F.relu(conv3)
        # bn3 = self.bn3(conv3)
        # pool3 = F.max_pool2d(bn3, kernel_size=(2, 2))

        # conv4 = self.conv4a(pool3)  # 512 to 512
        # conv4 = F.relu(conv4)
        # conv4 = self.conv4b(conv4)  # 512 to 512
        # conv4 = F.relu(conv4)
        # bn4 = self.bn4(conv4)
        # pool4 = F.max_pool2d(bn4, kernel_size=(2, 2))

        # conv5 = self.conv5a(pool4)  # 512 to 1024
        # conv5 = F.relu(conv5)
        # conv5 = self.conv5b(conv5)  # 1024 to 1024
        # conv5 = F.relu(conv5)
        # bn5 = self.bn5(conv5)

        # up6 = F.upsample(bn5, scale_factor=2, mode="nearest")
        # up6 = self.conv6a(up6)  # 1024 to 512
        # up6 = F.relu(up6)
        # merge6 = torch.cat([bn4, up6], dim=1) # 512 and 512 to 1024 along c dimension
        # conv6 = self.conv6b(merge6)  # 1024 to 512
        # conv6 = F.relu(conv6)
        # conv6 = self.conv6c(conv6)  # 512 to 512
        # conv6 = F.relu(conv6)
        # bn6 = self.bn6(conv6)

        # up7 = F.upsample(bn6, scale_factor=2, mode="nearest")
        # up7 = self.conv7a(up7)  # 1024 to 512
        # up7 = F.relu(up7)
        # merge7 = torch.cat([bn3, up7], dim=1) # 512 and 512 to 1024 along c dimension
        # conv7 = self.conv7b(merge7)  # 1024 to 512
        # conv7 = F.relu(conv7)
        # conv7 = self.conv7c(conv7)  # 512 to 512
        # conv7 = F.relu(conv7)
        # bn7 = self.bn7(conv7)

        # up8 = F.upsample(bn7, scale_factor=2, mode="nearest")
        # up8 = self.conv8a(up8)  # 512 to 256
        # up8 = F.relu(up8)
        # merge8 = torch.cat([bn2, up8], dim=1) # 256 and 256 to 512 along c dimension
        # conv8 = self.conv8b(merge8)  # 512 to 256
        # conv8 = F.relu(conv8)
        # conv8 = self.conv8c(conv8)  # 256 to 256
        # conv8 = F.relu(conv8)
        # bn8 = self.bn8(conv8)

        up9 = F.upsample(bn8, scale_factor=2, mode="nearest")
        up9 = self.conv9a(up9)  # 256 to 128
        up9 = F.relu(up9)
        merge9 = torch.cat([bn1, up9], dim=1) # 128 and 128 to 256 along c dimension
        conv9 = self.conv9b(merge9)  # 256 to 128
        conv9 = F.relu(conv9)
        conv9 = self.conv9c(conv9)  # 128 to 128
        conv9 = F.relu(conv9)  # no batch norm on last layer
 
        final_layer_logits = self.final_conv(conv9)

        # transpose from shape (b, c, h, w) back to (b, h, w, c) to align with training data
        final_layer_logits = torch.movedim(final_layer_logits, 1, -1)  # move c from second to final dim
        b, h, w, c = final_layer_logits.shape

        # unpack c=classes*months dimension into classes, months as separate dimensions
        final_layer_logits = final_layer_logits.reshape((b, h, w, self.n_output_classes, self.n_forecast_days))

        output = F.softmax(final_layer_logits, dim=-2)  # apply over n_output_classes dimension
        
        return output  # shape (b, h, w, c, t)


Some metrics for evaluating IceNet performance:

In [42]:
from torchmetrics import Metric

class IceNetAccuracy(Metric):
    """
    Binary accuracy metric for use at multiple leadtimes.
    """    

    # Set class properties
    is_differentiable: bool = False
    higher_is_better: bool = True
    full_state_update: bool = True

    def __init__(self, leadtimes_to_evaluate: list):
        """
        Construct a binary accuracy metric for use at multiple leadtimes.
        :param leadtimes_to_evaluate: A list of leadtimes to consider
            e.g., [0, 1, 2, 3, 4, 5] to consider all six months in accuracy computation or
            e.g., [0] to only look at the first month's accuracy
            e.g., [5] to only look at the sixth month's accuracy
        """
        super().__init__()
        self.leadtimes_to_evaluate = leadtimes_to_evaluate
        self.add_state("weighted_score", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("possible_score", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):
        # preds and target are shape (b, h, w, t)
        # sum marginal and full ice for binary eval
        preds = (preds > 0).long()
        target = (target > 0).long()
        base_score = preds[:, :, :, self.leadtimes_to_evaluate] == target[:, :, :, self.leadtimes_to_evaluate]
        self.weighted_score += torch.sum(base_score * sample_weight[:, :, :, self.leadtimes_to_evaluate])
        self.possible_score += torch.sum(sample_weight[:, :, :, self.leadtimes_to_evaluate])

    def compute(self):
        return self.weighted_score.float() / self.possible_score


class SIEError(Metric):
    """
    Sea Ice Extent error metric (in km^2) for use at multiple leadtimes.
    """ 

    # Set class properties
    is_differentiable: bool = False
    higher_is_better: bool = False
    full_state_update: bool = True

    def __init__(self, leadtimes_to_evaluate: list):
        """
        Construct an SIE error metric (in km^2) for use at multiple leadtimes.
        :param leadtimes_to_evaluate: A list of leadtimes to consider
            e.g., [0, 1, 2, 3, 4, 5] to consider all six months in computation or
            e.g., [0] to only look at the first month
            e.g., [5] to only look at the sixth month
        """
        super().__init__()
        self.leadtimes_to_evaluate = leadtimes_to_evaluate
        self.add_state("pred_sie", default=torch.tensor(0.), dist_reduce_fx="sum")
        self.add_state("true_sie", default=torch.tensor(0.), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor, sample_weight: torch.Tensor):
        # preds and target are shape (b, h, w, t)
        # sum marginal and full ice for binary eval
        preds = (preds > 0).long()
        target = (target > 0).long()
        self.pred_sie += preds[:, :, :, self.leadtimes_to_evaluate].sum()
        self.true_sie += target[:, :, :, self.leadtimes_to_evaluate].sum()

    def compute(self):
        return (self.pred_sie - self.true_sie) * 25**2 # each pixel is 25x25 km

A _LightningModule_ wrapper for UNet model.

In [43]:
import lightning.pytorch as pl
from torchmetrics import MetricCollection

class LitUNet(pl.LightningModule):
    """
    A LightningModule wrapping the UNet implementation of IceNet.
    """
    def __init__(self,
                 model: nn.Module,
                 criterion: callable,
                 learning_rate: float):
        """
        Construct a UNet LightningModule.
        Note that we keep hyperparameters separate from dataloaders to prevent data leakage at test time.
        :param model: PyTorch model
        :param criterion: PyTorch loss function for training instantiated with reduction="none"
        :param learning_rate: Float learning rate for our optimiser
        """
        super().__init__()
        self.model = model
        self.criterion = criterion
        self.learning_rate = learning_rate
        self.n_output_classes = model.n_output_classes  # this should be a property of the network

        metrics = {
            "val_accuracy": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),
            "val_sieerror": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))
        }
        for i in range(self.model.n_forecast_days):
            metrics[f"val_accuracy_{i}"] = IceNetAccuracy(leadtimes_to_evaluate=[i])
            metrics[f"val_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i])
        self.metrics = MetricCollection(metrics)

        test_metrics = {
            "test_accuracy": IceNetAccuracy(leadtimes_to_evaluate=list(range(self.model.n_forecast_days))),
            "test_sieerror": SIEError(leadtimes_to_evaluate=list(range(self.model.n_forecast_days)))
        }
        for i in range(self.model.n_forecast_days):
            test_metrics[f"test_accuracy_{i}"] = IceNetAccuracy(leadtimes_to_evaluate=[i])
            test_metrics[f"test_sieerror_{i}"] = SIEError(leadtimes_to_evaluate=[i])
        self.test_metrics = MetricCollection(test_metrics)

        self.save_hyperparameters(ignore=["model", "criterion"])

    def forward(self, x):
        """
        Implement forward function.
        :param x: Inputs to model.
        :return: Outputs of model.
        """
        return self.model(x)

    def training_step(self, batch):
        """
        Perform a pass through a batch of training data.
        Apply pixel-weighted loss by manually reducing.
        See e.g. https://discuss.pytorch.org/t/unet-pixel-wise-weighted-loss-function/46689/5.
        :param batch: Batch of input, output, weight triplets
        :param batch_idx: Index of batch
        :return: Loss from this batch of data for use in backprop
        """
        x, y, sample_weight = batch
        y_hat = self.model(x)
        # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)
        # note that criterion needs reduction="none" for weighting to work
        if isinstance(self.criterion, nn.CrossEntropyLoss):  # requires int class encoding
            loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())
        else:  # requires one-hot encoding
            loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))
        loss = torch.mean(loss * sample_weight.movedim(-2, 1))
        self.log("train_loss", loss, sync_dist=True)
        return loss

    def validation_step(self, batch):
        x, y, sample_weight = batch
        y_hat = self.model(x)
        # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)
        # note that criterion needs reduction="none" for weighting to work
        if isinstance(self.criterion, nn.CrossEntropyLoss):  # requires int class encoding
            loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())
        else:  # requires one-hot encoding
            loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))
        loss = torch.mean(loss * sample_weight.movedim(-2, 1))
        self.log("val_loss", loss, on_step=False, on_epoch=True, sync_dist=True)  # epoch-level loss
        y_hat_pred = y_hat.argmax(dim=-2).long()  # argmax over c where shape is (b, h, w, c, t)
        self.metrics.update(y_hat_pred, y.argmax(dim=-2).long(), sample_weight.squeeze(dim=-2))  # shape (b, h, w, t)
        return loss

    def on_validation_epoch_end(self):
        self.log_dict(self.metrics.compute(), on_step=False, on_epoch=True, sync_dist=True)  # epoch-level metrics
        self.metrics.reset()

    def test_step(self, batch):
        x, y, sample_weight = batch
        y_hat = self.model(x)
        # y and y_hat are shape (b, h, w, c, t) but loss expects (b, c, h, w, t)
        # note that criterion needs reduction="none" for weighting to work
        if isinstance(self.criterion, nn.CrossEntropyLoss):  # requires int class encoding
            loss = self.criterion(y_hat.movedim(-2, 1), y.argmax(-2).long())
        else:  # requires one-hot encoding
            loss = self.criterion(y_hat.movedim(-2, 1), y.movedim(-2, 1))
        loss = torch.mean(loss * sample_weight.movedim(-2, 1))
        self.log("test_loss", loss, on_step=False, on_epoch=True, sync_dist=True)  # epoch-level loss
        y_hat_pred = y_hat.argmax(dim=-2)  # argmax over c where shape is (b, h, w, c, t)
        self.test_metrics.update(y_hat_pred, y.argmax(dim=-2).long(), sample_weight.squeeze(dim=-2))  # shape (b, h, w, t)
        return loss

    def on_test_epoch_end(self):
        self.log_dict(self.test_metrics.compute(),on_step=False, on_epoch=True, sync_dist=True)  # epoch-level metrics
        self.test_metrics.reset()

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return {
            "optimizer": optimizer
        }

Function for training UNet model using PyTorch Lightning.

In [44]:
from lightning.pytorch.callbacks import ModelCheckpoint

def train_icenet(configuration_path,
                 learning_rate,
                 max_epochs,
                 batch_size,
                 n_workers,
                 filter_size,
                 n_filters_factor,
                 seed):
    """
    Train IceNet using the arguments specified in the `args` namespace.
    :param args: Namespace of configuration parameters
    """
    # init
    pl.seed_everything(seed)
    
    # configure datasets and dataloaders
    train_dataset = IceNetDataSetPyTorch(configuration_path, mode="train")
    val_dataset = IceNetDataSetPyTorch(configuration_path, mode="val")
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers,
                                  persistent_workers=True, shuffle=False)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers,
                                persistent_workers=True, shuffle=False)

    # construct unet
    model = UNet(
        input_channels=len(train_dataset._ds._config["channels"]),
        filter_size=filter_size,
        n_filters_factor=n_filters_factor,
        n_forecast_days=train_dataset._ds._config["n_forecast_days"]
    )
    
    criterion = nn.CrossEntropyLoss(reduction="none")
    
    # configure PyTorch Lightning module
    lit_module = LitUNet(
        model=model,
        criterion=criterion,
        learning_rate=learning_rate
    )

    # set up trainer configuration
    trainer = pl.Trainer(
        accelerator="auto",
        devices=1,
        log_every_n_steps=10,
        max_epochs=max_epochs,
        num_sanity_val_steps=1,
    )
    trainer.callbacks.append(ModelCheckpoint(monitor="val_accuracy", mode="max"))

    # train model
    print(f"Training {len(train_dataset)} examples / {len(train_dataloader)} batches (batch size {batch_size}).")
    print(f"Validating {len(val_dataset)} examples / {len(val_dataloader)} batches (batch size {batch_size}).")
    trainer.fit(lit_module, train_dataloader, val_dataloader)

In [45]:
seed = 45
train_icenet(configuration_path=dataset_config,
             learning_rate=1e-4,
             max_epochs=10,
             batch_size=4,
             n_workers=12,
             filter_size=3,
             n_filters_factor=1,
             seed=seed)

INFO: Global seed set to 45
INFO:lightning.fabric.utilities.seed:Global seed set to 45
INFO:root:Loading configuration dataset_config.pytorch_notebook.json
INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json
INFO:root:Loading configuration dataset_config.pytorch_notebook.json
INFO:root:Loading configuration /data/hpcdata/users/rychan/notebooks/notebook-pipeline/loader.notebook_api_data.json
INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICE

Training 91 examples / 23 batches (batch size 4).
Validating 21 examples / 6 batches (batch size 4).
Sanity Checking: 0it [00:00, ?it/s]

In [34]:
len(train_dataset._ds._config["channels"])

9

In [37]:
train_dataset._ds._config["n_forecast_days"]

7