In [1]:
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
import pandas as pd
import sys
from datetime import datetime, timedelta
import polars as pl
from sqlalchemy.orm import Session
from tqdm import tqdm
import geopandas as gpd
from pyproj import Transformer
from pprint import pprint

from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.collections import LineCollection
from matplotlib import colors, cm, gridspec, ticker, patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib import dates as mdates

import matplotlib.colors as mcolors

import string
alphabet = string.ascii_lowercase

%reload_ext autoreload
%autoreload 2
%matplotlib inline

sys.path.append("..")
sys.path.append("../..")

from utils import graph_utils
from utils import plot_utils, data_utils
from database import get_engine
from utils.config_utils import load_config
from utils.plot_utils import _get_colormap, COLORBAR_TITLES, QTY_FORMATS
import database.cell_queries as cell_queries

# matplotlib use style
plt.style.use("../stylefiles/article_casefigs.mplstyle")

# Configurations & data loading


In [2]:
qualitative_colors = [
    "#e7298a",
    "#66a61e",
    "#76CD74",
    "#8F1A65",
    "#386cb0",
    # "#e9a3c9",
    # "#a1d76a",
    # "#66c2a5",
    # "#fc8d62",
    # "#8da0cb",
]

LINEPLOT_KWS_CI_95 = {
    "errorbar": ("ci", 95),
}
LINEPLOT_KWS_NO_CI = dict(errorbar=None, lw=1.5, alpha=1.0)

count_linekwargs = {
    "ls": "--",
    "lw": 1,
    "color": "black",
    "label": "Trajectory count",
}

In [3]:
# storagepath = Path("/data/jenna/rain-cell-stats-analysis/cell_graph_data_debug")
storagepath = Path("/data/jenna/rain-cell-stats-analysis/cell_graph_data_v20250827")
fileglob = "cell_trajectories_*_*.parquet"

outpath = Path("/data/jenna/rain-cell-stats-analysis/figures/case_figures_v20250919")
outpath.mkdir(parents=True, exist_ok=True)

# Load the data
TRAJECTORIES = pl.read_parquet(storagepath / fileglob)

pprint(TRAJECTORIES.select("type").unique().to_series().to_list())

TRAJECTORIES

['merged',
 'split',
 'split-merge',
 'et45ml_unique_d1.median:1000:control',
 'merged_prunet0_prunedup',
 'split_prunet0_prunedup',
 'split_prunedup',
 'zdrcol_custom_filt_unique_d1.median:1000',
 'merged_prunedup',
 'zdrcol_custom_filt_unique_d1.median:1000:control',
 'split-merge_prunet0_prunedup',
 'split-merge_prunedup',
 'split_prunet0',
 'merged_prunet0',
 'et45ml_unique_d1.median:1000',
 'split-merge_prunet0']


method,type,identifier,t0_node,timestamp,level,area,event,num_cells_at_level,t0_time
str,str,i64,str,datetime[μs],i32,f64,str,u32,datetime[μs]
"""opencv_vil_1.0:minArea_10:clus…","""et45ml_unique_d1.median:1000:c…",1,"""2021-05-01T11:45:00_1""",2021-05-01 11:45:00,0,12.503103,"""t0_node.""",1,2021-05-01 11:45:00
"""opencv_vil_1.0:minArea_10:clus…","""et45ml_unique_d1.median:1000:c…",1,"""2021-05-01T11:45:00_1""",2021-05-01 11:50:00,1,29.507324,"""simple""",1,2021-05-01 11:45:00
"""opencv_vil_1.0:minArea_10:clus…","""et45ml_unique_d1.median:1000:c…",1,"""2021-05-01T11:45:00_1""",2021-05-01 11:55:00,2,46.511544,"""simple""",1,2021-05-01 11:45:00
"""opencv_vil_1.0:minArea_10:clus…","""et45ml_unique_d1.median:1000:c…",1,"""2021-05-01T11:45:00_1""",2021-05-01 12:00:00,3,68.517006,"""simple""",1,2021-05-01 11:45:00
"""opencv_vil_1.0:minArea_10:clus…","""et45ml_unique_d1.median:1000:c…",1,"""2021-05-01T11:45:00_1""",2021-05-01 12:05:00,4,42.510551,"""simple""",1,2021-05-01 11:45:00
…,…,…,…,…,…,…,…,…,…
"""opencv_vil_1.0:minArea_10:clus…","""merged_prunet0_prunedup""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:05:00,2,100.52495,"""simple""",1,2023-09-23 18:55:00
"""opencv_vil_1.0:minArea_10:clus…","""merged_prunet0_prunedup""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:10:00,3,142.035253,"""simple""",1,2023-09-23 18:55:00
"""opencv_vil_1.0:minArea_10:clus…","""merged_prunet0_prunedup""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:15:00,4,165.541087,"""simple""",1,2023-09-23 18:55:00
"""opencv_vil_1.0:minArea_10:clus…","""merged_prunet0_prunedup""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:20:00,5,153.538108,"""simple""",1,2023-09-23 18:55:00


In [4]:
pprint(TRAJECTORIES.select("type").unique().to_series().to_list())

['split_prunedup',
 'zdrcol_custom_filt_unique_d1.median:1000:control',
 'split_prunet0',
 'split-merge_prunet0',
 'split-merge_prunet0_prunedup',
 'et45ml_unique_d1.median:1000:control',
 'zdrcol_custom_filt_unique_d1.median:1000',
 'split-merge_prunedup',
 'merged_prunedup',
 'split_prunet0_prunedup',
 'merged_prunet0',
 'split',
 'merged_prunet0_prunedup',
 'merged',
 'split-merge',
 'et45ml_unique_d1.median:1000']


In [5]:
# Function to load the data from database
dbconf_file = Path(".").resolve().parent / "config/database/database.yaml"
dbconf = load_config(dbconf_file)
engine = get_engine(dbconf)

In [6]:
FETCH_QUANTITIES = ["vil", "rate", "zdrcol_custom_filt_unique_d1", "et45ml_unique_d1", "dist_from_radars"]
START_DATE = datetime(2021, 5, 1)
END_DATE = datetime(2023, 10, 1)

SQL_QUERY = """

 SELECT cell."timestamp",
    cell.identifier,
    cell.method,
    -- cell.geometry,
    ST_Area(cell.geometry) / 1e6 AS area_km2,
    stats.quantity,
    stats.statistic,
    stats.value
   FROM raincells.stormcells cell
     JOIN raincells.stormcell_rasterstats stats ON cell.identifier = stats.identifier AND cell."timestamp" = stats."timestamp" AND cell.method::text = stats.method::text

    WHERE 1 = 1
    AND cell."timestamp" >= '{start_date}'
    AND cell."timestamp" <= '{end_date}'
    AND stats.quantity IN ({quantities})

"""

with Session(engine) as session:
    query = SQL_QUERY.format(
        start_date=START_DATE.strftime("%Y-%m-%d %H:%M"),
        end_date=END_DATE.strftime("%Y-%m-%d %H:%M"),
        quantities=",".join([f"'{q}'" for q in FETCH_QUANTITIES]),
    )
    DATA = pl.read_database(
        query=query,
        connection=session.bind,
    )

    DATA = DATA.with_columns(pl.format("{}.{}", "quantity", "statistic").alias("on")).pivot(
        on="on",
        # index=["type", "t0_node", "timestamp", "level", "identifier", "method", "area", "event"],
        index=set(DATA.columns) - set(["quantity", "statistic", "value"]),
        values=["value"],
    )

