# Final Results

Recalculate metrics and plot them, but done at many more stations for each datetime. Basically, a more robust version of [002](./002_calc_some_metrics.ipynb)

In [2]:
import multiprocessing as mp
import warnings
from pathlib import Path
from random import sample
from typing import Optional, TypedDict

import pandas as pd

from utils import (
    calc_metrics,
    load_data,
    plot_metrics,
    plot_metrics_one,
    read_metrics_file,
)
from utils.constants import (
    EWM_ALPHA,
    NAN_THRESHOLD,
    WINDOW_SIZE,
    Events,
    datetimes,
)

warnings.filterwarnings("ignore")

In [2]:
class StationsToChoose(TypedDict):
    """
    Structure for station selection data.

    Attributes:
        stations: list[str]. Available station names for random sampling.
        num_sample: int. Number of stations to sample from the available list.
    """

    stations: list[str]
    num_sample: int

In [3]:
EVENT: Events = "Forbush Decrease"
MAX_SAMPLES: int = 10  # Samples per date
REPETITION: bool = False  # If True, it will repeat stations already calculated
EWM: bool = False  # If True, it will calculate EWM metrics

event_replace: str = EVENT.replace(" ", "")

## Calculate metrics

In [5]:
def get_valid_stations(
    df: pd.DataFrame, threshold: Optional[float] = None
) -> list[str]:
    """
    Drop stations (columns) that have a high ratio of NaN values.
    The DataFrame is the one with the data of the stations (all.txt).

    Args:
        df (pd.DataFrame): DataFrame with the stations data.
        threshold (Optional[float]): Ratio of NaN values to consider a station invalid.
            If None, it will use the global NAN_THRESHOLD.

    Returns:
        list[str]: List of valid station names (columns).
    """
    if threshold is None:
        threshold = NAN_THRESHOLD

    nans_count = dict(
        filter(
            lambda x: x[1] > 0,
            df.drop(columns="datetime").isna().sum().to_dict().items(),
        )
    )

    total = len(df)

    # Drop columns (stations) that exceed a nan ratio threshold
    stations = list(
        df.drop(
            columns=list(
                filter(
                    lambda station: nans_count[station] / total >= threshold,
                    nans_count,
                )
            )
        ).columns[1:]
    )

    return stations

In [6]:
stations: dict[str, list[str]] = {
    date: get_valid_stations(
        load_data(f"./data/{event_replace}/{date}/all.txt"),
        threshold=NAN_THRESHOLD,
    )
    for date in datetimes
}

chosen_stations: dict[str, list[str]] = {
    date: list(
        map(
            lambda filename: filename.name.strip().split("_", 1)[0].upper(),
            Path(f"./data/{event_replace}/{date}").glob("*.csv"),
        )
    )
    for date in datetimes
}

stations_to_choose: dict[str, StationsToChoose] = {
    date: {
        # Remove stations already calculated
        # and those that are fixed to be calculated
        "stations": list(
            set(stations[date])
            - set([date])
            - set(datetimes[date]["stations"].keys())
        ),
        # Final number of samples to choose
        "num_sample": num_samples
        if (
            num_samples := MAX_SAMPLES
            - len(set([date]) | set(datetimes[date]["stations"].keys()))
        )
        > 0
        else 0,
    }
    for date in datetimes
}

# Without repetition of stations already calculated
plot_stations = {
    date: sample(items["stations"], k=items["num_sample"])
    + list(datetimes[date]["stations"].keys())
    for date, items in stations_to_choose.items()
}

# Here is added repetition of stations already calculated (if needed)
if REPETITION:
    for date in plot_stations:
        plot_stations[date].extend(chosen_stations[date])

        # Drop duplicates
        plot_stations[date] = list(set(plot_stations[date]))

# Remove stations already calculated if repetition is False
else:
    plot_stations = {
        date: list(set(items) - set(chosen_stations[date]))
        for date, items in plot_stations.items()
    }

In [7]:
# In this case, I'll use only one date
plot_stations = dict(
    filter(lambda x: x[0] == "2024-05-10", plot_stations.items())
)

In [18]:
# 4 minutes approximate to calculate all metrics with my pc
suffix = f"ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else ""


