In [None]:
#!pip install "pandas<2.0.0"
#!pip install "pytorch-forecasting[mqf2]<1.0.0"
#!pip install numpy matplotlib pyarrow
#!pip install dtaidistance

In [None]:
import os
import numpy as np
import pandas as pd
import pytorch_lightning as pl

In [None]:
import warnings

warnings.filterwarnings('ignore')

## Setup

In [None]:
BASE_PATH = '/home/carl/projects/gwl_neu'

DATA_PATH = os.path.join(BASE_PATH, 'data')
MODEL_PATH = os.path.join(BASE_PATH, 'models')
RESULT_PATH = os.path.join(BASE_PATH, 'results')

LAG = 52  # weeks
LEAD = 8  # weeks
TRAIN_PERIOD = (pd.Timestamp(1990, 1, 1), pd.Timestamp(2012, 1, 1))
TEST_PERIOD = (pd.Timestamp(2012, 1, 1), pd.Timestamp(2016, 1, 1))

TIME_IDX = pd.date_range(TRAIN_PERIOD[0], TEST_PERIOD[1], freq='W-SUN', closed=None, name='time').to_frame().reset_index(drop=True)
TIME_IDX.index.name = 'time_idx'
TIME_IDX = TIME_IDX.reset_index()

## Data

### load data

In [None]:
static_df = pd.read_feather(os.path.join(DATA_PATH, 'static.feather'))
df = pd.read_feather(os.path.join(DATA_PATH, 'temporal.feather'))
df = df.merge(TIME_IDX, on='time', how='left')
df = df.merge(static_df.drop(columns=['y', 'x']), on='proj_id', how='left')

# encode day of the year as circular feature
df['day_sin'] = np.sin(2*np.pi / 365. * df['time'].dt.dayofyear).astype(np.float32)
df['day_cos'] = np.cos(2*np.pi / 365. * df['time'].dt.dayofyear).astype(np.float32)

df

In [None]:
test_samples = df[df['time'].between(*TEST_PERIOD)].groupby('proj_id').count()['time']
test_samples = test_samples[test_samples == test_samples.max()]
test_wells = static_df[static_df['proj_id'].isin(test_samples.index)].groupby('hyraum', group_keys=False).apply(lambda x: x.sample(frac=0.05, random_state=42))
test_wells

In [None]:
test_wells =  ['BB_28401185', 'BB_29400520' 'BB_31419861', 'BB_31464654', 'BB_31471808', 'BB_32392310', 
               'BB_33392320', 'BB_33442430', 'BB_33470881', 'BB_34426025', 'BB_34442481', 'BB_34522486', 
               'BB_36422930', 'BB_36441936', 'BB_36441951', 'BB_36441990', 'BB_38441747', 'BE_7214', 
               'BW_101-763-1', 'BW_103-112-9', 'BW_105-065-1', 'BW_105-116-3', 'BW_108-114-3', 'BW_111-509-7', 
               'BW_112-069-6', 'BW_113-115-8', 'BW_122-113-5', 'BW_126-762-3', 'BW_134-770-6', 'BW_137-113-3', 
               'BW_138-771-6', 'BW_146-114-6', 'BW_146-115-8', 'BW_149-020-7', 'BW_156-068-4', 'BW_159-066-7', 
               'BW_160-770-4', 'BW_161-771-0', 'BW_176-772-0', 'BW_177-770-1', 'BW_178-258-5', 'BW_2003-569-2', 
               'BW_227-020-3', 'BW_228-258-4', 'BW_263-259-5', 'BW_274-162-3', 'BW_5008-606-9', 'BY_11002', 
               'BY_13143', 'BY_16278', 'BY_17188', 'BY_2148', 'BY_3129', 'BY_6160', 'BY_9248', 'BY_9282', 
               'HE_10072', 'HE_11747', 'HE_12512', 'HE_12930', 'HE_16458', 'HE_6253', 'HE_6972', 'HE_8496', 
               'NI_100000489', 'NI_100000646', 'NI_200000876', 'NI_200000894', 'NI_40000501', 'NI_400060391', 
               'NI_400080061', 'NI_400081051', 'NI_500000058', 'NI_500000263', 'NI_500000367', 'NI_500000526', 
               'NI_500000594', 'NI_9610477', 'NI_9610849', 'NI_9610883', 'NI_9700010', 'NI_9700085', 'NI_9700178', 
               'NI_9700191', 'NI_9700192', 'NI_9700200', 'NI_9700203', 'NI_9700274', 'NI_9700291', 'NI_9850220', 
               'NI_9850831', 'NI_9852864', 'NW_100135020', 'NW_10203680', 'NW_110060090', 'NW_21180301', 
               'NW_40306021', 'NW_59130453', 'NW_60080280', 'NW_60090315', 'NW_60230113', 'NW_60240222', 
               'NW_60240325', 'NW_60240430', 'NW_70201018', 'NW_80000125', 'NW_80302695', 'NW_91133002', 
               'NW_91141709', 'NW_91168806', 'NW_91168909', 'RP_2375109100', 'RP_2379177700', 'RP_2393163500', 
               'SH_10L03003002', 'SH_10L51049015', 'SH_10L51120002', 'SH_10L54091003', 'SH_10L56031004', 'SH_10L57068003', 
               'SH_10L58026002', 'SH_10L58028005', 'SH_10L58123003', 'SH_10L59035004', 'SH_10L62020008', 'SN_45503444', 
               'SN_46410441', 'SN_46421125', 'SN_47440188', 'SN_47500596', 'SN_4840B5000', 'SN_48431031', 'SN_49486604', 
               'SN_50496167', 'SN_51410936', 'SN_51416002', 'ST_32360068', 'ST_33340002', 'ST_34320014', 'ST_34360055', 
               'ST_38385181', 'ST_41360080', 'ST_42320029', 'ST_42438270', 'ST_43360009', 'ST_43409272', 'ST_44330402', 
               'ST_44380030', 'TH_4531230790', 'TH_4729230702', 'TH_4731230724', 'TH_4734901150', 'TH_5034210608', 
               'TH_5227240535', 'TH_5429240534', 'TH_5430240547', 'TH_5633900114']
