# Nguyen Results

This notebook is used to compare the results in Rüdisser et al. 2025 to Nguyen et al. 2025

### Importing packages

In [None]:
%load_ext autoreload
%autoreload 2

import sys
from pathlib import Path

project_root = Path().resolve().parent.parent
sys.path.append(str(project_root))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import datetime
import yaml

from loguru import logger

# hide all logs except for errors
logger.remove()
logger.add(sys.stderr, level="ERROR")

from tqdm import tqdm

import matplotlib.font_manager as fm
import seaborn as sns

# Set up matplotlib style
sns.set_context("talk")
sns.set_style("whitegrid")
sns.set_style("ticks")

from scripts.data.visualise.insitu_geosphere import (
    prop
)

geo_cornflowerblue = "dodgerblue"
geo_lime = "gold"
geo_magenta = "firebrick"

from src.arcane2.data.data_utils.event import (
    EventCatalog
)

from src.arcane2.data.realtime.realtime_insitu_dataset import RealtimeInsituDataset
from src.arcane2.data.catalogs.nguyen_dataset import Nguyen_Dataset
from src.arcane2.data.abstract.multi_signal_dataset import MultiSignalDataset

### Loading config

In [None]:
# we load the config file used during training

config = yaml.safe_load(
    open(project_root / "config/base_dataset/curated_realtime_dataset_lowres.yaml")
)

### Loading results

In [None]:
# we prepare the filepaths for the results that were generated during inference

result_paths = []

run_names = ["train_arcane_rtsw_new_bounds_new_drops"]
for run_name in run_names:
    cache_path = project_root / f"cache/{run_name}"

    path = (
        cache_path
        / f"all_results_curated_realtime_dataset_lowres_tminus_all.pkl"
    )
    if path.exists():
        result_paths.append(path)

In [None]:
# we load the results and concatenate them into a single dataframe

for i, path in enumerate(tqdm(result_paths)):
    if i == 0:
        all_results = pd.read_pickle(path)
    else:
        loaded = pd.read_pickle(path)
        all_results = pd.concat([all_results, loaded], axis=0).sort_index()
        all_results = all_results.combine_first(loaded)
        all_results = all_results.groupby(all_results.index).first()


In [None]:
# we drop the missing values

all_results = all_results.dropna()

In [None]:
# check the percentage of missing values in all_results
expected_range = pd.date_range(
    start=all_results.index.min(),
    end=all_results.index.max(),
    freq="30min",
)
missing = expected_range.difference(all_results.index)
missing_percentage = len(missing) / len(expected_range) * 100
print(
    f"Missing values in all_results: {len(missing)} ({missing_percentage:.2f}%)"
)

In [None]:
# If the model trained correctly, the minimum value should be close to 0
print(f"Minimum value: {all_results["predicted_value_train_arcane_rtsw_new_bounds_new_drops_0_tminus1"].min()}")

### Creating Catalog Dataset

In [None]:
# We create the catalog dataset from the Nguyen Catalogs

catalog_paths = [
    project_root / "data/dataverse_files/ICME_catalog_OMNI.csv",
    project_root / "data/dataverse_files/Sheath_catalog_OMNI.csv",
]
event_types = config["dataset"]["single_signal_datasets"][0].get("event_types", "ICME")
filters = config["dataset"]["single_signal_datasets"][0].get("filters", None)
cap = config["dataset"]["single_signal_datasets"][0].get("cap", None)
resample_freq = config["dataset"]["single_signal_datasets"][0].get(
    "resample_freq", "10min"
)

catalog_dataset = Nguyen_Dataset(
    folder_paths=catalog_paths,
    resample_freq=resample_freq,
    event_types=event_types,
    filters=filters,
    cap=cap,
)

### Creating In Situ Dataset

In [None]:
# We create the insitu dataset from the NOAA archive without interpolating the data

folder_path = project_root / "data/noaa_archive_gsm.p"
components = config["dataset"]["single_signal_datasets"][1].get("components")
resample = config["dataset"]["single_signal_datasets"][1].get("resample")
resample_method = config["dataset"]["single_signal_datasets"][1].get("resample_method")
resample_freq = config["dataset"]["single_signal_datasets"][1].get("resample_freq")
padding = config["dataset"]["single_signal_datasets"][1].get("padding")
lin_interpol = config['dataset']['single_signal_datasets'][1].get('lin_interpol')
scaling = config["dataset"]["single_signal_datasets"][1].get("scaling", "None")