In [7]:
CASE_T0_NODES = [
    "2022-06-07T17:35:00_3",
    "2021-07-13T09:35:00_12",
]

acceptable_types = ["split", "merge", "split-merge"]
pprint(CASE_T0_NODES)

CASE_TRAJECTORIES = TRAJECTORIES.filter(
    (pl.col("t0_node").is_in(CASE_T0_NODES)) & (pl.col("type").is_in(acceptable_types))
)

data_days = CASE_TRAJECTORIES.select(pl.col("timestamp").dt.date().alias("date")).unique().to_series().to_list()

['2022-06-07T17:35:00_3', '2021-07-13T09:35:00_12']


In [8]:
TRAJECTORIES.filter((pl.col("t0_node") == "2022-06-07T17:35:00_3") & (pl.col("type") == "split-merge")).unique().sort(
    "level", "identifier"
)

method,type,identifier,t0_node,timestamp,level,area,event,num_cells_at_level,t0_time
str,str,i64,str,datetime[μs],i32,f64,str,u32,datetime[μs]
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",3,"""2022-06-07T17:35:00_3""",2022-06-07 17:30:00,-1,14.003476,"""born.""",2,2022-06-07 17:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",4,"""2022-06-07T17:35:00_3""",2022-06-07 17:30:00,-1,119.029543,"""simple.""",2,2022-06-07 17:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",3,"""2022-06-07T17:35:00_3""",2022-06-07 17:35:00,0,169.041956,"""merged_from_2.split-merge_midn…",1,2022-06-07 17:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",3,"""2022-06-07T17:35:00_3""",2022-06-07 17:40:00,1,111.527681,"""simple""",2,2022-06-07 17:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",4,"""2022-06-07T17:35:00_3""",2022-06-07 17:40:00,1,101.525199,"""simple""",2,2022-06-07 17:35:00
…,…,…,…,…,…,…,…,…,…
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",2,"""2022-06-07T17:35:00_3""",2022-06-07 17:55:00,4,10.002483,"""died.""",3,2022-06-07 17:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",3,"""2022-06-07T17:35:00_3""",2022-06-07 17:55:00,4,31.007696,"""died.""",3,2022-06-07 17:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",4,"""2022-06-07T17:35:00_3""",2022-06-07 17:55:00,4,160.039722,"""simple""",3,2022-06-07 17:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",3,"""2022-06-07T17:35:00_3""",2022-06-07 18:00:00,5,177.544067,"""simple""",1,2022-06-07 17:35:00


In [9]:
TRAJECTORIES.filter((pl.col("t0_node") == "2021-07-13T09:35:00_12") & (pl.col("type") == "split-merge")).unique().sort(
    "level", "identifier"
)

method,type,identifier,t0_node,timestamp,level,area,event,num_cells_at_level,t0_time
str,str,i64,str,datetime[μs],i32,f64,str,u32,datetime[μs]
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",5,"""2021-07-13T09:35:00_12""",2021-07-13 09:15:00,-4,120.029791,"""simple.""",2,2021-07-13 09:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",6,"""2021-07-13T09:35:00_12""",2021-07-13 09:15:00,-4,10.002483,"""born.""",2,2021-07-13 09:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",7,"""2021-07-13T09:35:00_12""",2021-07-13 09:20:00,-3,176.543818,"""simple.""",2,2021-07-13 09:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",8,"""2021-07-13T09:35:00_12""",2021-07-13 09:20:00,-3,43.010675,"""simple.""",2,2021-07-13 09:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",8,"""2021-07-13T09:35:00_12""",2021-07-13 09:25:00,-2,223.055363,"""simple.""",2,2021-07-13 09:35:00
…,…,…,…,…,…,…,…,…,…
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",22,"""2021-07-13T09:35:00_12""",2021-07-13 10:00:00,5,69.01713,"""simple""",4,2021-07-13 09:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",15,"""2021-07-13T09:35:00_12""",2021-07-13 10:05:00,6,163.540591,"""simple""",4,2021-07-13 09:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",17,"""2021-07-13T09:35:00_12""",2021-07-13 10:05:00,6,11.502855,"""simple""",4,2021-07-13 09:35:00
"""opencv_vil_1.0:minArea_10:clus…","""split-merge""",18,"""2021-07-13T09:35:00_12""",2021-07-13 10:05:00,6,18.504593,"""died.""",4,2021-07-13 09:35:00


In [10]:
trajectory_diffs = graph_utils.calculate_diff_variables(
    trajectories=TRAJECTORIES, data=DATA, sum_vars=["area", "rate.sum"]
)

pprint(trajectory_diffs.columns)

trajectory_diffs = trajectory_diffs.with_columns(
    ((pl.col("rate.sum") / pl.col("area")) / (pl.col("t0_rate.sum") / pl.col("t0_area")) - 1).alias(
        "t0_reldiff_mean:rate.sum"
    ),
).with_columns(
    (pl.col("t0_reldiff_mean:rate.sum") * 100).alias("t0_reldiff_mean:rate.sum:pct"),
)
trajectory_diffs

['type',
 'timestamp',
 't0_node',
 'level',
 'num_cells_at_level',
 'rate.sum',
 'area',
 't0_reldiff:rate.sum',
 't0_reldiff:area',
 't0_absdiff:rate.sum',
 't0_absdiff:area',
 't0_rate.sum',
 't0_area',
 't0_reldiff:rate.sum:pct',
 't0_reldiff:area:pct']


type,timestamp,t0_node,level,num_cells_at_level,rate.sum,area,t0_reldiff:rate.sum,t0_reldiff:area,t0_absdiff:rate.sum,t0_absdiff:area,t0_rate.sum,t0_area,t0_reldiff:rate.sum:pct,t0_reldiff:area:pct,t0_reldiff_mean:rate.sum,t0_reldiff_mean:rate.sum:pct
str,datetime[μs],str,i32,u32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""split-merge_prunet0""",2023-08-15 16:35:00,"""2023-08-15T16:45:00_17""",-2,3,10265.239868,291.572369,-0.013215,-0.08189,-137.470093,-26.006455,10402.709961,317.578823,-1.321483,-8.188976,0.0748,7.480031
"""zdrcol_custom_filt_unique_d1.m…",2021-05-12 18:35:00,"""2021-05-12T18:05:00_7""",6,1,655.388428,25.006207,1.419326,0.5625,384.491333,9.002234,270.897095,16.003972,141.932616,56.25,0.548369,54.836874
"""et45ml_unique_d1.median:1000""",2022-09-09 03:05:00,"""2022-09-09T03:00:00_6""",1,1,667.970154,23.505834,-0.114886,-0.096154,-86.700745,-2.500621,754.670898,26.006455,-11.48855,-9.615385,-0.020724,-2.072438
"""et45ml_unique_d1.median:1000""",2022-06-05 14:55:00,"""2022-06-05T14:45:00_19""",2,1,2241.599121,50.512537,-0.079616,0.030612,-193.903809,1.500372,2435.50293,49.012165,-7.961551,3.061224,-0.106954,-10.695366
"""merged_prunedup""",2022-08-17 04:15:00,"""2022-08-17T04:10:00_4""",1,1,5198.642578,531.631952,0.133494,0.256501,612.253418,108.526936,4586.38916,423.105015,13.349356,25.650118,-0.097897,-9.789694
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""split_prunet0""",2022-06-27 16:35:00,"""2022-06-27T16:25:00_7""",2,2,2191.763306,118.029295,-0.273612,-0.208054,-825.580933,-31.007696,3017.344238,149.036991,-27.361178,-20.805369,-0.082781,-8.278098
"""et45ml_unique_d1.median:1000""",2021-08-15 22:20:00,"""2021-08-15T21:50:00_4""",6,1,2480.366699,40.510055,0.654109,0.76087,980.848389,17.504345,1499.518311,23.00571,65.410898,76.086957,-0.060629,-6.062947
"""zdrcol_custom_filt_unique_d1.m…",2023-06-04 20:40:00,"""2023-06-04T21:10:00_11""",-6,1,1223.079956,34.008441,-0.341836,-0.552632,-635.240112,-42.010427,1858.320068,76.018868,-34.183568,-55.263158,0.471191,47.119082
"""merged_prunet0_prunedup""",2023-06-12 14:45:00,"""2023-06-12T14:30:00_28""",3,1,13384.720703,331.582299,-0.221263,-0.311526,-3803.0,-150.037239,17187.720703,481.619538,-22.126261,-31.152648,0.131107,13.110724


# Update display IDs for case studies


In [11]:
nodedict = {}
storage_path = Path("/data/jenna/rain-cell-stats-analysis/cell_graph_data_v20250827")
for t0_node in CASE_T0_NODES:
    t0_time = pd.to_datetime(t0_node.split("_")[0])
    subpath = t0_time.strftime("%Y%m%d")
    graph_path = storage_path / "subgraphs" / subpath

    graph = nx.read_gml(graph_path / f"subgraph_{t0_node}.gml", destringizer=graph_utils.string_decoder)

    nodes = sorted(graph.nodes)
    nodedict[t0_node] = {n: graph.nodes[n]["identifier"] for i, n in enumerate(nodes)}
pprint(nodedict)

{'2021-07-13T09:35:00_12': {'2021-07-13T09:15:00_5': 5,
                            '2021-07-13T09:15:00_6': 6,
                            '2021-07-13T09:20:00_7': 7,
                            '2021-07-13T09:20:00_8': 8,
                            '2021-07-13T09:25:00_11': 11,
                            '2021-07-13T09:25:00_8': 8,
                            '2021-07-13T09:30:00_7': 7,
                            '2021-07-13T09:30:00_8': 8,
                            '2021-07-13T09:35:00_12': 12,
                            '2021-07-13T09:40:00_14': 14,
                            '2021-07-13T09:40:00_15': 15,
                            '2021-07-13T09:45:00_16': 16,
                            '2021-07-13T09:45:00_18': 18,
                            '2021-07-13T09:50:00_15': 15,
                            '2021-07-13T09:50:00_17': 17,
                            '2021-07-13T09:55:00_17': 17,
                            '2021-07-13T09:55:00_18': 18,
                            

In [12]:
display_ids = {
    "2021-07-13T09:35:00_12": {
        "2021-07-13T09:15:00_5": 1,
        "2021-07-13T09:15:00_6": 2,
        "2021-07-13T09:20:00_7": 1,
        "2021-07-13T09:20:00_8": 2,
        "2021-07-13T09:25:00_8": 1,
        "2021-07-13T09:25:00_11": 2,
        "2021-07-13T09:30:00_7": 1,
        "2021-07-13T09:30:00_8": 2,
        "2021-07-13T09:35:00_12": 1,
        "2021-07-13T09:40:00_14": 1,
        "2021-07-13T09:40:00_15": 2,
        "2021-07-13T09:45:00_16": 1,
        "2021-07-13T09:45:00_18": 2,
        "2021-07-13T09:50:00_15": 1,
        "2021-07-13T09:50:00_17": 2,
        "2021-07-13T09:55:00_17": 3,
        "2021-07-13T09:55:00_18": 1,
        "2021-07-13T09:55:00_21": 2,
        "2021-07-13T09:55:00_23": 4,
        "2021-07-13T10:00:00_18": 3,
        "2021-07-13T10:00:00_20": 1,
        "2021-07-13T10:00:00_21": 2,
        "2021-07-13T10:00:00_22": 4,
        "2021-07-13T10:05:00_15": 3,
        "2021-07-13T10:05:00_17": 1,
        "2021-07-13T10:05:00_18": 2,
        "2021-07-13T10:05:00_19": 4,
    },
    "2022-06-07T17:35:00_3": {
        "2022-06-07T17:30:00_3": 1,
        "2022-06-07T17:30:00_4": 2,
        "2022-06-07T17:35:00_3": 1,
        "2022-06-07T17:40:00_3": 1,
        "2022-06-07T17:40:00_4": 2,
        "2022-06-07T17:45:00_2": 1,
        "2022-06-07T17:45:00_3": 2,
        "2022-06-07T17:50:00_3": 1,
        "2022-06-07T17:50:00_4": 3,
        "2022-06-07T17:50:00_5": 2,
        "2022-06-07T17:55:00_2": 1,
        "2022-06-07T17:55:00_3": 3,
        "2022-06-07T17:55:00_4": 2,
        "2022-06-07T18:00:00_3": 2,
        "2022-06-07T18:05:00_2": 2,
    },
}

# Plot case figures


In [13]:
# Path to saved subgraphs
storage_path = Path("/data/jenna/rain-cell-stats-analysis/cell_graph_data_v20250827/")

fetch_variables = [
    "vil.sum",
    # "vil.std",
    # "vil.percentile_95",
    # "vil.nodata",
    "vil.median",
    # "vil.mean",
    # "vil.max",
    # "vil.count",
    "et45ml_unique_d1.max",
    "et45ml_unique_d1.median",
    # "et45ml_unique_d1.mean",
    # "et45ml_unique_d1.std",
    # "et45ml_unique_d1.sum",
    "rate.sum",
    # "rate.std",
    # "rate.percentile_95",
    # "rate.nodata",
    # "rate.median",
    "rate.mean",
    # "rate.max",
    # "zdrcol_custom_filt_unique_d1.count",
    "zdrcol_custom_filt_unique_d1.max",
    # "zdrcol_custom_filt_unique_d1.mean",
    "zdrcol_custom_filt_unique_d1.median",
    # "zdrcol_custom_filt_unique_d1.percentile_95",
    # "zdrcol_custom_filt_unique_d1.std",
    # "zdrcol_custom_filt_unique_d1.sum",
]

cmap_qtys = {
    "area": "area",
    "rate": "RATE",
    "et45ml_unique_d1": "ET45ML",
    "vil": "VIL",
    "zdrcol_custom_filt_unique_d1": "zdrcol_custom_filt",
}

plot_vars_timeseries = [
    "area",
    "rate.mean",
    "rate.sum",
    "zdrcol_custom_filt_unique_d1.median",
]
plot_vars_titles = {
    "area": "Cell area [km$^2$]",
    "rate.mean": "Mean rain rate [mm h$^{-1}$]",
    "rate.sum": "Volume rain rate [10$^6$ m$^3$ h$^{-1}$]",
    "zdrcol_custom_filt_unique_d1.median": "Median ZDR column height [km]",
}

fetch_quantities = set([q.split(".")[0] for q in fetch_variables])
fetch_stats = set([q.split(".")[1] for q in fetch_variables])

qty = "R_log_high"
cmap, norm, ticks = _get_colormap(qty)
color_var = "rate.mean"

plot_conf_file = Path(".").resolve().parent / "config/plots/swiss-data/plot_data.yml"
plot_conf = load_config(plot_conf_file)

plot_conf["outdir"] = outpath

data_proj4 = "+proj=somerc +lat_0=46.95240555555556 +lon_0=7.439583333333333 +k_0=1 +x_0=2600000 +y_0=1200000 +ellps=bessel +towgs84=674.374,15.056,405.346,0,0,0,0 +units=m +no_defs"
epsg = 21781
transformer = Transformer.from_crs(data_proj4, "EPSG:4326", always_xy=True)

plot_cells_overview_hours = "2h"


bumber_kms = {
    "2021-07-13T09:35:00_12": (20, 20, 5, 5),
    "2022-06-07T17:35:00_3": (20, 20, 20, 20),
}

In [None]:
data_variables = [
    "VIL",
    "RATE",
    "zdrcol_custom_filt",
]
ncols = len(data_variables)

sparse_factors = {
    "2021-07-13T09:35:00_12": 6,
    "2022-06-07T17:35:00_3": 10,
}

for graph_name in CASE_T0_NODES[::-1]:
    outpath_ = outpath / graph_name
    outpath_.mkdir(exist_ok=True, parents=True)

    t0_time = pd.to_datetime(graph_name.split("_")[0])
    subpath = t0_time.strftime("%Y%m%d")
    graph_path = storage_path / "subgraphs" / subpath

    graph = nx.read_gml(graph_path / f"subgraph_{graph_name}.gml", destringizer=graph_utils.string_decoder)

    timestamps = [n["timestamp"] for _, n in graph.nodes(data=True)]
    timestamps = pd.to_datetime(timestamps)
    timestamps = pd.Series(timestamps).sort_values().drop_duplicates()

    xmin = timestamps.min()
    xmax = timestamps.max()

    num_timesteps = len(timestamps)
    num_time_resolutions = int((xmax - xmin) / pd.Timedelta("5min"))
    len_history = int((t0_time - xmin) / pd.Timedelta("5min"))
    len_future = int((xmax - t0_time) / pd.Timedelta("5min"))

    # # Calculate zoom bbox
    bumber_km = bumber_kms[graph_name]
    graph_cells = graph_utils.get_cells_in_graph(graph, engine, dbconf)
    minx, miny, maxx, maxy = graph_cells.total_bounds
    xy_ratio = (maxx - minx) / (maxy - miny)
    yx_ratio = (maxy - miny) / (maxx - minx)

    figheight = 3.0 * num_timesteps + 1.5
    figwidth = 3.5 * xy_ratio * len(data_variables)

    print(xy_ratio, yx_ratio, figheight, figwidth)
    # continue
    cell_data = graph_utils.get_cell_data(
        graph,
        quantities=fetch_quantities,
        statistics=fetch_stats,
        engine=engine,
    )
    # Change volume rain rate (rate.sum) to 10^6 m^3/h
    cell_data = cell_data.with_columns((pl.col("rate.sum") * 1e-3).alias("rate.sum"))
    cell_data = cell_data.with_columns(display_id=pl.lit(None).cast(pl.Int64))

    # Set display IDs
    if graph_name in display_ids:
        for n in graph.nodes:
            if n in display_ids[graph_name]:
                graph.nodes[n]["display_id"] = display_ids[graph_name][n]

            else:
                graph.nodes[n]["display_id"] = None
        for n in display_ids[graph_name]:
            timestamp = pd.to_datetime(n.split("_")[0])
            identifier = int(n.split("_")[1])
            cell_data = cell_data.with_columns(
                pl.when((pl.col("timestamp") == timestamp) & (pl.col("identifier") == identifier))
                .then(display_ids[graph_name][n])
                .otherwise(pl.col("display_id"))
                .alias("display_id")
            )
    else:
        for n in graph.nodes:
            graph.nodes[n]["display_id"] = None

    nrows = num_timesteps

    snapshot_subfig = plt.figure(figsize=(figwidth, figheight))
    snapshot_axs = snapshot_subfig.subplots(
        nrows=num_timesteps,
        ncols=len(data_variables),
        sharex=True,
        sharey=True,
        gridspec_kw={"hspace": 0.01, "wspace": 0.0},
    )

    for plot_var in data_variables:
        print(f"Plotting {plot_var}")
        var_axs = snapshot_axs[:, data_variables.index(plot_var)]

        cell_plot_kwargs = {
            "facecolor": "none",
            "edgecolor": "slategrey",
            "linewidth": 1.5,
        }
        track_color = "tab:orange"
        snapshot_subfig, var_axs, cbar = graph_utils.plot_graph_snapshots(
            graph,
            t0_time,
            snapshot_subfig,
            var_axs,
            cell_data,
            plot_conf,
            engine,
            dbconf,
            plot_times=timestamps.tolist(),
            bumber_km=bumber_km,
            plot_var=plot_var,
            cell_plot_kwargs=cell_plot_kwargs,
            track_color=track_color,
            plot_cbar=False,
            time_title=False,
            annotate_key="display_id",
            annotate_cells="graph",
            plot_motion_from_rate=False,
            sparse_factor=sparse_factors[graph_name],
        )

        if data_variables.index(plot_var) == 0:
            for i, ax in enumerate(var_axs.flatten()):
                t = timestamps.iloc[i]
                ax.text(
                    0.025,
                    1.015,
                    f"{t.strftime('%Y-%m-%d %H:%M')}",
                    transform=ax.transAxes,
                    fontsize="small",
                    ha="left",
                    va="bottom",
                    bbox=dict(facecolor="white", edgecolor="black", pad=3),
                    zorder=100,
                )
                # ax.set_ylabel(t.strftime("%Y-%m-%d %H:%M"), fontsize="x-small")
                # ax.label_outer()

        last_ax = var_axs.flatten()[-1]
        cbar = plot_utils.plot_colorbar(
            last_ax,
            plot_conf.input_data[plot_var].cmap_qty,
            orientation="horizontal",
            extend="max",
            cbar_ax_kws={
                "width": "95%",
                "height": "5%",
                "loc": "lower left",
                "bbox_to_anchor": (-0.02, -0.1, 1.0, 0.99),
                # "borderpad": 0,
            },
        )
        cbar.ax.tick_params(labelsize="xx-small", rotation=45, pad=0.05)
        cbar.set_label(label=cbar.ax.xaxis.get_label_text(), fontsize="small")

    print("Saving snapshot figure")
    plot_utils.save_figs(
        snapshot_subfig,
        outpath_,
        f"case_data_{graph_name}",
        extensions=["svg", "pdf", "png"],
        # extensions=["svg"],
        delete_fig=True,
        convert_colors=True,
        savefig_kwargs=dict(dpi=600, bbox_inches="tight"),
    )
    print("Done")
    # continue

    #######################################################################################
    # Plot overview figure with all cells during the day

    starttime = t0_time.floor(plot_cells_overview_hours)
    endtime = t0_time.ceil(plot_cells_overview_hours)
    all_cells_during_day = cell_queries.load_stormcells_between_times(
        starttime,
        endtime,
        dbconf,
    )
    all_cells_during_day["hour"] = all_cells_during_day["timestamp"].dt.hour

    hour_bounds = np.arange(0, 24, 1)
    hour_norm = mcolors.BoundaryNorm(boundaries=hour_bounds, ncolors=len(hour_bounds))
    hour_cmap = plt.get_cmap("cmc.hawaii_r", len(hour_bounds))

    var_name = "RATE"
    input_data_conf = {k: v for k, v in plot_conf.input_data.items() if k == var_name}

    dataset = data_utils.load_data(
        input_data_conf,
        t0_time,
        [t0_time],
        1,
        None,
    )

    # # Calculate zoom bbox
    graph_cells = graph_utils.get_cells_in_graph(graph, engine, dbconf)
    minx, miny, maxx, maxy = graph_cells.total_bounds
    min_col = np.round(((minx - dataset.x.values.min()) / 1000 - bumber_km[0]) / dataset.x.values.size, 3).item()
    max_col = np.round(((maxx - dataset.x.values.min()) / 1000 + bumber_km[1]) / dataset.x.values.size, 3).item()
    min_row = np.round(((miny - dataset.y.values.min()) / 1000 - bumber_km[2]) / dataset.y.values.size, 3).item()
    max_row = np.round(((maxy - dataset.y.values.min()) / 1000 + bumber_km[3]) / dataset.y.values.size, 3).item()

    bbox = [min_col, max_col, min_row, max_row]
    im_width = dataset.x.values.max() - dataset.x.values.min()
    im_height = dataset.y.values.max() - dataset.y.values.min()
    bbox_x = [
        dataset.x.values.min() + im_width * bbox[0],
        dataset.x.values.min() + im_width * bbox[1],
    ]
    bbox_y = [
        dataset.y.values.min() + im_height * bbox[2],
        dataset.y.values.min() + im_height * bbox[3],
    ]

    # overview_fig = whole_fig.add_subfigure(whole_gridspec[0, 0])
    # axs = overview_fig.add_subplot(111)
    fig, axs = plt.subplots(
        nrows=1,
        ncols=1,
        figsize=(plot_conf.col_width + 1, plot_conf.row_height),
        sharex="col",
        sharey="row",
        squeeze=True,
        constrained_layout=True,
    )
    # Borders
    border = gpd.read_file(plot_conf.map_params.border_shapefile)
    border_proj = border.to_crs(plot_conf.map_params.proj)

    segments = [np.array(linestring.coords)[:, :2] for linestring in border_proj["geometry"]]
    border_collection = LineCollection(segments, zorder=0, **plot_conf.map_params.border_plot_kwargs)

    # Radar locations
    if plot_conf.map_params.radar_shapefile is not None:
        radar_locations = gpd.read_file(plot_conf.map_params.radar_shapefile)
        radar_locations_proj = radar_locations.to_crs(plot_conf.map_params.proj)
        xy = radar_locations_proj["geometry"].map(lambda point: point.xy)
        radar_locations_proj = list(zip(*xy))
    else:
        radar_locations_proj = None

    axs.add_collection(border_collection)

    if radar_locations_proj is not None:
        axs.scatter(
            *radar_locations_proj,
            zorder=5,
            **plot_conf.map_params.radar_plot_kwargs,
        )

    for hour in hour_bounds:
        cells_to_plot = all_cells_during_day[all_cells_during_day["hour"] == hour]
        color = hour_cmap(hour)
        cells_to_plot.plot(ax=axs, facecolor="none", edgecolor=color, zorder=10, linewidth=0.1)

    cbar_ax_kws = {
        "width": "3%",  # width = 5% of parent_bbox width
        "height": "100%",
        "loc": "lower left",
        "bbox_to_anchor": (1.05, 0.0, 2, 1),
        "borderpad": 0,
    }
    cbar = plot_utils.plot_colorbar(
        axs,
        norm=hour_norm,
        cmap=hour_cmap,
        ticks=hour_bounds[::2],
        orientation="vertical",
        cbar_ax_kws=cbar_ax_kws,
    )
    cbar.set_label("Local hour [UTC + 2]", fontsize="small")
    cbar.ax.tick_params(labelsize="small", rotation=0, pad=0.2)

    axs.set_xticks(
        np.arange(
            dataset.x.values.min(),
            dataset.x.values.max(),
            plot_conf.tick_spacing * 1e3,
        )
    )
    axs.set_yticks(
        np.arange(
            dataset.y.values.min(),
            dataset.y.values.max(),
            plot_conf.tick_spacing * 1e3,
        )
    )
    axs.set_aspect(1)

    axs.set_xlim((dataset.x.values.min(), dataset.x.values.max()))
    axs.set_ylim((dataset.y.values.min(), dataset.y.values.max()))

    xticks = axs.get_xticks()
    yticks = axs.get_yticks()

    lonxticks, latxticks = transformer.transform(xticks, [yticks[0]] * len(xticks))
    lonyticks, latyticks = transformer.transform([xticks[0]] * len(yticks), yticks)

    axs.set_xticklabels([f"{l:.1f}° E" for l in lonxticks], fontsize="xx-small", rotation=25)
    axs.set_yticklabels([f"{l:.1f}° N" for l in latyticks], fontsize="xx-small")

    # axs.xaxis.set_major_formatter(plt.NullFormatter())
    # axs.yaxis.set_major_formatter(plt.NullFormatter())

    axs.grid(lw=0.8, color="tab:gray", ls=":", zorder=11)

    for tick in axs.xaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        # tick.label1.set_visible(False)
        # tick.label2.set_visible(False)
    for tick in axs.yaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        # tick.label1.set_visible(False)
        # tick.label2.set_visible(False)

    axs.plot(
        [bbox_x[0], bbox_x[-1], bbox_x[-1], bbox_x[0], bbox_x[0]],
        [bbox_y[0], bbox_y[0], bbox_y[-1], bbox_y[-1], bbox_y[0]],
        color="tab:red",
        lw="1.5",
        zorder=15,
    )

    print(f"Saving overview figure to {outpath_}")
    plot_utils.save_figs(
        fig,
        outpath_,
        f"case_overview_cells_{graph_name}",
        extensions=["svg", "pdf", "png"],
        delete_fig=True,
        convert_colors=True,
        savefig_kwargs=dict(dpi=600, bbox_inches="tight"),
    )
    print(f"Overview figure saved to {outpath_}")

    ########################################################################################
    #######################################################################################
    # Plot overview figure with RATE

    var_name = "RATE"
    input_data_conf = {k: v for k, v in plot_conf.input_data.items() if k == var_name}

    dataset = data_utils.load_data(
        input_data_conf,
        t0_time,
        [t0_time],
        1,
        None,
    )

    # # Calculate zoom bbox
    graph_cells = graph_utils.get_cells_in_graph(graph, engine, dbconf)
    minx, miny, maxx, maxy = graph_cells.total_bounds
    min_col = np.round(((minx - dataset.x.values.min()) / 1000 - bumber_km[0]) / dataset.x.values.size, 3).item()
    max_col = np.round(((maxx - dataset.x.values.min()) / 1000 + bumber_km[1]) / dataset.x.values.size, 3).item()
    min_row = np.round(((miny - dataset.y.values.min()) / 1000 - bumber_km[2]) / dataset.y.values.size, 3).item()
    max_row = np.round(((maxy - dataset.y.values.min()) / 1000 + bumber_km[3]) / dataset.y.values.size, 3).item()

    bbox = [min_col, max_col, min_row, max_row]
    print(f"For case {graph_name}, using bbox {bbox}")
    im_width = dataset.x.values.max() - dataset.x.values.min()
    im_height = dataset.y.values.max() - dataset.y.values.min()
    bbox_x = [
        dataset.x.values.min() + im_width * bbox[0],
        dataset.x.values.min() + im_width * bbox[1],
    ]
    bbox_y = [
        dataset.y.values.min() + im_height * bbox[2],
        dataset.y.values.min() + im_height * bbox[3],
    ]

    fig, axs = plt.subplots(
        nrows=1,
        ncols=1,
        figsize=(5 + 1, 3.5),
        sharex="col",
        sharey="row",
        squeeze=True,
        constrained_layout=True,
    )
    # Borders
    border = gpd.read_file(plot_conf.map_params.border_shapefile)
    border_proj = border.to_crs(plot_conf.map_params.proj)

    segments = [np.array(linestring.coords)[:, :2] for linestring in border_proj["geometry"]]
    border_collection = LineCollection(segments, zorder=0, **plot_conf.map_params.border_plot_kwargs)

    # Radar locations
    if plot_conf.map_params.radar_shapefile is not None:
        radar_locations = gpd.read_file(plot_conf.map_params.radar_shapefile)
        radar_locations_proj = radar_locations.to_crs(plot_conf.map_params.proj)
        xy = radar_locations_proj["geometry"].map(lambda point: point.xy)
        radar_locations_proj = list(zip(*xy))
    else:
        radar_locations_proj = None

    axs.add_collection(border_collection)

    if radar_locations_proj is not None:
        axs.scatter(
            *radar_locations_proj,
            zorder=5,
            **plot_conf.map_params.radar_plot_kwargs,
        )

    cbar_ax_kws = {
        "width": "3%",  # width = 5% of parent_bbox width
        "height": "100%",
        "loc": "lower left",
        "bbox_to_anchor": (1.05, 0.0, 2, 1),
        "borderpad": 0,
    }
    cbar = plot_utils.plot_colorbar(
        axs,
        qty=input_data_conf[var_name].cmap_qty,
        orientation="vertical",
        cbar_ax_kws=cbar_ax_kws,
    )
    # cbar.set_label("Rain rate [mm h$^{-1}$]", fontsize="small")
    cbar.ax.tick_params(labelsize="small", rotation=0, pad=0.2)
    for i, (time, arr) in tqdm(enumerate(dataset.groupby("time", squeeze=False))):
        im = arr[var_name].to_numpy().squeeze()

        nan_mask = arr[f"{var_name}_nan_mask"].values.squeeze()
        zero_mask = np.isclose(im, 0)

        im[zero_mask] = np.nan
        im[nan_mask] = np.nan

        cbar = plot_utils.plot_array(
            axs,
            im.copy(),
            x=dataset.x.values,
            y=dataset.y.values,
            qty=input_data_conf[var_name].cmap_qty,
            colorbar=False,
            zorder=1,
        )

    axs.set_xticks(
        np.arange(
            dataset.x.values.min(),
            dataset.x.values.max(),
            plot_conf.tick_spacing * 1e3,
        )
    )
    axs.set_yticks(
        np.arange(
            dataset.y.values.min(),
            dataset.y.values.max(),
            plot_conf.tick_spacing * 1e3,
        )
    )
    axs.set_aspect(1)

    axs.set_xlim((dataset.x.values.min(), dataset.x.values.max()))
    axs.set_ylim((dataset.y.values.min(), dataset.y.values.max()))

    xticks = axs.get_xticks()
    yticks = axs.get_yticks()

    lonxticks, latxticks = transformer.transform(xticks, [yticks[0]] * len(xticks))
    lonyticks, latyticks = transformer.transform([xticks[0]] * len(yticks), yticks)

    axs.set_xticklabels([f"{l:.1f}° E" for l in lonxticks], fontsize="xx-small", rotation=25)
    axs.set_yticklabels([f"{l:.1f}° N" for l in latyticks], fontsize="xx-small")

    axs.grid(lw=0.8, color="tab:gray", ls=":", zorder=11)

    for tick in axs.xaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        # tick.label1.set_visible(False)
        # tick.label2.set_visible(False)
    for tick in axs.yaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        # tick.label1.set_visible(False)
        # tick.label2.set_visible(False)

    axs.plot(
        [bbox_x[0], bbox_x[-1], bbox_x[-1], bbox_x[0], bbox_x[0]],
        [bbox_y[0], bbox_y[0], bbox_y[-1], bbox_y[-1], bbox_y[0]],
        color="tab:red",
        lw="1.5",
        zorder=15,
    )

    print(f"Saving RATE overview figure to {outpath_}")
    plot_utils.save_figs(
        fig,
        outpath_,
        f"case_overview_RATE_{graph_name}",
        extensions=["svg", "pdf", "png"],
        delete_fig=True,
        convert_colors=True,
        savefig_kwargs=dict(dpi=600, bbox_inches="tight"),
    )
    print(f"Overview figure saved to {outpath_}")

    ########################################################################################

    nrows = 6
    ncols = 1

    graph_fig = plt.figure(figsize=(5, 3), layout="constrained")
    hratios = [0.8] + [1.0] * (nrows - 1)  # Last row takes more space
    # spec = graph_fig.add_gridspec(ncols=ncols, nrows=nrows, wspace=0.1, hspace=0.1, height_ratios=hratios)

    # graph_fig = whole_fig.add_subfigure(whole_gridspec[0, 0])
    graph_ax = graph_fig.add_subplot(111)

    pos = nx.multipartite_layout(graph, subset_key="timestamp")
    sort_by = "display_id"
    for tt in timestamps:
        nodes = [n for n, d in graph.nodes(data=True) if d["timestamp"] == tt]
        ycoords = sorted(list(set([pos[n][1] for n in nodes])))
        num_nodes = len(nodes)
        node_order = sorted(nodes, key=lambda n: graph.nodes[n][sort_by])[::-1]  # Sort nodes by identifier
        # print(node_order, ycoords)
        for i, n in enumerate(node_order):
            pos[n] = np.array([pos[n][0], ycoords[i]])  # Assign y-coordinates based on node order and ycoords

    # Plot graph
    graph_fig, graph_ax = graph_utils.plot_graph(
        graph,
        f"{graph_name}_{color_var}",
        pos=pos,
        time_resolution="5min",
        outpath=outpath,
        ext="png",
        highlight_nodes=[graph_name],
        highlight_color="tab:orange",
        highlight_node_edgewidth=3,
        axs=graph_ax,
        fig=graph_fig,
        savefig=False,
        fill_color="tab:gray",
        write_title=False,
        xaxis_pad=0.1,
        label_key="display_id",
    )

    print(f"Saving graph figure to {outpath_}")
    plot_utils.save_figs(
        graph_fig,
        outpath_,
        f"case_graph_{graph_name}",
        extensions=["svg", "pdf", "png"],
        delete_fig=True,
        convert_colors=True,
        savefig_kwargs=dict(dpi=600, bbox_inches="tight"),
    )
    print(f"Graph figure saved to {outpath_}")

    ########################################################################################
    timeseries_fig = plt.figure(figsize=(6, (len(plot_vars_timeseries) + 1) * 2.5), layout="constrained")
    # timeseries_fig = whole_fig.add_subfigure(whole_gridspec[1, 0])
    timeseries_axs = timeseries_fig.subplots(
        ncols=1, nrows=len(plot_vars_timeseries) + 1, sharex=True, squeeze=True
    ).flatten()

    diff_ax = timeseries_axs[-1]
    df_ = trajectory_diffs.filter((pl.col("t0_node") == graph_name)).sort("level").to_pandas()
    df_.plot(
        x="timestamp", y="t0_reldiff:rate.sum:pct", lw=2, color="#e7298a", label="Total volume rain rate", ax=diff_ax
    )
    df_.plot(
        x="timestamp", y="t0_reldiff:area:pct", lw=2, color="#66a61e", ls="dashdot", label="Total cell area", ax=diff_ax
    )
    # df_.plot(x="timestamp", y="t0_reldiff_mean:rate.sum:pct", color="C2", label="Mean rainrate", ax=diff_ax)

    diff_ax.set_ylabel("Change from $t_0$ [%]", fontsize="small")
    diff_ax.set_xlabel("Time step [5 min]")
    diff_ax.axhline(0, color="black", linestyle="--", linewidth=1.2)
    diff_ax.grid(True)

    for i, var in enumerate(plot_vars_timeseries):
        ax = timeseries_axs[i] if len(plot_vars_timeseries) > 1 else timeseries_axs
        graph_utils.plot_graph_timeseries(
            graph,
            cell_data,
            variable=var,
            ax=ax,
            linecolor="k",
            marker="o",
            markersize=20,
            annotate=True,
            majority_variable="area",
            linewidth=1.5,
            majority_linewidth=2,
            annotate_key="display_id",
        )
        ax.set_ylabel(plot_vars_titles[var], fontsize="small")

        qty = var.split(".")[0]
        stat = var.split(".")[1].capitalize() if len(var.split(".")) > 1 else ""
        cmap_qty = cmap_qtys[qty]

        ax.yaxis.set_major_formatter(ticker.FuncFormatter(QTY_FORMATS[plot_conf.input_data[cmap_qty]["cmap_qty"]]))

    minute_interval = 1 if num_timesteps < 10 else 2

    for ax in timeseries_axs:
        # ax.set_xlabel("Timestamp")
        ax.xaxis.set_major_locator(mdates.MinuteLocator(byminute=range(0, 60, 5), interval=minute_interval))
        ax.xaxis.set_minor_locator(mdates.MinuteLocator(byminute=range(0, 60, 5), interval=1))
        ax.xaxis.set_major_formatter(mdates.DateFormatter("%H:%M"))
        ax.xaxis.set_tick_params(rotation=25, labelsize="small")
        # ax.label_outer()
        ax.grid(which="major", linestyle="-", linewidth=0.5, color="lightgray")
        ax.axvline(t0_time, color="tab:orange", linestyle="--", linewidth=1.2)

    print(f"Saving timeseries figure to {outpath_}")
    plot_utils.save_figs(
        timeseries_fig,
        outpath_,
        f"case_timeseries_{graph_name}",
        extensions=["svg", "pdf", "png"],
        delete_fig=True,
        convert_colors=True,
        savefig_kwargs=dict(dpi=600, bbox_inches="tight"),
    )
    print(f"Timeseries figure saved to {outpath_}")

0.4781936888953552 2.091202839397643 34.5 5.02103373340123
Plotting VIL
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_decoder DBG verbl=0
py_decoder DBG verb=0
py_dec



Done


In [38]:
var_name = "RATE"
input_data_conf = {k: v for k, v in plot_conf.input_data.items() if k == var_name}

dataset = data_utils.load_data(
    input_data_conf,
    t0_time,
    [t0_time],
    1,
    None,
)

dataset.x.values[10:500:10]

py_decoder DBG verb=0
py_decoder DBG verbl=0


array([2265038.91524988, 2275053.55916223, 2285068.20307458,
       2295082.84698694, 2305097.49089929, 2315112.13481164,
       2325126.77872399, 2335141.42263634, 2345156.0665487 ,
       2355170.71046105, 2365185.3543734 , 2375199.99828575,
       2385214.6421981 , 2395229.28611046, 2405243.93002281,
       2415258.57393516, 2425273.21784751, 2435287.86175987,
       2445302.50567222, 2455317.14958457, 2465331.79349692,
       2475346.43740928, 2485361.08132163, 2495375.72523398,
       2505390.36914633, 2515405.01305868, 2525419.65697104,
       2535434.30088339, 2545448.94479574, 2555463.58870809,
       2565478.23262045, 2575492.8765328 , 2585507.52044515,
       2595522.1643575 , 2605536.80826985, 2615551.45218221,
       2625566.09609456, 2635580.74000691, 2645595.38391926,
       2655610.02783162, 2665624.67174397, 2675639.31565632,
       2685653.95956867, 2695668.60348102, 2705683.24739338,
       2715697.89130573, 2725712.53521808, 2735727.17913043,
       2745741.82304279]

In [None]:
# Plot overviews in one figure
from matplotlib.patches import Polygon
import shapely

# Plot overviews in one figure
ncols = len(CASE_T0_NODES)
nrows = 1
figwidth = (3.0 * ncols + 1) * 2
figheight = (4.0 * nrows) * 2
fig = plt.figure(figsize=(figwidth, figheight))
axs = fig.subplots(
    nrows=nrows,
    ncols=ncols,
    sharex=True,
    sharey=True,
    gridspec_kw={"hspace": 0.01, "wspace": 0.0},
)

var_name = "RATE"
input_data_conf = {k: v for k, v in plot_conf.input_data.items() if k == var_name}

for order, graph_name in enumerate(CASE_T0_NODES):
    outpath_ = outpath / graph_name
    outpath_.mkdir(exist_ok=True, parents=True)

    t0_time = pd.to_datetime(graph_name.split("_")[0])
    subpath = t0_time.strftime("%Y%m%d")
    graph_path = storage_path / "subgraphs" / subpath

    graph = nx.read_gml(graph_path / f"subgraph_{graph_name}.gml", destringizer=graph_utils.string_decoder)

    timestamps = [n["timestamp"] for _, n in graph.nodes(data=True)]
    timestamps = pd.to_datetime(timestamps)
    timestamps = pd.Series(timestamps).sort_values().drop_duplicates()

    xmin = timestamps.min()
    xmax = timestamps.max()

    num_timesteps = len(timestamps)
    num_time_resolutions = int((xmax - xmin) / pd.Timedelta("5min"))
    len_history = int((t0_time - xmin) / pd.Timedelta("5min"))
    len_future = int((xmax - t0_time) / pd.Timedelta("5min"))

    # # Calculate zoom bbox
    bumber_km = bumber_kms[graph_name]

    dataset = data_utils.load_data(
        input_data_conf,
        t0_time,
        [t0_time],
        1,
        None,
    )

    # # Calculate zoom bbox
    graph_cells = graph_utils.get_cells_in_graph(graph, engine, dbconf)
    minx, miny, maxx, maxy = graph_cells.total_bounds
    min_col = np.round(((minx - dataset.x.values.min()) / 1000 - bumber_km[0]) / dataset.x.values.size, 3).item()
    max_col = np.round(((maxx - dataset.x.values.min()) / 1000 + bumber_km[1]) / dataset.x.values.size, 3).item()
    min_row = np.round(((miny - dataset.y.values.min()) / 1000 - bumber_km[2]) / dataset.y.values.size, 3).item()
    max_row = np.round(((maxy - dataset.y.values.min()) / 1000 + bumber_km[3]) / dataset.y.values.size, 3).item()

    bbox = [min_col, max_col, min_row, max_row]
    im_width = dataset.x.values.max() - dataset.x.values.min()
    im_height = dataset.y.values.max() - dataset.y.values.min()
    bbox_x = [
        dataset.x.values.min() + im_width * bbox[0],
        dataset.x.values.min() + im_width * bbox[1],
    ]
    bbox_y = [
        dataset.y.values.min() + im_height * bbox[2],
        dataset.y.values.min() + im_height * bbox[3],
    ]

    # Borders
    border = gpd.read_file("../shapefiles/Europe_merged.shp")
    border = border[border["GID_0"].isin(["CHE", "DEU", "ITA", "FRA", "LUX"])]
    border_proj = border.to_crs(plot_conf.map_params.proj)

    polys = []
    for multipoly in border_proj["geometry"]:
        if hasattr(multipoly, "geoms"):
            for poly in multipoly.geoms:
                polys.append(poly)
        else:
            polys.append(multipoly)

    segments = [np.array(poly.exterior.coords) for poly in polys]
    border_collection = LineCollection(segments, zorder=0, **plot_conf.map_params.border_plot_kwargs)

    # Plot rivers and lakes
    riverlake_alpha = 0.5
    riverlake_color = "#386cb0"
    riverlake_lw = 0.5

    xmin, ymin, xmax, ymax = (
        dataset.x.values.min(),
        dataset.y.values.min(),
        dataset.x.values.max(),
        dataset.y.values.max(),
    )
    clip_poly = shapely.geometry.box(xmin, ymin, xmax, ymax)

    rivers_segments = [np.array(poly.coords) for poly in polys]
    rivers_collection = LineCollection(
        rivers_segments, zorder=0, color=riverlake_color, lw=riverlake_lw, alpha=riverlake_alpha
    )

    lakes = gpd.read_file("../shapefiles/k4seenyyyymmdd11_ch2007Poly.shp").to_crs(plot_conf.map_params.proj)
    polys = []
    for multipoly in lakes["geometry"]:
        if hasattr(multipoly, "geoms"):
            for poly in multipoly.geoms:
                polys.append(poly)
        else:
            polys.append(multipoly)

    lakes = gpd.read_file("../shapefiles/ne_10m_lakes_europe.shp").to_crs(plot_conf.map_params.proj)
    lakes = lakes.clip(clip_poly)
    for multipoly in lakes["geometry"]:
        # if overlaps with poly, dont add
        overlaps = False
        for p in polys:
            if multipoly.intersects(p):
                overlaps = True
                break
        if overlaps:
            continue
        if hasattr(multipoly, "geoms"):
            for poly in multipoly.geoms:
                polys.append(poly)
        else:
            polys.append(multipoly)

    lakes = gpd.read_file("../shapefiles/ne_10m_geography_marine_polys.shp").to_crs(plot_conf.map_params.proj)
    lakes = lakes[lakes["name"].isin(["Mediterranean Sea", "Tyrrhenian Sea", "Ligurian Sea"])].dissolve().explode()
    for multipoly in lakes["geometry"]:
        if hasattr(multipoly, "geoms"):
            for poly in multipoly.geoms:
                polys.append(poly)
        else:
            polys.append(multipoly)
    for poly in polys:
        p = Polygon(
            np.array(poly.exterior.coords),
            zorder=0,
            edgecolor=None,
            color=riverlake_color,
            lw=0.0,
            alpha=riverlake_alpha,
        )
        ax.add_patch(p)

    # Radar locations
    if plot_conf.map_params.radar_shapefile is not None:
        radar_locations = gpd.read_file(plot_conf.map_params.radar_shapefile)
        radar_locations_proj = radar_locations.to_crs(plot_conf.map_params.proj)
        xy = radar_locations_proj["geometry"].map(lambda point: point.xy)
        radar_locations_proj = list(zip(*xy))
    else:
        radar_locations_proj = None

    ax = axs[CASE_T0_NODES.index(graph_name)] if ncols > 1 else axs

    ax.add_collection(border_collection)

    if radar_locations_proj is not None:
        ax.scatter(
            *radar_locations_proj,
            zorder=5,
            **plot_conf.map_params.radar_plot_kwargs,
        )
    for i, (time, arr) in tqdm(enumerate(dataset.groupby("time", squeeze=False))):
        im = arr[var_name].to_numpy().squeeze()

        nan_mask = arr[f"{var_name}_nan_mask"].values.squeeze()
        zero_mask = np.isclose(im, 0)

        im[zero_mask] = np.nan
        im[nan_mask] = np.nan

        cbar = plot_utils.plot_array(
            ax,
            im.copy(),
            x=dataset.x.values,
            y=dataset.y.values,
            qty=input_data_conf[var_name].cmap_qty,
            colorbar=False,
            zorder=1,
        )
        tt = pd.to_datetime(time)
        ax.set_title(f"({alphabet[order]}) {tt:%Y-%m-%d %H:%M} UTC", y=1.0, fontsize="large")
    ax.plot(
        [bbox_x[0], bbox_x[-1], bbox_x[-1], bbox_x[0], bbox_x[0]],
        [bbox_y[0], bbox_y[0], bbox_y[-1], bbox_y[-1], bbox_y[0]],
        color=qualitative_colors[0],
        lw=3,
        zorder=15,
    )

for ax in axs.flatten():
    ax.set_xticks(
        np.arange(
            dataset.x.values.min(),
            dataset.x.values.max(),
            plot_conf.tick_spacing * 1e3,
        )
    )
    ax.set_yticks(
        np.arange(
            dataset.y.values.min(),
            dataset.y.values.max(),
            plot_conf.tick_spacing * 1e3,
        )
    )
    ax.set_aspect(1)

    ax.set_xlim((dataset.x.values.min(), dataset.x.values.max()))
    ax.set_ylim((dataset.y.values.min(), dataset.y.values.max()))

    xticks = ax.get_xticks()
    yticks = ax.get_yticks()

    lonxticks, latxticks = transformer.transform(xticks, [yticks[0]] * len(xticks))
    lonyticks, latyticks = transformer.transform([xticks[0]] * len(yticks), yticks)

    ax.set_xticklabels([f"{l:.1f}° E" for l in lonxticks], fontsize="small")
    ax.xaxis.set_tick_params(rotation=25, pad=1)
    ax.set_yticklabels([f"{l:.1f}° N" for l in latyticks], fontsize="small")

    ax.grid(lw=0.8, color="tab:gray", ls=":", zorder=11)

    for tick in ax.xaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        # tick.label1.set_visible(False)
        # tick.label2.set_visible(False)
    for tick in ax.yaxis.get_major_ticks():
        tick.tick1line.set_visible(False)
        tick.tick2line.set_visible(False)
        # tick.label1.set_visible(False)
        # tick.label2.set_visible(False)
    ax.label_outer()

cbar_ax_kws = {
    "width": "3%",  # width = 5% of parent_bbox width
    "height": "100%",
    "loc": "lower left",
    "bbox_to_anchor": (1.02, 0.0, 1.2, 1),
    "borderpad": 0,
}
cbar = plot_utils.plot_colorbar(
    axs[-1],
    qty=input_data_conf[var_name].cmap_qty,
    orientation="vertical",
    cbar_ax_kws=cbar_ax_kws,
)
cbar.set_label("Rain rate [mm h$^{-1}$]", fontsize="large")
cbar.ax.tick_params(labelsize="small", rotation=0, pad=5)

fig.subplots_adjust(wspace=0.05, hspace=0.01)

plot_utils.save_figs(
    fig,
    outpath,
    f"case_overview_RATE_all",
    extensions=["svg", "pdf", "png"],
    delete_fig=True,
    convert_colors=True,
    savefig_kwargs=dict(dpi=600, bbox_inches="tight"),
)

py_decoder DBG verb=0
py_decoder DBG verbl=0


0it [00:00, ?it/s]

1it [00:00, 14.02it/s]


py_decoder DBG verb=0
py_decoder DBG verbl=0


1it [00:00, 14.95it/s]
