# Final Results

Recalculate metrics and plot them, but done at many more stations for each datetime. Basically, a more robust version of [002](./002_calc_some_metrics.ipynb)

In [None]:
import pandas as pd
from pathlib import Path

from typing import TypedDict, Optional

from random import sample
from utils.constants import WINDOW_SIZE, Events, NAN_THRESHOLD, EWM_ALPHA, METRICS
from utils import (
    load_data,
    calc_metrics,
    plot_metrics,
    read_metrics_file,
    plot_metrics_one,
)

import multiprocessing as mp

import warnings


warnings.filterwarnings("ignore")

In [10]:
DatetimeBounds = list[str, str]


class DateEventsInfo(TypedDict):
    """
    Information about the event date.

    Attributes:
        bounds: list[DatetimeBounds]. The start and end bounds of the event date. Generally,
                the bounds are the same for all stations.
        freq: str. The frequency of the data for the event date.
        stations: dict[str, Optional[DatetimeBounds]]. The list of relevant stations
                  for the event date.
    """

    bounds: list[DatetimeBounds]
    freq: str
    stations: dict[str, Optional[DatetimeBounds]]


class StationsToChoose(TypedDict):
    """
    Structure for station selection data.

    Attributes:
        stations: list[str]. Available station names for random sampling.
        num_sample: int. Number of stations to sample from the available list.
    """

    stations: list[str]
    num_sample: int


In [5]:
EVENT: Events = "Forbush Decrease"
MAX_SAMPLES: int = 10  # Samples per date
REPETITION: bool = True  # If True, it will repeat stations already calculated
EWM: bool = True  # If True, it will calculate EWM metrics

event_replace: str = EVENT.replace(" ", "")

# Relevant dates for the event
# Stations lists are only examples where the event was clear;
# can be modify them as needed
# TODO: Fix datetime event for each station
datetimes: dict[str, Optional[DateEventsInfo]] = {
    "2023-04-23": {
        "bounds": ["2023-04-23 23:00:00", "2023-04-24 06:00:00"],
        "freq": "1h",
        "stations": {
            "AATB": None,
            "APTY": None,
            "IRK2": None,
            "LMKS": None,
            "NEWK": None,
            "NAIN": None,
            "SOPO": None,
        },
    },
    "2024-03-24": {
        "bounds": ["2024-03-24 14:00:00", "2024-03-25 04:30:00"],
        "freq": "90min",
        "stations": {
            "APTY": None,
            "DOMC": None,
            "INVK": None,
            "JUNG1": None,
            "KIEL2": None,
            "LMKS": None,
            "MWSN": None,
            "NEWK": None,
            "MXCO": None,
            "OULU": None,
            # TXBY, YKTK
        },
    },
    "2024-05-10": {
        "bounds": ["2024-05-10 18:00:00", "2024-05-11 01:00:00"],
        "freq": "1h",
        "stations": {
            "APTY": None,
            "DOMB": None,
            "DOMC": None,
            "INVK": None,
            "IRK3": None,
            "JBGO": None,
            "KERG": None,
            "KIEL2": None,
            "LMKS": None,
            "MWSN": None,
            # SOPB, PWNK, SOPO, TERA, THUL, TXBY, YKTK
        },
    },
}

## Calculate metrics

In [5]:
def get_valid_stations(
    df: pd.DataFrame, threshold: Optional[float] = None
) -> list[str]:
    """
    Drop stations (columns) that have a high ratio of NaN values.
    The DataFrame is the one with the data of the stations (all.txt).

    Args:
        df (pd.DataFrame): DataFrame with the stations data.
        threshold (Optional[float]): Ratio of NaN values to consider a station invalid.
            If None, it will use the global NAN_THRESHOLD.

    Returns:
        list[str]: List of valid station names (columns).
    """
    if threshold is None:
        threshold = NAN_THRESHOLD

    nans_count = dict(
        filter(
            lambda x: x[1] > 0,
            df.drop(columns="datetime").isna().sum().to_dict().items(),
        )
    )

    total = len(df)

    # Drop columns (stations) that superates a nan ratio threshold
    stations = list(
        df.drop(
            columns=list(
                filter(
                    lambda station: nans_count[station] / total >= threshold, nans_count
                )
            )
        ).columns[1:]
    )

    return stations

In [6]:
stations: dict[str, list[str]] = {
    date: get_valid_stations(
        load_data(f"./data/{event_replace}/{date}/all.txt"),
        threshold=NAN_THRESHOLD,
    )
    for date in datetimes
}

choosen_stations: dict[str, list[str]] = {
    date: list(
        map(
            lambda filename: filename.name.strip().split("_", 1)[0].upper(),
            Path(f"./data/{event_replace}/{date}").glob("*.csv"),
        )
    )
    for date in datetimes
}

