In [3]:
import torch
#import torch.nn as nn
#import torch.nn.functional as F
#from torch.utils.data import Dataset,DataLoader
#import torchmetrics
#from torchviz import make_dot
#import mlflow
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LinearSegmentedColormap
import xarray as xr
import netCDF4
import math
import numpy as np
import os
import time
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from datetime import datetime, timedelta
from datasets import concatenate_datasets
import glob
import sys
sys.path.append('../scripts')
from unet import UNet

%alias_magic t time

Created `%t` as an alias for `%time`.
Created `%%t` as an alias for `%%time`.


In [4]:
# Load ICON forecast data (already normalized)
FCPATH      = "/hpc/uwork/ffundel/PREDICT_FRONTS/input/"
DATE        = "2025050500"
FILES       = sorted(glob.glob(FCPATH+DATE+'/'+'*.nc', recursive=True))
datasets    = [xr.open_dataset(f) for f in FILES]
combined_ds = xr.concat(datasets, dim='time')
Input       = torch.tensor(combined_ds['ICON'].values, dtype=torch.float)

In [5]:
# Load Model and Weightsm put in inference mode and run inference for an entire ICON forecast run
best_model = UNet(in_channels=6, out_channels=4, init_features=64)
best_model.load_state_dict(torch.load('/hpc/uhome/ffundel/repos/ai-fronts/model/best-model-parameters.pt'))
best_model.eval()
inference  = best_model(Input)

In [12]:
[print(param) for param in best_model.parameters()]