insitu_dataset = RealtimeInsituDataset(
    folder_path=folder_path,
    components=components,
    resample=resample,
    resample_freq=resample_freq,
    resample_method=resample_method,
    padding=padding,
    lin_interpol=lin_interpol,
    scaling=scaling,
)

### Creating MultiSignalDataset

In [None]:
# The two datasets are combined into a MultiSignalDataset

catalog_idx = 0

multi_signal_dataset = MultiSignalDataset(
    single_signal_datasets=[catalog_dataset, insitu_dataset],
    catalog_idx=catalog_idx,
)

## preprocess results

In [None]:
from src.arcane2.data.utils import compare_catalogs_for_results, merge_columns_by_mean, shift_columns

# We merge the columns and shift them by the time shift
df_merged = merge_columns_by_mean(all_results, prefix="predicted_value_train_arcane_rtsw_new_bounds_new_drops_")
df_shifted_and_merged = shift_columns(df_merged)

In [None]:
# To test the validity of our approach, we generate a catalog from the created ground truth time series

original_catalog = catalog_dataset.catalog.event_cat

In [None]:
# drop nan values
df_shifted_and_merged_processed_precat = df_shifted_and_merged.dropna()

In [None]:
# check the percentage of missing values in df_shifted_and_merged_processed_precat
expected_range = pd.date_range(
    start=df_shifted_and_merged_processed_precat.index.min(),
    end=df_shifted_and_merged_processed_precat.index.max(),
    freq="30min",
)
missing = expected_range.difference(df_shifted_and_merged_processed_precat.index)
missing_percentage = len(missing) / len(expected_range) * 100
print(
    f"Missing values in df_shifted_and_merged_processed_precat: {len(missing)} ({missing_percentage:.2f}%)"
)

In [None]:
detectable_original_catalog = []

for event in original_catalog:
    if event.begin > df_shifted_and_merged_processed_precat.index[0] and event.begin < df_shifted_and_merged_processed_precat.index[-1]:
        #calculate the number of datapoints in the time range at a resolution of 30 min
        expected_nr_datapoints = int((event.end - event.begin).total_seconds() / 60 / 30)
        actual_nr_datapoints = df_shifted_and_merged_processed_precat.loc[event.begin : event.end].shape[0]
        if actual_nr_datapoints > expected_nr_datapoints * 0.99:
            detectable_original_catalog.append(event)
        else:
            df_shifted_and_merged_processed_precat.loc[event.begin : event.end] = np.nan

In [None]:
# drop nan values
df_shifted_and_merged_processed = df_shifted_and_merged_processed_precat.dropna()

### Analysing Results

In [None]:
extracted_catalog = EventCatalog(
    event_types="CME",
    catalog_name="True Catalog",
    spacecraft="Wind",
    dataframe=df_shifted_and_merged_processed,
    key="NGUYEN_catalog-ICME",
    resample_freq="30min",
    creep_delta=30,
).event_cat

In [None]:
len(extracted_catalog)

In [None]:
# We create a dataframe with the number of events per month

dateindex = pd.date_range(start="1998", end="2024", freq="1Y")

event_numbers = pd.DataFrame(
    index=dateindex,
    columns=["Detectable", "Extracted"],
)

for date in dateindex:
    event_numbers.loc[date, "Detectable"] = len(
        [
            x
            for x in detectable_original_catalog
            if x.begin.year == date.year 
        ]
    )
    event_numbers.loc[date, "Extracted"] = len(
        [
            x
            for x in extracted_catalog
            if x.begin.year == date.year 
        ]
    )

In [None]:
# We plot the number of events per month

fig, axs = plt.subplots(1, 1, figsize=(10, 5))

event_numbers.plot(
    ax=axs, kind ="bar", color=[geo_magenta, geo_lime, geo_cornflowerblue], width=0.8
)

axs.set_ylabel("Number of Events")
axs.set_xlabel("Date")
plt.grid(True, which="both", linestyle="--", linewidth=0.5, alpha=0.7)