stations_to_choose: dict[str, StationsToChoose] = {
    date: {
        # Remove stations already calculated
        # and those that are fixed to be calculated
        "stations": list(
            set(stations[date])
            - set(choosen_stations[date])
            - set(datetimes[date]["stations"].keys())
        ),
        # Final number of samples to choose
        "num_sample": num_samples
        if (
            num_samples := MAX_SAMPLES
            - len(choosen_stations[date])
            - len(datetimes[date]["stations"].keys())
        )
        > 0
        else 0,
    }
    for date in datetimes
}

# Without repetition of stations already calculated
plot_stations = {
    date: sample(items["stations"], k=items["num_sample"])
    + list(datetimes[date]["stations"].keys())
    for date, items in stations_to_choose.items()
}

# Here is added repetition of stations already calculated (if needed)
if REPETITION:
    for date in plot_stations:
        plot_stations[date].extend(choosen_stations[date])

        # Drop duplicates
        plot_stations[date] = list(set(plot_stations[date]))

In [7]:
# 4 minutes approximate to calculate all metrics with my pc
suffix = f"ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA else ""


def prepare_df(path: str) -> pd.DataFrame:
    df = load_data(path).set_index("datetime")
    return df.ewm(alpha=EWM_ALPHA).mean() if EWM_ALPHA and EWM else df


arguments = [
    (
        prepare_df(f"./data/{event_replace}/{date}/all.txt"),
        station,
        date,
        suffix,
    )
    for date, stations in plot_stations.items()
    for station in stations
]

with mp.Pool(processes=mp.cpu_count()) as pool:
    results = pool.starmap(
        calc_metrics,
        arguments,
    )


## Plotting

In [5]:
plot_stations: dict[str, list[str]] = {
    date: list(
        set(
            map(
                # Get Station name from filename
                lambda filename: filename.name.strip().split("_", 1)[0].upper(),
                Path(f"./data/{event_replace}/{date}").glob("*.csv"),
            )
        )
    )
    for date in datetimes
}

In [6]:
# Nice! Expected output
dict(map(lambda x: (x, len(plot_stations[x])), plot_stations))

{'2023-04-23': 10, '2024-03-24': 11, '2024-05-10': 11}

### Two differents plots

In [36]:
def plot_metrics_wrapper(args_tuple: tuple[str, str, int]) -> None:
    date, station, suffix = args_tuple

    df = read_metrics_file(
        event=event_replace,
        date=date,
        station=station,
        window_size=WINDOW_SIZE,
        datetime_cols={"datetime": ""},
        suffix=f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else "",
    )

    # Odd suffix: all metrics except "lepel_ziv"
    if suffix % 2 == 1:
        df = df.drop(columns=["lepel_ziv"], errors="ignore")
        relevant_metrics = ["*"]

    else:  # Even suffix: only "lepel_ziv"
        relevant_metrics = ["lepel_ziv"]

    if station in datetimes[date]["stations"] and datetimes[date]["stations"][station]:
        min_datetime, max_datetime = datetimes[date]["stations"][station]
    else:
        min_datetime, max_datetime = datetimes[date]["bounds"]

    plot_metrics(
        window_size=WINDOW_SIZE,
        relevant_metrics=relevant_metrics,
        df=df,
        event=event_replace,
        date=date,
        station=station,
        min_datetime=min_datetime,
        max_datetime=max_datetime,
        freq_date_range=datetimes[date]["freq"],
        save_format="pdf",
        suffix=str(suffix),
        show=False,
    )


with mp.Pool(processes=mp.cpu_count()) as pool:
    arguments_plot = [
        (date, station, suffix + 2 if EWM else suffix)
        for date, stations in plot_stations.items()
        for station in stations
        for suffix in [1, 2]
    ]

    pool.map(plot_metrics_wrapper, arguments_plot)

### One plot

In [11]:
def plot_metrics_one_wrapper(args_tuple: tuple[str, str]) -> None:
    date, station = args_tuple

    df = read_metrics_file(
        event=event_replace,
        date=date,
        station=station,
        window_size=WINDOW_SIZE,
        datetime_cols={"datetime": ""},
        suffix=f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else "",
    )

    if station in datetimes[date]["stations"] and datetimes[date]["stations"][station]:
        min_datetime, max_datetime = datetimes[date]["stations"][station]
    else:
        min_datetime, max_datetime = datetimes[date]["bounds"]

    display(df.head())

    plot_metrics_one(
        window_size=WINDOW_SIZE,
        relevant_metrics=None,
        df=df,
        event=event_replace,
        date=date,
        station=station,
        min_datetime=min_datetime,
        max_datetime=max_datetime,
        freq_date_range=datetimes[date]["freq"],
        rotation_xticks=60,
        save_format="pdf",
        show=False,
    )