test_wells = static_df[static_df['proj_id'].isin(test_wells)]

### Reference Wells

1. filter for wells with complete time series data

In [None]:
full_ts = []
full_ts_idx = []
min_ts, max_ts, len_ts = df['time_idx'].min(), df['time_idx'].max(), len(DATE_RANGE)
for proj_id, group in df[(df['cluster'].notnull()) & (~df['proj_id'].isin(test_wells['proj_id']))].groupby('proj_id'):
    if group['time_idx'].min() == min_ts and group['time_idx'].max() == max_ts and len(group) == len_ts-1:
        full_ts.append(group.set_index('time_idx')['gwl'].values)
        full_ts_idx.append(proj_id)
len(full_ts)

2. cluster time series to 100 entities by K-Medoids

In [None]:
from dtaidistance import clustering, dtw
from dtaidistance.preprocessing import differencing


N_REF_WELLS = 100
series = differencing(np.stack(full_ts), smooth=0.1)
model = clustering.KMedoids(dtw.distance_matrix_fast, k=N_REF_WELLS, dists_options={"window": 52})
cluster_idx = model.fit(series)
ref_proj_ids = np.array(full_ts_idx)[list(cluster_idx.keys())]
np.sort(ref_proj_ids)

In [None]:
N_REF_WELLS = 100
ref_proj_ids = [
    'BB_33522338', 'BB_34402050', 'BB_34426100', 'BB_36441970',
    'BB_39431451', 'BW_100-517-0', 'BW_100-813-7', 'BW_101-713-8',
    'BW_101-812-0', 'BW_103-714-0', 'BW_103-763-0', 'BW_104-112-1',
    'BW_107-309-4', 'BW_107-517-2', 'BW_107-666-2', 'BW_109-812-6',
    'BW_110-116-6', 'BW_111-568-6', 'BW_111-813-7', 'BW_115-113-3',
    'BW_115-114-5', 'BW_116-721-2', 'BW_119-765-9', 'BW_119-771-0',
    'BW_119-813-3', 'BW_122-021-6', 'BW_125-257-2', 'BW_131-115-0',
    'BW_132-721-5', 'BW_135-064-6', 'BW_135-769-9', 'BW_139-119-9',
    'BW_145-771-8', 'BW_154-772-0', 'BW_156-770-6', 'BW_158-767-0',
    'BW_160-768-0', 'BW_164-772-6', 'BW_170-772-3', 'BW_172-772-2',
    'BW_177-772-5', 'BW_188-258-0', 'BW_193-769-2', 'BW_2010-813-1',
    'BW_4-812-8', 'BW_59-568-8', 'BY_11148', 'BY_83614', 'BY_9182',
    'HE_11738', 'HE_12447', 'HE_13622', 'HE_5754', 'HE_5798',
    'HE_6336', 'HE_6615', 'HE_7095', 'HE_7945', 'HE_8106', 'HE_8126',
    'HE_8999', 'HE_9534', 'HE_9595', 'HE_9620', 'HE_9692',
    'NI_100000467', 'NI_100000644', 'NI_100000670', 'NI_100000730',
    'NI_100000732', 'NI_100000914', 'NI_200000660', 'NI_200001410',
    'NI_40501911', 'NI_40502371', 'NI_40507101', 'NI_40507140',
    'NI_9700168', 'NI_9700201', 'NW_110320037', 'NW_129660334',
    'NW_59620286', 'NW_60100205', 'NW_70195213', 'NW_70195316',
    'NW_70276018', 'NW_80100247', 'NW_80301680', 'NW_91122405',
    'NW_91130104', 'NW_91141102', 'NW_91167309', 'NW_91173607',
    'SH_10L56010001', 'SN_49420761', 'SN_49430964', 'SN_53403678',
    'SN_54403689', 'ST_31380006', 'ST_41300022'
]