def prepare_df(path: str) -> pd.DataFrame:
    df = load_data(path).set_index("datetime")
    return df.ewm(alpha=EWM_ALPHA).mean() if EWM_ALPHA and EWM else df


arguments = [
    (
        prepare_df(f"./data/{event_replace}/{date}/all.txt"),
        station,
        date,
        suffix,
    )
    for date, stations in plot_stations.items()
    for station in stations
]

with mp.Pool(processes=mp.cpu_count()) as pool:
    results = pool.starmap(
        calc_metrics,
        arguments,
    )

## Plotting

In [31]:
plot_stations: dict[str, list[str]] = {
    date: list(
        set(
            map(
                # Get Station name from filename
                lambda filename: filename.name.strip().split("_", 1)[0].upper(),
                Path(f"./data/{event_replace}/{date}").glob("*.csv"),
            )
        )
    )
    for date in datetimes
}

In [32]:
# Nice! Expected output
dict(map(lambda x: (x, len(plot_stations[x])), plot_stations))

{'2023-04-23': 10, '2024-03-24': 11, '2024-05-10': 11}

In [5]:
def _mpl_worker_init():
    import os  # noqa: E401
    import tempfile

    os.environ["MPLCONFIGDIR"] = tempfile.mkdtemp(prefix="mplcache-")

In [11]:
EWM = True

In [16]:
# Again, filter only one date
plot_stations = dict(
    filter(lambda x: x[0] == "2024-03-24", plot_stations.items())
)

### Two differents plots

In [12]:
def plot_metrics_wrapper(args_tuple: tuple[str, str, int]) -> None:
    date, station, suffix = args_tuple

    df = read_metrics_file(
        event=event_replace,
        date=date,
        station=station,
        window_size=WINDOW_SIZE,
        datetime_cols={"datetime": ""},
        suffix=f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else "",
    )

    # Odd suffix: all metrics except "lepel_ziv"
    if suffix % 2 == 1:
        df = df.drop(columns=["lepel_ziv"], errors="ignore")
        relevant_metrics = ["*"]

    else:  # Even suffix: only "lepel_ziv"
        relevant_metrics = ["lepel_ziv"]

    if (
        station in datetimes[date]["stations"]
        and datetimes[date]["stations"][station]
    ):
        min_datetime, max_datetime = datetimes[date]["stations"][station]
    else:
        min_datetime, max_datetime = datetimes[date]["bounds"]

    plot_metrics(
        window_size=WINDOW_SIZE,
        relevant_metrics=relevant_metrics,
        df=df,
        event=event_replace,
        date=date,
        station=station,
        min_datetime=min_datetime,
        max_datetime=max_datetime,
        freq_hours=2,
        save_format="pdf",
        suffix=str(suffix),
        show=False,
    )


with mp.Pool(processes=mp.cpu_count(), initializer=_mpl_worker_init) as pool:
    arguments_plot = [
        (date, station, suffix + 2 if EWM else suffix)
        for date, stations in plot_stations.items()
        for station in stations
        for suffix in [1, 2]
    ]

    pool.map(plot_metrics_wrapper, arguments_plot)

### One plot

TODO: Change color palette

In [13]:
def plot_metrics_one_wrapper(args_tuple: tuple[str, str]) -> None:
    date, station = args_tuple

    df = read_metrics_file(
        event=event_replace,
        date=date,
        station=station,
        window_size=WINDOW_SIZE,
        datetime_cols={"datetime": ""},
        suffix=f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else "",
    )

    if (
        station in datetimes[date]["stations"]
        and datetimes[date]["stations"][station]
    ):
        min_datetime, max_datetime = datetimes[date]["stations"][station]
    else:
        min_datetime, max_datetime = datetimes[date]["bounds"]

    plot_metrics_one(
        window_size=WINDOW_SIZE,
        relevant_metrics=None,
        df=df,
        event=event_replace,
        date=date,
        station=station,
        min_datetime=min_datetime,
        max_datetime=max_datetime,
        freq_hours=2,
        save_format="pdf",
        show=False,
    )


with mp.Pool(processes=mp.cpu_count(), initializer=_mpl_worker_init) as pool:
    arguments_plot = [
        (date, station)
        for date, stations in plot_stations.items()
        for station in stations
    ]

    pool.map(plot_metrics_one_wrapper, arguments_plot)