with mp.Pool(processes=mp.cpu_count()) as pool:
    arguments_plot = [
        (date, station)
        for date, stations in plot_stations.items()
        for station in stations
    ]

    pool.map(plot_metrics_one_wrapper, arguments_plot)

## Prediction with derivates

In [27]:
class MetricsSummary(TypedDict):
    """
    Structure for metrics summary data.

    Attributes:
        event (str): The name of the event.
        date (str): The date of the event.
        station (str): The station name.
        metric (str): The name of the calculated metric.
        index (str): The calculated metrics for the station.
    """

    event: str
    date: str
    station: str
    metric: str
    index: str

In [18]:
plot_stations: dict[str, list[str]] = {
    date: list(
        set(
            map(
                # Get Station name from filename
                lambda filename: filename.name.strip().split("_", 1)[0].upper(),
                Path(f"./data/{event_replace}/{date}").glob("*.csv"),
            )
        )
    )
    for date in datetimes
}

In [77]:
def valid_interval(
    event: Events,
    date: str,
    station: str,
    data: pd.DataFrame = None,
) -> pd.DataFrame:
    if data is None:
        suffix = f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else ""
        data = read_metrics_file(
            event,
            date,
            station,
            WINDOW_SIZE,
            datetime_cols={"datetime": None},
            suffix=suffix,
        ).set_index("datetime")

    if station in datetimes[date]["stations"] and datetimes[date]["stations"][station]:
        max_datetime = datetimes[date]["stations"][station][1]
    else:
        max_datetime = datetimes[date]["bounds"][1]

    data = data[(data["window_shape"] == WINDOW_SIZE) & (data.index <= max_datetime)]
    return data


def process_derivatives(
    event: Events, date: str, station: str, percentil: int = 0.95
) -> list[MetricsSummary]:
    assert 0 < percentil < 1.0, "Percentil must be between 0.0 and 1.0"

    suffix = f"-ewm_alpha_{EWM_ALPHA}" if EWM_ALPHA and EWM else ""
    data = read_metrics_file(
        event,
        date,
        station,
        WINDOW_SIZE,
        datetime_cols={"datetime": None},
        suffix=suffix,
    ).set_index("datetime")

    metrics_columns = list(filter(lambda col: col in METRICS, data.columns))
    metrics_columns += ["value"]

    valid_indexes = valid_interval(event, date, station, data).index
    diff = data[metrics_columns].diff()
    interest_df = diff[diff.index.isin(valid_indexes)]
    quantiles = interest_df.quantile(percentil)

    results: list[MetricsSummary] = []
    for col in metrics_columns:
        quantil = quantiles[col]
        points = interest_df[interest_df[col] >= quantil][col]
        if len(points) < 0:
            continue

        interest_index = points.idxmax()  # Maybe this operation can be changed
        results.append(
            {
                "event": event,
                "date": date,
                "station": station,
                "metric": col,
                "index": str(interest_index),
            }
        )

    return results

In [83]:
percentil: float = 0.9

arguments: list[tuple[Events, str, str, int]] = list(
    map(
        lambda date, station: ("Forbush Decrease", date, station, percentil),
        *zip(
            *[
                (date, station)
                for date, stations in plot_stations.items()
                for station in stations
            ]
        ),
    )
)

with mp.Pool(processes=mp.cpu_count()) as pool:
    results = pool.starmap(
        process_derivatives,
        arguments,
    )

In [88]:
# Just to check => Resulting DataFrame should have 480 rows
len(results), len(results[0])

(32, 15)

In [91]:
df = pd.DataFrame(columns=["date", "station", "metric", "index"])
for res in results:
    df = pd.concat([df, pd.DataFrame(res)], ignore_index=True)

df["date"] = pd.to_datetime(df["date"])
df["index"] = pd.to_datetime(df["index"])

df.to_csv(f"./data/{event_replace}/summary_derivatives.csv", index=False)

df

Unnamed: 0,date,station,metric,index,event
0,2023-04-23,NEWK,entropy,2023-04-23 01:05:00,Forbush Decrease
1,2023-04-23,NEWK,sampen,2023-04-23 11:44:00,Forbush Decrease
2,2023-04-23,NEWK,permutation_entropy,2023-04-24 02:54:00,Forbush Decrease
3,2023-04-23,NEWK,shannon_entropy,2023-04-23 18:37:00,Forbush Decrease
4,2023-04-23,NEWK,spectral_entropy,2023-04-24 02:59:00,Forbush Decrease
...,...,...,...,...,...
475,2024-05-10,CALG,katz_fd,2024-05-10 02:35:00,Forbush Decrease
476,2024-05-10,CALG,petrosian_fd,2024-05-10 19:45:00,Forbush Decrease
477,2024-05-10,CALG,lepel_ziv,2024-05-10 02:59:00,Forbush Decrease
478,2024-05-10,CALG,corr_dim,2024-05-10 05:47:00,Forbush Decrease
