Well, in [002_calc_some_stats.ipynb](https://github.com/dosquisd/NMDB-FD-PredictorWithGraphs/blob/5952726527e9855ce448bb8a9d8ad7a7f33317ba/notebooks/002_calc_some_stats.ipynb) there was a lot of code, mixing creating dataset and plotting them, so I created this notebook to have a dedicated notebook, only to create dataset

In [1]:
%%capture
%load_ext autoreload
%autoreload 2

In [3]:
import itertools
import logging
from typing import Any, Callable, Dict, List, Optional, Tuple

import networkx as nx
import nolds
import numpy as np
import pandas as pd
from scipy.stats import entropy

from utils import (
    MIN_VALUE_THRESHOLD,
    ROOTDIR,
    AdjacencyMethod,
    DistanceTransformation,
    EventData,
    GraphEvent,
    Normalizer,
    encode_variables_to_filename,
    get_events,
    graph_fractal_dimension,
    logger,
    read_dataset,
)

logger.setLevel(logging.DEBUG)

LOW__CUTOFF_RIGIDITY = 3.0
HIGH__CUTOFF_RIGIDITY = 6.0

## Create dataset

In [8]:
def calc_graph_metrics(
    events: Dict[str, EventData],
    *,
    threshold: float = 0.0,
    valid_methods: Optional[List[AdjacencyMethod]] = None,
) -> Tuple[pd.DataFrame, List[Dict[str, Any]]]:
    def calculate_avg_node_metric(
        metrics: Dict[str, float],
        fn: Optional[Callable[[np.ndarray], float]] = None,
        *,
        value_per_node: Optional[Dict[str, float]] = None,
    ) -> float:
        if fn is None:
            fn = lambda nums: float(nums.mean())  # noqa: E731

        if value_per_node is None:
            value_per_node = {}

        values = np.array(
            list(
                map(
                    lambda item: item[1] * value_per_node.get(item[0], 1.0),
                    metrics.items(),
                )
            )
        )
        return fn(values)

    if valid_methods is None:
        valid_methods = [AdjacencyMethod.MANHATTAN, AdjacencyMethod.MINKOWSKI]

    # The same amount of iterations, but at least there's only one loop instead of 4 nested loops
    dataset = []
    combinations = itertools.product(
        events.keys(), DistanceTransformation, Normalizer, valid_methods
    )
    for event_date, transform_method, normalization, adj_method in combinations:
        data = events[event_date]
        raw_df = data["raw"].reset_index(drop=True)
        logger.info(
            f"Processing event: {event_date} with transformation: {transform_method.value}, "
            f"normalization: {normalization.value}, adjacency method: {adj_method.value}, "
            f"raw shape: {raw_df.shape}"
        )

        if transform_method == DistanceTransformation.LOG:
            raw_df[raw_df.abs() < MIN_VALUE_THRESHOLD] = MIN_VALUE_THRESHOLD

        # Transform raw_df data before any normalization
        transformed_data = transform_method.transform(raw_df.to_numpy())
        transformed_df = pd.DataFrame(transformed_data, columns=raw_df.columns)
        transformed_df[transformed_df.abs() < MIN_VALUE_THRESHOLD] = MIN_VALUE_THRESHOLD

        # Normalize data
        columns = transformed_df.columns
        normalized_data = normalization.normalize(transformed_df.to_numpy())
        normalized_df = pd.DataFrame(normalized_data, columns=columns)
        nan_sum = normalized_df.isna().sum()
        nan_columns = nan_sum[nan_sum > 0].index.tolist()
        if nan_columns:
            logger.warning(f"NAN COLUMNS DROPPED: {nan_columns}")
            normalized_df.drop(columns=nan_columns, inplace=True)

        # normalized_df[normalized_df.abs() < MIN_VALUE_THRESHOLD] = MIN_VALUE_THRESHOLD

        # Create graph event with the normalized data and metadata
        graph_event = GraphEvent(
            data=normalized_df,
            metadata={
                # Event metadata
                "drop": data.get("drop", 0.0),
                "intensity": data.get("intensity", "Unknown"),
                "dst": data.get("dst", 0.0),
                # Station metadata
                "cutoff_rigidity": data["cutoff_rigidity"],
                "altitude": data["altitude"],
            },
        )

        try:
            graph = graph_event.get_graph_networkx(adj_method, threshold=threshold)
        except Exception as e:
            logger.error(
                f"Error processing event {event_date} with method {adj_method.value}: {e}"
            )

            print("\nraw_df snapshot:")
            nan_sum = raw_df.isna().sum()
            display(nan_sum[nan_sum > 0])

            print("transformed_df snapshot:")
            nan_sum = transformed_df.isna().sum()
            display(transformed_df)
            display(nan_sum[nan_sum > 0])

            print("normalized_df snapshot:")
            nan_sum = normalized_df.isna().sum()
            display(normalized_df)
            display(nan_sum[nan_sum > 0])

            raise e

        # Calculate MST and store it
        graph = nx.minimum_spanning_tree(graph)
        events[event_date]["graphs"][adj_method] = graph

        # Weight distribution in order to calculate other metrics
        weights = np.array(
            list(map(lambda item: item[2]["weight"], graph.edges(data=True)))
        )

        # Node values for node-based (local) metrics
        katz_centrality = nx.katz_centrality(graph)
        closeness_centrality = nx.closeness_centrality(graph)
        betweenness_centrality = nx.betweenness_centrality(graph)
        laplacian_centrality = nx.laplacian_centrality(graph)

        # Average node metrics per group (low, medium, high cutoff rigidity)
        avg_per_group = {}
        for (metric_name, centrality), interval in itertools.product(
            (
                ("avg_katz", katz_centrality),
                ("avg_closeness", closeness_centrality),
                ("avg_betweenness", betweenness_centrality),
                ("avg_laplacian", laplacian_centrality),
            ),
            ("low", "medium", "high"),
        ):
            if interval == "low":
                value_per_node = {
                    node: 1.0 if rigidity < LOW__CUTOFF_RIGIDITY else 0.0
                    for node, rigidity in graph_event.metadata[
                        "cutoff_rigidity"
                    ].items()
                }
            elif interval == "medium":
                value_per_node = {
                    node: 1.0
                    if LOW__CUTOFF_RIGIDITY <= rigidity < HIGH__CUTOFF_RIGIDITY
                    else 0.0
                    for node, rigidity in graph_event.metadata[
                        "cutoff_rigidity"
                    ].items()
                }
            else:  # high
                value_per_node = {
                    node: 1.0 if rigidity >= HIGH__CUTOFF_RIGIDITY else 0.0
                    for node, rigidity in graph_event.metadata[
                        "cutoff_rigidity"
                    ].items()
                }

            avg_per_group[metric_name + "_" + interval] = calculate_avg_node_metric(
                centrality, value_per_node=value_per_node
            )

        # Global graph metrics
        dataset.append(
            {
                "event_date": event_date,
                # Event metadata
                "drop": data.get("drop", 0.0),
                "intensity": data.get("intensity", "Unknown"),
                "dst": data.get("dst", 0.0),
                # Metric metadata
                "transformation": transform_method.value,
                "normalization": normalization.value,
                "adjacency_method": adj_method.value,
                "graph": graph,
                # Graph global metrics
                "global_efficiency": nx.global_efficiency(graph),
                "estrada_index": nx.estrada_index(graph),
                "entropy": entropy(weights),
                "fractal": graph_fractal_dimension(graph, seed=37)[0],
                "hurst_rs": nolds.hurst_rs(weights, fit="poly"),
                "modularity": nx.algorithms.community.modularity(
                    graph,
                    list(nx.algorithms.community.greedy_modularity_communities(graph)),
                ),
                "assortativity": nx.degree_assortativity_coefficient(graph),
                # Average node metrics
                "avg_katz": calculate_avg_node_metric(katz_centrality),
                "avg_closeness": calculate_avg_node_metric(closeness_centrality),
                "avg_betweenness": calculate_avg_node_metric(betweenness_centrality),
                "avg_laplacian": calculate_avg_node_metric(laplacian_centrality),
                # Average node metrics per group
                **avg_per_group,
            }
        )

    dataset_df = pd.DataFrame(dataset)
    return dataset_df, dataset

### Save datasets

In [None]:
valid_adj_methods = [
    AdjacencyMethod.MANHATTAN,
    AdjacencyMethod.MINKOWSKI,
]

# Options for processing different files and configurations
filename_options = ["all.txt", "all.original.txt", "all.imp.txt"]
imput_data_options = [False, True]
use_threshold_options = [False, True]

options = itertools.product(filename_options, imput_data_options, use_threshold_options)
for filename, imput_data, use_threshold in options:
    if use_threshold != imput_data:
        logger.info(
            f"Skipping file: {filename} with imput_data={imput_data} and "
            f"use_threshold={use_threshold} (invalid combination)"
        )
        continue

    logger.info(f"Processing file: {filename} with imput_data={imput_data}")
    events = get_events(
        filename=filename, imput_data=imput_data, use_threshold=use_threshold
    )
    dataset_df, dataset = calc_graph_metrics(events, valid_methods=valid_adj_methods)

    logger.info(f"Saving dataset for file: {filename} with imput_data={imput_data}")
    dataset_df.drop(columns=["graph"]).to_csv(
        ROOTDIR
        / "data"
        / encode_variables_to_filename(
            event_filename=filename, imput_data=imput_data, use_threshold=use_threshold
        ),
        index=False,
    )

## Read datasets to ensure are created properly

In [10]:
read_dataset(ROOTDIR / "data" / "dataset_all_imput-False_threshold-False.csv")

Unnamed: 0,event_date,drop,intensity,dst,transformation,normalization,adjacency_method,global_efficiency,estrada_index,entropy,...,avg_closeness_low,avg_closeness_medium,avg_closeness_high,avg_betweenness_low,avg_betweenness_medium,avg_betweenness_high,avg_laplacian_low,avg_laplacian_medium,avg_laplacian_high,graph
0,2023-04-23,6.57,G2,-213.0,none,min_max,manhattan,0.260470,78.118375,3.425651,...,0.122219,0.043008,0.017454,0.132820,0.018023,0.001894,0.048578,0.018893,0.004967,"(ATHN, MXCO, NANM, ROME, AATB, BKSN, JUNG, JUN..."
1,2023-04-23,6.57,G2,-213.0,none,min_max,minkowski,0.262560,79.558149,3.432043,...,0.117957,0.040433,0.016817,0.139235,0.017717,0.001894,0.048802,0.019755,0.005055,"(ATHN, MXCO, NANM, ROME, AATB, BKSN, JUNG, JUN..."
2,2023-04-23,6.57,G2,-213.0,none,z_score,manhattan,0.269012,80.006590,3.433672,...,0.120643,0.046529,0.020508,0.113453,0.023766,0.009531,0.041691,0.024559,0.008078,"(ATHN, MXCO, NANM, ROME, AATB, BKSN, JUNG, JUN..."
3,2023-04-23,6.57,G2,-213.0,none,z_score,minkowski,0.275242,81.061524,3.406625,...,0.119887,0.053726,0.021017,0.101417,0.026149,0.012586,0.040507,0.022641,0.010740,"(ATHN, MXCO, NANM, ROME, AATB, BKSN, JUNG, JUN..."
4,2023-04-23,6.57,G2,-213.0,none,robust,manhattan,0.300785,84.485612,3.420624,...,0.141403,0.059335,0.027612,0.077224,0.031097,0.005437,0.045756,0.022627,0.007684,"(ATHN, MXCO, NANM, ROME, AATB, BKSN, JUNG, JUN..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1015,2005-09-11,12.25,G3,-139.0,exponential,robust,minkowski,0.370329,33.360243,2.587530,...,0.177041,0.029660,0.061841,0.142125,0.009524,0.071795,0.090309,0.020465,0.044618,"(ESOI, MXCO, ROME, AATB, LMKS, MOSC, NEWK, CAL..."
1016,2005-09-11,12.25,G3,-139.0,exponential,decimal_scaling,manhattan,0.376799,33.796712,2.626843,...,0.190431,0.027422,0.053406,0.169231,0.000000,0.051282,0.112813,0.010838,0.034985,"(ESOI, MXCO, ROME, AATB, LMKS, MOSC, NEWK, CAL..."
1017,2005-09-11,12.25,G3,-139.0,exponential,decimal_scaling,minkowski,0.385665,34.457961,2.618426,...,0.193768,0.031328,0.054469,0.160440,0.000000,0.051282,0.112444,0.014234,0.033603,"(ESOI, MXCO, ROME, AATB, LMKS, MOSC, NEWK, CAL..."
1018,2005-09-11,12.25,G3,-139.0,exponential,none,manhattan,0.376799,33.796712,2.626843,...,0.190431,0.027422,0.053406,0.169231,0.000000,0.051282,0.112813,0.010838,0.034985,"(ESOI, MXCO, ROME, AATB, LMKS, MOSC, NEWK, CAL..."
