In [None]:
from IPython.display import IFrame # Display YouTube videos

<div style="background-color: #ccffcc; padding: 10px;">
    <h1> Tutorial X </h1>
    <h2> Image Segmentation with U-Net and fastai </h2>
</div>

# Overview

This Jupyter notebook demonstrates how artificial neural networks (ANNs) can be applied to image segmentation problems. Segmentation in this context refers to the task of assigning discrete labels to individual pixels or regions of an image. We can use segmentation models to identify and locate features of interest within an image. This notebook contains a simple application to self-driving cars, where we train a segmentation model to identify important features in dashcam footage, as well as a more complicated example, based on the work of [Coney et al. (2023)](https://doi.org/10.1002/qj.4592), identifying and characterising trapped lee waves over the UK.

## Recommended reading

* [Fastai: A Layered API for Deep Learning](https://doi.org/10.3390/info11020108)
* [U-Net: Convolutional Networks for Biomedical Image Segmentation](https://doi.org/10.1007/978-3-319-24574-4_28)
* [Identifying and characterising trapped lee waves using deep learning techniques](https://doi.org/10.1002/qj.4592)

<hr>

<div style="background-color: #e6ccff; padding: 10px;">

<h1> Machine Learning Theory </h1>

# Image Segmentation

## The problem

Image segmentation models are designed to tackle the problem of partitioning an image into meaningful segments or regions, each corresponding to different objects or parts of objects within the image. This process is crucial in various applications such as medical imaging, where it helps in identifying and isolating different anatomical structures (e.g. organs or tumours), or in autonomous driving, where it can aid in recognising and distinguishing between pedestrians, vehicles, and road signs. By accurately segmenting images, these models enable more precise analysis and interpretation, facilitating tasks like object detection, scene understanding, and image editing. Essentially, image segmentation transforms raw visual data into structured information, making it easier for machines to understand and interact with the visual world. More recently, segmentation models are being applied to weather and climate forecasting applications, where their ability to identify structures in image data makes them ideally suited.

## Popular models for image segmentation

* U-Net: Its architecture has become a standard in medical image segmentation due to its ability to perform well with limited training data and its precise localization capabilities.
* Mask R-CNN: This model is highly significant for instance segmentation, as it not only detects objects but also provides pixel-level masks, making it versatile for various applications, including autonomous driving and video analysis.
* DeepLab: Known for its high accuracy in semantic segmentation, DeepLab’s use of atrous convolution allows it to capture multi-scale context, making it a powerful tool for tasks requiring detailed scene understanding.

## The U-Net model architecture

The [U-Net](https://doi.org/10.1007/978-3-319-24574-4_28) model architecture is a type of convolutional neural network (CNN) originally designed for biomedical image segmentation. Introduced by Olaf Ronneberger, Philipp Fischer, and Thomas Brox in 2015, U-Net is known for its distinctive U-shaped structure. This architecture consists of a contracting path (the encoder) to capture context and a symmetric expanding path (the decoder) that enables precise localization. The contracting path follows the typical architecture of a convolutional network, with repeated application of convolutions, each followed by a rectified linear unit (ReLU) and a max-pooling operation. The expanding path, on the other hand, involves upsampling the feature maps and performing convolutions, which helps in reconstructing the image with high resolution. U-Net's ability to work with very few training images and its efficient use of data augmentation make it particularly effective for tasks where annotated data is scarce.

![Schematic of U Net](https://rmets.onlinelibrary.wiley.com/cms/asset/63d9263f-f5c7-48dc-8ba6-60f19fb6e5a7/qj4592-fig-0004-m.jpg)

The video in the cell below gives a 10-minute introduction to the U-Net.

In [None]:
IFrame("https://www.youtube.com/embed/NhdzGfB1q74?si=p8ti5ydxXvqJuABi","560", "315" )

<div style="background-color: #cce5ff; padding: 10px;">

# Python

## [PyTorch](https://pytorch.org/)

PyTorch is an open-source machine learning library developed by Facebook's AI Research lab. It is widely used for applications such as natural language processing and computer vision. PyTorch is known for its flexibility and ease of use, particularly due to its dynamic computation graph, which allows for more intuitive model building and debugging. However, it is considered quite low-level compared to some other frameworks (e.g. [Keras](https://keras.io/)), meaning that defining complex models like a U-Net can require a significant amount of verbose code. This verbosity can make the development process more cumbersome, especially for those who are new to deep learning.

## [fastai](https://www.fast.ai/)

fastai is a high-level library built on top of PyTorch that simplifies the process of training deep learning models. It provides a range of pre-built functions and classes that allow users to leverage the powerful capabilities of PyTorch without needing to write extensive amounts of code. With fastai, you can define and train complex models, such as U-Nets, in just a few lines of code. This makes it an excellent choice for both beginners and experienced practitioners who want to quickly prototype and experiment with different models while still benefitting from the flexibility and performance of PyTorch under the hood.

## Further reading

If you want to run this notebook locally or on a remote service:

* [running Jupyter notebooks](https://jupyter.readthedocs.io/en/latest/running.html)
* [installing the required Python environments](https://github.com/cemac/LIFD_ENV_ML_NOTEBOOKS/blob/main/howtorun.md)
* [running the Jupyter notebooks locally](https://github.com/cemac/LIFD_ENV_ML_NOTEBOOKS/blob/main/jupyter_notebooks.md)

</div>

<hr>

<div style="background-color: #ffffcc; padding: 10px;">
    
<h1> Requirements </h1>

These notebooks should run with the following requirements satisfied.

<h2> Python Packages: </h2>

* fastai
* pytorch
* numpy
* xarray
* dask
* netCDF4
* bottleneck
* matplotlib
* cartopy
* notebook

<h2> Data Requirements</h2>

This notebook refers to some external datasets and learner objects which are downloaded via Python scripts within the notebook.

</div>

**Contents:**

1. [Overview and machine-learning theory](#Overview)
2. [Application to self-driving cars](#Application-to-self-driving-cars)
3. [Application to detection of lee waves](#Application-to-detection-of-lee-waves)

<div style="background-color: #cce5ff; padding: 10px;">

## Import modules

These are all the modules needed during this tutorial.

</div>

In [None]:
from fastai.vision.all import *
from matplotlib.colors import ListedColormap
import xarray as xr
import matplotlib.pyplot as plt
import urllib.request
import os
import zipfile
import pickle
import fastai


<div style="background-color: #cce5ff; padding: 10px;">

### Note on CUDA
If you have a GPU then you should enable CUDA by commenting out the cell below. Otherwise, all code will run on the CPU by default (much slower).

</div>

In [None]:
fastai.torch_core.default_device(use=False) # comment out if you have a GPU!

<div style="background-color: #ffffcc; padding: 10px;">

### Note for Windows users

The fastai library was designed with Linux filesystems in mind, so can raise certain errors on Windows machines. The following cell is a workaround to allow the rest of the notebook to run on Windows.

</div>

In [None]:
if os.name == 'nt':
    import pathlib
    temp = pathlib.PosixPath
    pathlib.PosixPath = pathlib.WindowsPath

<div style="background-color: #ffffcc; padding: 10px;">

### Note for Colab users

Google Colab doesn't come with cartopy installed by default. Uncomment the cell below to rectify this.

</div>

In [None]:
# pip install cartopy

<div style="background-color: #e6ccff; padding: 10px;">

# Application to self-driving cars

This example is a quick introduction to the `unet_learner` object in fastai and is based on the official package documentation.

We will start by downloading a small version of the [CamVid](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/) dataset. The Cambridge-driving Labeled Video Database (CamVid) is the first collection of videos with object class semantic labels, complete with metadata. The database provides ground truth labels that associate each pixel with one of 32 semantic classes. The image below uses colour to show how pixels are labelled according to the object present.

![CamVid image](http://mi.eng.cam.ac.uk/research/projects/VideoRec/CamVid/pr/DBOverview1_1_huff_0000964.jpg)

fastai includes many helper functions, such as `untar_data`, which simplify downloading of example datasets.

</div>

In [None]:
# Download subset of the CamVid segmentation dataset
path = untar_data(URLs.CAMVID_TINY)
path.ls()

<div style="background-color: #ccffcc; padding: 10px;">

The above cell has printed the names of directories where the images and label masks are stored. The label assigned to each pixel is stored as an integer. Class labels corresponding to each integer are stored in `codes.txt`.

</div>

In [None]:
# Import class labels
codes = np.loadtxt(path/'codes.txt', dtype=str)
len(codes), codes

<div style="background-color: #ccffcc; padding: 10px;">

We can see that there are 32 unique classes with corresponding labels printed above.

</div>

In [None]:
# Get filenames of input images
fnames = get_image_files(path/'images')
fnames[0]

In [None]:
# Look at example of label file
(path/'labels').ls()[0]

<div style="background-color: #ccffcc; padding: 10px;">

Note that the image file containing the label mask has the same filename as the original image but with `_P` appended just before the `.png` file extension.

We can now write a function called `label_func` to associate each image file with its corresponding label mask.

</div>

In [None]:
# Function to locate file containing class labels
def label_func(fn):
    return path/'labels'/f'{fn.stem}_P{fn.suffix}'
label_func(fnames[0])

<div style="background-color: #ccffcc; padding: 10px;">

The `DataBlock` object in fastai is a powerful and flexible tool designed to simplify the process of creating datasets for machine learning. It provides a high-level API that allows you to define the structure of your data pipeline in a clear and concise manner. With `DataBlock`, you can specify how to get your data, how to split it into training and validation sets, how to label it, and how to apply transformations.

</div>

In [None]:
# Build fastai DataBlock
camvid = DataBlock(
    blocks=(ImageBlock, MaskBlock(codes)), # (input, output)
    get_items=get_image_files, # function for retrieving input files
    get_y=label_func, # function to locate label mask
    splitter=RandomSplitter(), # how to perform train/validation split
    batch_tfms=aug_transforms(size=(120,160)) # data augmentation transforms including output size
)

<div style="background-color: #ccffcc; padding: 10px;">

Now that we have a `DataBlock`, we can create a `DataLoaders` object. This is similar to a `DataLoader` in PyTorch, but more high-level in that it manages more aspects of the training process. As well as our `DataBlock`, we need to specify the directory where our training data are located and the batch size to use in training.

We are choosing a batch size of eight to ensure that VRAM usage doesn't exceed 4 GB. Depending on the size of your GPU memory, you may wish to increase the batch size to speed up training.

</div>

In [None]:
# Construct a DataLoaders object from the DataBlock (not the same as DataLoader in PyTorch)
dls = camvid.dataloaders(path/'images', path=path, bs=8)

<div style="background-color: #ccffcc; padding: 10px;">

To verify that our `DataLoaders` object is working correctly, we can display some images from a sample batch with their class labels overlaid as colours.

</div>

In [None]:
# Show segmented images from sample batch
dls.show_batch(max_n=6)

<div style="background-color: #ccffcc; padding: 10px;">

In the fastai library, a learner object is a central component designed to streamline the process of training and evaluating models. It encapsulates the model, the data, and the training loop, providing a high-level API that simplifies complex tasks. With a learner, you can easily fine-tune hyperparameters, apply callbacks, and leverage built-in functionalities like learning rate scheduling and mixed-precision training.

For our purposes here, we will instantiate a `unet_learner` object. This will automatically set up the neural network for the dimensions of our problem, including the number of unique pixel classes. In addition to automating the training loop, the `unet_learner` uses a pre-trained neural network as its encoder model. Here we specify ResNet-34 as the encoder. This is a residual neural network with 34 layers and has been pre-trained on the ImageNet database of more than one million images!

</div>

In [None]:
# Download fastai U-Net learner with ResNet-34 as the encoder
learn = unet_learner(dls, resnet34)

<div style="background-color: #ccffcc; padding: 10px;">

fastai encourages users to leverage the concept of transfer learning, rather than training new models from scratch. In transfer learning, we typically download a model (here ResNet-34) that has been pre-trained on a large generic dataset to perform a standard task, e.g. image recognition. This pre-trained model, sometimes called a foundation model, contains a great deal of prior information for recognising features in images. For a specific task, such as identifying objects in dashcam footage, we need only fine-tune this foundation model (iteratively update the weights and biases) for a few epochs on our training data.

Depending on your hardware, running the next cell may take a few minutes, so be warned! The next few lines of code are therefore commented out by default.

</div>

In [None]:
# Fine-tune the model until the validation loss stops improving
# learn.fine_tune(100, cbs=EarlyStoppingCallback(monitor='valid_loss', patience=5))

<div style="background-color: #ccffcc; padding: 10px;">

After few epochs of fine tuning, our validation loss has stopped improving and the model is trained. We can now plot our model predictions superimposed on the original images, and compare with the true label masks.

</div>

In [None]:
# Compare model output with ground truth
# learn.show_results(max_n=6)

<div style="background-color: #ccffcc; padding: 10px;">

fastai includes helper functions such as `plot_top_losses`, which shows which examples the model had the hardest time classifying. Looking at the highest losses can be a good way to find outliers and errors in the training data (unlikely here, as this dataset is very clean).

</div>

In [None]:
# Look at the images where the validation loss was highest
# interp = SegmentationInterpretation.from_learner(learn)
# interp.plot_top_losses(k=3)

<div style="background-color: #ccffcc; padding: 10px;">

In this example, we have seen how, in only a few lines of code, we can train a segmentation model with good performance and with a relatively short training time. The model predictions, although not perfect, are remarkable given the small amount of training data and the high number of object classes (32).

</div>

<hr>

<div style="background-color: #ccffcc; padding: 10px;">

# Application to detection of lee waves

This example is based on the work of [Coney et al. (2023)](https://doi.org/10.1002/qj.4592), who used neural networks to identify and characterise trapped lee waves over the UK. A lee wave is type of gravity wave created when air in the atmosphere flows over mountainous terrain.

## Gravity waves

A gravity wave is a vertical wave, for example a ripple, in the atmosphere. Gravity waves can be formed when air is forced upwards by topography (e.g. wind blowing over a mountain). This creates turbulence that can be felt throughout the column of air above a mountain. Gravity waves are of interest for improving our understanding and forecasting capability, e.g. for aviation. If you'd like to learn more, NOAA have a useful information page all about gravity waves in the atmosphere [here](https://www.weather.gov/source/zhu/ZHU_Training_Page/Miscellaneous/gravity_wave/gravity_wave.html).

![diagram of gravity waves](https://www.weather.gov/source/zhu/ZHU_Training_Page/Miscellaneous/gravity_wave/radarscope2.png)

(taken from https://www.weather.gov/source/zhu/ZHU_Training_Page/Miscellaneous/gravity_wave/gravity_wave.html)

## Lee waves

Lee waves can be observed by eye as you get clouds forming on the crest of the wave, e.g. when you look up and see stripes of clouds or lenticular clouds like the image seen below, where a mountain has forced a wave in the air to form. These can be spotted in photos and satellite images. For more information, the Met Office have a basic overview [here](https://www.metoffice.gov.uk/weather/learn-about/weather/types-of-weather/wind/lee-waves).

![Lenticular cloud over mountains image](https://www.metoffice.gov.uk/binaries/content/gallery/metofficegovuk/images/weather/learn-about/weather/lenticular-cloud.jpg)

(taken from https://www.metoffice.gov.uk/weather/learn-about/weather/types-of-weather/wind/lee-waves)

## NWP data

Lee waves can be identified in Numerical Weather Prediction (NWP) model output in a range of fields, such as vertical wind velocity just above topography. Below is an image of model output where lee waves are resolved, showing a characteristic stripey vertical velocity pattern. These patterns are easily picked up by eye, but not so easily detected automatically. To detect these patterns, typically spectral analysis is employed using idealised representations of waves.

![Example UKV data showing stripey lee waves in the verticle velocity output](https://rmets.onlinelibrary.wiley.com/cms/asset/10a1023d-3e98-4f26-9100-224ac84ea3d1/qj4592-fig-0001-m.jpg)

(Figure 1 from [Coney et al. (2023)](https://doi.org/10.1002/qj.4592))

</div>

<div style="background-color: #cce5ff; padding: 10px;">

## Downloading the training data

To start with, we need to download and extract the training data using the following two Python scripts.

</div>

In [None]:
url = "https://zenodo.org/records/10230764/files/data.zip"
filename = "data.zip"

if not os.path.isfile(filename):
    # If the file doesn't exist, download it using urllib
    urllib.request.urlretrieve(url, filename)
    print("File downloaded successfully.")
else:
    print("File already exists.")

In [None]:
# Check if the 'data' directory exists
if not os.path.isdir('data'):
    print("Directory 'data' does not exist. Creating it now...")
    
    # Create the 'data' directory
    os.mkdir('data')
    
    print("Extracting contents of 'data.zip' into current directory...")
    
    # Extract the contents of 'data.zip' into the current directory
    with zipfile.ZipFile('data.zip', 'r') as zip_ref:
        # Extract only the 'data' directory from within the zip file
        for member in zip_ref.namelist():
            if member.startswith('data/'):
                zip_ref.extract(member, '.')
    
    print("Extraction complete.")
else:
    print("Directory 'data' already exists.")

<div style="background-color: #ccffcc; padding: 10px;">

Once you have the data already downloaded, re-running the above scripts won't do anything.

</div>

<div style="background-color: #ccffcc; padding: 10px;">

## Training the segmentation model

Setting our root path to the directory containing the training data will save us some typing.

</div>

In [None]:
# Set root path to directory containing training data
root = Path('data/train')


<div style="background-color: #ccffcc; padding: 10px;">

Unlike in the previous example, where the class label definitions were provided in a file that came with the data, this time we need to define the class labels ourselves. We do this by creating a Python dictionary called `codes`.

</div>

In [None]:
# Define binary class labels
codes = {0:'no wave', 255:'lee wave'}

<div style="background-color: #ccffcc; padding: 10px;">

Each pixel in our segmentation mask is therefore encoded as an 8-bit integer, with black pixels indicating a lee wave is present.

Again, we need to define a function to match input images to their respective segmentation mask (take a peek inside the data directory to verify that this function will do what we intend...).

</div>

In [None]:
# Function to retrieve label mask for a given input file
def label_func(fn): 
    string = str(fn.stem)[:49] + 'mask.png' # mask files have .png suffix
    return root/'masks_png'/string

<div style="background-color: #ccffcc; padding: 10px;">

The NWP data that we will be using as input to our network is stored in NetCDF format, as is common for weather and climate datasets. fastai does not understand this file format natively, so we therefore need to define a function `open_xarray` to read in the data. The function uses the xarray library to open a NetCDF file and returns a NumPy array of the vertical velocities.

</div>

In [None]:
# Function to extract vertical velocity array from NetCDF file
def open_xarray(fname):
    x = xr.open_dataarray(fname)
    array = x.values # return values as NumPy array
    return array

<div style="background-color: #ccffcc; padding: 10px;">

One of fastai's more powerful features is the way that it automates the process of data augmentation. Data augmentation is a technique used to increase the diversity of data available for training machine learning models by applying various transformations, such as rotations or flips, to existing data, thereby improving the model's performance and robustness.

Here we apply several transformations to our data:

* z-score normalisation centres and scales the pixel values about zero. This improves performance of the optimisation algorithms used for training the model.
* random flipping
* random zooming in
* random rotation

Augmenting the data using these transformations effectively increases the size of our training dataset and greatly reduces the chance of overfitting.

</div>

In [None]:
# Data augmentation transformations to apply on GPU
tfms = [
    Normalize.from_stats([0,0,0], [1,1,1]), # mean zero and std dev. one for all channels
    Flip(), # random flip images with probability 0.5
    Zoom(max_zoom=20, p=0.5), # apply up to a 20x zoom with probability 0.5
    Rotate(max_deg=360, p=0.9) # apply a random rotation with probability 0.9
]

<div style="background-color: #ccffcc; padding: 10px;">

We now have everything we need to build a fastai `DataBlock` object. Note that we are using the `open_xarray` and `label_func` functions that we just defined, as well as our dictionary `codes`.

</div>

In [None]:
# Build fastai DataBlock
waves_ds = DataBlock(
    blocks=(ImageBlock, MaskBlock(codes)), # (input, output)
    get_items=get_files, # function for retrieving input files (not images this time!)
    get_x=open_xarray, # function to extract input array from NetCDF file
    get_y=label_func, # function to locate label mask
    splitter=RandomSplitter(), # how to perform train/validation split
    batch_tfms=tfms, # data augmentation transforms to be applied
)

<div style="background-color: #ccffcc; padding: 10px;">

A `DataLoaders` object can now be created as before. This time we set the batch to two: these images are higher resolution, so take up more VRAM per image.

</div>

In [None]:
# Construct a DataLoaders object from the DataBlock
dls = waves_ds.dataloaders(root/'vertical_velocities', path=root, bs=2) # batch size of two to conserve GPU memory

<div style="background-color: #ccffcc; padding: 10px;">

With the `show_batch` function we can check that the `DataLoaders` is working correctly. Note that some of the images in the training set do not contain lee waves.

</div>

In [None]:
# Show segmented images from sample batch (batch size is just two)
dls.show_batch()

<div style="background-color: #ccffcc; padding: 10px;">

The models in this example will take much longer to train that the model in the first example. We will therefore be saving trained models to disc. The following Python script creates a directory to store the trained models, if such a directory doesn't yet exist.

</div>

In [None]:
# Check if the 'models_out' directory exists
if not os.path.isdir('models_out'):
    print("Directory 'models_out' does not exist. Creating it now...")
    
    # Create the 'models_out' directory
    os.mkdir('models_out')
else:
    print("Directory 'models_out' already exists.")

<div style="background-color: #e6ccff; padding: 10px;">

We will use the same U-Net / ResNet-34 architecture as before. During training we will also monitor the [$F_1$ score](https://en.wikipedia.org/wiki/F-score) on the validation set. The $F_1$ score is the harmonic mean of the [precision and recall](https://en.wikipedia.org/wiki/Precision_and_recall) of a binary classifier. Values range from zero to one, with one indicating a perfect score.

</div>

In [None]:
# Download fastai U-Net learner with ResNet-34 as the encoder and compute F1 score
learn2 = unet_learner(dls, resnet34, metrics=DiceMulti)

<div style="background-color: #ffffcc; padding: 10px;">
The learner object is now ready for training. Note that, unless you are using a very powerful GPU, the following code may take tens of minutes and possibly hours to run! The code is therefore commented out by default.

</div>

In [None]:
%%time
# Fine-tune the model until the validation loss stops improving
# learn2.fine_tune(100, cbs=EarlyStoppingCallback(monitor='valid_loss', patience=5))

<div style="background-color: #ccffcc; padding: 10px;">

Each fastai learner stores a filesystem path as one of its attributes. We must first change this path to our `models_out` directory before exporting `learn2` as a pickle file.

</div>

In [None]:
# Export trained model to 'models_out' directory
# learn2.path = Path('models_out')
# learn2.export('segmodel.pkl')

<div style="background-color: #ccffcc; padding: 10px;">

A quick comparison of the model predictions with ground truth should indicate whether training has been successful.

</div>

In [None]:
# Compare model output with ground truth
# learn2.show_results()

<div style="background-color: #ccffcc; padding: 10px;">

## Training alternative model heads

Using the segmentation model, we can identify where lee waves are present from the vertical velocity field. However, we might also be interested in the physical characteristics of those waves. For example:

* amplitude
* wavelength
* orientation

We could start from scratch and train a new model to predict each quantity, but a more efficient approach is to use transfer learning. Our segmentation model has already been trained to identify features related to lee waves. It therefore makes a perfect starting point to develop further models to predict wave characteristics. We can do this by training what are called alternative model heads, where head refers to the last few layers of the neural network. In general, the earlier layers of the network serve to extract relevant features from the image, while the head uses these features to predict a quantity of interest. New models can be cheaply obtained by simply training new heads on an old model.

Unfortunately, the NWP data that we used to train the segmentation model does not include measurements of any of the wave characteristics we are interested in. To get around this problem, [Coney et al. (2023)](https://doi.org/10.1002/qj.4592) generated their own synthetic lee wave data using Leif Denby's [synthetic-gravity-waves](https://doi.org/10.5281/zenodo.7576811) package for Python. The following two Python scripts download and extract these synthetic data.

</div>

In [None]:
url = "https://huggingface.co/datasets/CEMAC/synthetic_lee_waves/resolve/main/synthetic_data.zip"
filename = "synthetic_data.zip"

# Check if the file exists
if not os.path.isfile(filename):
    print(f"File '{filename}' does not exist. Downloading it now...")
    
    # Download the file using urllib
    urllib.request.urlretrieve(url, filename)
    
    print("File downloaded successfully.")
else:
    print(f"File '{filename}' already exists.")

In [None]:
# Check if the 'var_amp_synthetic' directory exists
if not os.path.isdir('var_amp_synthetic'):
    print("Directory 'var_amp_synthetic' does not exist. Creating it now...")
    
    # Create the 'var_amp_synthetic' directory
    os.mkdir('var_amp_synthetic')
    
    print("Extracting contents of 'synthetic_data.zip' into 'var_amp_synthetic' directory...")
    
    # Extract only the 'var_amp_synthetic' contents from the zip file
    with zipfile.ZipFile('synthetic_data.zip', 'r') as zip_ref:
        for member in zip_ref.namelist():
            if member.startswith('var_amp_synthetic/'):
                zip_ref.extract(member, '.')
    
    print("Extraction complete.")
else:
    print("Directory 'var_amp_synthetic' already exists.")

<div style="background-color: #ccffcc; padding: 10px;">

First, let's rebase our root path to the synthetic training data directory.

</div>

In [None]:
# Set root path to directory containing training data
root = Path('var_amp_synthetic/train')

<div style="background-color: #ccffcc; padding: 10px;">

Although we will not be using label masks for training our alternative model heads, we still need the labelling function to be defined in our global namespace. This is a peculiarity of the fastai library and is necessary for it to function properly.

</div>

In [None]:
# Function to retrieve label mask for a given input file
def label_func(fn): 
    string = str(fn.stem)[:49] + 'mask.png'
    return root/'masks_png'/string


<div style="background-color: #ccffcc; padding: 10px;">

The function to read input files is the same as before. 

</div>

In [None]:
# Function to extract vertical velocity array from NetCDF file
def open_xarray(fname):
    x = xr.open_dataarray(fname)
    array = x.values
    return array

<div style="background-color: #ccffcc; padding: 10px;">

The synthetic lee wave data are stored in NumPy format. We therefore need to define a new function to read these files and return them in the format expected by our model. We also add some Gaussian noise to the data each time they are read in, as a form of data augmentation.

</div>

In [None]:
# Function to read in NumPy array file
def open_np(fname):
    x = np.load(fname)
    noise = np.random.normal(size=(512, 512))
    x = x + threshold*noise # add Gaussian noise
    x2 = np.array([x, x, x]) # copy data to three channels for input into ResNet-34
    return torch.Tensor(x2)

<div style="background-color: #ccffcc; padding: 10px;">

The files containing the different wave characteristics are stored in appropriately named directories. We need to define a separate retrieval function for each.

</div>

In [None]:
# Function to retrieve corresponding wavelength labels
def label_func_wl(fn):
    string = str(fn.stem)[:49] + '.npy'
    lbl = np.load(root/'wavelength'/string).astype('float')
    return lbl/1000

In [None]:
# Function to retrieve corresponding amplitude labels
def label_func_amp(fn):
    string = str(fn.stem)[:49] + '.npy'
    lbl = np.load(root/'amplitude'/string).astype('float')
    return lbl

In [None]:
# Function to retrieve corresponding orientation labels
def label_func_or(fn):
    string = str(fn.stem)[:49] + '.npy'
    lbl = np.load(root/'orientation'/string).astype('float')
    lbl_rad = lbl*np.pi/180 # convert degrees to radians
    return np.array([np.sin(lbl_rad),np.cos(lbl_rad)])

<div style="background-color: #ccffcc; padding: 10px;">

The procedure for training an alternative model head is sufficiently complicated to warrant its own function. To summarise briefly what this `train` function does:

1. Our pre-trained segmentation model learner is loaded from disc.
2. The `model` attribute is extracted (this contains the neural network itself).
3. The orientation model will return sines and cosines of the orientation angle, so the dimensions of the output layer can remained unchanged (for segmentation the network was returning two probabilities).
4. The amplitude and wavelength models will return a single number, so we need to redefine the last layers of the network to return one output.
5. We are now predicting real numbers, not probabilities, so the loss function is changed to Mean Squared Error (MSE).
6. A new `DataLoaders` is constructed with the modified model and loss function.
7. Different learning rates are used for different layers.
8. Weights and biases are frozen in the earlier layers of the model, so that only the head is updated during training.
9. The trained model is saved to disc.

We will call this function three times, once for each wave characteristic. Note that the threshold argument determines how much Gaussian noise is added to the inputs (this is a hyperparameter to tune).

Note that these models take even longer to train than the segmentation model, so be warned! The offending lines of code are commented out by default.

</div>

In [None]:
# Function to train alternative model to predict wave characteristic
def train(waves, characteristic, threshold, epochs=100):
    dls = waves.dataloaders(root/'data', path=root, bs=2)
    learn3 = load_learner('models_out/segmodel.pkl')
    model = learn3.model # extract segmentation model from learner object
    if characteristic != 'orientation':
        model.layers[-2] = nn.Sequential(
            torch.nn.Conv2d(99, 50, kernel_size=(1, 1), stride=(1, 1)),
            torch.nn.ReLU(),
            torch.nn.Conv2d(50, 1, kernel_size=(1, 1), stride=(1, 1))
        ) # redefine last layer of model head to give one-dimensional output for amplitude or wavelength
    loss_func = MSELossFlat() # MSE loss function for real output
    learn3 = Learner(dls, model, loss_func=loss_func) # build new learner object
    base_lr = 1e-4
    print('lr', base_lr)
    lr_mult = 10 # multiplier for differential learning rates across layers
    learn3.unfreeze()
    learn3.freeze_to(-3) # only train model head (freeze parameters in earlier layers)
    learn3.fit_one_cycle(epochs, slice(base_lr/lr_mult, base_lr), cbs=EarlyStoppingCallback(monitor='valid_loss', patience=5))
    learn3.path = Path('models_out')
    learn3.export(characteristic + '_' + str(threshold) + '.pkl')

In [None]:
# Train amplitude model
threshold = 0.0625
waves = DataBlock(
    blocks=(DataBlock, DataBlock), # use DataBlocks for arrays
    get_items=get_files, # function for retrieving input files
    get_x=open_np, # function to read NumPy array input file
    get_y=label_func_amp, # function to locate amplitude labels
    splitter=RandomSplitter(), # how to perform train/validation split
    batch_tfms=[Normalize.from_stats(*imagenet_stats)], # normalize using mean and std devs from ImageNet dataset used to train ResNet-34
)
# train(waves, 'amplitude', threshold)

In [None]:
# Train wavelength model
threshold = 0.125
waves = DataBlock(
    blocks=(DataBlock, DataBlock), # use DataBlocks for arrays
    get_items=get_files, # function for retrieving input files
    get_x=open_np, # function to read NumPy array input file
    get_y=label_func_wl, # function to locate wavelength labels
    splitter=RandomSplitter(), # how to perform train/validation split
    batch_tfms=[Normalize.from_stats(*imagenet_stats)], # normalize using mean and std devs from ImageNet dataset used to train ResNet-34
)
# train(waves, 'wavelength', threshold)

In [None]:
# Train orientation model
threshold = 0.25
waves = DataBlock(
    blocks=(DataBlock, DataBlock), # use DataBlocks for arrays
    get_items=get_files, # function for retrieving input files
    get_x=open_np, # function to read NumPy array input file
    get_y=label_func_or, # function to locate orientation labels
    splitter=RandomSplitter(), # how to perform train/validation split
    batch_tfms=[Normalize.from_stats(*imagenet_stats)], # normalize using mean and std devs from ImageNet dataset used to train ResNet-34
)
# train(waves, 'orientation', threshold)

<div style="background-color: #ccffcc; padding: 10px;">

## Plotting the results

We have trained our U-Net segmentation model to identify trapped lee waves, as well as three alternative heads to predict wave characteristics. It is now time to visualise the results and see how our model performs on test data.

</div>

<div style="background-color: #cce5ff; padding: 10px;">

Since training these models is very time-consuming, it is recommended to download the pre-trained models using the Python script below.

</div>

In [None]:
# Create output directory if it doesn't exist
os.makedirs('models_out', exist_ok=True)

urls = [
    "https://huggingface.co/CEMAC/LeeWaveNet/resolve/main/segmodel.pkl",
    "https://huggingface.co/CEMAC/LeeWaveNet/resolve/main/amplitude_0.0625.pkl",
    "https://huggingface.co/CEMAC/LeeWaveNet/resolve/main/wavelength_0.125.pkl",
    "https://huggingface.co/CEMAC/LeeWaveNet/resolve/main/orientation_0.25.pkl"
]

for url in urls:
    filename = os.path.basename(url)
    filepath = os.path.join('models_out', filename)

    if not os.path.isfile(filepath):
        # If the file doesn't exist, download it
        urllib.request.urlretrieve(url, filepath)
        print(f"File {filename} downloaded successfully to models_out.")
    else:
        print(f"File {filename} already exists in models_out.")

<div style="background-color: #ccffcc; padding: 10px;">

It will be easier if we keep all our trained models in one place. Let's define a function to read them from disc and store them in a dictionary.

</div>

In [None]:
# Function to load the trained models
def load_models():
    learn2 = load_learner('models_out/segmodel.pkl')
    wavelength_model = load_learner('models_out/wavelength_0.125.pkl')
    orientation_model = load_learner('models_out/orientation_0.25.pkl')
    amplitude_model = load_learner('models_out/amplitude_0.0625.pkl')
    models_dict = {
        'segmentation': learn2,
        'wavelength': wavelength_model,
        'orientation': orientation_model,
        'amplitude': amplitude_model
    }
    return models_dict

<div style="background-color: #ccffcc; padding: 10px;">

Making predictions from the trained models is slightly tricky. Due to fastai peculiarities, the vertical velocity input data need to be supplied in slightly different formats to the various models. We are storing the prediction output in `DataArray` objects using the xarray library, and we are returning these together in a `DataSet`. Adding the coordinates as metadata will make plotting easier later on.

</div>

In [None]:
# Function to make predictions with trained models
def predict(models_dict, ds, xcoord='projection_x_coordinate', ycoord='projection_y_coordinate', mask_nonwaves=True):
    arr = ds['upward_air_velocity'].values # extract input array of air velocities
    
    segmentation = models_dict['segmentation'].predict(arr)[0].numpy()
    ds['segmentation'] = ((ycoord, xcoord), segmentation)
    
    wavelength = models_dict['wavelength'].predict(torch.Tensor([arr, arr, arr]))[0][0].numpy()
    ds['wavelength'] = ((ycoord, xcoord), wavelength)
    
    orient = models_dict['orientation'].predict(torch.Tensor([arr, arr, arr]))[0]
    orient = 180/np.pi * np.arctan(orient[0]/orient[1]) # convert sines and cosines into an angle in degrees
    ds['orientation'] = ((ycoord, xcoord), orient)
    
    amplitude =  models_dict['amplitude'].predict(torch.Tensor(np.array([arr, arr, arr])))[0][0]
    ds['amplitude'] = ((ycoord, xcoord), amplitude)
    
    if mask_nonwaves:
        for char in ['amplitude','orientation','wavelength']:
            ds[char] = ds[char].where(ds['segmentation']==1) # remove characteristic values where no waves
    return ds

<div style="background-color: #ccffcc; padding: 10px;">

The cell below defines some custom matplotlib colour maps that we will use for plotting our results.

</div>

In [None]:
# Define colour map for vertical velocities
vv_cmap = ListedColormap([(0/255,56/255,116/255),(0/255,101/255,206/255),(0/255,128/255,189/255),
                         (0/255,151/255,154/255),(0/255,174/255,119/255),(0/255,197/255,85/255),
                         (0/255,220/255,50/255),(0/255,244/255,16/255),(255/255,253/255,255/255),
                         (238/255,212/255,0/255),(225/255,182/255,0/255),(213/255,151/255,0/255),
                         (201/255,121/255,0/255),(189/255,91/255,0/255),(177/255,60/255,0/255),
                         (165/255,30/255,0/255),(153/255,0/255,0/255),])

# Define colour map for wave amplitudes
amp_cmap = ListedColormap([(255/255,253/255,255/255),(238/255,212/255,0/255),(238/255,212/255,0/255),
                         (225/255,182/255,0/255), (225/255,182/255,0/255),
                         (213/255,151/255,0/255), (213/255,151/255,0/255),
                         (201/255,121/255,0/255), (201/255,121/255,0/255),
                         (189/255,91/255,0/255),(189/255,91/255,0/255),(177/255,60/255,0/255),(177/255,60/255,0/255),
                         (165/255,30/255,0/255),(165/255,30/255,0/255),(153/255,0/255,0/255),(153/255,0/255,0/255),])

<div style="background-color: #ccffcc; padding: 10px;">

The orientation model head outputs the sine and cosine of the wave orientation angle. In order to visualise these orientations as vectors, we will use the `quiver` method in matplotlib. This takes vector components as input rather than the angle, so we need to write a function to compute these vector components.

<div style="background-color: #ccffcc; padding: 10px;">

In [None]:
# Compute horizontal and vertical components of orientation vectors
def quiver_orient(dataset, sep=32, xcoord='projection_x_coordinate', ycoord='projection_y_coordinate'):
    angle_rad = dataset['orientation'].values*(np.pi/180) # convert degrees to radians
    
    new_x = np.zeros(int(512/sep)) # define coarser coordinate grid
    new_y = np.zeros(int(512/sep))
    angle_rad2 = np.zeros((int(512/sep), int(512/sep))) # initialize coarse matrix of orientation angles

    # Populate matrix by iterating over it and computing angles
    i = 0
    while i < len(angle_rad2):
        j = 0
        while j < len(angle_rad2[i]):
            new_angle_rad = np.pi/2 - angle_rad[sep*i][sep*j] # write orientation angles between 0 and pi
            angle_rad2[i][j] = new_angle_rad # store orientation angle in angle_rad2
            new_y[j] = dataset[ycoord][j*sep] # store y coordinate
            j = j + 1
        new_x[i] = dataset[xcoord][i*sep] # store y coordinate
        i = i + 1

    # Store coarse array in xarray Dataset
    alt_dataframe = xr.Dataset(
        data_vars={'angle_rad': ((ycoord, xcoord), angle_rad2)},
        coords = {xcoord: new_x, ycoord: new_y}
    )
    sf = 2 # scale factor to make arrows larger on plot

    # Compute horizontal and vertical components of scaled orientation vectors
    alt_dataframe['orient_u'] = ((ycoord, xcoord), sf*np.cos(angle_rad2))
    alt_dataframe['orient_v'] = ((ycoord, xcoord), sf*np.sin(angle_rad2))
    alt_dataframe['-orient_u'] = ((ycoord, xcoord), sf*-np.cos(angle_rad2))
    alt_dataframe['-orient_v'] = ((ycoord, xcoord), sf*-np.sin(angle_rad2))
    return alt_dataframe

<div style="background-color: #ccffcc; padding: 10px;">

We are now ready to define our main plotting function. This takes an xarray `DataSet` containing the input and output of our models, and produces a four-panel figure showing the model predictions. Collecting the plotting code into a function like this reduces pollution of our global namespace.
</div>

In [None]:
# Function to plot the model predictions
def plot(ds, data='ukv'):
    if data == 'ukv':
        with open('data/projection/crs.pkl', 'rb') as projfile:
            proj = pickle.load(projfile) # import coordinate reference system for Met Office data
        xcoord = 'projection_x_coordinate'
        ycoord = 'projection_y_coordinate'
    if data == 'synthetic':
        proj = None # synthetic data on regular x-y grid
        xcoord = 'x'
        ycoord = 'y'
    
    fig = plt.figure(figsize=(13, 10), layout='constrained')
    ax1 = fig.add_subplot(221, projection=proj)
    ax2 = fig.add_subplot(222, projection=proj)
    ax3 = fig.add_subplot(223, projection=proj)
    ax4 = fig.add_subplot(224, projection=proj)
    
    ds['upward_air_velocity'].plot.pcolormesh(
        cmap=vv_cmap,
        robust=False,
        rasterized=True,
        ax=ax1,
        vmin=-4.25,
        vmax=4.25,
        add_colorbar=True,
        cbar_kwargs={
            'label':'Upward Air Velocity (m s $^{-1}$)',
            'shrink':0.6,
            'ticks':np.arange(-4,5,1),
            'extend':'neither'
        }
    )
    
    ds['segmentation'].plot.contour(
        cmap=ListedColormap(['black']),
        alpha=1,
        add_colorbar=False,
        ax=ax1
    )
    
    ds['wavelength'].plot.pcolormesh(
        cmap='viridis',
        alpha=1,
        add_colorbar=True,
        rasterized=True,
        ax=ax2,
        cbar_kwargs={
            'label':'Wavelength (km)',
            'shrink':0.6,
            'extend':'neither'
        }
    )
    
    ds['amplitude'].plot.pcolormesh(
        cmap=amp_cmap,
        vmin=0,
        vmax=4.25,
        robust=False,
        rasterized=True,
        ax=ax3,
        add_colorbar=True,
        cbar_kwargs={
            'label':'Amplitude Prediction (m s $^{-1}$)',
            'shrink':0.6,'ticks':np.arange(0,4.5,0.5),
            'extend':'neither'
        }
    )
            
    ds['upward_air_velocity'].plot.pcolormesh(
        cmap=vv_cmap,
        robust=False,
        rasterized=True,
        ax=ax4,
        vmin=-4.25,
        vmax=4.25,
        add_colorbar=True,
        alpha=1,
        cbar_kwargs={
            'label':'Upward Air Velocity (m s $^{-1}$)',
            'shrink':0.6,
            'ticks':np.arange(-4,5,1),
            'extend':'neither'
        }
    )
    
    headlength = 3
    headaxislength = 2 # set these to 0 for no arrows

    width = 0.004 # arrow width
    ds2 = quiver_orient(ds, sep=16, xcoord=xcoord, ycoord=ycoord) # dataset with horizontal and vertical components
    
    ds2.plot.quiver(
        xcoord, ycoord, 'orient_u', 'orient_v',
        ax=ax4,
        transform=proj,
        width=width,
        pivot='tail',
        headlength=headlength,
        headaxislength=headaxislength,
        add_guide=False
    )

    ds2.plot.quiver(
        xcoord, ycoord, '-orient_u', '-orient_v',
        ax=ax4,
        transform=proj,
        width=width,
        pivot='tail',
        headlength=headlength,
        headaxislength=headaxislength,
        add_guide=False
    ) # add arrows pointing in opposite direction
    
    ax1.set_title('700 hPa Vertical Velocity and segmentation mask')
    ax2.set_title('Wavelength')
    ax3.set_title('Amplitude')
    ax4.set_title('Orientation (perpendicular to wave fronts)')
    
    if proj != None:
        for ax in [ax1, ax2, ax3, ax4]:
            ax.coastlines('10m', alpha=0.5)
    if data == 'ukv':
        forecast_time = str(ds['forecast_reference_time'].values)[:-10] + 'Z'
        fig.suptitle('Lee Wave Data: Characteristics Prediction ' + forecast_time)
    if data == 'synthetic':
        fig.suptitle('Synthetic Wave Characteristic Prediction', y=.93)
            
    plt.savefig('model_predictions.pdf', bbox_inches='tight')
    plt.show()

<div style="background-color: #ccffcc; padding: 10px;">

Finally, we are ready to plot our results!

We start by loading in test data from over the UK on the 14th of February 2021.

</div>

In [None]:
# Read in test data for plotting
leewaves = xr.open_dataset('data/test_feb/vertical_velocities/20210214T0900Z-PT0000H00M-wind_vertical_velocity_at_700hPa.nc')

<div style="background-color: #ccffcc; padding: 10px;">

Next, we read in our trained models and make our predictions.

</div>

In [None]:
# Load trained models
models = load_models()

In [None]:
# Make predictions with trained models
output = predict(models, leewaves)

<div style="background-color: #ccffcc; padding: 10px;">

Calling our plotting function shows the results here and also saves to disc as a pdf.

</div>

In [None]:
# Plot the model predictions and save output to disc
plot(output, data='ukv')