# TFT Hydrological Forecasting - Exploratory Analysis

This notebook provides an example of how to use the TFT predictions framework for hydrological forecasting.

In [None]:
# Import required libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path

# Import our framework
from src.config.settings import Settings
from src.data.loaders import prepare_model_data, split_time_series
from src.models.tft_model import TFTModelWrapper
from src.evaluation.metrics import create_metrics_dataframe
from src.utils.helpers import set_random_seed

## Configuration Setup

First, let's set up our configuration for the TFT model:

In [None]:
# Load configuration
config_path = Path("configs/single_gauge.yaml")

if config_path.exists():
    settings = Settings.from_yaml(config_path)
    print(f"Loaded configuration from {config_path}")
else:
    settings = Settings()
    print("Using default configuration")

# Update static parameters based on target
settings.update_static_parameters()

# Set random seed for reproducibility
set_random_seed(settings.model.random_state)

print(f"Target variable: {settings.data.hydro_target}")
print(f"Input meteorological variables: {settings.data.meteo_input}")
print(f"Model configuration: {settings.model.input_chunk_length} -> {settings.model.output_chunk_length}")

## Data Loading and Exploration

Let's load some sample data and explore its characteristics:

In [None]:
# Example gauge IDs (replace with your actual gauge IDs)
sample_gauge_ids = ["gauge_001", "gauge_002"]  

try:
    # Load data for sample gauges
    target_series, covariate_series = prepare_model_data(
        settings, 
        area_filter=sample_gauge_ids
    )
    
    print(f"Loaded data for {len(target_series) if isinstance(target_series, list) else 1} gauge(s)")
    
    # Display basic information about the first series
    first_series = target_series[0] if isinstance(target_series, list) else target_series
    print(f"\nFirst series:")
    print(f"Length: {len(first_series)} time steps")
    print(f"Date range: {first_series.start_time()} to {first_series.end_time()}")
    print(f"Variables: {first_series.columns.tolist()}")
    
    if first_series.static_covariates is not None:
        print(f"Static covariates: {first_series.static_covariates.columns.tolist()}")
    
except Exception as e:
    print(f"Error loading data: {e}")
    print("This is expected if you don't have the actual data files.")
    print("The framework is ready to use once you provide the correct data paths.")
    
    # Create synthetic data for demonstration
    print("\nCreating synthetic data for demonstration...")
    
    dates = pd.date_range('2020-01-01', periods=365, freq='D')
    
    # Synthetic discharge data with seasonal pattern
    discharge = 10 + 5 * np.sin(2 * np.pi * np.arange(365) / 365) + np.random.normal(0, 1, 365)
    
    # Synthetic meteorological data
    precipitation = np.maximum(0, np.random.exponential(2, 365))
    temp_max = 15 + 10 * np.sin(2 * np.pi * np.arange(365) / 365) + np.random.normal(0, 2, 365)
    temp_min = temp_max - 5 - np.random.exponential(2, 365)
    
    # Create synthetic TimeSeries
    from darts import TimeSeries
    
    target_series = TimeSeries.from_times_and_values(
        times=dates,
        values=discharge,
        columns=["discharge"]
    )
    
    covariate_series = TimeSeries.from_times_and_values(
        times=dates,
        values=np.column_stack([precipitation, temp_max, temp_min]),
        columns=["precipitation", "temp_max", "temp_min"]
    )
    
    print(f"Created synthetic data with {len(target_series)} time steps")

## Data Visualization

Let's visualize the time series data:

In [None]:
# Plot time series data
fig, axes = plt.subplots(2, 1, figsize=(12, 8))

# Plot target variable
target_series.plot(ax=axes[0])
axes[0].set_title(f"Target Variable: {settings.data.hydro_target}")
axes[0].set_ylabel("Discharge")

# Plot meteorological variables
if covariate_series is not None:
    covariate_series.plot(ax=axes[1])
    axes[1].set_title("Meteorological Variables")
    axes[1].set_ylabel("Values")

plt.tight_layout()
plt.show()

# Display basic statistics
print("\nTarget Series Statistics:")
target_df = target_series.pd_dataframe()
print(target_df.describe())

if covariate_series is not None:
    print("\nCovariate Series Statistics:")
    cov_df = covariate_series.pd_dataframe()
    print(cov_df.describe())

## Data Splitting

Split the data into train, validation, and test sets:

In [None]:
# Split the data
train_target, val_target, test_target = split_time_series(
    target_series,
    train_split=settings.training.train_split,
    val_split=settings.training.val_split
)

