In [None]:
#!pip install "pandas<2.0.0"
#!pip install "pytorch-forecasting[mqf2]<1.0.0"
#!pip install numpy matplotlib pyarrow

In [None]:
import os
import numpy as np
import pandas as pd
import pytorch_lightning as pl

In [None]:
import warnings

warnings.filterwarnings('ignore')

## Setup

In [None]:
BASE_PATH = '/home/carl/projects/gwl_neu'

DATA_PATH = os.path.join(BASE_PATH, 'data')
MODEL_PATH = os.path.join(BASE_PATH, 'models')
RESULT_PATH = os.path.join(BASE_PATH, 'results')

LAG = 52  # weeks
LEAD = 8  # weeks
TRAIN_PERIOD = (pd.Timestamp(1990, 1, 1), pd.Timestamp(2012, 1, 1))
TEST_PERIOD = (pd.Timestamp(2012, 1, 1), pd.Timestamp(2016, 1, 1))

TIME_IDX = pd.date_range(TRAIN_PERIOD[0], TEST_PERIOD[1], freq='W-SUN', closed=None, name='time').to_frame().reset_index(drop=True)
TIME_IDX.index.name = 'time_idx'
TIME_IDX = TIME_IDX.reset_index()

## Data

### load data

In [None]:
static_df = pd.read_feather(os.path.join(DATA_PATH, 'static.feather'))
df = pd.read_feather(os.path.join(DATA_PATH, 'temporal.feather'))
df = df.merge(TIME_IDX, on='time', how='left')
df = df.merge(static_df.drop(columns=['y', 'x']), on='proj_id', how='left')

# encode day of the year as circular feature
df['day_sin'] = np.sin(2*np.pi / 365. * df['time'].dt.dayofyear).astype(np.float32)
df['day_cos'] = np.cos(2*np.pi / 365. * df['time'].dt.dayofyear).astype(np.float32)

df

In [None]:
N_NEIGHBORS = 6
MAX_DIST = 120.

def merge_random_neighbor_wells(df, n_neighbors, max_dist, n_samplings=1):
    _df = df[['proj_id', 'time_idx']].copy()
    _df['diff'] = _df.groupby('proj_id')['time_idx'].apply(lambda s: s - pd.Series(s).shift(1))
    _df['diff'] = ~(_df['diff'].fillna(1.) == 1.)
    _df['cumsum'] = _df.groupby('proj_id')['diff'].cumsum()
    periods = _df.groupby(['proj_id', 'cumsum'])['time_idx'].agg(['min', 'max'])
    periods = periods.reset_index().merge(static_df[['proj_id', 'y', 'x']], how='left', on='proj_id').set_index(['proj_id', 'cumsum'])
    
    _dfs = []
    for sampling in range(n_samplings):
        matches = []
        for idx, (min_ts, max_ts, y, x) in periods.iterrows():
            matching_periods = periods[
                (periods['min'] <= min_ts) & 
                (periods['max'] >= max_ts) & 
                (periods.index.get_level_values('proj_id') != idx[0]) & 
                (np.sqrt((periods['x'] - x)**2  + (periods['y'] - y)**2) < max_dist)
            ]
            sample_weights = -np.sqrt((matching_periods['x'] - x)**2  + (matching_periods['y'] - y)**2)
            sample_weights -= sample_weights.min() - 0.001
            sample = matching_periods.sample(n_neighbors, weights=sample_weights, replace=False)
            record = {'proj_id': idx[0], 'cumsum': idx[1], 'min_ts': min_ts, 'max_ts': max_ts, 'sampling': sampling+1}
            for i, proj_id in enumerate(sample.index.get_level_values('proj_id')):
                record[f'neighbor_{i}'] = proj_id
            matches.append(record)
        matches_df = pd.DataFrame.from_records(matches)

        dfs = []
        for idx, row in matches_df.iterrows():
            _df = pd.DataFrame({'time_idx': range(row['min_ts'], row['max_ts']+1)})
            _df['proj_id'] = row['proj_id']
            _df['sampling'] = row['sampling']
            for i in range(N_NEIGHBORS):
                _df[f'neighbor_{i}'] = row[f'neighbor_{i}']
            dfs.append(_df)
        expanded = pd.concat(dfs)

        __df = df.merge(expanded, on=['proj_id', 'time_idx'], how='left')

        for i in range(n_neighbors):
            __df = __df.merge(
                df[
                    ['gwl', 'humidity', 'temperature', 'precipitation', 'lai', 'land_cover', 'rock_type', 'geochemical_rock_type',
                     'cavity_type', 'permeability', 'elevation', 'gw_recharge', 'percolation', 
                     'lat', 'lon', 'time_idx', 'proj_id']
                ], 
                left_on=['time_idx', f'neighbor_{i}'], 
                right_on=['time_idx', 'proj_id'], 
                how='left', 
                suffixes=('', f'_n{i}')
            )
        _dfs.append(__df)    
    return pd.concat(_dfs, axis=0).reset_index(drop=True)

