In [44]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import sys
from datetime import datetime
import polars as pl
from sqlalchemy.orm import Session

from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import colors as mcolors, ticker, patches
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from pprint import pprint

import string
alphabet = string.ascii_lowercase

%reload_ext autoreload
%autoreload 2
%matplotlib inline

sys.path.append("..")
sys.path.append("../..")

from utils import graph_utils
from database import get_engine
from utils.config_utils import load_config
from utils import plot_utils

from utils import plots
from utils import quantity_definitions as QD


# Load and pre-process data


## Configurations


In [45]:
# matplotlib use style
plt.style.use("../stylefiles/article.mplstyle")

In [46]:
OUTPATH = Path("/data/jenna/rain-cell-stats-analysis/figures/publish_data_v20251114")
OUTPATH.mkdir(parents=True, exist_ok=True)

In [47]:
dbconf_file = Path(".").resolve().parent / "config/database/database.yaml"
dbconf = load_config(dbconf_file)
engine = get_engine(dbconf)

### Color & plotting definitions


In [48]:
qualitative_colors = [
    "#e7298a",
    "#66a61e",
    "#76CD74",
    "#8F1A65",
    "#386cb0",
    # "#e9a3c9",
    # "#a1d76a",
    # "#66c2a5",
    # "#fc8d62",
    # "#8da0cb",
]

LINEPLOT_KWS_CI_95 = {
    "errorbar": ("ci", 95),
}
LINEPLOT_KWS_NO_CI = dict(errorbar=None, lw=1.5, alpha=1.0)

count_linekwargs = {
    "ls": "--",
    "lw": 1,
    "color": "black",
    "label": "Subgraph count",
}

## Cell subgraphs from file


In [49]:
# storagepath = Path("/data/jenna/rain-cell-stats-analysis/cell_graph_data_v20250827")
storagepath = Path("/data/jenna/rain-cell-stats-analysis/publish_data_v20251105_v2")

fileglob = "cell_trajectories_*_*.parquet"

# Load the data
TRAJECTORIES = pl.read_parquet(storagepath / fileglob)

TRAJECTORIES

method,type,identifier,t0_node,timestamp,level,area,event,num_cells_at_level,t0_time
str,str,i64,str,datetime[μs],i32,f64,str,u32,datetime[μs]
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2021-05-01T15:50:00_1""",2021-05-01 15:45:00,-1,20.004965,"""simple.""",2,2021-05-01 15:50:00
"""opencv_vil_1.0:minArea_10:clus…","""merged""",2,"""2021-05-01T15:50:00_1""",2021-05-01 15:45:00,-1,15.503848,"""born.""",2,2021-05-01 15:50:00
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2021-05-01T15:50:00_1""",2021-05-01 15:50:00,0,48.512041,"""merged_midnode.""",1,2021-05-01 15:50:00
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2021-05-01T15:50:00_1""",2021-05-01 15:55:00,1,47.011668,"""simple""",1,2021-05-01 15:50:00
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2021-05-01T15:50:00_1""",2021-05-01 16:00:00,2,12.503103,"""died.""",1,2021-05-01 15:50:00
…,…,…,…,…,…,…,…,…,…
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:05:00,2,100.52495,"""simple""",1,2023-09-23 18:55:00
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:10:00,3,142.035253,"""simple""",1,2023-09-23 18:55:00
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:15:00,4,165.541087,"""simple""",1,2023-09-23 18:55:00
"""opencv_vil_1.0:minArea_10:clus…","""merged""",1,"""2023-09-23T18:55:00_1""",2023-09-23 19:20:00,5,153.538108,"""simple""",1,2023-09-23 18:55:00


In [50]:
pprint(TRAJECTORIES.select("type").unique().to_series().sort().to_list())

['merged', 'split', 'split-merge']


In [51]:
SPLIT_MERGE_TRAJECTORIES = TRAJECTORIES.filter(pl.col("type").is_in(["split", "merged", "split-merge"]))

T0_TYPES = SPLIT_MERGE_TRAJECTORIES.select("type").unique().to_series().sort().to_list()


split_t0_nodes = (
    SPLIT_MERGE_TRAJECTORIES.filter(pl.col("type") == "split").select("t0_node").unique().to_series().sort().to_list()
)
merge_t0_nodes = (
    SPLIT_MERGE_TRAJECTORIES.filter(pl.col("type") == "merged").select("t0_node").unique().to_series().sort().to_list()
)
split_merge_t0_nodes = (
    SPLIT_MERGE_TRAJECTORIES.filter(pl.col("type") == "split-merge")
    .select("t0_node")
    .unique()
    .to_series()
    .sort()
    .to_list()
)

T0_NODES = sorted(set(split_t0_nodes + merge_t0_nodes + split_merge_t0_nodes))
print(f"Number of split t0 nodes: {len(split_t0_nodes)}")
print(f"Number of merge t0 nodes: {len(merge_t0_nodes)}")
print(f"Number of split-merge t0 nodes: {len(split_merge_t0_nodes)}")
print(f"Total number of t0 nodes: {len(T0_NODES)}")

Number of split t0 nodes: 21951
Number of merge t0 nodes: 23391
Number of split-merge t0 nodes: 7717
Total number of t0 nodes: 53059


## Cell feature values from database


In [52]:
START_DATE = datetime(2021, 5, 1)
END_DATE = datetime(2023, 10, 1)

FETCH_QUANTITIES = [
    "vil",
    "rate",
    "zdrcol_custom_filt_unique_d1",
    "dist_from_radars",
]

FETCH_STATS = [
    "mean",
    "median",
    "min",
    "count",
    "sum",
    "max",
]

SQL_QUERY = """

 SELECT cell."timestamp",
    cell.identifier,
    cell.method,
    -- cell.geometry,
    ST_Area(cell.geometry) / 1e6 AS area,
    stats.quantity,
    stats.statistic,
    stats.value
   FROM raincells.stormcells cell
     JOIN raincells.stormcell_rasterstats stats ON
        cell.identifier = stats.identifier
        AND cell."timestamp" = stats."timestamp"
        AND cell.method::text = stats.method::text

    WHERE 1 = 1
    AND cell."timestamp" >= '{start_date}'
    AND cell."timestamp" <= '{end_date}'
    AND stats.quantity IN ({quantities})
    AND stats.statistic IN ({stats})

"""

with Session(engine) as session:
    query = SQL_QUERY.format(
        start_date=START_DATE.strftime("%Y-%m-%d %H:%M"),
        end_date=END_DATE.strftime("%Y-%m-%d %H:%M"),
        quantities=",".join([f"'{q}'" for q in FETCH_QUANTITIES]),
        stats=",".join([f"'{s}'" for s in FETCH_STATS]),
    )
    DATA = pl.read_database(
        query=query,
        connection=session.bind,
    )

    DATA = DATA.with_columns(pl.format("{}.{}", "quantity", "statistic").alias("on")).pivot(
        on="on",
        index=set(DATA.columns) - set(["quantity", "statistic", "value"]),
        values=["value"],
    )

In [53]:
# Add columns for local hour and month
DATA = DATA.with_columns(
    local_hour=pl.col("timestamp").dt.replace_time_zone("UTC").dt.convert_time_zone("Europe/Zurich").dt.hour(),
    month=pl.col("timestamp").dt.month(),
)

In [54]:
pprint(sorted(DATA.columns))

['area',
 'dist_from_radars.min',
 'identifier',
 'local_hour',
 'method',
 'month',
 'rate.count',
 'rate.max',
 'rate.mean',
 'rate.median',
 'rate.sum',
 'timestamp',
 'vil.count',
 'vil.max',
 'vil.mean',
 'vil.median',
 'vil.sum',
 'zdrcol_custom_filt_unique_d1.count',
 'zdrcol_custom_filt_unique_d1.max',
 'zdrcol_custom_filt_unique_d1.mean',
 'zdrcol_custom_filt_unique_d1.median',
 'zdrcol_custom_filt_unique_d1.min',
 'zdrcol_custom_filt_unique_d1.sum']


In [55]:
# Get all cells that are a part of a cell track
SQL_QUERY = """

    SELECT
        timestamp,
        identifier,
        method

    FROM raincells.track_graphs

    WHERE 1 = 1
        AND timestamp >= '{start_date}'
        AND timestamp <= '{end_date}'

