# Final Results

Recalculate metrics and plot them, but done at many more stations for each datetime. Basically, a more robust version of [002](./002_calc_some_metrics.ipynb)

In [2]:
from pathlib import Path
import pandas as pd

from typing import TypedDict, Optional

from random import sample
from utils.constants import WINDOW_SIZE, Events, NAN_THRESHOLD
from utils import load_data, calc_metrics, plot_metrics, read_metrics_file

import multiprocessing as mp

import warnings


warnings.filterwarnings("ignore")

In [3]:
class DateEventsInfo(TypedDict):
    """
    Information about the event date.

    Attributes:
        bounds: list[str] The start and end bounds of the event date.
        freq: str The frequency of the data for the event date.
        stations: list[str] The list of relevant stations for the event date.
    """

    bounds: list[str]
    freq: str
    stations: list[str]


In [7]:
EVENT: Events = "Forbush Decrease"
MAX_SAMPLES: int = 10  # Samples per date
REPETITION: bool = True  # If True, it will repeat stations already calculated

event_replace: str = EVENT.replace(" ", "")

# Relevant dates for the event
# Stations lists are only examples where the event was clear;
# can be modify them as needed
datetimes: dict[str, DateEventsInfo] = {
    "2023-04-23": {
        "bounds": ["2023-04-23 23:00:00", "2023-04-24 06:00:00"],
        "freq": "1h",
        "stations": [
            "AATB",
            "APTY",
            "IRK2",
            "LMKS",
            "NEWK",
            "NAIN",
            "SOPO",
        ],
    },
    "2024-03-24": {
        "bounds": ["2024-03-24 14:00:00", "2024-03-25 04:30:00"],
        "freq": "90min",
        "stations": [
            "APTY",
            "DOMC",
            "INVK",
            "JUNG1",
            "KIEL2",
            "LMKS",
            "MWSN",
            "NEWK",
            "MXCO",
            "OULU",
            # TXBY, YKTK
        ],
    },
    "2024-05-10": {
        "bounds": ["2024-05-10 18:00:00", "2024-05-11 01:00:00"],
        "freq": "1h",
        "stations": [
            "APTY",
            "DOMB",
            "DOMC",
            "INVK",
            "IRK3",
            "JBGO",
            "KERG",
            "KIEL2",
            "LMKS",
            "MWSN",
            # SOPB, PWNK, SOPO, TERA, THUL, TXBY, YKTK
        ],
    },
}

## Calculate metrics

In [25]:
def get_valid_stations(
    df: pd.DataFrame, threshold: Optional[float] = None
) -> list[str]:
    """
    Drop stations (columns) that have a high ratio of NaN values.
    The DataFrame is the one with the data of the stations (all.txt).

    Args:
        df (pd.DataFrame): DataFrame with the stations data.
        threshold (Optional[float]): Ratio of NaN values to consider a station invalid.
            If None, it will use the global NAN_THRESHOLD.

    Returns:
        list[str]: List of valid station names (columns).
    """
    if threshold is None:
        threshold = NAN_THRESHOLD

    nans_count = dict(
        filter(
            lambda x: x[1] > 0,
            df.drop(columns="datetime").isna().sum().to_dict().items(),
        )
    )

    total = len(df)

    # Drop columns (stations) that superates a nan ratio threshold
    stations = list(
        df.drop(
            columns=list(
                filter(
                    lambda station: nans_count[station] / total >= threshold, nans_count
                )
            )
        ).columns[1:]
    )

    return stations

In [31]:
class StationsToChoose(TypedDict):
    stations: list[str]
    num_sample: int


stations: dict[str, list[str]] = {
    date: get_valid_stations(
        load_data(f"./data/{event_replace}/{date}/all.txt"),
        threshold=NAN_THRESHOLD,
    )
    for date in datetimes
}

choosen_stations: dict[str, list[str]] = {
    date: list(
        map(
            lambda filename: filename.name.strip().split("_", 1)[0].upper(),
            Path(f"./data/{event_replace}/{date}").glob("*.csv"),
        )
    )
    for date in datetimes
}

stations_to_choose: dict[str, StationsToChoose] = {
    date: {
        # Remove stations already calculated
        # and those that are fixed to be calculated
        "stations": list(
            set(stations[date])
            - set(choosen_stations[date])
            - set(datetimes[date]["stations"])
        ),
        # Final number of samples to choose
        "num_sample": num_samples
        if (
            num_samples := MAX_SAMPLES
            - len(choosen_stations[date])
            - len(datetimes[date]["stations"])
        )
        > 0
        else 0,
    }
    for date in datetimes
}

# Without repetition of stations already calculated
plot_stations = {
    date: sample(items["stations"], k=items["num_sample"]) + datetimes[date]["stations"]
    for date, items in stations_to_choose.items()
}

# Here is added repetition of stations already calculated (if needed)
if REPETITION:
    for date in plot_stations:
        plot_stations[date].extend(choosen_stations[date])


In [32]:
# 4 minutes approximate to calculate all metrics with my pc
arguments = [
    (
        load_data(f"./data/{event_replace}/{date}/all.txt").set_index("datetime"),
        station,
        date,
    )
    for date, stations in plot_stations.items()
    for station in stations
]

with mp.Pool(processes=mp.cpu_count()) as pool:
    results = pool.starmap(
        calc_metrics,
        arguments,
    )


## Plotting

In [8]:
plot_stations: dict[str, list[str]] = {
    date: list(
        map(
            # Get Station name from filename
            lambda filename: filename.name.strip().split("_", 1)[0].upper(),
            Path(f"./data/{event_replace}/{date}").glob("*.csv"),
        )
    )
    for date in datetimes
}

In [9]:
def plot_metrics_wrapper(args_tuple) -> None:
    date, station, suffix = args_tuple

    df = read_metrics_file(
        event=event_replace,
        date=date,
        station=station,
        window_size=WINDOW_SIZE,
        datetime_cols={"datetime": ""},
    )

    if suffix == 1:
        df = df.drop(columns=["lepel_ziv"], errors="ignore")
        relevant_metrics = ["*"]
    else:
        relevant_metrics = ["lepel_ziv"]

    freq_date_range = datetimes[date]["freq"]

    plot_metrics(
        window_size=WINDOW_SIZE,
        relevant_metrics=relevant_metrics,
        df=df,
        event=event_replace,
        date=date,
        station=station,
        min_datetime=datetimes[date]["bounds"][0],
        max_datetime=datetimes[date]["bounds"][1],
        freq_date_range=freq_date_range,
        save_format="pdf",
        suffix=str(suffix),
        show=False
    )


with mp.Pool(processes=mp.cpu_count()) as pool:
    arguments_plot = [
        (date, station, suffix)
        for date, stations in plot_stations.items()
        for station in stations
        for suffix in [1, 2]
    ]

    pool.map(plot_metrics_wrapper, arguments_plot)