### Cross Validation

spatio-temporal cross-validation

In [None]:
train_df = df[df['time'].between(*TRAIN_PERIOD)]
train_df = train_df[~train_df['proj_id'].isin(test_wells['proj_id'])].reset_index(drop=True)
train_df = merge_random_neighbor_wells(train_df, N_NEIGHBORS, MAX_DIST, n_samplings=2)
train_df

In [None]:
test_samples = df[df['time'].between(*TEST_PERIOD)].groupby('proj_id').count()['time']
test_samples = test_samples[test_samples == test_samples.max()]
test_wells = static_df[static_df['proj_id']isin(test_samples.index)].groupby('hyraum', group_keys=False).apply(lambda x: x.sample(frac=0.05, random_state=42))

In [None]:
test_wells =  ['BB_28401185', 'BB_29400520' 'BB_31419861', 'BB_31464654', 'BB_31471808', 'BB_32392310', 
               'BB_33392320', 'BB_33442430', 'BB_33470881', 'BB_34426025', 'BB_34442481', 'BB_34522486', 
               'BB_36422930', 'BB_36441936', 'BB_36441951', 'BB_36441990', 'BB_38441747', 'BE_7214', 
               'BW_101-763-1', 'BW_103-112-9', 'BW_105-065-1', 'BW_105-116-3', 'BW_108-114-3', 'BW_111-509-7', 
               'BW_112-069-6', 'BW_113-115-8', 'BW_122-113-5', 'BW_126-762-3', 'BW_134-770-6', 'BW_137-113-3', 
               'BW_138-771-6', 'BW_146-114-6', 'BW_146-115-8', 'BW_149-020-7', 'BW_156-068-4', 'BW_159-066-7', 
               'BW_160-770-4', 'BW_161-771-0', 'BW_176-772-0', 'BW_177-770-1', 'BW_178-258-5', 'BW_2003-569-2', 
               'BW_227-020-3', 'BW_228-258-4', 'BW_263-259-5', 'BW_274-162-3', 'BW_5008-606-9', 'BY_11002', 
               'BY_13143', 'BY_16278', 'BY_17188', 'BY_2148', 'BY_3129', 'BY_6160', 'BY_9248', 'BY_9282', 
               'HE_10072', 'HE_11747', 'HE_12512', 'HE_12930', 'HE_16458', 'HE_6253', 'HE_6972', 'HE_8496', 
               'NI_100000489', 'NI_100000646', 'NI_200000876', 'NI_200000894', 'NI_40000501', 'NI_400060391', 
               'NI_400080061', 'NI_400081051', 'NI_500000058', 'NI_500000263', 'NI_500000367', 'NI_500000526', 
               'NI_500000594', 'NI_9610477', 'NI_9610849', 'NI_9610883', 'NI_9700010', 'NI_9700085', 'NI_9700178', 
               'NI_9700191', 'NI_9700192', 'NI_9700200', 'NI_9700203', 'NI_9700274', 'NI_9700291', 'NI_9850220', 
               'NI_9850831', 'NI_9852864', 'NW_100135020', 'NW_10203680', 'NW_110060090', 'NW_21180301', 
               'NW_40306021', 'NW_59130453', 'NW_60080280', 'NW_60090315', 'NW_60230113', 'NW_60240222', 
               'NW_60240325', 'NW_60240430', 'NW_70201018', 'NW_80000125', 'NW_80302695', 'NW_91133002', 
               'NW_91141709', 'NW_91168806', 'NW_91168909', 'RP_2375109100', 'RP_2379177700', 'RP_2393163500', 
               'SH_10L03003002', 'SH_10L51049015', 'SH_10L51120002', 'SH_10L54091003', 'SH_10L56031004', 'SH_10L57068003', 
               'SH_10L58026002', 'SH_10L58028005', 'SH_10L58123003', 'SH_10L59035004', 'SH_10L62020008', 'SN_45503444', 
               'SN_46410441', 'SN_46421125', 'SN_47440188', 'SN_47500596', 'SN_4840B5000', 'SN_48431031', 'SN_49486604', 
               'SN_50496167', 'SN_51410936', 'SN_51416002', 'ST_32360068', 'ST_33340002', 'ST_34320014', 'ST_34360055', 
               'ST_38385181', 'ST_41360080', 'ST_42320029', 'ST_42438270', 'ST_43360009', 'ST_43409272', 'ST_44330402', 
               'ST_44380030', 'TH_4531230790', 'TH_4729230702', 'TH_4731230724', 'TH_4734901150', 'TH_5034210608', 
               'TH_5227240535', 'TH_5429240534', 'TH_5430240547', 'TH_5633900114']