"""

with Session(engine) as session:
    query = SQL_QUERY.format(
        start_date=START_DATE.strftime("%Y-%m-%d %H:%M"),
        end_date=END_DATE.strftime("%Y-%m-%d %H:%M"),
    )
    cells_in_tracks = pl.read_database(
        query=query,
        connection=session.bind,
    )

In [56]:
# Only pick cells that are part of a track
DATA = DATA.join(cells_in_tracks, on=["timestamp", "identifier", "method"], how="inner", suffix="_in_track")

## Subgraph development


In [57]:
data_ = DATA.with_columns((pl.col("area") * pl.col("rate.mean")).alias("area_w_rate.mean"))

sum_vars = [
    "area",
    "rate.sum",
]
max_vars = [
    "vil.max",
    "rate.max",
]
trajectory_diffs = graph_utils.calculate_diff_variables(
    trajectories=TRAJECTORIES,
    data=data_,
    sum_vars=sum_vars,
    max_vars=max_vars,
)

TRAJECTORY_DEVELOPMENT = trajectory_diffs.with_columns(
    pl.col("level").min().over("t0_node", "type").alias("min_level_available"),
    pl.col("level").max().over("t0_node", "type").alias("max_level_available"),
    pl.col("vil.max").max().over("t0_node", "type").alias("vil.max_in_trajectory"),
    pl.col("rate.max").max().over("t0_node", "type").alias("rate.max_in_trajectory"),
    ((pl.col("rate.sum") / pl.col("area")) / (pl.col("t0_rate.sum") / pl.col("t0_area")) - 1).alias(
        "t0_reldiff_mean:rate.sum"
    ),
    pl.col("t0_node").str.split("_").list.first().str.strptime(pl.Datetime, "%Y-%m-%dT%H:%M:%S").alias("t0_timestamp"),
).with_columns(
    (pl.col("t0_reldiff_mean:rate.sum") * 100).alias("t0_reldiff_mean:rate.sum:pct"),
    pl.col("t0_timestamp")
    .dt.replace_time_zone("UTC")
    .dt.convert_time_zone("Europe/Zurich")
    .dt.hour()
    .alias("t0_local_hour"),
    pl.col("t0_timestamp").dt.month().alias("t0_month"),
    pl.col("t0_timestamp")
    .dt.replace_time_zone("UTC")
    .dt.convert_time_zone("Europe/Zurich")
    .dt.hour()
    .alias("local_hour"),
    pl.col("t0_timestamp").dt.month().alias("month"),
)

TRAJECTORY_DEVELOPMENT

type,timestamp,t0_node,level,num_cells_at_level,area,rate.sum,vil.max,rate.max,t0_reldiff:area,t0_reldiff:rate.sum,t0_reldiff:vil.max,t0_reldiff:rate.max,t0_absdiff:area,t0_absdiff:rate.sum,t0_absdiff:vil.max,t0_absdiff:rate.max,t0_area,t0_rate.sum,t0_vil.max,t0_rate.max,t0_reldiff:area:pct,t0_reldiff:rate.sum:pct,t0_reldiff:vil.max:pct,t0_reldiff:rate.max:pct,min_level_available,max_level_available,vil.max_in_trajectory,rate.max_in_trajectory,t0_reldiff_mean:rate.sum,t0_timestamp,t0_reldiff_mean:rate.sum:pct,t0_local_hour,t0_month,local_hour,month
str,datetime[μs],str,i32,u32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i32,i32,f64,f64,f64,datetime[μs],f64,i8,i8,i8,i8
"""split""",2022-05-16 11:20:00,"""2022-05-16T11:20:00_4""",0,1,31.50782,739.330505,3.0,37.054626,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,31.50782,739.330505,3.0,37.054626,0.0,0.0,0.0,0.0,-3,6,7.0,110.430473,0.0,2022-05-16 11:20:00,0.0,13,5,13,5
"""split""",2022-05-25 13:10:00,"""2022-05-25T13:15:00_4""",-1,1,191.047418,12717.0,21.5,118.428223,-0.094787,0.195017,0.194444,0.0,-20.004965,2075.314453,3.5,0.0,211.052383,10641.685547,18.0,118.428223,-9.478673,19.501746,19.444444,0.0,-3,6,23.0,118.428223,0.32015,2022-05-25 13:15:00,32.015017,15,5,15,5
"""split""",2021-06-11 02:50:00,"""2021-06-11T03:20:00_1""",-6,1,29.0072,862.493958,4.0,56.680031,-0.625806,-0.537519,-0.384615,-0.409619,-48.512041,-1002.435608,-2.5,-39.325829,77.51924,1864.929565,6.5,96.005859,-62.580645,-53.751929,-38.461538,-40.961905,-6,4,6.5,96.005859,0.23594,2021-06-11 03:20:00,23.593984,5,6,5,6
"""split""",2023-09-17 16:45:00,"""2023-09-17T16:40:00_2""",1,2,378.093843,14240.469727,23.0,120.0,0.037037,0.080256,0.095238,0.0,13.503352,1057.969727,2.0,0.0,364.590492,13182.5,21.0,120.0,3.703704,8.025562,9.52381,0.0,-6,5,26.5,120.0,0.041675,2023-09-17 16:40:00,4.167506,18,9,18,9
"""merged""",2021-06-29 09:10:00,"""2021-06-29T09:10:00_13""",0,1,1902.972319,49392.515625,24.5,118.428223,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1902.972319,49392.515625,24.5,118.428223,0.0,0.0,0.0,0.0,-1,2,24.5,118.428223,0.0,2021-06-29 09:10:00,0.0,11,6,11,6
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
"""split""",2023-08-25 08:20:00,"""2023-08-25T08:20:00_10""",0,1,173.543074,898.140015,1.5,7.66,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,173.543074,898.140015,1.5,7.66,0.0,0.0,0.0,0.0,-2,4,2.5,8.91,0.0,2023-08-25 08:20:00,0.0,10,8,10,8
"""merged""",2023-06-14 00:45:00,"""2023-06-14T00:25:00_3""",4,1,696.672915,14072.789062,5.5,62.68,0.260633,0.275947,-0.153846,-0.073466,144.03575,3043.499023,-1.0,-4.970001,552.637165,11029.290039,6.5,67.650002,26.063348,27.594696,-15.384615,-7.346639,-3,4,6.5,68.629997,0.012147,2023-06-14 00:25:00,1.214744,2,6,2,6
"""split-merge""",2022-09-14 12:30:00,"""2022-09-14T12:20:00_8""",2,2,87.521723,4710.276489,12.0,118.428223,0.258993,0.975582,2.428571,0.947195,18.004469,2326.028442,8.5,57.608299,69.517254,2384.248047,3.5,60.819923,25.899281,97.558157,242.857143,94.719454,-5,6,18.5,118.428223,0.569176,2022-09-14 12:20:00,56.917621,14,9,14,9
"""merged""",2023-07-24 10:05:00,"""2023-07-24T10:30:00_19""",-5,2,327.581306,5470.479858,5.0,51.310001,-0.217443,-0.30721,-0.166667,-0.158439,-91.022592,-2425.819946,-1.0,-9.66,418.603898,7896.299805,6.0,60.970001,-21.744325,-30.720971,-16.666667,-15.843857,-5,6,7.0,86.150002,-0.114709,2023-07-24 10:30:00,-11.47092,12,7,12,7


In [58]:
del data_

In [59]:
pprint(sorted(TRAJECTORY_DEVELOPMENT.columns))

['area',
 'level',
 'local_hour',
 'max_level_available',
 'min_level_available',
 'month',
 'num_cells_at_level',
 'rate.max',
 'rate.max_in_trajectory',
 'rate.sum',
 't0_absdiff:area',
 't0_absdiff:rate.max',
 't0_absdiff:rate.sum',
 't0_absdiff:vil.max',
 't0_area',
 't0_local_hour',
 't0_month',
 't0_node',
 't0_rate.max',
 't0_rate.sum',
 't0_reldiff:area',
 't0_reldiff:area:pct',
 't0_reldiff:rate.max',
 't0_reldiff:rate.max:pct',
 't0_reldiff:rate.sum',
 't0_reldiff:rate.sum:pct',
 't0_reldiff:vil.max',
 't0_reldiff:vil.max:pct',
 't0_reldiff_mean:rate.sum',
 't0_reldiff_mean:rate.sum:pct',
 't0_timestamp',
 't0_vil.max',
 'timestamp',
 'type',
 'vil.max',
 'vil.max_in_trajectory']


## Split and merge events from database


In [60]:
# All split merge cells
SQL_QUERY = """

 SELECT "timestamp", identifier, method, prev_identifiers, next_identifiers, cur_area, prev_areas, next_areas
	FROM raincells.cells_with_parents_children_with_areas

    WHERE 1 = 1
    AND "timestamp" >= '{start_date}'
    AND "timestamp" <= '{end_date}'
