In [1]:
import os
import glob
import pandas as pd
import numpy as np
from typing import Literal
from obspy import Trace
from collections.abc import Callable
from datetime import datetime, timedelta
from scipy.stats import norm
from scipy import stats
from eruption_forecast.tremor import shanon_entropy
from eruption_forecast.utils.array import chunk_daily_data, remove_maximum_outlier, remove_outliers
from eruption_forecast.plots.tremor_plots import plot_tremor
from eruption_forecast.sources import SDS

ModuleNotFoundError: No module named 'hypothesis'

In [None]:
start_date = "2025-01-05"
end_date = "2025-01-07"
dates = pd.date_range(start=start_date, end=end_date, freq="D")

In [None]:
sds = SDS(
    sds_dir=r"D:\Data\OJN",
    station= "OJN",
    channel="EHZ",
    interpolate=False,
    verbose=True,
)

In [None]:
def calculate_window_metrics(
    trace: Trace,
    window_duration_minutes: int = 10,
    metric_function: Callable[[np.ndarray], float] = np.mean,
    remove_outlier_method: Literal["maximum", "all"] | None = None,
    mask_zero_value: bool = False,
    minimum_completion_ratio: float = 0.3,
    absolute_value: bool = False,
    value_multiplier: float = 1.0,
):
    start_datetime = trace.stats.starttime.datetime
    sampling_rate = trace.stats.sampling_rate

    data = trace.data
    if absolute_value:
        data = np.abs(data)

    chunking_data = chunk_daily_data(
        data=data,
        window_min=window_duration_minutes,
        sampling_rate=sampling_rate,
        mask_zero_value=mask_zero_value
    )

    indices = []
    data_points = []
    for index, window_data in enumerate(chunking_data):
        # Initialize metric_value to np.nan
        metric_value = np.nan

        minimum_sample_acquired = True
        if isinstance(window_data, np.ma.MaskedArray):
            valid_samples = window_data.count()
            valid_samples_ratio = valid_samples / len(window_data)
            minimum_sample_acquired = valid_samples_ratio >= minimum_completion_ratio

        if not minimum_sample_acquired:
            metric_value = np.nan

        elif len(window_data) == 1:
            metric_value = window_data[0]

        elif remove_outlier_method is None:
            metric_value = metric_function(window_data)

        elif remove_outlier_method:
            window_data = (
                remove_maximum_outlier(window_data)
                if remove_outlier_method == "maximum"
                else remove_outliers(window_data)
            )

            # Re-check length after outlier removal just in case,
            # though remove_maximum_outlier mostly removes one
            if len(window_data) > 0:
                metric_value = metric_function(window_data)
                if value_multiplier != 1.0 and not np.isnan(metric_value):
                    metric_value *= value_multiplier

        indices.append(start_datetime + timedelta(minutes=index * window_duration_minutes))
        data_points.append(metric_value)

    return pd.Series(data=data_points, index=indices, name="entropy", dtype=float)

In [None]:
def shanon_entropy_legacy(data: np.ndarray) -> float:
    energy = np.sum(np.square(data))

    if energy < 1.:
        return np.nan

    y = norm.pdf(data, loc=np.mean(data), scale=np.std(data))

    y_masked = np.ma.MaskedArray(y, (y == 0))
    """Handling zero values."""

    y = y_masked.filled(np.nan)
    """Return all masked values as np.nan"""

    entropy = -1 * np.sum(y * np.log2(y))

    return entropy

In [None]:
results = []
def main():
    for date in dates:
        trace: Trace = sds.get_trace(date)
        if trace is None:
            continue

        trace: Trace = trace.filter(
            "bandpass", freqmin=8., freqmax=16., corners=4
        )

        result = calculate_window_metrics(
            trace=trace,
            metric_function=shanon_entropy_legacy,
            remove_outlier_method="maximum",
            absolute_value=True,
            mask_zero_value=True
        )

        result.to_csv(os.path.join(os.getcwd(), "output", "entropy", f"{date.strftime('%Y-%m-%d')}.csv"))

        results.append(result)

In [None]:
main()

In [None]:
files = glob.glob(os.path.join(os.getcwd(), "output", "entropy", "*.csv"))

In [None]:
dfs = [pd.read_csv(file, index_col=0, parse_dates=True) for file in files]

In [None]:
concat_df = pd.concat(dfs, ignore_index=False, sort=True)
concat_df.rename(columns={"entropy":"shannon_entropy"}, inplace=True)
concat_df.to_csv(os.path.join(os.getcwd(), "output", "entropy.csv"))

In [None]:
concat_df.head()