test_wells = static_df[static_df['proj_id'].isin(test_wells)]

In [None]:
test_df = df[df['time'].between(*TEST_PERIOD)]
test_df = merge_random_neighbor_wells(test_df, N_NEIGHBORS, MAX_DIST, n_samplings=9)
test_df = test_df[test_df['proj_id'].isin(test_wells['proj_id'])].reset_index(drop=True)
test_df.to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_test_set.feather'))
test_df

In [None]:
test_df = pd.read_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_test_set.feather'))
test_df

### Time Series Data Set

In [None]:
from pytorch_forecasting import TimeSeriesDataSet

STATIC_REALS = ["elevation", "gw_recharge", "percolation", "lat", "lon"]
STATIC_CATEGORICALS = ["land_cover", "rock_type", "geochemical_rock_type", "cavity_type", "permeability"]
TIME_VARYING_KNOWN_REALS = ['humidity', 'precipitation', 'temperature', 'lai', 'day_sin', 'day_cos']

train_ds = TimeSeriesDataSet(
    train_df,
    group_ids=["proj_id", "sampling"],
    target="gwl",
    time_idx="time_idx",
    min_encoder_length=LAG,
    max_encoder_length=LAG,
    min_prediction_length=LEAD,
    max_prediction_length=LEAD,
    static_reals=[f'{var}_n{i}' for i in range(N_NEIGHBORS) for var in STATIC_REALS] + STATIC_REALS,
    static_categoricals=['g_land_cover', 'g_rock_type', 'g_geochemical_rock_type', 'g_cavity_type', 'g_permeability'],
    time_varying_unknown_reals=[f'{var}_n{i}' for i in range(N_NEIGHBORS) for var in ['humidity', 'precipitation', 'temperature', 'lai', 'gwl']],
    time_varying_known_reals=TIME_VARYING_KNOWN_REALS,
    add_target_scales=False,
    allow_missing_timesteps=True,
    variable_groups={
        'g_land_cover': ['land_cover'] + [f'land_cover_n{i}' for i in range(N_NEIGHBORS)],
        'g_rock_type': ['rock_type'] + [f'rock_type_n{i}' for i in range(N_NEIGHBORS)],
        'g_geochemical_rock_type': ['geochemical_rock_type'] + [f'geochemical_rock_type_n{i}' for i in range(N_NEIGHBORS)],
        'g_cavity_type': ['cavity_type'] + [f'cavity_type_n{i}' for i in range(N_NEIGHBORS)],
        'g_permeability': ['permeability'] + [f'permeability_n{i}' for i in range(N_NEIGHBORS)],       
    }
)

train_ds.save(os.path.join(RESULT_PATH, 'preprocessing', 'train_tft_local_interpolation.pt'))

In [None]:
from pytorch_forecasting import TimeSeriesDataSet