"""

with Session(engine) as session:
    query = SQL_QUERY.format(
        start_date=START_DATE.strftime("%Y-%m-%d %H:%M"),
        end_date=END_DATE.strftime("%Y-%m-%d %H:%M"),
    )
    SPLIT_MERGES = pl.read_database(
        query=query,
        connection=session.bind,
    ).unique()

In [61]:
SPLIT_MERGES = SPLIT_MERGES.with_columns(
    node=pl.format("{}_{}", pl.col("timestamp").dt.strftime("%Y-%m-%dT%H:%M:%S"), pl.col("identifier")),
    num_successors=pl.col("next_identifiers").list.len(),
    num_predecessors=pl.col("prev_identifiers").list.len(),
).filter(pl.col("node").is_in(T0_NODES))

SPLIT_CELLS = SPLIT_MERGES.filter((pl.col("num_successors") > 1) & (pl.col("num_predecessors") < 2)).with_columns(
    type=pl.lit("split")
)
MERGE_CELLS = SPLIT_MERGES.filter((pl.col("num_predecessors") > 1) & (pl.col("num_successors") < 2)).with_columns(
    type=pl.lit("merged")
)
SPLIT_MERGE_CELLS = SPLIT_MERGES.filter((pl.col("num_successors") > 1) & (pl.col("num_predecessors") > 1)).with_columns(
    type=pl.lit("split-merge")
)

SPLIT_MERGES_CELLS = pl.concat([SPLIT_CELLS, MERGE_CELLS, SPLIT_MERGE_CELLS]).with_columns(
    local_hour=pl.col("timestamp").dt.replace_time_zone("UTC").dt.convert_time_zone("Europe/Zurich").dt.hour(),
    month=pl.col("timestamp").dt.month(),
)
SPLIT_MERGES_CELLS = SPLIT_MERGES_CELLS.join(
    DATA, on=["timestamp", "identifier", "method"], how="inner", suffix="_data"
)

In [62]:
SPLIT_MERGES_CELLS.join(DATA, on=["timestamp", "identifier", "method"], how="inner", suffix="_data")

timestamp,identifier,method,prev_identifiers,next_identifiers,cur_area,prev_areas,next_areas,node,num_successors,num_predecessors,type,local_hour,month,area,zdrcol_custom_filt_unique_d1.count,rate.sum,rate.count,rate.mean,vil.median,vil.sum,vil.count,rate.max,vil.mean,vil.max,rate.median,dist_from_radars.min,zdrcol_custom_filt_unique_d1.sum,zdrcol_custom_filt_unique_d1.min,zdrcol_custom_filt_unique_d1.mean,zdrcol_custom_filt_unique_d1.max,zdrcol_custom_filt_unique_d1.median,local_hour_data,month_data,area_data,zdrcol_custom_filt_unique_d1.count_data,rate.sum_data,rate.count_data,rate.mean_data,vil.median_data,vil.sum_data,vil.count_data,rate.max_data,vil.mean_data,vil.max_data,rate.median_data,dist_from_radars.min_data,zdrcol_custom_filt_unique_d1.sum_data,zdrcol_custom_filt_unique_d1.min_data,zdrcol_custom_filt_unique_d1.mean_data,zdrcol_custom_filt_unique_d1.max_data,zdrcol_custom_filt_unique_d1.median_data,local_hour_data,month_data
datetime[μs],i64,str,list[i64],list[i64],f64,list[f64],list[f64],str,u32,u32,str,i8,i8,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i8,i8,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i8.1,i8.1
2022-05-01 20:15:00,1,"""opencv_vil_1.0:minArea_10:clus…","[1, 2]",[1],32.508069,"[11.502855, 16.003972]",[32.007944],"""2022-05-01T20:15:00_1""",1,2,"""merged""",22,5,32.508069,,591.512329,49.0,12.07168,1.5,64.0,45.0,27.840015,1.422222,3.0,11.996038,44.342797,,,,,,22,5,32.508069,,591.512329,49.0,12.07168,1.5,64.0,45.0,27.840015,1.422222,3.0,11.996038,44.342797,,,,,,22,5
2022-05-03 14:55:00,1,"""opencv_vil_1.0:minArea_10:clus…",[1],"[1, 2]",103.525695,[85.021102],"[39.009682, 84.020854]","""2022-05-03T14:55:00_1""",2,1,"""split""",16,5,103.525695,,6320.095703,142.0,44.507716,2.0,320.5,111.0,118.428223,2.887387,11.0,33.296749,220.651189,,,,,,16,5,103.525695,,6320.095703,142.0,44.507716,2.0,320.5,111.0,118.428223,2.887387,11.0,33.296749,220.651189,,,,,,16,5
2022-05-03 15:05:00,1,"""opencv_vil_1.0:minArea_10:clus…","[1, 2]",[1],152.037736,"[39.009682, 84.020854]",[106.026316],"""2022-05-03T15:05:00_1""",1,2,"""merged""",17,5,152.037736,,9534.526367,192.0,49.658991,2.5,494.0,157.0,118.428223,3.146497,10.0,34.506222,219.274874,,,,,,17,5,152.037736,,9534.526367,192.0,49.658991,2.5,494.0,157.0,118.428223,3.146497,10.0,34.506222,219.274874,,,,,,17,5
2022-05-03 15:30:00,1,"""opencv_vil_1.0:minArea_10:clus…","[1, 2]",[1],142.035253,"[16.504096, 105.026068]",[94.023337],"""2022-05-03T15:30:00_1""",1,2,"""merged""",17,5,142.035253,,6340.787109,178.0,35.622399,2.0,317.5,149.0,118.428223,2.130872,6.0,28.857056,221.461817,,,,,,17,5,142.035253,,6340.787109,178.0,35.622399,2.0,317.5,149.0,118.428223,2.130872,6.0,28.857056,221.461817,,,,,,17,5
2022-05-04 19:20:00,1,"""opencv_vil_1.0:minArea_10:clus…",[3],"[1, 2]",57.514275,[51.512786],"[15.503848, 22.505586]","""2022-05-04T19:20:00_1""",2,1,"""split""",21,5,57.514275,,1308.713867,81.0,16.156961,1.0,66.0,66.0,22.425371,1.0,1.0,16.753111,22.49327,,,,,,21,5,57.514275,,1308.713867,81.0,16.156961,1.0,66.0,66.0,22.425371,1.0,1.0,16.753111,22.49327,,,,,,21,5
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
2021-06-09 14:00:00,76,"""opencv_vil_1.0:minArea_10:clus…",[69],"[76, 77]",283.070258,[285.570879],"[28.507075, 238.059086]","""2021-06-09T14:00:00_76""",2,1,"""split""",16,6,283.070258,20.0,6825.06543,338.0,20.192501,1.5,970.0,309.0,118.428223,3.139159,21.5,11.125732,86.721541,27800.0,800.0,1390.0,2800.0,1400.0,16,6,283.070258,20.0,6825.06543,338.0,20.192501,1.5,970.0,309.0,118.428223,3.139159,21.5,11.125732,86.721541,27800.0,800.0,1390.0,2800.0,1400.0,16,6
2021-06-09 13:50:00,80,"""opencv_vil_1.0:minArea_10:clus…","[73, 74]",[81],96.523957,"[19.504841, 65.516261]",[101.025075],"""2021-06-09T13:50:00_80""",1,2,"""merged""",15,6,96.523957,,4860.381836,133.0,36.544224,2.0,377.0,121.0,118.428223,3.115702,12.0,20.112127,184.65724,,,,,,15,6,96.523957,,4860.381836,133.0,36.544224,2.0,377.0,121.0,118.428223,3.115702,12.0,20.112127,184.65724,,,,,,15,6
2021-06-09 14:00:00,80,"""opencv_vil_1.0:minArea_10:clus…","[74, 75]",[80],92.02284,"[16.003972, 47.511792]",[118.529419],"""2021-06-09T14:00:00_80""",1,2,"""merged""",16,6,92.02284,11.0,4146.621094,128.0,32.395477,2.5,502.5,121.0,118.428223,4.152893,18.0,17.087845,122.016704,14600.0,800.0,1327.272727,2200.0,800.0,16,6,92.02284,11.0,4146.621094,128.0,32.395477,2.5,502.5,121.0,118.428223,4.152893,18.0,17.087845,122.016704,14600.0,800.0,1327.272727,2200.0,800.0,16,6
2021-06-09 14:45:00,80,"""opencv_vil_1.0:minArea_10:clus…","[84, 87]",[78],241.059831,"[11.502855, 157.038977]",[281.069762],"""2021-06-09T14:45:00_80""",1,2,"""merged""",16,6,241.059831,,9023.541016,297.0,30.382293,2.0,853.5,263.0,118.428223,3.245247,13.5,17.379173,141.138239,,,,,,16,6,241.059831,,9023.541016,297.0,30.382293,2.0,853.5,263.0,118.428223,3.245247,13.5,17.379173,141.138239,,,,,,16,6


In [63]:
SPLIT_MERGES_CELLS.select("num_successors", "num_predecessors").describe()

statistic,num_successors,num_predecessors
str,f64,f64
"""count""",53059.0,53059.0
"""null_count""",0.0,0.0
"""mean""",1.629526,1.665523
"""std""",0.639057,0.646453
"""min""",1.0,1.0
"""25%""",1.0,1.0
"""50%""",2.0,2.0
"""75%""",2.0,2.0
"""max""",11.0,11.0


# Number of cells and tracks


In [64]:
# Fraction of split events from all cells
N_all = DATA.select(pl.struct("timestamp", "identifier", "method")).n_unique()
N_split = (
    SPLIT_MERGES_CELLS.filter(pl.col("type") == "split")
    .select(pl.struct("timestamp", "identifier", "method"))
    .n_unique()
)
N_merged = (
    SPLIT_MERGES_CELLS.filter(pl.col("type") == "merged")
    .select(pl.struct("timestamp", "identifier", "method"))
    .n_unique()
)
N_split_merge = (
    SPLIT_MERGES_CELLS.filter(pl.col("type") == "split-merge")
    .select(pl.struct("timestamp", "identifier", "method"))
    .n_unique()
)

pprint(f"Total number of unique cells: {N_all:,}")
pprint(f"Number of unique split cells: {N_split:,} ({N_split / N_all * 100:.1f} %)")
pprint(f"Number of unique merged cells: {N_merged:,} ({N_merged / N_all * 100:.1f} %)")
pprint(f"Number of unique split-merge cells: {N_split_merge:,} ({N_split_merge / N_all * 100:.1f} %)")
pprint(
    f"Total number of unique split/merged cells: {N_split + N_merged + N_split_merge:,} ({(N_split + N_merged + N_split_merge) / N_all * 100:.1f} %)"
)

'Total number of unique cells: 735,163'
'Number of unique split cells: 21,951 (3.0 %)'
'Number of unique merged cells: 23,391 (3.2 %)'
'Number of unique split-merge cells: 7,717 (1.0 %)'
'Total number of unique split/merged cells: 53,059 (7.2 %)'


In [65]:
cells_with_zdrcol = DATA.filter(pl.col("zdrcol_custom_filt_unique_d1.median") > 0)
cells_with_zdrcol_and_split_merge = cells_with_zdrcol.join(
    SPLIT_MERGES_CELLS, on=["timestamp", "identifier", "method"], how="inner", suffix="_in_split_merge"
)

N_cells_with_zdrcol = cells_with_zdrcol.select(pl.struct("timestamp", "identifier", "method")).n_unique()
N_cells_with_zdrcol_and_split_merge = cells_with_zdrcol_and_split_merge.select(
    pl.struct("timestamp", "identifier", "method")
).n_unique()

pprint(f"Total number of unique cells with zdrcol > 0: {N_cells_with_zdrcol:,}")
pprint(
    f"Number of unique split cells with zdrcol > 0: {N_cells_with_zdrcol_and_split_merge:,} ({N_cells_with_zdrcol_and_split_merge / N_cells_with_zdrcol * 100:.1f} %)"
)

'Total number of unique cells with zdrcol > 0: 181,835'
'Number of unique split cells with zdrcol > 0: 21,364 (11.7 %)'


In [66]:
vil_thr = 20
cells_with_high_vil = DATA.filter(pl.col("vil.max") >= vil_thr)
cells_with_high_vil_and_split_merge = cells_with_high_vil.join(
    SPLIT_MERGES_CELLS, on=["timestamp", "identifier", "method"], how="inner", suffix="_in_split_merge"
)

N_cells_with_high_vil = cells_with_high_vil.select(pl.struct("timestamp", "identifier", "method")).n_unique()
N_cells_with_high_vil_and_split_merge = cells_with_high_vil_and_split_merge.select(
    pl.struct("timestamp", "identifier", "method")
).n_unique()

pprint(f"Total number of unique cells with VIL > {vil_thr}: {N_cells_with_high_vil:,}")
pprint(
    f"Number of unique split cells with VIL > {vil_thr}: {N_cells_with_high_vil_and_split_merge:,} ({N_cells_with_high_vil_and_split_merge / N_cells_with_high_vil * 100:.1f} %)"
)

'Total number of unique cells with VIL > 20: 37,786'
'Number of unique split cells with VIL > 20: 6,770 (17.9 %)'


# Figures


## Number of participating cells in split and merge events


In [67]:
# Plot a histogram of the number of predecessor and successor cells

nrows = 2
ncols = 1

fig = plt.figure(constrained_layout=True, figsize=(4.5 * ncols, 4.5 * nrows))
widths = [
    1,
]
heights = [1, 1.2]
axs = fig.subplot_mosaic(
    [
        ["split"],
        ["split_merge"],
    ],
    gridspec_kw={"width_ratios": widths, "height_ratios": heights, "wspace": 0.1, "hspace": 0.05},
)

hist_lw = 0.8
hist_alpha = 0.6

# In 1st row, plot the number of successor cells for splits
df = pd.concat(
    [
        SPLIT_CELLS.with_columns(num_cells=pl.col("num_successors")).to_pandas(),
        MERGE_CELLS.with_columns(num_cells=pl.col("num_predecessors")).to_pandas(),
    ]
)
labels = [f"Split events (N = {SPLIT_CELLS.height:,d})", f"Merge events (N = {MERGE_CELLS.height:,d})"]
g1 = sns.histplot(
    data=df,
    x="num_cells",
    hue="type",
    hue_order=["split", "merged"],
    palette={lab: c for lab, c in zip(["split", "merged"], qualitative_colors[:2])},
    ax=axs["split"],
    stat="percent",
    edgecolor="k",
    element="bars",
    linewidth=hist_lw,
    alpha=hist_alpha,
    bins=6,
    binrange=(2, 8),
    discrete=True,
    legend=False,
    rasterized=True,
    multiple="dodge",
    shrink=0.9,
    common_norm=False,
)
# Set bar edgecolor to the same as facecolor
for patch in g1.patches:
    facecolor = list(patch.get_facecolor())
    facecolor[-1] = 1
    patch.set_edgecolor(facecolor)
    patch.set_linewidth(hist_lw)

handles = []
for label, color in zip(labels, qualitative_colors[:2]):
    rect = patches.Rectangle((0, 0), 1, 1, color=color, linewidth=hist_lw)
    rect.set_alpha(hist_alpha)
    facecolor = list(rect.get_facecolor())
    facecolor[-1] = 1
    rect.set_edgecolor(facecolor)
    rect.set_linewidth(hist_lw)
    handles.append(rect)

axs["split"].legend(
    handles=handles,
    labels=labels,
    loc="upper right",
    fontsize="small",
    framealpha=0.8,
    edgecolor="white",
    facecolor="white",
    frameon=True,
    fancybox=True,
)

axs["split"].set_title(f"(a) Split and merge events")
axs["split"].set_xlabel("Number of split/merging cells")
axs["split"].set_ylabel("Fraction of events [%]")


# Num of preds and successors for split-merge cells
max_num_cells = SPLIT_MERGE_CELLS.select(pl.max_horizontal("num_successors", "num_predecessors")).to_numpy().max()
print(f"Max number of cells in merge-split events: {max_num_cells}")

hist, xedges, yedges = np.histogram2d(
    SPLIT_MERGE_CELLS.to_pandas()["num_successors"],
    SPLIT_MERGE_CELLS.to_pandas()["num_predecessors"],
    bins=np.arange(1.5, max_num_cells + 0.6, 1),
    density=False,
)

hist_rel = hist / hist.sum() * 100  # relative to t0 count

# cmap = "cmc.davos_r"
cmap = "cmc.hawaii_r"
norm = mcolors.LogNorm(vmin=0.1, vmax=100)

row = 1
col = 1
sns.heatmap(
    data=np.flipud(hist_rel),
    ax=axs["split_merge"],
    cmap=cmap,
    norm=norm,
    vmin=None,
    vmax=None,
    linecolor="black",
    linewidths=0.5,
    linestyle="--",
    annot=True,
    annot_kws=dict(fontsize="x-small"),
    fmt=".3f",
    cbar=False,
    square=True,
    # xticklabels=xedges,
    # yticklabels=yedges[::-1],  # reverse y-axis labels
)
axs["split_merge"].set_title(f"(b) Merge-split events (N = {SPLIT_MERGE_CELLS.height:,d})")
axs["split_merge"].set_ylabel("Number of split cells")
axs["split_merge"].set_xlabel("Number of merging cells")


for ax in [
    axs["split"],
]:
    ax.axes.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.axes.xaxis.set_minor_locator(ticker.MultipleLocator(1))
    ax.axes.yaxis.set_major_locator(ticker.MultipleLocator(20))
    ax.axes.yaxis.set_minor_locator(ticker.MultipleLocator(5))

    ax.set_ylim(0, 100)

    for cont in ax.containers:
        # cont.set_edgecolor("black")
        # cont.set_linewidth(hist_lw)
        ax.bar_label(
            cont,
            fmt="%.2f",
            label_type="edge",
            rotation=35,
            fontsize="x-small",
            rotation_mode="default",
            # xytext=(0, 3),  # offset text slightly above the bar
            # textcoords="offset points",
        )
    ax.grid(axis="y", which="major")
    ax.grid(axis="y", which="minor", alpha=0.5)

axs["split_merge"].set_xticklabels(
    [f"{int(x)}" for x in np.arange(2, max_num_cells + 0.1, 1)],
    rotation=0,
    fontsize="small",
)
axs["split_merge"].set_yticklabels(
    [f"{int(x)}" for x in np.arange(2, max_num_cells + 0.1, 1)][::-1],  # reverse y-axis labels
    rotation=0,
    fontsize="small",
)
cax = inset_axes(
    axs["split_merge"],
    width="3%",
    height="100%",
    loc="center left",
    borderpad=0.0,
    bbox_to_anchor=(1.01, 0.0, 1, 1),
    bbox_transform=axs["split_merge"].transAxes,
)
cbar = fig.colorbar(
    axs["split_merge"].collections[0],
    cax=cax,
    orientation="vertical",
    format=plt.FuncFormatter(lambda x, _: f"{x:,.0f}"),
    label="Fraction of events [%]",
)

outname = "split_merge_num_cells"
plot_utils.save_figs(
    fig=fig,
    outpath=OUTPATH,
    name=outname,
    extensions=["svg", "png", "pdf"],
    delete_fig=True,
    convert_colors=True,
    savefig_kwargs=dict(dpi=300, transparent=False),
)

Max number of cells in merge-split events: 11


In [68]:
# Save to file
# 1d histogram of num_successors vs num_predecessors
df_ = (df.groupby(["type", "num_cells"]).count() / df.groupby(["type"]).count())["identifier"]
df_.name = "fraction_of_events"
df_.to_csv(OUTPATH / "fig7a_split_merge_num_cells_histogram.csv")

# 2d histogram of num_successors vs num_predecessors
num_merging_cells = [x for x in np.arange(2, max_num_cells + 0.1, 1)]
num_split_cells = [x for x in np.arange(2, max_num_cells + 0.1, 1)]

hist_df = pd.DataFrame(
    hist_rel,
    index=num_split_cells,
    columns=num_merging_cells,
)
hist_df.index.name = "num_split_cells"
hist_df.columns.name = "num_merging_cells"

hist_df.to_csv(OUTPATH / "fig7b_split_merge_num_cells_histogram.csv")

## Ratio of smallest to largest cell in split and merge events


In [69]:
# Distribution of largest vs smallest cell area for split and merge cells

nrows = 2
ncols = 1

fig = plt.figure(constrained_layout=True, figsize=(5.5 * ncols, 3.5 * nrows))
widths = [
    1,
]
heights = [1, 1]
axs = fig.subplot_mosaic(
    [
        ["split"],
        ["merged"],
    ],
    gridspec_kw={"width_ratios": widths, "height_ratios": heights, "wspace": 0.1, "hspace": 0.1},
)

df = pd.concat(
    [
        SPLIT_CELLS.with_columns(
            num_cells=pl.col("num_successors"),
            largest_area=pl.col("next_areas").list.max(),
            smallest_area=pl.col("next_areas").list.min(),
        ).to_pandas(),
        MERGE_CELLS.with_columns(
            num_cells=pl.col("num_predecessors"),
            largest_area=pl.col("prev_areas").list.max(),
            smallest_area=pl.col("prev_areas").list.min(),
        ).to_pandas(),
    ]
)
df["area_ratio"] = df["smallest_area"] / df["largest_area"]

area_intervals = [
    10,
    100,
    250,
    500,
    1000,
    2000,
    5000,
]

area_norm = mcolors.BoundaryNorm(boundaries=area_intervals, ncolors=len(area_intervals), extend="max")
area_cmap = plt.get_cmap("cmc.hawaii_r", len(area_intervals))

ratio_bins = np.arange(0, 1, 0.01)
alpha = 0.4

points = {}

for ttype in ["split", "merged"]:
    points[ttype] = {}

    groups = df.groupby(pd.cut(df["largest_area"], area_intervals, right=True))
    legend_handles = []
    legend_labels = []
    N_all_ttype = df[df["type"] == ttype].shape[0]
    print(f"\nTotal number of {ttype} events: N = {N_all_ttype:,} \n")

    for name, group in groups:
        color = area_cmap(area_intervals.index(name.right))
        histg = sns.histplot(
            data=group[group["type"] == ttype],
            x="area_ratio",
            ax=axs[ttype],
            stat="percent",
            hue_norm=area_norm,
            color=color,
            bins=ratio_bins,
            edgecolor=color,
            element="step",
            linewidth=hist_lw,
            alpha=alpha,
            zorder=len(area_intervals) - area_intervals.index(name.right),
        )
        rect = patches.Rectangle((0, 0), 1, 1, color=color, linewidth=hist_lw)
        rect.set_alpha(alpha)
        facecolor = list(rect.get_facecolor())
        facecolor[-1] = 1
        rect.set_edgecolor(facecolor)
        rect.set_linewidth(hist_lw)
        legend_handles.append(rect)

        N_cat = group[group["type"] == ttype].shape[0]
        print(f"Category {name}: N = {N_cat:,} ({N_cat / N_all_ttype * 100:.1f} %)")

        legend_labels.append(f"{name.left:,.0f} - {name.right:,.0f} km$^2$ (N = {N_cat:,d})")

        # save values
        points_ = histg.get_children()[0].get_paths()[0].vertices
        num_points = points_.shape[0]
        xs = points_[num_points // 2 :, 0][::-1]
        ys = points_[num_points // 2 :, 1][::-1]
        points[ttype][str(name)] = {"x": xs.tolist(), "y": ys.tolist()}

    axs[ttype].legend(
        handles=legend_handles,
        labels=legend_labels,
        loc="upper right",
        fontsize="small",
        framealpha=1.0,
        edgecolor="white",
        facecolor="white",
        frameon=True,
        fancybox=True,
        title="Largest cell area",
        title_fontsize="small",
        labelspacing=0.2,
        handlelength=2.0,
        handletextpad=0.5,
        borderpad=0.2,
    )

    axs[ttype].set_xlim(0, 1)
    axs[ttype].set_xlabel("Ratio of smallest to largest cell area")
    axs[ttype].set_ylabel("Fraction of events [%]")
    axs[ttype].xaxis.set_major_locator(ticker.MultipleLocator(0.1))
    axs[ttype].xaxis.set_minor_locator(ticker.MultipleLocator(0.05))
    axs[ttype].yaxis.set_major_locator(ticker.MultipleLocator(5))
    axs[ttype].yaxis.set_minor_locator(ticker.MultipleLocator(2.5))
    axs[ttype].grid(axis="both", which="both", linestyle="--", linewidth=0.5, alpha=0.5, color="k")

axs["split"].set_title(f"(a) Split events (N = {SPLIT_CELLS.height:,d})")
axs["merged"].set_title(f"(b) Merge events (N = {MERGE_CELLS.height:,d})")

outname = "split_merge_areas_min_max_ratio"
plot_utils.save_figs(
    fig=fig,
    outpath=OUTPATH,
    name=outname,
    extensions=["svg", "png", "pdf"],
    convert_colors=True,
    delete_fig=True,
    savefig_kwargs=dict(dpi=300, transparent=False),
)

  groups = df.groupby(pd.cut(df["largest_area"], area_intervals, right=True))
  groups = df.groupby(pd.cut(df["largest_area"], area_intervals, right=True))



Total number of split events: N = 21,951 

Category (10, 100]: N = 9,210 (42.0 %)
Category (100, 250]: N = 5,903 (26.9 %)
Category (250, 500]: N = 3,208 (14.6 %)
Category (500, 1000]: N = 2,152 (9.8 %)
Category (1000, 2000]: N = 1,032 (4.7 %)
Category (2000, 5000]: N = 398 (1.8 %)

Total number of merged events: N = 23,391 

Category (10, 100]: N = 9,432 (40.3 %)
Category (100, 250]: N = 6,199 (26.5 %)
Category (250, 500]: N = 3,528 (15.1 %)
Category (500, 1000]: N = 2,454 (10.5 %)
Category (1000, 2000]: N = 1,189 (5.1 %)
Category (2000, 5000]: N = 543 (2.3 %)


In [70]:
import json

# Save points to json
with open(OUTPATH / "fig9_split_merge_area_ratio_histograms.json", "w") as f:
    json.dump(points, f)

## Distributions of cell areas


In [71]:
quantities = [
    "area",
    # "rate.mean",
]

t0_types = ["split", "merged", "split-merge", "splitted", "merging"]

t0_titles = ["Split events", "Merge events", "Merge-split events", "Split cells", "Merging cells"]

ncols = 2  # len(quantities)
nrows = 3  # len(t0_types)

fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(4.0 * ncols, 2.5 * nrows),
    sharex=True,
    sharey=True,
    squeeze=False,
)
points = {}
for i, t0_type in enumerate(t0_types):
    ax = np.array([axs.flatten(order="F")[i]])

    if t0_type == "splitted":
        cond = ((pl.col("type") == "split") | (pl.col("type") == "split-merge")) & (pl.col("level") == 1)
    elif t0_type == "merging":
        cond = ((pl.col("type") == "merged") | (pl.col("type") == "split-merge")) & (pl.col("level") == -1)
    else:
        cond = (pl.col("type") == t0_type) & (pl.col("level") == 0)

    t0_nodes = (
        TRAJECTORIES.filter(cond)
        .unique()
        .select(
            [
                "identifier",
                "timestamp",
                "method",
                "t0_node",
                "type",
                "level",
                "event",
            ]
        )
    )
    # t0_nodes
    t0_cells = t0_nodes.join(DATA, on=["identifier", "timestamp", "method"], how="inner")

    dataframes = [
        DATA,
        t0_cells,
    ]
    labels = [
        "All cells",
        t0_titles[i],
    ]

    ax, points_ = plots.plot_1d_histograms(
        dataframes=dataframes,
        quantities=quantities,
        labels=labels,
        colors=qualitative_colors,
        axs=ax,
        histogram_limits=QD.HISTOGRAM_LIMITS,
        histogram_nbins=QD.HISTOGRAM_NBINS,
        histogram_discrete=QD.HISTOGRAM_DISCRETE,
        qty_titles=QD.TITLES,
        histogram_ax_limits=QD.HISTOGRAM_AX_LIMITS,
        qty_formats=QD.QTY_FORMATS,
        hist_alpha=0.5,
        rasterized=False,
    )
    points[t0_type] = points_
for j, ax_ in enumerate(axs.flatten(order="F")[:-1]):
    ax_.set_title(rf"({string.ascii_lowercase[j]}) {t0_titles[j]}")
    ax_.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    ax_.yaxis.set_major_locator(ticker.MultipleLocator(2))
    # ax_.grid(axis="y", which="major", linestyle="--", linewidth=0.5, alpha=0.7, zorder=0)
    # ax_.grid(axis="y", which="minor", linestyle="--", linewidth=0.3, alpha=0.5, zorder=0)
    ax_.grid(axis="both", which="both", linestyle="--", linewidth=0.5, alpha=0.5, color="k", zorder=-1)
    ax_.set_ylabel("Fraction of events [%]", fontsize="small")
    ax_.xaxis.set_tick_params(labelbottom=True)
    ax_.xaxis.get_label().set_visible(True)
    ax_.set_xlabel("Cell area [km$^2$]", fontsize="small")

# axs.flatten(order="F")[-2].xaxis.set_tick_params(labelbottom=True)
axs.flatten(order="F")[-1].axis("off")

outname = f"cell_area_histograms_split_merge_cells"
plot_utils.save_figs(
    fig=fig,
    outpath=OUTPATH,
    name=outname,
    extensions=["svg", "png", "pdf"],
    delete_fig=True,
    convert_colors=True,
    savefig_kwargs=dict(
        dpi=300,
        transparent=False,
    ),
)

In [72]:
import json

# Save points to json
with open(OUTPATH / "fig8_split_merge_cell_area_histograms.json", "w") as f:
    json.dump(points, f)

In [73]:
quantities = [
    # "area",
    "vil.max"
    # "rate.mean",
]

t0_types = ["split", "merged", "split-merge", "splitted", "merging"]

t0_titles = ["Split events", "Merge events", "Merge-split events", "Split cells", "Merging cells"]

ncols = 2  # len(quantities)
nrows = 3  # len(t0_types)

fig, axs = plt.subplots(
    nrows=nrows,
    ncols=ncols,
    figsize=(4.0 * ncols, 2.5 * nrows),
    sharex=True,
    sharey=True,
    squeeze=False,
)
for i, t0_type in enumerate(t0_types):
    ax = np.array([axs.flatten(order="F")[i]])

    if t0_type == "splitted":
        cond = ((pl.col("type") == "split") | (pl.col("type") == "split-merge")) & (pl.col("level") == 1)
    elif t0_type == "merging":
        cond = ((pl.col("type") == "merged") | (pl.col("type") == "split-merge")) & (pl.col("level") == -1)
    else:
        cond = (pl.col("type") == t0_type) & (pl.col("level") == 0)

    t0_nodes = (
        TRAJECTORIES.filter(cond)
        .unique()
        .select(
            [
                "identifier",
                "timestamp",
                "method",
                "t0_node",
                "type",
                "level",
                "event",
            ]
        )
    )
    # t0_nodes
    t0_cells = t0_nodes.join(DATA, on=["identifier", "timestamp", "method"], how="inner")

    dataframes = [
        DATA,
        t0_cells,
    ]
    labels = [
        "All cells",
        t0_titles[i],
    ]

    plots.plot_1d_histograms(
        dataframes=dataframes,
        quantities=quantities,
        labels=labels,
        colors=qualitative_colors,
        axs=ax,
        histogram_limits=QD.HISTOGRAM_LIMITS,
        histogram_nbins=QD.HISTOGRAM_NBINS,
        histogram_discrete=QD.HISTOGRAM_DISCRETE,
        qty_titles=QD.TITLES,
        histogram_ax_limits=QD.HISTOGRAM_AX_LIMITS,
        qty_formats=QD.QTY_FORMATS,
        hist_alpha=0.5,
        rasterized=False,
    )
for j, ax_ in enumerate(axs.flatten(order="F")[:-1]):
    ax_.set_title(rf"({string.ascii_lowercase[j]}) {t0_titles[j]}")
    ax_.yaxis.set_minor_locator(ticker.MultipleLocator(0.5))
    ax_.yaxis.set_major_locator(ticker.MultipleLocator(2))
    # ax_.grid(axis="y", which="major", linestyle="--", linewidth=0.5, alpha=0.7, zorder=0)
    # ax_.grid(axis="y", which="minor", linestyle="--", linewidth=0.3, alpha=0.5, zorder=0)
    ax_.grid(axis="both", which="both", linestyle="--", linewidth=0.5, alpha=0.5, color="k", zorder=-1)
    ax_.set_ylabel("Fraction of events [%]", fontsize="small")
    ax_.xaxis.set_tick_params(labelbottom=True)
    ax_.xaxis.get_label().set_visible(True)
    # ax_.set_xlabel("Cell area [km$^2$]", fontsize="small")

# axs.flatten(order="F")[-2].xaxis.set_tick_params(labelbottom=True)
axs.flatten(order="F")[-1].axis("off")

outname = f"cell_vil.max_histograms_split_merge_cells"
plot_utils.save_figs(
    fig=fig,
    outpath=OUTPATH,
    name=outname,
    extensions=["svg", "png", "pdf"],
    delete_fig=True,
    convert_colors=True,
    savefig_kwargs=dict(
        dpi=300,
        transparent=False,
    ),
)

## Relative change in volume rain rate & cell area during split and merge events


In [74]:
t0_types = ["split", "merged", "split-merge"]

t0_titles = [
    lambda x: "(a) Split events",
    lambda x: "(b) Merge events",
    lambda x: "(c) Merge-split events",
]
variables = [
    "t0_reldiff:rate.sum:pct",
    "t0_reldiff:area:pct",
    # "t0_reldiff_mean:rate.sum:pct",
]
variable_labels = [
    "Total volume rain rate",
    "Total cell area",
    # "Mean rain rate",
]
linestyles = [
    "solid",
    "dashdot",
]

fig, axs, twin_axs = plots.plot_trajectory_development_figure(
    trajectories=TRAJECTORY_DEVELOPMENT,
    t0_types=t0_types,
    variables=variables,
    variable_labels=variable_labels,
    row_title_funcs=t0_titles,
    colors=qualitative_colors,
    outfilename="trajectory_development_split_merge",
    lineplot_kwargs=LINEPLOT_KWS_CI_95,
    count_linekwargs=count_linekwargs,
    include_control=False,
    ylim=(-40, 40),
    count_ylim=(0, 24),
    count_ytick_multiples=(3, 1.5),
    xtick_multiples=(2, 1),
    ytick_multiples=(10, 5),
    legend_ncols=1,
    extensions=["png", "pdf", "svg"],
    outpath=OUTPATH,
    savefig_kwargs=dict(dpi=300, transparent=False),
    row_height=3,
    col_width=5,
    linestyles=linestyles,
    return_fig=True,
)

values = {}
for i, ax in enumerate(axs.flatten()):
    title = ax.get_title()
    values[title] = {}
    c = 0
    for j, line in enumerate(ax.lines):
        label = line.get_label()
        if "child" not in label:
            # Mean values
            xs = line.get_xdata()
            ys = line.get_ydata()
            values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

            # Get confidence interval shading
            points_ = ax.collections[c].get_paths()[0].vertices
            num_points = points_.shape[0]
            bottom_xs = points_[1 : num_points // 2, 0]
            bottom_ys = points_[1 : num_points // 2, 1]
            top_xs = points_[num_points // 2 + 1 : -1, 0][::-1]
            top_ys = points_[num_points // 2 + 1 : -1, 1][::-1]

            values[title][f"{label}_ci"] = {
                "bottom_x": bottom_xs.tolist(),
                "bottom_y": bottom_ys.tolist(),
                "top_x": top_xs.tolist(),
                "top_y": top_ys.tolist(),
            }

            c += 1
    twin_ax = twin_axs[i]
    label = twin_ax.get_ylabel()
    xs = twin_ax.lines[0].get_xdata()
    ys = twin_ax.lines[0].get_ydata()
    values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

with open(OUTPATH / "figC1_trajectory_development_split_merge.json", "w") as f:
    json.dump(values, f)

plt.close(fig)

In [75]:
# Filter with maximum vil during trajectory
t0_types = ["split", "merged", "split-merge"]


variables = [
    "t0_reldiff:rate.sum:pct",
    "t0_reldiff:area:pct",
    # "t0_reldiff_mean:rate.sum:pct",
]
variable_labels = [
    "Total volume rain rate",
    "Total cell area",
    # "Mean rain rate",
]
linestyles = [
    "solid",
    "dashdot",
]

vil_thrs = [5, 10, 20, 40]  # kg/m^2
max_vil_thr = 10  # kg/m^2

for max_vil_thr in vil_thrs:

    df = TRAJECTORY_DEVELOPMENT.filter((pl.col("vil.max_in_trajectory") >= max_vil_thr))  # at least 20 kg/m^2 max vil
    # print(f"Min level: {min_level}, max level: {max_level}, max number of trajectories: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}")

    t0_titles = [
        lambda x: f"\n(a) Split events, max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
        lambda x: f"\n(b) Merge events, max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
        lambda x: f"\n(c) Merge-split events, max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
    ]

    fig, axs, twin_axs = plots.plot_trajectory_development_figure(
        trajectories=df,
        t0_types=t0_types,
        variables=variables,
        variable_labels=variable_labels,
        row_title_funcs=t0_titles,
        colors=qualitative_colors,
        outfilename=f"trajectory_development_split_merge_vil_thr_{max_vil_thr}",
        lineplot_kwargs=LINEPLOT_KWS_CI_95,
        count_linekwargs=count_linekwargs,
        include_control=False,
        ylim=(-40, 40),
        count_ylim=(0, 24),
        count_ytick_multiples=(3, 1.5),
        xtick_multiples=(2, 1),
        ytick_multiples=(10, 5),
        legend_ncols=1,
        extensions=["png", "pdf", "svg"],
        outpath=OUTPATH,
        savefig_kwargs=dict(dpi=300, transparent=False),
        row_height=3,
        col_width=5,
        linestyles=linestyles,
        return_fig=True,
    )

    values = {}
    for i, ax in enumerate(axs.flatten()):
        title = ax.get_title()
        values[title] = {}
        c = 0
        for j, line in enumerate(ax.lines):
            label = line.get_label()
            if "child" not in label:
                # Mean values
                xs = line.get_xdata()
                ys = line.get_ydata()
                values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

                # Get confidence interval shading
                points_ = ax.collections[c].get_paths()[0].vertices
                num_points = points_.shape[0]
                bottom_xs = points_[1 : num_points // 2, 0]
                bottom_ys = points_[1 : num_points // 2, 1]
                top_xs = points_[num_points // 2 + 1 : -1, 0][::-1]
                top_ys = points_[num_points // 2 + 1 : -1, 1][::-1]

                values[title][f"{label}_ci"] = {
                    "bottom_x": bottom_xs.tolist(),
                    "bottom_y": bottom_ys.tolist(),
                    "top_x": top_xs.tolist(),
                    "top_y": top_ys.tolist(),
                }

                c += 1
        twin_ax = twin_axs[i]
        label = twin_ax.get_ylabel()
        xs = twin_ax.lines[0].get_xdata()
        ys = twin_ax.lines[0].get_ydata()
        values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

    # Save to json
    with open(OUTPATH / f"fig10_trajectory_development_split_merge_vil_thr_{max_vil_thr}.json", "w") as f:
        json.dump(values, f)

    plt.close(fig)

  lambda x: f"\n(a) Split events, max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
  lambda x: f"\n(b) Merge events, max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
  lambda x: f"\n(c) Merge-split events, max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",


In [76]:
values

{'\n(a) Split events, max VIL in trajectory $\\geq$ 40 kg m$^{-2}$': {'Total volume rain rate': {'x': [-6.0,
    -5.0,
    -4.0,
    -3.0,
    -2.0,
    -1.0,
    0.0,
    1.0,
    2.0,
    3.0,
    4.0,
    5.0,
    6.0],
   'y': [6.23763373319517,
    7.54412881581775,
    7.941486758036472,
    6.540817407130094,
    4.468339205496304,
    2.870584227781361,
    0.0,
    -3.2970942057947554,
    -4.912350199628906,
    -6.604394753207269,
    -8.461896303066277,
    -12.509850610244538,
    -15.61767098057984]},
  'Total volume rain rate_ci': {'bottom_x': [-6.0,
    -5.0,
    -4.0,
    -3.0,
    -2.0,
    -1.0,
    0.0,
    1.0,
    2.0,
    3.0,
    4.0,
    5.0,
    6.0],
   'bottom_y': [-2.660823610288944,
    0.7400111304279987,
    2.6678921158399462,
    2.9184469361379524,
    2.338640508248853,
    1.8936466710615982,
    0.0,
    -4.301525233421863,
    -6.763341473648324,
    -9.329453149861237,
    -11.886947537521207,
    -16.7907476141717,
    -20.690851347855418],
   '

In [77]:
t0_types = ["split", "merged", "split-merge"]

t0_titles = [
    lambda x: "(a) Split events",
    lambda x: "(b) Merge events",
    lambda x: "(c) Merge-split events",
]
variables = [
    "t0_reldiff:rate.sum:pct",
    "t0_reldiff:area:pct",
    # "t0_reldiff_mean:rate.sum:pct",
]
variable_labels = [
    "Total volume rain rate",
    "Total cell area",
    # "Mean rain rate",
]
linestyles = [
    "solid",
    "dashdot",
]

min_max_limit_combos = [
    (0, 2),
    (0, 3),
    (0, 4),
    (0, 6),
    (-2, 0),
    (-3, 0),
    (-4, 0),
    (-6, 0),
    (-2, 2),
    (-3, 3),
    (-4, 4),
    (-6, 6),
    (-2, 6),
    (-3, 6),
    (-4, 6),
    (-2, 4),
    (-3, 4),
]

for min_level, max_level in min_max_limit_combos:

    df = TRAJECTORY_DEVELOPMENT.filter(
        (pl.col("min_level_available") <= min_level) & (pl.col("max_level_available") >= max_level)
    )  # .filter(pl.col("level").is_between(min_level, max_level, closed="both"))
    print(
        f"Min level: {min_level}, max level: {max_level}, max number of subgraphs: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}"
    )

    fig, axs, twin_axs = plots.plot_trajectory_development_figure(
        trajectories=df,
        t0_types=t0_types,
        variables=variables,
        variable_labels=variable_labels,
        row_title_funcs=t0_titles,
        colors=qualitative_colors,
        outfilename=f"trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}",
        lineplot_kwargs=LINEPLOT_KWS_CI_95,
        count_linekwargs=count_linekwargs,
        include_control=False,
        ylim=(-40, 40),
        count_ylim=(0, 24),
        count_ytick_multiples=(3, 1.5),
        xtick_multiples=(2, 1),
        ytick_multiples=(10, 5),
        legend_ncols=1,
        extensions=["png", "pdf", "svg"],
        outpath=OUTPATH,
        savefig_kwargs=dict(dpi=300, transparent=False),
        col_width=6,
        linestyles=linestyles,
        return_fig=True,
    )

    values = {}
    for i, ax in enumerate(axs.flatten()):
        title = ax.get_title()
        values[title] = {}
        c = 0
        for j, line in enumerate(ax.lines):
            label = line.get_label()
            if "child" not in label:
                # Mean values
                xs = line.get_xdata()
                ys = line.get_ydata()
                values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

                # Get confidence interval shading
                points_ = ax.collections[c].get_paths()[0].vertices
                num_points = points_.shape[0]
                bottom_xs = points_[1 : num_points // 2, 0]
                bottom_ys = points_[1 : num_points // 2, 1]
                top_xs = points_[num_points // 2 + 1 : -1, 0][::-1]
                top_ys = points_[num_points // 2 + 1 : -1, 1][::-1]

                values[title][f"{label}_ci"] = {
                    "bottom_x": bottom_xs.tolist(),
                    "bottom_y": bottom_ys.tolist(),
                    "top_x": top_xs.tolist(),
                    "top_y": top_ys.tolist(),
                }

                c += 1
        twin_ax = twin_axs[i]
        label = twin_ax.get_ylabel()
        xs = twin_ax.lines[0].get_xdata()
        ys = twin_ax.lines[0].get_ydata()
        values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

    # Save to json
    with open(
        OUTPATH / f"trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}.json", "w"
    ) as f:
        json.dump(values, f)

    plt.close(fig)

# for min_level, max_level in min_max_limit_combos:

#     df = TRAJECTORY_DEVELOPMENT.filter(
#         (pl.col("min_level_available") == min_level) & (pl.col("max_level_available") == max_level)
#     ) #.filter(pl.col("level").is_between(min_level, max_level, closed="both"))
#     print(f"Min level: {min_level}, max level: {max_level}, max number of trajectories: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}")

#     plots.plot_trajectory_development_figure(
#         trajectories=df,
#         t0_types=t0_types,
#         variables=variables,
#         variable_labels=variable_labels,
#         row_title_funcs=t0_titles,
#         colors=qualitative_colors,
#         outfilename=f"trajectory_development_split_merge_all_available_stricly_between_min_{min_level}_max_{max_level}",
#         lineplot_kwargs=LINEPLOT_KWS_CI_95,
#         count_linekwargs=count_linekwargs,
#         include_control=False,
#         ylim=(-40, 40),
#         count_ylim=(0, 24),
#         count_ytick_multiples=(3, 1.5),
#         xtick_multiples=(2, 1),
#         ytick_multiples=(10, 5),
#         legend_ncols=1,
#         extensions=["png", "pdf", "svg"],
#         outpath=OUTPATH,
#         savefig_kwargs=dict(dpi=300, transparent=False),
#         col_width=6,
#         linestyles=linestyles,
#     )

Min level: 0, max level: 2, max number of subgraphs: 38123
Min level: 0, max level: 3, max number of subgraphs: 31783
Min level: 0, max level: 4, max number of subgraphs: 26655
Min level: 0, max level: 6, max number of subgraphs: 18900
Min level: -2, max level: 0, max number of subgraphs: 30959
Min level: -3, max level: 0, max number of subgraphs: 22710
Min level: -4, max level: 0, max number of subgraphs: 17184
Min level: -6, max level: 0, max number of subgraphs: 10112
Min level: -2, max level: 2, max number of subgraphs: 23020
Min level: -3, max level: 3, max number of subgraphs: 14837
Min level: -4, max level: 4, max number of subgraphs: 9824
Min level: -6, max level: 6, max number of subgraphs: 4484
Min level: -2, max level: 6, max number of subgraphs: 12077
Min level: -3, max level: 6, max number of subgraphs: 9328
Min level: -4, max level: 6, max number of subgraphs: 7243
Min level: -2, max level: 4, max number of subgraphs: 16664
Min level: -3, max level: 4, max number of subgr

In [78]:
t0_types = ["split", "merged", "split-merge"]

variables = [
    "t0_reldiff:rate.sum:pct",
    "t0_reldiff:area:pct",
    # "t0_reldiff_mean:rate.sum:pct",
]
variable_labels = [
    "Total volume rain rate",
    "Total cell area",
    # "Mean rain rate",
]
linestyles = [
    "solid",
    "dashdot",
]

min_max_limit_combos = [
    # (0, 2),
    # (0, 3),
    # (0, 4),
    # (0, 6),
    # (-2, 0),
    # (-4, 0),
    # (-6, 0),
    # (-2, 2),
    # (-3, 3),
    # (-4, 4),
    # (-6, 6),
    # (-2, 6),
    (-3, 6),
    (-3, 0),
    # (-4, 6),
    # (-2, 4),
    # (-3, 4),
]

min_alphabet_index = 2

for min_level, max_level in min_max_limit_combos:

    df = TRAJECTORY_DEVELOPMENT.filter(
        (pl.col("min_level_available") <= min_level) & (pl.col("max_level_available") >= max_level)
    )  # .filter(pl.col("level").is_between(min_level, max_level, closed="both"))
    print(
        f"Min level: {min_level}, max level: {max_level}, max number of subgraphs: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}"
    )

    time_str = rf", subgraph exists at ${min_level*5:d}\ldots{max_level*5:d}$ mins"
    t0_titles = [
        lambda x: f"({alphabet[min_alphabet_index + 1]}) Split events{time_str}",
        lambda x: f"({alphabet[min_alphabet_index + 2]}) Merge events{time_str}",
        lambda x: f"({alphabet[min_alphabet_index + 3]}) Merge-split events{time_str}",
    ]

    fig, axs, twin_axs = plots.plot_trajectory_development_figure(
        trajectories=df,
        t0_types=t0_types,
        variables=variables,
        variable_labels=variable_labels,
        row_title_funcs=t0_titles,
        colors=qualitative_colors,
        outfilename=f"trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}",
        lineplot_kwargs=LINEPLOT_KWS_CI_95,
        count_linekwargs=count_linekwargs,
        include_control=False,
        ylim=(-40, 40),
        count_ylim=(0, 24),
        count_ytick_multiples=(3, 1.5),
        xtick_multiples=(2, 1),
        ytick_multiples=(10, 5),
        legend_ncols=1,
        extensions=["png", "pdf", "svg"],
        outpath=OUTPATH,
        savefig_kwargs=dict(dpi=300, transparent=False),
        linestyles=linestyles,
        row_height=3,
        col_width=5,
        return_fig=True,
    )
    min_alphabet_index += 3

    values = {}
    for i, ax in enumerate(axs.flatten()):
        title = ax.get_title()
        values[title] = {}
        c = 0
        for j, line in enumerate(ax.lines):
            label = line.get_label()
            if "child" not in label:
                # Mean values
                xs = line.get_xdata()
                ys = line.get_ydata()
                values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

                # Get confidence interval shading
                points_ = ax.collections[c].get_paths()[0].vertices
                num_points = points_.shape[0]
                bottom_xs = points_[1 : num_points // 2, 0]
                bottom_ys = points_[1 : num_points // 2, 1]
                top_xs = points_[num_points // 2 + 1 : -1, 0][::-1]
                top_ys = points_[num_points // 2 + 1 : -1, 1][::-1]

                values[title][f"{label}_ci"] = {
                    "bottom_x": bottom_xs.tolist(),
                    "bottom_y": bottom_ys.tolist(),
                    "top_x": top_xs.tolist(),
                    "top_y": top_ys.tolist(),
                }

                c += 1
        twin_ax = twin_axs[i]
        label = twin_ax.get_ylabel()
        xs = twin_ax.lines[0].get_xdata()
        ys = twin_ax.lines[0].get_ydata()
        values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

    # Save to json
    with open(
        OUTPATH
        / f"figC1_trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}.json",
        "w",
    ) as f:
        json.dump(values, f)

    plt.close(fig)

Min level: -3, max level: 6, max number of subgraphs: 9328
Min level: -3, max level: 0, max number of subgraphs: 22710


In [79]:
t0_types = ["split", "merged", "split-merge"]

variables = [
    "t0_reldiff:rate.sum:pct",
    "t0_reldiff:area:pct",
    # "t0_reldiff_mean:rate.sum:pct",
]
variable_labels = [
    "Total volume rain rate",
    "Total cell area",
    # "Mean rain rate",
]
linestyles = [
    "solid",
    "dashdot",
]

min_max_limit_combos = [
    # (0, 2),
    # (0, 3),
    # (0, 4),
    # (0, 6),
    # (-2, 0),
    # (-4, 0),
    # (-6, 0),
    # (-2, 2),
    # (-3, 3),
    # (-4, 4),
    # (-6, 6),
    # (-2, 6),
    (-3, 6),
    (-3, 0),
    # (-4, 6),
    # (-2, 4),
    # (-3, 4),
]

vil_thrs = [
    20,
]  # kg/m^2
# vil_thrs = [5, 10, 20, 40]  # kg/m^2

for max_vil_thr in vil_thrs:
    min_alphabet_index = -1
    # min_alphabet_index = 2
    for min_level, max_level in min_max_limit_combos:

        # df = TRAJECTORY_DEVELOPMENT.filter((pl.col("vil.max_in_trajectory") >= max_vil_thr))  # at least 20 kg/m^2 max vil
        # print(f"Min level: {min_level}, max level: {max_level}, max number of trajectories: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}")

        t0_titles = [
            lambda x: rf"(a) Split events, max VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: rf"(b) Merge events, max VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: rf"(c) Merge-split events, max VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
        ]

        df = TRAJECTORY_DEVELOPMENT.filter(
            (pl.col("vil.max_in_trajectory") >= max_vil_thr)
            & (pl.col("min_level_available") <= min_level)
            & (pl.col("max_level_available") >= max_level)
        )  # .filter(pl.col("level").is_between(min_level, max_level, closed="both"))
        print(
            f"VIL thr {max_vil_thr}, Min level: {min_level}, max level: {max_level}, max number of subgraphs: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}"
        )

        time_str = rf", subgraph exists at {min_level*5:d} $\ldots ${max_level*5:d} mins"
        t0_titles = [
            lambda x: f"({alphabet[min_alphabet_index + 1]}) Split events{time_str}\nmax VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: f"({alphabet[min_alphabet_index + 2]}) Merge events{time_str}\nmax VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: f"({alphabet[min_alphabet_index + 3]}) Merge-split events{time_str}\nmax VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
        ]

        plots.plot_trajectory_development_figure(
            trajectories=df,
            t0_types=t0_types,
            variables=variables,
            variable_labels=variable_labels,
            row_title_funcs=t0_titles,
            colors=qualitative_colors,
            outfilename=f"trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}_max_vil_thr_{max_vil_thr}",
            lineplot_kwargs=LINEPLOT_KWS_CI_95,
            count_linekwargs=count_linekwargs,
            include_control=False,
            ylim=(-40, 40),
            count_ylim=(0, 24),
            count_ytick_multiples=(3, 1.5),
            xtick_multiples=(2, 1),
            ytick_multiples=(10, 5),
            legend_ncols=1,
            extensions=["png", "pdf", "svg"],
            outpath=OUTPATH,
            savefig_kwargs=dict(dpi=300, transparent=False),
            linestyles=linestyles,
            # row_height=3,
            # col_width=5,
            # figure_direction="vertical",
            row_height=4,
            col_width=5,
            figure_direction="horizontal",
        )
        min_alphabet_index += 3

        values = {}
        for i, ax in enumerate(axs.flatten()):
            title = ax.get_title()
            values[title] = {}
            c = 0
            for j, line in enumerate(ax.lines):
                label = line.get_label()
                if "child" not in label:
                    # Mean values
                    xs = line.get_xdata()
                    ys = line.get_ydata()
                    values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

                    # Get confidence interval shading
                    points_ = ax.collections[c].get_paths()[0].vertices
                    num_points = points_.shape[0]
                    bottom_xs = points_[1 : num_points // 2, 0]
                    bottom_ys = points_[1 : num_points // 2, 1]
                    top_xs = points_[num_points // 2 + 1 : -1, 0][::-1]
                    top_ys = points_[num_points // 2 + 1 : -1, 1][::-1]

                    values[title][f"{label}_ci"] = {
                        "bottom_x": bottom_xs.tolist(),
                        "bottom_y": bottom_ys.tolist(),
                        "top_x": top_xs.tolist(),
                        "top_y": top_ys.tolist(),
                    }

                    c += 1
            twin_ax = twin_axs[i]
            label = twin_ax.get_ylabel()
            xs = twin_ax.lines[0].get_xdata()
            ys = twin_ax.lines[0].get_ydata()
            values[title][label] = {"x": xs.tolist(), "y": ys.tolist()}

        # Save to json
        with open(
            OUTPATH
            / f"fig10_trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}_max_vil_thr_{max_vil_thr}.json",
            "w",
        ) as f:
            json.dump(values, f)

        plt.close(fig)

  lambda x: f"({alphabet[min_alphabet_index + 1]}) Split events{time_str}\nmax VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
  lambda x: f"({alphabet[min_alphabet_index + 2]}) Merge events{time_str}\nmax VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
  lambda x: f"({alphabet[min_alphabet_index + 3]}) Merge-split events{time_str}\nmax VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",


VIL thr 20, Min level: -3, max level: 6, max number of subgraphs: 2967
VIL thr 20, Min level: -3, max level: 0, max number of subgraphs: 6210


In [80]:
plt.style.use("../stylefiles/thesis.mplstyle")
plt.rcParams["mathtext.fontset"] = "dejavuserif"

In [81]:
# Figure to thesis

from matplotlib.lines import Line2D
import matplotlib.patches as mpatches
from matplotlib.legend_handler import HandlerTuple

t0_types = ["split", "merged", "split-merge"]

variables = [
    "t0_reldiff:rate.sum:pct",
    "t0_reldiff:area:pct",
    # "t0_reldiff_mean:rate.sum:pct",
]
variable_labels = [
    "Total volume rain rate",
    "Total cell area",
    # "Mean rain rate",
]
linestyles = [
    "solid",
    "dashdot",
]

min_max_limit_combos = [
    # (0, 2),
    # (0, 3),
    # (0, 4),
    # (0, 6),
    # (-2, 0),
    # (-4, 0),
    # (-6, 0),
    # (-2, 2),
    # (-3, 3),
    # (-4, 4),
    # (-6, 6),
    # (-2, 6),
    (-3, 6),
    (-3, 0),
    # (-4, 6),
    # (-2, 4),
    # (-3, 4),
]

vil_thrs = [
    20,
]  # kg/m^2
# vil_thrs = [5, 10, 20, 40]  # kg/m^2

for max_vil_thr in vil_thrs:
    min_alphabet_index = -1
    # min_alphabet_index = 2
    for figc, (min_level, max_level) in enumerate(min_max_limit_combos):

        # df = TRAJECTORY_DEVELOPMENT.filter((pl.col("vil.max_in_trajectory") >= max_vil_thr))  # at least 20 kg/m^2 max vil
        # print(f"Min level: {min_level}, max level: {max_level}, max number of trajectories: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}")

        t0_titles = [
            lambda x: rf"(a) Split events, max VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: rf"(b) Merge events, max VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: rf"(c) Merge-split events, max VIL in subgraph $\geq$ {max_vil_thr} kg m$^{{-2}}$",
        ]

        df = TRAJECTORY_DEVELOPMENT.filter(
            (pl.col("vil.max_in_trajectory") >= max_vil_thr)
            & (pl.col("min_level_available") <= min_level)
            & (pl.col("max_level_available") >= max_level)
        )  # .filter(pl.col("level").is_between(min_level, max_level, closed="both"))
        print(
            f"VIL thr {max_vil_thr}, Min level: {min_level}, max level: {max_level}, max number of subgraphs: {df.select(pl.col('t0_node').over("type").n_unique().max()).item()}"
        )

        time_str = rf"subgraph exists at {min_level*5:d} $\ldots ${max_level*5:d} mins"
        t0_titles = [
            lambda x: f"({alphabet[min_alphabet_index + 1]}) Split events\n       {time_str}",
            # \n       max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: f"({alphabet[min_alphabet_index + 2]}) Merge events\n       {time_str}",
            # \n       max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
            lambda x: f"({alphabet[min_alphabet_index + 3]}) Merge-split events\n      {time_str}",
            # \n       max VIL in trajectory $\geq$ {max_vil_thr} kg m$^{{-2}}$",
        ]

        fig, axs, twin_axs = plots.plot_trajectory_development_figure(
            trajectories=df,
            t0_types=t0_types,
            variables=variables,
            variable_labels=variable_labels,
            row_title_funcs=t0_titles,
            colors=qualitative_colors,
            outfilename=f"trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}_max_vil_thr_{max_vil_thr}",
            lineplot_kwargs=LINEPLOT_KWS_CI_95,
            count_linekwargs=count_linekwargs,
            include_control=False,
            ylim=(-40, 40),
            count_ylim=(0, 24),
            count_ytick_multiples=(3, 1.5),
            xtick_multiples=(2, 1),
            ytick_multiples=(10, 5),
            legend_ncols=1,
            extensions=["png", "pdf", "svg"],
            outpath=OUTPATH,
            savefig_kwargs=dict(dpi=300, transparent=False),
            linestyles=linestyles,
            # row_height=3,
            # col_width=5,
            # figure_direction="vertical",
            row_height=3.5,
            col_width=4,
            figure_direction="horizontal",
            return_fig=True,
        )

        handles, labels = axs[0, 0].get_legend_handles_labels()

        col = qualitative_colors[: len(variables)]
        labels.append("95% CI of mean")
        handles.append([mpatches.Patch(facecolor=c, label="cat", alpha=0.6) for c in col])

        labels.append("Subgraph count")
        handles.append(Line2D([0], [0], color="k", ls="--", label="Subgraph count"))

        for ax in axs.flatten():
            ax.legend_.remove()

            title_txt = ax.get_title()
            ax.set_title("")
            ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda x, p: f"{x*5:.0f}"))
            ax.set_xlabel("Time since event [min]")

            ax.axvline(x=0, color="k", linestyle=":", linewidth=1.0, zorder=4)

            ax.text(
                x=0.03,
                y=0.90,
                s=title_txt,
                fontsize="medium",
                va="center",
                transform=ax.transAxes,
                zorder=5,
                bbox=dict(boxstyle="Square,pad=0.3", fc="white", ec="none", alpha=1.0, lw=2),
            )

        if figc == 0:
            fig.legend(
                handles=handles,
                labels=labels,
                bbox_to_anchor=(0.5, 1.1),
                loc="upper center",
                ncols=len(handles),
                fontsize="medium",
                title_fontsize="medium",
                frameon=False,
                handler_map={list: HandlerTuple(None)},
            )

        plot_utils.save_figs(
            fig=fig,
            delete_fig=True,
            outpath=OUTPATH,
            name=f"trajectory_development_split_merge_all_available_between_min_{min_level}_max_{max_level}_max_vil_thr_{max_vil_thr}",
            extensions=["png", "pdf", "svg"],
            savefig_kwargs=dict(dpi=300, transparent=False),
        )

        min_alphabet_index += 3

VIL thr 20, Min level: -3, max level: 6, max number of subgraphs: 2967
VIL thr 20, Min level: -3, max level: 0, max number of subgraphs: 6210