# Format x-axis to show only the year
axs.set_xticks(range(0, len(event_numbers.index), 4))  # Show only one tick per year
axs.set_xticklabels(event_numbers.index[::4].year, ha="center", rotation=0)

plt.tight_layout()

print("Event numbers:")
plt.show()

In [None]:
TP, FP, FN, _, found_already, detected, _, ious = compare_catalogs_for_results(
    extracted_catalog, detectable_original_catalog
)

In [None]:
# We expect less events in the extracted catalog than in the original catalog. The results should be perfect scores.

print("####################################")
print(" RESULTS FOR GENERATED CATALOG")
print("####################################")
print("")
print(f"original: {len(detectable_original_catalog)}")
print(f"predicted: {len(extracted_catalog)}")
print("")
print("")
print(f"TP: {len(TP)}")
print(f"FP: {len(FP)}")
print(f"FN: {len(FN)}")

predicted = len(extracted_catalog)
precision = len(TP) / (len(TP) + len(FP))
recall = len(TP) / (len(TP) + len(FN))
print(f"ratio: {predicted/len(TP)}")
print("")
print(f"Precision: {(predicted-len(FP))/(predicted)}")
print(f"Recall: {len(TP)/len(TP + FN)}")
print(f"F1: {2*(precision*recall)/(precision+recall)}")
print("")
print(f"mean iou: {np.mean(ious)}")

## Threshold comparison

In [None]:
# We create the threshold classifier baseline

df = multi_signal_dataset.df.copy()

from scipy.constants import k, proton_mass, pi
import numpy as np

v_threshold = 30 * 1e3

T_threshold = v_threshold**2 * proton_mass * pi / (8 * k)

T_threshold = np.round(T_threshold, -3)

b_threshold = 8
beta_threshold = 0.3
v_threshold = 30

print(
    f"Thresholds: T = {T_threshold} K, B = {b_threshold} nT, beta = {beta_threshold}, V = {v_threshold} km/s"
)

cols = ["true_value", "predicted_value_threshold"]

result_df = pd.DataFrame(columns=cols, index=df.index)

result_df["predicted_value_threshold"] = 0
result_df["true_value"] = df["NGUYEN_catalog-ICME"]

# Set true_value to 1 only when all three conditions are true
result_df.loc[
    (df["NOAA Realtime Archive_insitu-bt"] >= b_threshold)
    & (df["NOAA Realtime Archive_insitu-beta"] <= beta_threshold)
    & (df["NOAA Realtime Archive_insitu-tp"] <= T_threshold),
    "predicted_value_threshold",
] = 1


### preprocess results

In [None]:
# We start with the eventwise comparison of the threshold classifier

print("####################################")
print(" RESULTS FOR EVENTWISE THRESHOLD CLASSIFIER")
print("####################################")
print("")


threshold_catalog = EventCatalog(
    event_types="CME",
    catalog_name="Threshold Catalog",
    spacecraft="OMNI",
    dataframe=result_df,
    key="predicted_value_threshold",
    creep_delta=30,
).event_cat

threshold_catalog = [ event for event in threshold_catalog if event.duration > datetime.timedelta(minutes=30)]

(
    TP_threshold,
    FP_threshold,
    FN_threshold,
    threshold_delays,
    found_already_threshold,
    detected_threshold,
    threshold_durations,
    ious_threshold,
) = compare_catalogs_for_results(threshold_catalog, extracted_catalog)

print(f"TP: {len(TP_threshold)}")
print(f"FP: {len(FP_threshold)}")
print(f"FN: {len(FN_threshold)}")


predicted_threshold = len(threshold_catalog)
precision_threshold = len(TP_threshold) / (len(TP_threshold) + len(FP_threshold)) 
recall_threshold = len(TP_threshold) / (len(TP_threshold) + len(FN_threshold))
f1_threshold = (
    2
    * (precision_threshold * recall_threshold)
    / (precision_threshold + recall_threshold)
)