Parameter containing:
tensor([[[[-2.6396e-02, -1.2886e-01, -9.2089e-02],
          [ 3.7986e-02,  3.1866e-02,  1.2026e-01],
          [ 7.7582e-02,  4.8818e-02, -1.0995e-01]],

         [[-7.8685e-02,  5.4547e-02, -2.2077e-02],
          [ 7.4425e-02, -8.8641e-02,  1.2165e-01],
          [-8.0474e-02,  4.1449e-02,  1.5613e-03]],

         [[ 7.1563e-02,  4.1507e-02, -3.4370e-02],
          [-9.5409e-02,  7.1339e-02,  5.9933e-02],
          [ 7.5047e-02, -5.4454e-02, -8.9461e-02]],

         [[ 2.4930e-02, -5.8007e-02,  1.3190e-01],
          [ 1.2361e-01, -1.7684e-02, -1.1533e-01],
          [-4.4301e-02, -8.3635e-02,  4.0580e-02]],

         [[-9.5313e-02,  1.0993e-01, -5.8860e-02],
          [-1.0092e-01,  4.8323e-02, -1.3179e-01],
          [-9.0563e-02, -9.5444e-02, -1.3051e-01]],

         [[ 1.1341e-01,  1.1207e-01,  1.3479e-02],
          [-4.9974e-02, -3.9003e-02, -1.2391e-01],
          [-2.5611e-02, -1.1390e-01, -1.2838e-01]]],


        [[[ 6.4139e-02,  2.7462e-02,  2.2773e-

[None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None,
 None]

In [15]:
# Squeeze forecast to index (class) field with highest probability
fc          = np.array(inference.detach().numpy())
fc[:,3,:,:] = np.where(fc[:,3,:,:]>.95,fc[:,3,:,:],np.nan)
fc          = np.nanargmax(fc,axis=1)
fc          = np.where(fc==3, np.nan, fc)
fc[fc==0]   = inference.detach().numpy()[:,0,:,:][fc==0]
fc[fc==1]   = inference.detach().numpy()[:,1,:,:][fc==1]+1
fc[fc==2]   = inference.detach().numpy()[:,2,:,:][fc==2]+2

# Create a xarry with coordinated for fronts and the surface pressure
coords          = {'type':['fc','ps'],'case':range(57),'lat': combined_ds['lat'].values,'lon': combined_ds['lon'].values}
arr             = xr.DataArray([fc,combined_ds["ICON"].values[:,0,:,:]],coords)

In [17]:
# Plot loop over all forecast lead-times
for case in range(arr.shape[1]):

    dt = datetime.strptime(os.path.basename(os.path.dirname(FILES[case])), '%Y%m%d%H')
    vt = dt + timedelta(hours=case*3)
    print(vt)
    
    # Front data
    data         = arr[0,case,:,:]

    # MSLP data
    contour_data = (arr[1,case,:,:]*11.454)+1012.807 # see mslp normalization values in prepare_input.py

    # 2D lat lon grid for plotting
    Lon, Lat = np.meshgrid(arr['lon'].values, arr['lat'].values)

    # Colorscale
    colors = []

    # 10 (probability) classes of red
    white_to_red = LinearSegmentedColormap.from_list('white_to_red', ['white', 'red'])
    red_colors = white_to_red(np.linspace(0, 1, 10))
    colors.extend(red_colors)

    # 10 (probability) classes of blue
    white_to_blue = LinearSegmentedColormap.from_list('white_to_blue', ['white', 'blue'])
    blue_colors = white_to_blue(np.linspace(0, 1, 10))
    colors.extend(blue_colors)

    # 10 (probability) classes of pink
    white_to_pink = LinearSegmentedColormap.from_list('white_to_pink', ['white', 'blueviolet'])
    pink_colors =  white_to_pink(np.linspace(0, 1, 10))
    colors.extend(pink_colors)

    # Colormap
    custom_cmap = mcolors.ListedColormap(colors)

    # 31 class borders for 30 classes
    bounds = np.linspace(0, 3, 31)
    norm = mcolors.BoundaryNorm(bounds, custom_cmap.N)

    # Create plot
    fig = plt.figure(figsize=(10, 8))
    ax = plt.axes(projection=ccrs.Orthographic(0, 35))

    # Filled land masses
    land_feature = cfeature.NaturalEarthFeature(
        category='physical',
        name='land',
        scale='110m',
        facecolor='lightgray'
    )
    ax.add_feature(land_feature,zorder=0)
    
    # Coastlines and Grids
    ax.coastlines(zorder=0,color='gray')
    ax.gridlines()
    
    # Plot front data
    mesh = ax.pcolormesh(Lon, Lat, data, cmap=custom_cmap, norm=norm, shading='auto', transform=ccrs.PlateCarree())
    
    # Add colorbar
    cbar = plt.colorbar(mesh, boundaries=bounds, ticks=np.linspace(0.5, 2.5, 3))
    cbar.ax.set_yticklabels(['warm', 'cold', 'occ.'])
    cbar.set_label('Klassen')
    cbar.ax.tick_params(length=0)
    
    # Add MSLP contour lines with labels
    contour_levels = np.arange(900, 1100, 5)
    contour_lines  = ax.contour(Lon, Lat, contour_data, levels=contour_levels, colors='black', linewidths=1, transform=ccrs.PlateCarree(), zorder=3)
    ax.clabel(contour_lines, inline=True, fontsize=8)
    
    # Labels plot
    plt.title('ICON Fronts (AI)',loc="left")
    ax.set_title('',loc="center")
    ax.set_title( "Ini: "+dt.strftime('%Y-%m-%d %H UTC')+" + "+str(case*3).zfill(3)+"h\n"
                  "Val: "+vt.strftime('%a %Y-%m-%d %H UTC'),loc="right")
    
    # Save and close plot
    plt.savefig('/hpc/uhome/ffundel/repos/ai-fronts/plots/plot_'+str(case).zfill(3)+'.png')
    plt.close(fig)



2025-05-05 00:00:00
2025-05-05 03:00:00
2025-05-05 06:00:00
2025-05-05 09:00:00
2025-05-05 12:00:00
2025-05-05 15:00:00
2025-05-05 18:00:00
2025-05-05 21:00:00
2025-05-06 00:00:00
2025-05-06 03:00:00
2025-05-06 06:00:00
2025-05-06 09:00:00
2025-05-06 12:00:00
2025-05-06 15:00:00
2025-05-06 18:00:00
2025-05-06 21:00:00
2025-05-07 00:00:00
2025-05-07 03:00:00
2025-05-07 06:00:00
2025-05-07 09:00:00
2025-05-07 12:00:00
2025-05-07 15:00:00
2025-05-07 18:00:00
2025-05-07 21:00:00
2025-05-08 00:00:00
2025-05-08 03:00:00
2025-05-08 06:00:00
2025-05-08 09:00:00
2025-05-08 12:00:00
2025-05-08 15:00:00
2025-05-08 18:00:00
2025-05-08 21:00:00
2025-05-09 00:00:00
2025-05-09 03:00:00
2025-05-09 06:00:00
2025-05-09 09:00:00
2025-05-09 12:00:00
2025-05-09 15:00:00
2025-05-09 18:00:00
2025-05-09 21:00:00
2025-05-10 00:00:00
2025-05-10 03:00:00
2025-05-10 06:00:00
2025-05-10 09:00:00
2025-05-10 12:00:00
2025-05-10 15:00:00
2025-05-10 18:00:00
2025-05-10 21:00:00
2025-05-11 00:00:00
2025-05-11 03:00:00