train_cov, val_cov, test_cov = None, None, None
if covariate_series is not None:
    train_cov, val_cov, test_cov = split_time_series(
        covariate_series,
        train_split=settings.training.train_split,
        val_split=settings.training.val_split
    )

print(f"Train set length: {len(train_target)}")
print(f"Validation set length: {len(val_target)}")
print(f"Test set length: {len(test_target)}")

# Visualize the split
fig, ax = plt.subplots(figsize=(12, 6))

train_target.plot(ax=ax, label="Train")
val_target.plot(ax=ax, label="Validation")
test_target.plot(ax=ax, label="Test")

ax.set_title("Data Split Visualization")
ax.set_ylabel("Discharge")
ax.legend()
plt.show()

## Model Training

Now let's train a TFT model on the data:

In [None]:
# Adjust model settings for faster training in notebook
settings.model.n_epochs = 10  # Reduced for demonstration
settings.model.batch_size = 8
settings.model.hidden_size = 32

print("Training TFT model...")
print(f"Configuration: {settings.model.input_chunk_length} -> {settings.model.output_chunk_length}")
print(f"Epochs: {settings.model.n_epochs}")

# Initialize model
model_wrapper = TFTModelWrapper(settings)

# Train the model
try:
    model_wrapper.train(
        target_series=train_target,
        val_series=val_target,
        covariates=train_cov,
        val_covariates=val_cov
    )
    
    print("\nTraining completed successfully!")
    
    # Get training history
    history = model_wrapper.get_training_history()
    
    if history["train"] and history["val"]:
        # Plot training history
        plt.figure(figsize=(10, 6))
        plt.plot(history["train"], label="Training Loss")
        plt.plot(history["val"], label="Validation Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training History")
        plt.legend()
        plt.show()
        
        print(f"Final training loss: {history['train'][-1]:.4f}")
        print(f"Final validation loss: {history['val'][-1]:.4f}")
    
except Exception as e:
    print(f"Training failed: {e}")
    print("This might be due to insufficient data or resource constraints.")

## Model Evaluation

Let's make predictions and evaluate the model:

In [None]:
try:
    # Make predictions on test set
    print("Making predictions...")
    predictions = model_wrapper.predict(
        series=train_target,
        covariates=train_cov,
        n=len(test_target)
    )
    
    print(f"Generated {len(predictions)} predictions")
    
    # Visualize predictions vs actual
    fig, ax = plt.subplots(figsize=(12, 6))
    
    test_target.plot(ax=ax, label="Actual", color="blue")
    predictions.plot(ax=ax, label="Predicted", color="red", alpha=0.7)
    
    ax.set_title("Predictions vs Actual")
    ax.set_ylabel("Discharge")
    ax.legend()
    plt.show()
    
    # Calculate evaluation metrics
    print("\nCalculating evaluation metrics...")
    metrics_df = model_wrapper.evaluate([predictions], [test_target])
    
    print("\nEvaluation Results:")
    print("=" * 30)
    for col in ["NSE", "KGE", "RMSE", "correlation"]:
        if col in metrics_df.columns:
            value = metrics_df.iloc[0][col]
            if not pd.isna(value):
                print(f"{col}: {value:.4f}")
    
    # Show full metrics table
    print("\nFull Metrics Table:")
    print(metrics_df.round(4))
    
except Exception as e:
    print(f"Prediction/evaluation failed: {e}")

## Conclusion

This notebook demonstrated the basic workflow for using the TFT predictions framework:

1. **Configuration**: Setting up model and data parameters
2. **Data Loading**: Loading and preparing hydrological time series data
3. **Visualization**: Exploring the data characteristics
4. **Training**: Training the TFT model
5. **Evaluation**: Making predictions and calculating performance metrics

### Next Steps

- **Scale up**: Use the full dataset with more gauges and longer time series
- **Hyperparameter tuning**: Optimize model parameters for better performance
- **Multi-gauge training**: Train on multiple gauges simultaneously
- **Advanced evaluation**: Include confidence intervals and uncertainty quantification

### Scripts Available

For production use, consider using the command-line scripts:

```bash
# Train single gauge model
python scripts/train_single_gauge.py --gauge-id GAUGE_001 --config configs/single_gauge.yaml

# Train multi-gauge model  
python scripts/train_multi_gauge.py --gauge-list gauge_list.txt --config configs/multi_gauge.yaml

# Make predictions
python scripts/predict.py --model-path models/tft_model.pkl --gauge-ids GAUGE_001 GAUGE_002

# Evaluate model
python scripts/evaluate.py --model-path models/tft_model.pkl --create-plots
```