print(f"Precision: {precision_threshold}")
print(f"Recall: {recall_threshold}")
print(f"F1: {f1_threshold}")
print(f"mean iou: {np.mean(ious_threshold)}")

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(10, 5))
colors = plt.cm.plasma(np.linspace(0, 1, 26))
sm = plt.cm.ScalarMappable(cmap=plt.cm.plasma, norm=plt.Normalize(vmin=-26, vmax=-1))

resultsdict = {}

for t in tqdm(range(1,26, 1)):
    key = f"predicted_value_train_arcane_rtsw_new_bounds_new_drops_tminus{t*2}"
    
    thresholds = np.arange(0, 1, 0.1)
    precisions = []
    recalls = []
    f1s = []
    ious = []
    
    cat_temp = EventCatalog(
            event_types="CME",
            spacecraft="Wind",
            dataframe=df_shifted_and_merged_processed,
            key=key,
            creep_delta=30,
            thresh=0.1
        )
    
    for event in cat_temp.event_cat:
        probability = df_shifted_and_merged_processed.loc[event.begin:event.end, key].mean()
        event.proba = probability

    for thresh in tqdm(thresholds):
        icmes_in_proba = [event for event in cat_temp.event_cat if event.proba > thresh]
        TP, FP, FN, _, found_already, detected, _, ious = (
        compare_catalogs_for_results(icmes_in_proba, detectable_original_catalog)
            )
        
        predicted = len(icmes_in_proba)

        if len(TP) == 0:
            ratio = 0
        else:
            ratio = predicted / len(TP)

        if predicted == 0:
            precision = 0
        else:
            precision = len(TP) / (len(TP) + len(FP)) 

        if len(TP) + len(FN) == 0:
            recall = 1
        else:
            recall = len(TP) / (len(TP) + len(FN))

        if precision + recall == 0:
            f1 = 0
        else:
            f1 = 2 * (precision * recall) / (precision + recall)

        if precision > 0 and recall > 0:

            precisions.append(precision)
            recalls.append(recall)
            f1s.append(f1)
            ious.append(np.mean(ious))

    axs.plot(recalls, precisions, color=colors[t], label="")
    axs.set_xlim(0.1, 1)
    axs.set_ylim(0.1, 1)
    axs.set_xlabel("Recall")
    axs.set_ylabel("Precision")

    max_f1 = np.argmax(f1s)
    max_f1_thresh = thresholds[max_f1]
    max_f1_recall = recalls[max_f1]
    max_f1_precision = precisions[max_f1]
    max_f1_f1 = f1s[max_f1]
    max_f1_iou = ious[max_f1]

    resultsdict[t] = {
        "threshold": max_f1_thresh,
        "recall": max_f1_recall,
        "precision": max_f1_precision,
        "f1": max_f1_f1,
        "iou": max_f1_iou,
    }

cb1 = fig.colorbar(sm, ax=axs, orientation="vertical")

ticks = [-1, -5, -10, - 15, -20, -25] 
tick_labels = [f"{-t} h" for t in ticks[::-1]] 

cb1.set_ticks(ticks)
cb1.set_ticklabels(tick_labels)

axs.plot([0, recall_threshold, recall_threshold], [precision_threshold, precision_threshold, 0], color=geo_cornflowerblue, label="Threshold Classifier")

# show legend
axs.legend(loc="lower left")

fig.tight_layout()
plt.show()

In [None]:
t = np.arange(1, 26, 1)
precisions = [resultsdict[i]["precision"] for i in t]
recalls = [resultsdict[i]["recall"] for i in t]
f1s = [resultsdict[i]["f1"] for i in t]
thresholds = [resultsdict[i]["threshold"] for i in t]
ious = [resultsdict[i]["iou"] for i in t]

fig, axs = plt.subplots(1, 1, figsize=(10, 5))

axs.plot(t, f1s, label="F1", color=geo_cornflowerblue)

axs.set_xlabel("$\delta$ [hours]")
axs.set_ylabel("F1-Score")

In [None]:
print(f"Maximum F1: {max(f1s)} at {t[np.argmax(f1s)]} hours")
print(f"Precision at maximum F1: {precisions[np.argmax(f1s)]}")
print(f"Recall at maximum F1: {recalls[np.argmax(f1s)]}")
print(f"IOU at maximum F1: {ious[np.argmax(f1s)]}")