In [None]:
#!pip install "pandas<2.0.0"
#!pip install "pytorch-forecasting[mqf2]<1.0.0"
#!pip install numpy matplotlib pyarrow

In [None]:
import os
import numpy as np
import pandas as pd
import pytorch_lightning as pl

In [None]:
import warnings

warnings.filterwarnings('ignore')

## Setup

In [None]:
BASE_PATH = '/home/carl/projects/gwl_neu'

DATA_PATH = os.path.join(BASE_PATH, 'data')
MODEL_PATH = os.path.join(BASE_PATH, 'models')
RESULT_PATH = os.path.join(BASE_PATH, 'results')

LAG = 52  # weeks
LEAD = 8  # weeks
TRAIN_PERIOD = (pd.Timestamp(1990, 1, 1), pd.Timestamp(2012, 1, 1))
TEST_PERIOD = (pd.Timestamp(2012, 1, 1), pd.Timestamp(2016, 1, 1))

TIME_IDX = pd.date_range(TRAIN_PERIOD[0], TEST_PERIOD[1], freq='W-SUN', closed=None, name='time').to_frame().reset_index(drop=True)
TIME_IDX.index.name = 'time_idx'
TIME_IDX = TIME_IDX.reset_index()

## Data

### load data

In [None]:
static_df = pd.read_feather(os.path.join(DATA_PATH, 'static.feather'))
df = pd.read_feather(os.path.join(DATA_PATH, 'temporal.feather'))
df = df.merge(TIME_IDX, on='time', how='left')
df = df.merge(static_df.drop(columns=['y', 'x']), on='proj_id', how='left')

# encode day of the year as circular feature
df['day_sin'] = np.sin(2*np.pi / 365. * df['time'].dt.dayofyear).astype(np.float32)
df['day_cos'] = np.cos(2*np.pi / 365. * df['time'].dt.dayofyear).astype(np.float32)

df

### Cross Validation

In [None]:
train_df = df[df['time'].between(*TRAIN_PERIOD)]
train_df

In [None]:
test_df = df[df['time'].between(*TEST_PERIOD)]
test_df = test_df[test_df['proj_id'].isin(test_df['proj_id'].value_counts()[lambda x: x>=104].index.tolist())]
test_df

In [None]:
short_wells = test_df['proj_id'].value_counts()[lambda x: x<104].index.tolist()
short_wells

### Time Series Data Set

In [None]:
from pytorch_forecasting import TimeSeriesDataSet

train_ds = TimeSeriesDataSet(
    train_df,
    group_ids=["proj_id"],
    target="gwl",
    time_idx="time_idx",
    min_encoder_length=LAG,
    max_encoder_length=LAG,
    min_prediction_length=LEAD,
    max_prediction_length=LEAD,
    static_reals=["elevation", "gw_recharge", "percolation", "lat", "lon"],
    static_categoricals=["land_cover", "rock_type", "geochemical_rock_type", "cavity_type", "permeability"],
    time_varying_unknown_reals=['gwl'],
    time_varying_known_reals=['humidity', 'precipitation', 'temperature', 'lai', 'day_sin', 'day_cos'],
    add_target_scales=True,
    allow_missing_timesteps=True,
    categorical_encoders={
        "land_cover": NaNLabelEncoder(add_nan=True),
        "rock_type": NaNLabelEncoder(add_nan=True), 
        "geochemical_rock_type": NaNLabelEncoder(add_nan=True), 
        "cavity_type": NaNLabelEncoder(add_nan=True), 
        "permeability": NaNLabelEncoder(add_nan=True),
    },
)

train_ds.save(os.path.join(RESULT_PATH, 'preprocessing', 'train_ds_nhits.pt'))

In [None]:
from pytorch_forecasting import TimeSeriesDataSet

train_ds = TimeSeriesDataSet.load(os.path.join(RESULT_PATH, 'preprocessing', 'train_ds_nhits.pt'))

### Data Loader

In [None]:
train_dataloader = train_ds.to_dataloader(train=True, batch_size=4096, num_workers=8)

## Model

In [None]:
from pytorch_forecasting.models.nhits import NHiTS
from pytorch_forecasting.metrics.distributions import MQF2DistributionLoss

model = NHiTS.from_dataset(
    train_ds,
    loss=MQF2DistributionLoss(prediction_length=LEAD),
)

trainer = pl.Trainer(
    max_epochs=10,
    accelerator='gpu', 
    devices=1,
    enable_model_summary=True,
)
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
)

In [None]:
from pytorch_forecasting.models.nhits import NHiTS
from pytorch_forecasting.metrics.distributions import MQF2DistributionLoss

MODEL_NAME = 'nhits.ckpt'

model = NHiTS.load_from_checkpoint(os.path.join(MODEL_PATH, MODEL_NAME))

### Evaluation

#### predict test data

In [None]:
test_ds = TimeSeriesDataSet.from_dataset(train_ds, test_df)
test_dataloader = test_ds.to_dataloader(train=False, batch_size=4096, num_workers=2)
raw_predictions, index = model.predict(test_dataloader, mode="quantiles",return_index=True, show_progress_bar=True)
q_predictions = raw_predictions.numpy()
np.save(os.path.join(RESULT_PATH, 'predictions', 'nhits_raw_predictions.npy'), q_predictions)
index.to_feather(os.path.join(RESULT_PATH, 'predictions', 'nhits_prediction_index.feather'))

In [None]:
q_predictions = np.load(os.path.join(RESULT_PATH, 'predictions', 'nhits_raw_predictions.npy'))
index = pd.read_feather(os.path.join(RESULT_PATH, 'predictions', 'nhits_prediction_index.feather'))

In [None]:
from utils import predictions_to_df

predictions_df = predictions_to_df(index, np.transpose(q_predictions, (2, 1, 0))[3], ['proj_id'], TIME_IDX, LEAD)
for q_idx, q_name in [(0, '02'), (1, '10'), (2, '25'), (4, '75'), (5, '90'), (6, '98')]:
    q_df = predictions_to_df(index, np.transpose(q_predictions, (2, 1, 0))[q_idx], ['proj_id'], TIME_IDX, LEAD)
    predictions_df[f'forecast_q{q_name}'] = q_df['forecast'].values
predictions_df = predictions_df.reset_index().merge(test_df[['proj_id', 'time', 'gwl']], on=['proj_id', 'time'], how='left').set_index(['proj_id', 'time', 'horizon'])
predictions_df.reset_index().to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_predictions.feather'))
predictions_df

or load predictions

In [None]:
predictions_df = pd.read_feather(os.path.join(RESULT_PATH, 'predictions', 'nhits_predictions.feather')).set_index(['proj_id', 'time', 'horizon'])
predictions_df

In [None]:
from utils import plot_predictions

plot_predictions(predictions_df, 'BB_26471092', horizon=8, confidence=('forecast_q10', 'forecast_q90'))

### Metrics

In [None]:
from utils import get_metrics

metrics_df = get_metrics(predictions_df.dropna())
metrics_df.reset_index().to_feather(os.path.join(RESULT_PATH, 'metrics', 'nhits_metrics.feather'))
metrics_df

In [None]:
metrics_df = pd.read_feather(os.path.join(RESULT_PATH, 'metrics', 'nhits_metrics.feather')).set_index(['proj_id', 'horizon'])
metrics_df