train_ds = TimeSeriesDataSet.load(os.path.join(RESULT_PATH, 'preprocessing', 'train_tft_local_interpolation.pt')

### Data Loader

In [None]:
train_dataloader = train_ds.to_dataloader(train=True, batch_size=2048, num_workers=2)

## Model

train a new model 

In [None]:
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer

model = TemporalFusionTransformer.from_dataset(train_ds)

trainer = pl.Trainer(
    max_epochs=3,
    accelerator='gpu', 
    devices=1,
    enable_model_summary=True,
)
trainer.fit(
    model,
    train_dataloaders=train_dataloader,
)

or load an existing one

In [None]:
from pytorch_forecasting.models.temporal_fusion_transformer import TemporalFusionTransformer


MODEL_NAME = 'tft_local_interpolation.ckpt'

model = TemporalFusionTransformer.load_from_checkpoint(os.path.join(MODEL_PATH, MODEL_NAME))

### Evaluation

#### predict test data

In [None]:
test_ds = TimeSeriesDataSet.from_dataset(train_ds, test_df)
test_dataloader = test_ds.to_dataloader(train=False, batch_size=2048, num_workers=2)
raw_predictions, index = model.predict(test_dataloader, mode="raw", return_index=True, show_progress_bar=True)
q_predictions = raw_predictions['prediction'].numpy()
np.save(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_raw_predictions.npy'), q_predictions)
index.to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_prediction_index.feather'))

In [None]:
q_predictions = np.load(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_raw_predictions.npy'))
index = pd.read_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_prediction_index.feather'))

In [None]:
from utils import predictions_to_df


predictions_df = predictions_to_df(index, np.transpose(q_predictions, (2, 1, 0))[3], ['proj_id', 'sampling'], TIME_IDX, LEAD)
predictions_df = predictions_df.groupby(axis=0, level=[0, 2, 3]).mean()
predictions_df = predictions_df.reset_index().merge(test_df.loc[test_df['sampling'] == 1, ['proj_id', 'time', 'gwl']], on=['proj_id', 'time'], how='left').set_index(['proj_id', 'time', 'horizon'])
for q_idx, q_name in [(0, '02'), (1, '10'), (2, '25'), (4, '75'), (5, '90'), (6, '98')]:
    q_df = predictions_to_df(index, np.transpose(q_predictions, (2, 1, 0))[q_idx], ['proj_id', 'sampling'], TIME_IDX, LEAD)
    q_df = q_df.groupby(axis=0, level=[0, 2, 3]).mean()
    q_df.rename(columns={'forecast': f'forecast_q{q_name}'}, inplace=True)
    predictions_df = predictions_df.merge(q_df, left_index=True, right_index=True)
predictions_df.reset_index().to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_predictions.feather'))
predictions_df

or load predictions

In [None]:
predictions_df = pd.read_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_predictions.feather')).set_index(['proj_id', 'time', 'horizon'])
predictions_df

In [None]:
from utils import plot_predictions

plot_predictions(predictions_df, 'BB_28401185', horizon='all')

### Metrics

In [None]:
from utils import get_metrics

metrics_df = get_metrics(predictions_df.dropna())
metrics_df.reset_index().to_feather(os.path.join(RESULT_PATH, 'metrics', 'tft_local_interpolation_metrics.feather'))
metrics_df

In [None]:
metrics_df = pd.read_feather(os.path.join(RESULT_PATH, 'metrics', 'tft_local_interpolation_metrics.feather')).set_index(['proj_id', 'horizon'])
metrics_df

### Error Analysis

In [None]:
N_NEIGHBORS = 6

def haversine(lon1, lat1, lon2, lat2):
    lon1, lat1, lon2, lat2 = np.radians([lon1, lat1, lon2, lat2])
    dlon = lon2 - lon1
    dlat = lat2 - lat1
    haver_formula = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
    r = 6371
    dist = 2 * r * np.arcsin(np.sqrt(haver_formula))
    return pd.Series(dist)

_static_df = static_df.set_index('proj_id')

test_df_stats = []
for proj_id, group in test_df.groupby('proj_id'):
    distances = []
    for i in range(N_NEIGHBORS):
        distances.append(haversine(group['lon'], group['lat'], group[f'lon_n{i}'], group[f'lat_n{i}']).values)
    mean_dist = np.mean(np.concatenate(distances))
    group_hyraums = _static_df.loc[group[[f'neighbor_{n}' for n in range(N_NEIGHBORS)]].values.flatten(), 'hyraum'].values
    proj_hyraum = _static_df.loc[proj_id, 'hyraum']
    hyraum_homogenity = np.sum(group_hyraums == proj_hyraum).sum() / len(group_hyraums)
    test_df_stats.append({
        'proj_id': proj_id,
        'mean_neighbor_dist': mean_dist,
        'hyraum_homogenity': hyraum_homogenity,
    })
test_df_stats = pd.DataFrame.from_records(test_df_stats)
test_df_stats.to_feather(os.path.join(RESULT_PATH, 'predictions', 'tft_local_interpolation_test_df_stats.feather'))
test_df_stats

### Model Interpretation

In [None]:
interpretation = model.interpret_output(raw_predictions, reduction="sum")

In [None]:
import json

variable_importance = {
    'static_variables': dict(list(zip(model.static_variables, (interpretation['static_variables'].numpy()/np.sum(interpretation['static_variables'].numpy())).tolist()))),
    'encoder_variables': dict(list(zip(model.encoder_variables, (interpretation['encoder_variables'].numpy()/np.sum(interpretation['encoder_variables'].numpy())).tolist()))),
    'decoder_variables': dict(list(zip(model.decoder_variables, (interpretation['decoder_variables'].numpy()/np.sum(interpretation['decoder_variables'].numpy())).tolist())))
        
}
with open(os.path.join(RESULTS_PATH, 'interpreation', 'tft_local_interpolation_variable_importance.json'), 'w') as f:
    json.dump(variable_importance, f)