In [None]:
ref_df = df[df['proj_id'].isin(ref_proj_ids)]
ref_df = ref_df.set_index(['time', 'proj_id'])['gwl'].unstack()
ref_df.columns = [f'ref_well_{i}' for i in range(N_REF_WELLS)]
df = df.merge(ref_df, left_on='time', right_index=True, how='left')
df

### Cross Validation

spatio-temporal cross-validation

In [None]:
train_df = df[df['time'].between(*TRAIN_PERIOD)]
train_df = train_df[~train_df['proj_id'].isin(ref_proj_ids)]
train_df = train_df[~train_df['proj_id'].isin(test_wells['proj_id'])]
train_df

In [None]:
test_df = df[df['time'].between(*TEST_PERIOD)]
test_df = test_df[test_df['proj_id'].isin(test_wells['proj_id'])]
test_df

### Time Series Data Set

In [None]:
from pytorch_forecasting import TimeSeriesDataSet

STATIC_REALS = ["elevation", "gw_recharge", "percolation", "lat", "lon"]
STATIC_CATEGORICALS = ["land_cover", "rock_type", "geochemical_rock_type", "cavity_type", "permeability"]
TIME_VARYING_KNOWN_REALS = ['humidity', 'precipitation', 'temperature', 'lai', 'day_sin', 'day_cos']

train_ds = TimeSeriesDataSet(
    train_df,
    group_ids=["proj_id"],
    target="gwl",
    time_idx="time_idx",
    min_encoder_length=LAG,
    max_encoder_length=LAG,
    min_prediction_length=LEAD,
    max_prediction_length=LEAD,
    static_reals=STATIC_REALS,
    static_categoricals=STATIC_CATEGORICALS,
    time_varying_unknown_reals=ref_df.columns.to_list(),
    time_varying_known_reals=TIME_VARYING_KNOWN_REALS,
    add_target_scales=False,
    allow_missing_timesteps=True,
    categorical_encoders={
        "land_cover": NaNLabelEncoder(add_nan=True),
        "rock_type": NaNLabelEncoder(add_nan=True), 
        "geochemical_rock_type": NaNLabelEncoder(add_nan=True), 
        "cavity_type": NaNLabelEncoder(add_nan=True), 
        "permeability": NaNLabelEncoder(add_nan=True),
    },
)

train_ds.save(os.path.join(RESULT_PATH, 'preprocessing', 'train_tft_global_interpolation.pt'))

In [None]:
from pytorch_forecasting import TimeSeriesDataSet

train_ds = TimeSeriesDataSet.load(os.path.join(RESULT_PATH, 'preprocessing', 'train_tft_glob_cluster.pt'))

### Data Loader

In [None]:
train_dataloader = train_ds.to_dataloader(train=True, batch_size=2048, num_workers=2)

## Model

train a new model 

In [None]:
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer

model = TemporalFusionTransformer.from_dataset(train_ds)

trainer = pl.Trainer(
    max_epochs=3,
    accelerator='gpu', 
    devices=1,
    enable_model_summary=True,
)
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
)

or load an existing one

In [None]:
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer#


MODEL_NAME = 'tft_global_interpolation.ckpt'

model = TemporalFusionTransformer.load_from_checkpoint(os.path.join(MODEL_PATH, MODEL_NAME))

### Evaluation

#### predict test data

In [None]:
test_ds = TimeSeriesDataSet.from_dataset(train_ds, test_df)
test_dataloader = test_ds.to_dataloader(train=False, batch_size=2048, num_workers=2)
raw_predictions, index = model.predict(test_dataloader, mode="raw", return_index=True, show_progress_bar=True)
q_predictions = raw_predictions['prediction'].numpy()
np.save(os.path.join(RESULT_PATH, 'predictions', 'tft_global_interpolation_raw_predictions.npy'), q_predictions)
index.to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_global_interpolation_prediction_index.feather'))

In [None]:
q_predictions = np.load(os.path.join(RESULT_PATH, 'predictions', 'tft_global_interpolation_raw_predictions.npy'))
index = pd.read_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_global_interpolation_prediction_index.feather'))

In [None]:
from utils import predictions_to_df

predictions_df = predictions_to_df(index, np.transpose(q_predictions, (2, 1, 0))[3], ['proj_id'], TIME_IDX, LEAD)
predictions_df = predictions_df.reset_index().merge(test_df[['proj_id', 'time', 'gwl']], on=['proj_id', 'time'], how='left').set_index(['proj_id', 'time', 'horizon'])
for q_idx, q_name in [(0, '02'), (1, '10'), (2, '25'), (4, '75'), (5, '90'), (6, '98')]:
    q_df = predictions_to_df(index, np.transpose(q_predictions, (2, 1, 0))[q_idx], ['proj_id'], TIME_IDX, LEAD)
    q_df.rename(columns={'forecast': f'forecast_q{q_name}'}, inplace=True)
    predictions_df = predictions_df.merge(q_df, left_index=True, right_index=True)
predictions_df.reset_index().to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_global_interpolation_predictions.feather'))
predictions_df

or restore predictions

In [None]:
predictions_df = pd.read_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_global_interpolation_predictions.feather')).set_index(['proj_id', 'time', 'horizon'])
predictions_df

In [None]:
from utils import plot_predictions

plot_predictions(predictions_df, 'TH_5633900114', horizon='all')

### Metrics

In [None]:
from utils import get_metrics

metrics_df = get_metrics(predictions_df.dropna())
metrics_df.reset_index().to_feather(os.path.join(RESULT_PATH, 'metrics', 'tft_global_interpolation_metrics.feather'))
metrics_df

In [None]:
metrics_df = pd.read_feather(os.path.join(RESULT_PATH, 'metrics', 'tft_global_interpolation_metrics.feather')).set_index(['proj_id', 'horizon'])
metrics_df

### Error Analysis

In [None]:
def haversine(lon1, lat1, lon2, lat2):
    lon1, lat1, lon2, lat2 = np.radians([lon1, lat1, lon2, lat2])
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    haver_formula = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    r = 6371
    dist = 2 * r * np.arcsin(np.sqrt(haver_formula))
    return pd.Series(dist)

_static_df = static_df.set_index('proj_id').loc[ref_proj_ids, ['lat', 'lon']]

test_df_stats = []
for proj_id, row in test_wells.set_index('proj_id')[['lat', 'lon']].iterrows():
    _df = _static_df.copy()
    _df['t_lat'], _df['t_lon'] = row['lat'], row['lon']
    test_df_stats.append({
        'proj_id': proj_id,
        'min_ref_dist': haversine(_df['t_lon'], _df['t_lat'], _df['lon'], _df['lat']).min()
    })

test_df_stats = pd.DataFrame.from_records(test_df_stats)
test_df_stats.to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_global_interpolation_test_df_stats.feather'))
test_df_stats

### Interpretation

In [None]:
interpretation = model.interpret_output(raw_predictions, reduction="sum")

In [None]:
import json

variable_importance = {
    'static_variables': dict(list(zip(model.static_variables, (interpretation['static_variables'].numpy()/np.sum(interpretation['static_variables'].numpy())).tolist()))),
    'encoder_variables': dict(list(zip(model.encoder_variables, (interpretation['encoder_variables'].numpy()/np.sum(interpretation['encoder_variables'].numpy())).tolist()))),
    'decoder_variables': dict(list(zip(model.decoder_variables, (interpretation['decoder_variables'].numpy()/np.sum(interpretation['decoder_variables'].numpy())).tolist())))
        
}
with open(os.path.join(RESULTS_PATH, 'interpreation', 'tft_global_interpolation_variable_importance.json'), 'w') as f:
    json.dump(variable_importance, f)