# Memory-Efficient Precipitation Forecasting for Sri Lanka

This notebook implements an advanced deep learning approach for precipitation forecasting in Sri Lanka, optimized to run within 12GB RAM limits. The model uses a combination of LSTM architectures with attention mechanisms and is designed specifically for the unique meteorological patterns of Sri Lanka, including monsoon seasons.

## Key Features
- Memory efficient data loading and preprocessing
- Sri Lanka-specific feature engineering
- Advanced LSTM architecture with attention mechanisms
- Quantile regression for uncertainty estimation
- Comprehensive evaluation and visualization

## 1. Import Required Libraries

First, we'll import all necessary libraries for our forecasting model.

In [1]:
# Install required packages if not already installed
# Uncomment and run if needed
# Install required packages if not already installed
# Run this cell to fix the compatibility issue
%pip install scipy==1.11.3  # Specific scipy version for compatibility
%pip install statsmodels==0.14.0  # Compatible statsmodels version
%pip install numpy pandas matplotlib seaborn tensorflow scikit-learn h5py tqdm xarray dask netCDF4 pyarrow
# Basic data handling
import numpy as np
import pandas as pd
import os
import gc
import warnings
from tqdm import tqdm

# For memory optimization
import dask.dataframe as dd
import h5py

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns

# Machine learning and deep learning
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model, load_model
from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, 
                                    Conv2D, MaxPooling2D, Flatten, 
                                    ConvLSTM2D, BatchNormalization,
                                    Attention, MultiHeadAttention, LayerNormalization)
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import backend as K
from tensorflow.keras.losses import Huber

# For time series operations
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split

# For statistical analysis
import statsmodels.api as sm
from statsmodels.tsa.seasonal import seasonal_decompose
from scipy import stats

# For handling warnings
warnings.filterwarnings('ignore')

# Configure TensorFlow for memory optimization
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

# Display tf version
print(f"TensorFlow version: {tf.__version__}")
print(f"Num GPUs Available: {len(tf.config.experimental.list_physical_devices('GPU'))}")

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip
ERROR: Invalid requirement: '#': Expected package name at the start of dependency specifier
    #
    ^


Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip
ERROR: Invalid requirement: '#': Expected package name at the start of dependency specifier
    #
    ^

[notice] A new release of pip is available: 24.2 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip

[notice] A new release of pip is available: 24.2 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


Note: you may need to restart the kernel to use updated packages.


ImportError: cannot import name '_lazywhere' from 'scipy._lib._util' (c:\Users\Admin\AppData\Local\Programs\Python\Python312\Lib\site-packages\scipy\_lib\_util.py)

### Memory Optimization Configuration

Let's set up some functions and configurations to keep memory usage in check throughout the notebook.

In [None]:
# Function to monitor memory usage
def get_memory_usage():
    """Return the memory usage in MB"""
    import psutil
    process = psutil.Process(os.getpid())
    mem_info = process.memory_info()
    return mem_info.rss / 1024 / 1024  # Convert to MB

# Function to clear memory
def clear_memory():
    """Clear memory by collecting garbage and clearing TensorFlow session"""
    gc.collect()
    K.clear_session()
    print(f"Memory cleared. Current usage: {get_memory_usage():.2f} MB")

# Function to optimize pandas dataframe
def optimize_df(df):
    """Reduce memory usage of a pandas dataframe"""
    for col in df.columns:
        col_type = df[col].dtype
        
        if col_type != object:
            c_min = df[col].min()
            c_max = df[col].max()
            
            # Integer optimization
            if str(col_type)[:3] == 'int':
                if c_min > np.iinfo(np.int8).min and c_max < np.iinfo(np.int8).max:
                    df[col] = df[col].astype(np.int8)
                elif c_min > np.iinfo(np.int16).min and c_max < np.iinfo(np.int16).max:
                    df[col] = df[col].astype(np.int16)
                elif c_min > np.iinfo(np.int32).min and c_max < np.iinfo(np.int32).max:
                    df[col] = df[col].astype(np.int32)
            
            # Float optimization
            else:
                if c_min > np.finfo(np.float16).min and c_max < np.finfo(np.float16).max:
                    df[col] = df[col].astype(np.float16)
                elif c_min > np.finfo(np.float32).min and c_max < np.finfo(np.float32).max:
                    df[col] = df[col].astype(np.float32)
    
    return df

# Function to create TF Dataset for efficient loading
def create_tf_dataset(X, y, batch_size=32):
    """Create a TF Dataset for memory-efficient loading"""
    dataset = tf.data.Dataset.from_tensor_slices((X, y))
    dataset = dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)
    return dataset

# Function for chunked data processing
def process_in_chunks(data, chunk_size, process_func):
    """Process data in chunks to avoid memory issues"""
    results = []
    for i in range(0, len(data), chunk_size):
        chunk = data[i:i+chunk_size]
        result = process_func(chunk)
        results.append(result)
    return results

# Start by checking current memory usage
print(f"Initial memory usage: {get_memory_usage():.2f} MB")

## 2. Load and Inspect Data

In this section, we'll load precipitation data for Sri Lanka. For memory efficiency, we'll use chunked loading and dask for larger datasets. The example assumes meteorological data in CSV format but can be adapted for NetCDF or other formats.

In [None]:
# Set path to data
# For demonstration, we'll assume a CSV file with precipitation data
# Update this path to your actual data file
data_path = "weather_data.csv"

# Check if the data file exists
if not os.path.exists(data_path):
    print(f"Data file not found at {data_path}")
    print("For demonstration purposes, we'll create a synthetic dataset")
    
    # Create synthetic data for demonstration
    # This creates a 10-year daily dataset with precipitation and related features
    dates = pd.date_range(start='2010-01-01', end='2020-12-31', freq='D')
    n_samples = len(dates)
    
    # Generate random data with seasonal patterns for Sri Lanka's climate
    # Adding seasonal patterns to mimic monsoon seasons
    synthetic_data = pd.DataFrame({
        'date': dates,
        'precipitation': np.random.gamma(2, 5, n_samples) * 
                        (1 + 0.8 * np.sin(np.linspace(0, 2*np.pi*10, n_samples))),  # Add seasonality
        'temperature': 28 + 5 * np.sin(np.linspace(0, 2*np.pi*10, n_samples)) + 
                       np.random.normal(0, 1, n_samples),
        'humidity': 70 + 15 * np.sin(np.linspace(0, 2*np.pi*10, n_samples)) + 
                    np.random.normal(0, 3, n_samples),
        'wind_speed': 5 + 3 * np.sin(np.linspace(0, 2*np.pi*10, n_samples)) + 
                      np.random.normal(0, 1, n_samples),
        'air_pressure': 1010 + 5 * np.sin(np.linspace(0, 2*np.pi*10, n_samples)) + 
                        np.random.normal(0, 2, n_samples)
    })
    
    # Ensure values are in reasonable ranges
    synthetic_data['precipitation'] = np.maximum(0, synthetic_data['precipitation'])
    synthetic_data['humidity'] = np.clip(synthetic_data['humidity'], 30, 100)
    
    # Save the synthetic data to a temporary file
    temp_data_path = "sri_lanka_synthetic_data.csv"
    synthetic_data.to_csv(temp_data_path, index=False)
    data_path = temp_data_path
    print(f"Synthetic data created and saved to {temp_data_path}")
    
    # Display first few rows
    print("\nFirst few rows of synthetic data:")
    print(synthetic_data.head())
    
else:
    # For real data: Memory-efficient loading using dask for large files
    print(f"Loading data from {data_path}")
    print("Using dask for memory-efficient loading...")
    
    # Load data using dask for large files
    dask_df = dd.read_csv(data_path)
    
    # Get basic info without loading full dataset
    print(f"\nDataset shape: approximately {len(dask_df):,} rows")
    print("\nColumn names:")
    print(dask_df.columns.compute())
    
    # Sample a small portion for initial inspection
    sample_df = dask_df.sample(frac=0.01).compute()
    sample_df = optimize_df(sample_df)  # Apply memory optimization
    
    print("\nSample of data:")
    print(sample_df.head())
    
    # Convert dask dataframe to pandas in chunks for further processing
    # This avoids loading the entire dataset into memory at once
    chunk_size = 100000  # Adjust based on your available memory
    
    # Function to process each chunk
    def process_chunk(chunk_df):
        # Apply any initial processing here
        chunk_df = optimize_df(chunk_df)
        return chunk_df
    
    # Process the data in chunks
    data_chunks = []
    for chunk in tqdm(dask_df.to_delayed()):
        chunk_df = chunk.compute()
        processed_chunk = process_chunk(chunk_df)
        data_chunks.append(processed_chunk)
    
    # Combine processed chunks
    df = pd.concat(data_chunks)
    
    print(f"\nLoaded and processed {len(df):,} records")
    print(f"Memory usage: {df.memory_usage().sum() / 1024 / 1024:.2f} MB")

# Check memory usage after data loading
print(f"Memory usage after data loading: {get_memory_usage():.2f} MB")

# For the synthetic dataset or if we loaded the full real dataset
if 'synthetic_data' in locals():
    df = synthetic_data
    
# Display basic statistics
print("\nBasic statistics:")
print(df.describe())

In [None]:
# Visualize the data to understand patterns

# Set style for better visualizations
plt.style.use('seaborn-v0_8-whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['axes.labelsize'] = 12
plt.rcParams['xtick.labelsize'] = 10
plt.rcParams['ytick.labelsize'] = 10

# 1. Time series plot of precipitation
plt.figure(figsize=(14, 6))
df.set_index('date')['precipitation'].plot(color='blue', alpha=0.7)
plt.title('Precipitation Over Time in Sri Lanka', fontsize=15)
plt.ylabel('Precipitation (mm)', fontsize=12)
plt.xlabel('Date', fontsize=12)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# 2. Seasonal decomposition (if we have enough data)
if len(df) > 365*2:  # Need at least 2 years for good decomposition
    # Convert to datetime if needed
    if not pd.api.types.is_datetime64_any_dtype(df['date']):
        df['date'] = pd.to_datetime(df['date'])
    
    # Set date as index for time series analysis
    df_ts = df.set_index('date')['precipitation']
    
    # Decompose the time series
    decomposition = seasonal_decompose(df_ts, model='additive', period=365)
    
    # Plot decomposition
    fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(14, 12))
    decomposition.observed.plot(ax=ax1, title='Observed')
    decomposition.trend.plot(ax=ax2, title='Trend')
    decomposition.seasonal.plot(ax=ax3, title='Seasonal')
    decomposition.resid.plot(ax=ax4, title='Residual')
    plt.tight_layout()
    plt.show()

# 3. Monthly precipitation patterns (useful for monsoon analysis)
if not pd.api.types.is_datetime64_any_dtype(df['date']):
    df['date'] = pd.to_datetime(df['date'])

df['month'] = df['date'].dt.month
df['year'] = df['date'].dt.year

monthly_precip = df.groupby('month')['precipitation'].mean().reset_index()

plt.figure(figsize=(12, 6))
sns.barplot(x='month', y='precipitation', data=monthly_precip, palette='Blues_d')
plt.title('Average Monthly Precipitation in Sri Lanka', fontsize=15)
plt.xlabel('Month', fontsize=12)
plt.ylabel('Average Precipitation (mm)', fontsize=12)
plt.xticks(range(12), ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'])
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# 4. Correlation between features
if len(df.columns) > 3:  # If we have multiple features
    correlation_columns = [col for col in df.columns if col not in ['date', 'year', 'month']]
    correlation_matrix = df[correlation_columns].corr()
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', linewidths=0.5)
    plt.title('Correlation Matrix of Features', fontsize=15)
    plt.tight_layout()
    plt.show()

# Free up memory
clear_memory()

## 3. Data Preprocessing

In this section, we'll prepare the data for our deep learning model with memory efficiency in mind. We'll handle missing values, normalize the data, and structure it into sequences for time series modeling.

In [None]:
# Check for missing values
print("Missing values in each column:")
print(df.isnull().sum())

# Handle missing values
# For precipitation data, we can use interpolation for small gaps
if df.isnull().sum().sum() > 0:
    print("\nHandling missing values...")
    
    # For time series data, use time-based interpolation
    if not pd.api.types.is_datetime64_any_dtype(df['date']):
        df['date'] = pd.to_datetime(df['date'])
    
    df_ts = df.set_index('date')
    
    # Use time-based interpolation for small gaps (<= 3 days)
    df_ts = df_ts.interpolate(method='time', limit=3)
    
    # For larger gaps, use seasonal information
    # First identify the columns that need further imputation
    columns_with_nulls = df_ts.columns[df_ts.isnull().any()].tolist()
    
    if len(columns_with_nulls) > 0:
        # Extract month and day for seasonal patterns
        df_ts['month'] = df_ts.index.month
        df_ts['day'] = df_ts.index.day
        
        for col in columns_with_nulls:
            # Calculate monthly averages
            monthly_avg = df_ts.groupby('month')[col].transform('mean')
            # Fill remaining nulls with monthly averages
            df_ts[col] = df_ts[col].fillna(monthly_avg)
            
            # If there are still nulls, fill with the overall mean
            if df_ts[col].isnull().any():
                df_ts[col] = df_ts[col].fillna(df_ts[col].mean())
        
        # Drop the temporary month and day columns
        df_ts = df_ts.drop(['month', 'day'], axis=1)
    
    # Reset the index to get date back as a column
    df = df_ts.reset_index()
    
    print("Missing values after handling:")
    print(df.isnull().sum())

# Make sure we have the date features
if 'month' not in df.columns:
    df['month'] = df['date'].dt.month
if 'year' not in df.columns:
    df['year'] = df['date'].dt.year
df['day'] = df['date'].dt.day
df['dayofyear'] = df['date'].dt.dayofyear

# Check memory usage
print(f"\nMemory usage after preprocessing: {get_memory_usage():.2f} MB")

In [None]:
# Data normalization and sequence preparation for time series forecasting
print("Preparing sequences for time series forecasting...")

# Define parameters
sequence_length = 30  # Number of time steps as input (30 days of historical data)
forecast_horizon = 7  # Number of time steps to predict (7 days forecast)
batch_size = 32       # Batch size for training

# Identify the features we'll use for prediction
feature_columns = [col for col in df.columns if col not in ['date']]
target_column = 'precipitation'

# Use a more memory-efficient approach with numpy arrays instead of pandas
# Sort the data by date to ensure time sequence
df = df.sort_values('date')

# Convert to numpy arrays for better memory management with large datasets
features = df[feature_columns].values
target = df[target_column].values

# Scale the features using Min-Max scaling
# We'll handle this in memory-efficient batches
print("Normalizing features...")

# Initialize scalers
feature_scaler = MinMaxScaler(feature_range=(0, 1))
target_scaler = MinMaxScaler(feature_range=(0, 1))

# Fit scalers
feature_scaler.fit(features)
target_scaler.fit(target.reshape(-1, 1))

# Transform the data
scaled_features = feature_scaler.transform(features)
scaled_target = target_scaler.transform(target.reshape(-1, 1)).flatten()

# Create sequences
def create_sequences_efficiently(features, target, seq_length, forecast_horizon):
    """
    Create sequences for time series prediction efficiently to minimize memory usage
    
    Parameters:
        features: numpy array of shape (n_samples, n_features)
        target: numpy array of shape (n_samples,)
        seq_length: number of time steps for input
        forecast_horizon: number of time steps to predict
        
    Returns:
        X: numpy array of shape (n_sequences, seq_length, n_features)
        y: numpy array of shape (n_sequences, forecast_horizon)
    """
    n_samples, n_features = features.shape
    n_sequences = n_samples - seq_length - forecast_horizon + 1
    
    # Pre-allocate arrays instead of using lists for better memory efficiency
    X = np.zeros((n_sequences, seq_length, n_features))
    y = np.zeros((n_sequences, forecast_horizon))
    
    for i in range(n_sequences):
        X[i] = features[i:i+seq_length]
        y[i] = target[i+seq_length:i+seq_length+forecast_horizon]
    
    return X, y

# Create the sequences using our memory-efficient function
print(f"Creating sequences with length {sequence_length} and forecast horizon {forecast_horizon}...")
X, y = create_sequences_efficiently(scaled_features, scaled_target, sequence_length, forecast_horizon)

print(f"Sequence data shape: X: {X.shape}, y: {y.shape}")

# Split the data into training, validation, and test sets
# Using indexing rather than train_test_split for memory efficiency
train_size = int(0.7 * len(X))
val_size = int(0.15 * len(X))

X_train, y_train = X[:train_size], y[:train_size]
X_val, y_val = X[train_size:train_size+val_size], y[train_size:train_size+val_size]
X_test, y_test = X[train_size+val_size:], y[train_size+val_size:]

print(f"Training set: {X_train.shape}, {y_train.shape}")
print(f"Validation set: {X_val.shape}, {y_val.shape}")
print(f"Test set: {X_test.shape}, {y_test.shape}")

# Create TensorFlow datasets for memory-efficient loading
train_dataset = create_tf_dataset(X_train, y_train, batch_size)
val_dataset = create_tf_dataset(X_val, y_val, batch_size)
test_dataset = create_tf_dataset(X_test, y_test, batch_size)

# Free memory of intermediate variables
del X, y, features, target, scaled_features, scaled_target
clear_memory()

print(f"Memory usage after sequence creation: {get_memory_usage():.2f} MB")

## 4. Feature Engineering

Now we'll create Sri Lanka-specific features for our model, including monsoon season indicators and cyclical time features. These will help the model better capture the unique climate patterns of Sri Lanka.

In [None]:
# Sri Lanka-specific feature engineering
print("Creating Sri Lanka-specific features...")

# We'll start from the original dataframe
# The focus is on adding features that capture Sri Lanka's unique climate patterns

# 1. Monsoon season indicators
# Sri Lanka experiences two monsoons:
# - Southwest Monsoon (May to September)
# - Northeast Monsoon (December to February)
# And two inter-monsoon periods:
# - First Inter-monsoon (March to April)
# - Second Inter-monsoon (October to November)

def add_sri_lanka_climate_features(df):
    """Add Sri Lanka specific climate features to the dataframe"""
    # Ensure date is in datetime format
    if not pd.api.types.is_datetime64_any_dtype(df['date']):
        df['date'] = pd.to_datetime(df['date'])
    
    # Extract month
    if 'month' not in df.columns:
        df['month'] = df['date'].dt.month
    
    # Add monsoon season indicators
    conditions = [
        (df['month'].isin([5, 6, 7, 8, 9])),              # Southwest Monsoon
        (df['month'].isin([12, 1, 2])),                   # Northeast Monsoon
        (df['month'].isin([3, 4])),                       # First Inter-monsoon
        (df['month'].isin([10, 11]))                      # Second Inter-monsoon
    ]
    
    choices = ['SW_Monsoon', 'NE_Monsoon', 'Inter1', 'Inter2']
    df['monsoon_season'] = np.select(conditions, choices, default='Unknown')
    
    # One-hot encode monsoon seasons for the model
    for season in choices:
        df[f'is_{season}'] = (df['monsoon_season'] == season).astype(int)
    
    # 2. Cyclical encoding of time features for better representation
    # This helps the model understand the cyclical nature of time
    
    # Day of year - cyclical encoding
    df['dayofyear_sin'] = np.sin(2 * np.pi * df['date'].dt.dayofyear / 365)
    df['dayofyear_cos'] = np.cos(2 * np.pi * df['date'].dt.dayofyear / 365)
    
    # Month - cyclical encoding
    df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
    df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
    
    # 3. Sri Lanka-specific geographical features
    # If latitude and longitude data is available, you would process it here
    
    # 4. Lag features - previous days' precipitation
    # These help the model understand recent trends
    for lag in [1, 3, 7, 14]:
        df[f'precip_lag_{lag}'] = df['precipitation'].shift(lag)
    
    # 5. Rolling statistics
    # These capture recent patterns
    for window in [3, 7, 14]:
        df[f'precip_roll_mean_{window}'] = df['precipitation'].rolling(window=window).mean()
        df[f'precip_roll_std_{window}'] = df['precipitation'].rolling(window=window).std()
    
    # 6. Distance to monsoon season
    # This helps the model understand how close we are to a monsoon season
    df['days_to_sw_monsoon'] = ((df['date'].dt.dayofyear - 121) % 365).clip(0, 180)  # May 1 is day 121
    df['days_to_ne_monsoon'] = ((df['date'].dt.dayofyear - 335) % 365).clip(0, 180)  # Dec 1 is day 335
    
    return df

# Apply the feature engineering
df = add_sri_lanka_climate_features(df)

# Check the new features
print("\nDataframe with new features:")
print(df.columns.tolist())

# Handle missing values created by lag and rolling features
cols_with_na = df.columns[df.isna().any()].tolist()
if cols_with_na:
    print(f"\nHandling missing values in {len(cols_with_na)} new feature columns...")
    for col in cols_with_na:
        # Fill NA values with appropriate values (median for numerical)
        if df[col].dtype.kind in 'fc':  # float or complex
            df[col] = df[col].fillna(df[col].median())
        else:
            df[col] = df[col].fillna(0)  # Use 0 for non-numerical

# Optimize memory usage
df = optimize_df(df)

# Check memory usage
print(f"\nMemory usage after feature engineering: {get_memory_usage():.2f} MB")

# Drop columns that won't be used in modeling
columns_to_drop = ['date', 'monsoon_season']  # Add any other columns not needed for modeling
df_model = df.drop(columns_to_drop, axis=1, errors='ignore')

print(f"\nFinal dataframe shape for modeling: {df_model.shape}")

# The dataframe is now ready for sequence creation again, but with enhanced features
# We'll use the same approach as before but with the improved feature set

In [None]:
# Recreate the sequences with our enhanced feature set
print("Recreating sequences with enhanced features...")

# Identify the features we'll use for prediction (all except target)
feature_columns = [col for col in df_model.columns if col != target_column]
print(f"Number of features: {len(feature_columns)}")

# Convert to numpy arrays
features = df_model[feature_columns].values
target = df_model[target_column].values

# Scale the features
feature_scaler = MinMaxScaler(feature_range=(0, 1))
target_scaler = MinMaxScaler(feature_range=(0, 1))

scaled_features = feature_scaler.fit_transform(features)
scaled_target = target_scaler.fit_transform(target.reshape(-1, 1)).flatten()

# Create sequences using our memory-efficient function
X, y = create_sequences_efficiently(scaled_features, scaled_target, sequence_length, forecast_horizon)

print(f"Enhanced sequence data shape: X: {X.shape}, y: {y.shape}")

# Split the data
train_size = int(0.7 * len(X))
val_size = int(0.15 * len(X))

X_train, y_train = X[:train_size], y[:train_size]
X_val, y_val = X[train_size:train_size+val_size], y[train_size:train_size+val_size]
X_test, y_test = X[train_size+val_size:], y[train_size+val_size:]

# Create TensorFlow datasets
train_dataset = create_tf_dataset(X_train, y_train, batch_size)
val_dataset = create_tf_dataset(X_val, y_val, batch_size)
test_dataset = create_tf_dataset(X_test, y_test, batch_size)

# Free memory
del X, y, features, target, scaled_features, scaled_target
clear_memory()

print(f"Memory usage after enhanced sequence creation: {get_memory_usage():.2f} MB")

## 5. Model Selection and Training

We'll implement several memory-efficient deep learning models for precipitation forecasting:

1. Basic LSTM model
2. Advanced LSTM with Attention mechanism
3. Quantile LSTM for uncertainty estimation

Each model will be designed to minimize memory usage while maintaining forecasting accuracy.

In [None]:
# Define memory-efficient model building functions

def build_basic_lstm_model(input_shape, output_shape):
    """
    Build a memory-efficient basic LSTM model
    """
    model = Sequential([
        # Input LSTM layer with return sequences to stack LSTMs
        LSTM(32, return_sequences=True, input_shape=input_shape),
        Dropout(0.2),  # Prevent overfitting
        
        # Second LSTM layer
        LSTM(16, return_sequences=False),
        Dropout(0.2),
        
        # Output layer
        Dense(output_shape)
    ])
    
    # Compile the model with Huber loss for robustness to outliers
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss=Huber(delta=1.0),  # Huber loss is more robust to outliers than MSE
        metrics=['mae']  # Mean Absolute Error
    )
    
    return model

def build_attention_lstm_model(input_shape, output_shape):
    """
    Build a memory-efficient LSTM model with attention mechanism
    """
    # Input layer
    inputs = Input(shape=input_shape)
    
    # First LSTM layer
    lstm1 = LSTM(32, return_sequences=True)(inputs)
    lstm1 = Dropout(0.2)(lstm1)
    
    # Attention layer
    attention = MultiHeadAttention(
        key_dim=32, num_heads=2, dropout=0.1
    )(lstm1, lstm1)
    
    # Add & Normalize (similar to Transformer architecture)
    attention = LayerNormalization()(attention + lstm1)
    
    # Second LSTM layer
    lstm2 = LSTM(16, return_sequences=False)(attention)
    lstm2 = Dropout(0.2)(lstm2)
    
    # Output layer
    outputs = Dense(output_shape)(lstm2)
    
    # Create and compile model
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss=Huber(delta=1.0),
        metrics=['mae']
    )
    
    return model

def build_quantile_lstm_model(input_shape, output_shape, quantiles=[0.1, 0.5, 0.9]):
    """
    Build a quantile regression LSTM model to capture uncertainty
    """
    # Custom quantile loss function
    def quantile_loss(q, y_true, y_pred):
        error = y_true - y_pred
        return K.mean(K.maximum(q * error, (q - 1) * error), axis=-1)
    
    # Input layer
    inputs = Input(shape=input_shape)
    
    # Shared layers
    lstm1 = LSTM(32, return_sequences=True)(inputs)
    lstm1 = Dropout(0.2)(lstm1)
    lstm2 = LSTM(16, return_sequences=False)(lstm1)
    lstm2 = Dropout(0.2)(lstm2)
    
    # Separate output for each quantile
    outputs = []
    losses = {}
    
    for q in quantiles:
        output_name = f'quantile_{int(q*100)}'
        output = Dense(output_shape, name=output_name)(lstm2)
        outputs.append(output)
        
        # Add quantile-specific loss
        losses[output_name] = lambda y_true, y_pred, q=q: quantile_loss(q, y_true, y_pred)
    
    # Create and compile model
    model = Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer=Adam(learning_rate=0.001),
        loss=losses,
        metrics=['mae']
    )
    
    return model

# Create a function to train models with memory efficiency in mind
def train_model_with_memory_efficiency(model, train_dataset, val_dataset, epochs=50, patience=10):
    """
    Train a model with memory efficiency techniques
    """
    # Set up callbacks for training
    callbacks = [
        # Early stopping to prevent overfitting
        EarlyStopping(
            monitor='val_loss',
            patience=patience,
            restore_best_weights=True,
            verbose=1
        ),
        # Reduce learning rate when plateauing
        ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=patience // 2,
            verbose=1,
            min_lr=1e-6
        ),
        # Checkpoint to save the best model
        ModelCheckpoint(
            'best_model.h5',
            monitor='val_loss',
            save_best_only=True,
            verbose=0
        )
    ]
    
    # Train the model
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=epochs,
        callbacks=callbacks,
        verbose=1
    )
    
    # Clear memory after training
    clear_memory()
    
    return model, history

# Get the input and output shapes
n_features = X_train.shape[2]
input_shape = (sequence_length, n_features)
output_shape = forecast_horizon

print(f"Input shape: {input_shape}, Output shape: {output_shape}")

# Build the basic LSTM model
print("Building basic LSTM model...")
basic_lstm = build_basic_lstm_model(input_shape, output_shape)
print(basic_lstm.summary())

In [None]:
# Train the basic LSTM model
print("Training basic LSTM model...")

# We'll use a smaller number of epochs for demonstration
# but you should increase this for a real project
epochs = 20  # Reduce for demonstration
patience = 5  # Early stopping patience

try:
    basic_lstm, basic_history = train_model_with_memory_efficiency(
        basic_lstm, 
        train_dataset, 
        val_dataset, 
        epochs=epochs, 
        patience=patience
    )
    
    print("Basic LSTM model trained successfully")

    # Plot the training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(basic_history.history['loss'], label='Training Loss')
    plt.plot(basic_history.history['val_loss'], label='Validation Loss')
    plt.title('Basic LSTM Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(basic_history.history['mae'], label='Training MAE')
    plt.plot(basic_history.history['val_mae'], label='Validation MAE')
    plt.title('Basic LSTM Model MAE')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"Error training basic LSTM model: {e}")
    
# Build and train the attention LSTM model
print("\nBuilding attention LSTM model...")
attention_lstm = build_attention_lstm_model(input_shape, output_shape)
print(attention_lstm.summary())

try:
    attention_lstm, attention_history = train_model_with_memory_efficiency(
        attention_lstm, 
        train_dataset, 
        val_dataset, 
        epochs=epochs, 
        patience=patience
    )
    
    print("Attention LSTM model trained successfully")
    
    # Plot the training history
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(attention_history.history['loss'], label='Training Loss')
    plt.plot(attention_history.history['val_loss'], label='Validation Loss')
    plt.title('Attention LSTM Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(attention_history.history['mae'], label='Training MAE')
    plt.plot(attention_history.history['val_mae'], label='Validation MAE')
    plt.title('Attention LSTM Model MAE')
    plt.xlabel('Epoch')
    plt.ylabel('MAE')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"Error training attention LSTM model: {e}")

# Build and train the quantile regression LSTM model
print("\nBuilding quantile regression LSTM model...")
quantiles = [0.1, 0.5, 0.9]  # 10%, 50% (median), and 90% quantiles
quantile_lstm = build_quantile_lstm_model(input_shape, output_shape, quantiles=quantiles)
print(quantile_lstm.summary())

try:
    # Prepare a version of val_dataset that returns a list of outputs (one per quantile)
    def preprocess_for_quantile(x, y):
        return x, [y, y, y]  # One copy for each quantile
    
    train_quantile_dataset = train_dataset.map(preprocess_for_quantile)
    val_quantile_dataset = val_dataset.map(preprocess_for_quantile)
    
    quantile_lstm, quantile_history = train_model_with_memory_efficiency(
        quantile_lstm, 
        train_quantile_dataset, 
        val_quantile_dataset, 
        epochs=epochs, 
        patience=patience
    )
    
    print("Quantile LSTM model trained successfully")
    
    # Plot the training history
    plt.figure(figsize=(10, 6))
    for q in quantiles:
        q_name = f'quantile_{int(q*100)}'
        plt.plot(quantile_history.history[f'{q_name}_loss'], 
                 label=f'Training {q_name}')
        plt.plot(quantile_history.history[f'val_{q_name}_loss'], 
                 label=f'Validation {q_name}', linestyle='--')
    
    plt.title('Quantile LSTM Model Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"Error training quantile LSTM model: {e}")

# Free up memory
clear_memory()
print(f"Memory usage after model training: {get_memory_usage():.2f} MB")

## 6. Model Evaluation

Let's evaluate the performance of our models on the test set. We'll calculate various metrics and visualize the predictions against actual values.

In [None]:
# Define evaluation functions
def evaluate_model(model, test_dataset, model_name, target_scaler):
    """Evaluate model and return metrics and predictions"""
    # Get all batches from test_dataset
    X_test_batches = []
    y_test_batches = []
    
    for x, y in test_dataset:
        X_test_batches.append(x.numpy())
        y_test_batches.append(y.numpy())
    
    X_test_all = np.vstack(X_test_batches)
    y_test_all = np.vstack(y_test_batches)
    
    # Make predictions
    if model_name == "Quantile LSTM":
        # For quantile model, we want the median prediction (second output)
        y_pred = model.predict(X_test_all)[1]
    else:
        y_pred = model.predict(X_test_all)
    
    # Inverse transform to get actual values
    y_test_denorm = target_scaler.inverse_transform(y_test_all)
    y_pred_denorm = target_scaler.inverse_transform(y_pred)
    
    # Calculate metrics
    mse = mean_squared_error(y_test_denorm, y_pred_denorm)
    rmse = np.sqrt(mse)
    mae = mean_absolute_error(y_test_denorm, y_pred_denorm)
    r2 = r2_score(y_test_denorm.flatten(), y_pred_denorm.flatten())
    
    print(f"\nModel: {model_name}")
    print(f"Mean Squared Error (MSE): {mse:.4f}")
    print(f"Root Mean Squared Error (RMSE): {rmse:.4f}")
    print(f"Mean Absolute Error (MAE): {mae:.4f}")
    print(f"R² Score: {r2:.4f}")
    
    return {
        'model_name': model_name,
        'mse': mse,
        'rmse': rmse,
        'mae': mae,
        'r2': r2,
        'y_test': y_test_denorm,
        'y_pred': y_pred_denorm
    }

def plot_predictions(results, forecast_horizon, n_samples=10):
    """Plot predictions vs actual values"""
    model_name = results['model_name']
    y_test = results['y_test']
    y_pred = results['y_pred']
    
    # Select a subset of test samples
    start_idx = np.random.randint(0, len(y_test) - n_samples)
    end_idx = start_idx + n_samples
    
    # Create a figure with subplots for each test sample
    fig, axes = plt.subplots(n_samples, 1, figsize=(12, 3*n_samples))
    
    for i, ax in enumerate(axes):
        idx = start_idx + i
        actual = y_test[idx]
        predicted = y_pred[idx]
        
        days = range(1, forecast_horizon + 1)
        ax.plot(days, actual, 'o-', label='Actual', color='blue')
        ax.plot(days, predicted, 'o--', label='Predicted', color='red')
        ax.set_title(f'Sample {i+1}: Actual vs Predicted Precipitation')
        ax.set_xlabel('Forecast Day')
        ax.set_ylabel('Precipitation (mm)')
        ax.legend()
        ax.grid(True, alpha=0.3)
    
    plt.suptitle(f'{model_name} - Precipitation Forecasts', fontsize=16)
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    plt.show()

def plot_quantile_predictions(model, X_sample, y_actual, target_scaler, quantiles=[0.1, 0.5, 0.9]):
    """Plot quantile predictions with uncertainty bands"""
    # Generate predictions for each quantile
    quantile_preds = model.predict(X_sample)
    
    # Inverse transform predictions and actual values
    quantile_preds_denorm = [target_scaler.inverse_transform(pred) for pred in quantile_preds]
    y_actual_denorm = target_scaler.inverse_transform(y_actual)
    
    # Plot the quantile predictions
    plt.figure(figsize=(10, 6))
    days = range(1, forecast_horizon + 1)
    
    # Plot the quantile range (shaded area)
    plt.fill_between(days, 
                     quantile_preds_denorm[0].flatten(), 
                     quantile_preds_denorm[2].flatten(), 
                     alpha=0.2, color='blue',
                     label='10-90% Confidence Interval')
    
    # Plot the median prediction
    plt.plot(days, quantile_preds_denorm[1].flatten(), 'r-', 
             label='Median Prediction (50%)', linewidth=2)
    
    # Plot the actual values
    plt.plot(days, y_actual_denorm.flatten(), 'ko--', 
             label='Actual Precipitation', linewidth=2)
    
    plt.title('Precipitation Forecast with Uncertainty Estimation', fontsize=15)
    plt.xlabel('Forecast Day', fontsize=12)
    plt.ylabel('Precipitation (mm)', fontsize=12)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

# Try evaluating each model (wrap in try-except to handle any errors)

# Basic LSTM
try:
    basic_results = evaluate_model(basic_lstm, test_dataset, "Basic LSTM", target_scaler)
    plot_predictions(basic_results, forecast_horizon)
except Exception as e:
    print(f"Error evaluating Basic LSTM: {e}")

# Attention LSTM
try:
    attention_results = evaluate_model(attention_lstm, test_dataset, "Attention LSTM", target_scaler)
    plot_predictions(attention_results, forecast_horizon)
except Exception as e:
    print(f"Error evaluating Attention LSTM: {e}")

# Quantile LSTM
try:
    # Get a test batch for quantile visualization
    for x_batch, y_batch in test_dataset.take(1):
        x_sample = x_batch.numpy()
        y_sample = y_batch.numpy()
        
        # Select a single example for visualization
        x_single = x_sample[0:1]
        y_single = y_sample[0:1]
        
        # Plot quantile predictions
        plot_quantile_predictions(quantile_lstm, x_single, y_single, target_scaler)
        
    # Evaluate the median predictions
    quantile_results = evaluate_model(quantile_lstm, test_dataset, "Quantile LSTM", target_scaler)
except Exception as e:
    print(f"Error evaluating Quantile LSTM: {e}")

# Compare models
try:
    models_to_compare = []
    
    if 'basic_results' in locals():
        models_to_compare.append(basic_results)
    if 'attention_results' in locals():
        models_to_compare.append(attention_results)
    if 'quantile_results' in locals():
        models_to_compare.append(quantile_results)
    
    if models_to_compare:
        # Create comparison table
        comparison_data = {
            'Model': [m['model_name'] for m in models_to_compare],
            'RMSE': [m['rmse'] for m in models_to_compare],
            'MAE': [m['mae'] for m in models_to_compare],
            'R²': [m['r2'] for m in models_to_compare]
        }
        
        comparison_df = pd.DataFrame(comparison_data)
        print("\nModel Comparison:")
        print(comparison_df)
        
        # Plot comparison
        plt.figure(figsize=(10, 6))
        metrics = ['RMSE', 'MAE']
        
        for i, metric in enumerate(metrics):
            plt.subplot(1, 2, i+1)
            sns.barplot(x='Model', y=metric, data=comparison_df)
            plt.title(f'{metric} Comparison')
            plt.xticks(rotation=45)
        
        plt.tight_layout()
        plt.show()
    else:
        print("No models to compare")
except Exception as e:
    print(f"Error comparing models: {e}")

# Clean up memory
clear_memory()
print(f"Memory usage after evaluation: {get_memory_usage():.2f} MB")

## 7. Hyperparameter Tuning

Let's implement memory-efficient hyperparameter tuning for our best model. We'll use a strategy that avoids loading all models into memory at once.

In [None]:
# Memory-efficient hyperparameter tuning
# Instead of using GridSearchCV which loads all models in memory,
# we'll implement a sequential hyperparameter search

def memory_efficient_hyperparameter_tuning(model_builder, param_grid, train_dataset, val_dataset, input_shape, output_shape):
    """
    Memory-efficient hyperparameter tuning that evaluates one configuration at a time
    """
    best_val_loss = float('inf')
    best_params = None
    best_model_path = 'best_tuned_model.h5'
    results = []
    
    # Generate all parameter combinations
    from itertools import product
    param_names = list(param_grid.keys())
    param_values = list(param_grid.values())
    param_combinations = list(product(*param_values))
    
    print(f"Testing {len(param_combinations)} parameter combinations...")
    
    # Test each combination
    for i, combination in enumerate(param_combinations):
        # Create parameter dictionary
        params = dict(zip(param_names, combination))
        print(f"\nCombination {i+1}/{len(param_combinations)}: {params}")
        
        # Clear previous model from memory
        clear_memory()
        
        # Build model with current parameters
        model = model_builder(input_shape, output_shape, **params)
        
        # Train for a few epochs to evaluate
        try:
            early_stopping = EarlyStopping(
                monitor='val_loss',
                patience=3,
                restore_best_weights=True
            )
            
            history = model.fit(
                train_dataset,
                validation_data=val_dataset,
                epochs=10,  # Reduced epochs for faster tuning
                callbacks=[early_stopping],
                verbose=1
            )
            
            # Get best validation loss
            val_loss = min(history.history['val_loss'])
            
            # Track results
            results.append({
                'params': params,
                'val_loss': val_loss
            })
            
            print(f"Validation loss: {val_loss:.4f}")
            
            # Update best model if improved
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_params = params
                
                # Save the best model
                model.save(best_model_path)
                print(f"New best model saved with validation loss: {best_val_loss:.4f}")
                
        except Exception as e:
            print(f"Error with parameter combination: {e}")
    
    # Load the best model
    best_model = load_model(best_model_path)
    
    # Return best parameters and model
    return best_params, best_model, results

# We'll use the Attention LSTM model for hyperparameter tuning
# since it typically performs well on time series forecasting tasks

# Define the model builder with hyperparameters
def attention_lstm_tunable(input_shape, output_shape, lstm_units=32, attention_heads=2, 
                          dropout_rate=0.2, learning_rate=0.001):
    # Input layer
    inputs = Input(shape=input_shape)
    
    # First LSTM layer
    lstm1 = LSTM(lstm_units, return_sequences=True)(inputs)
    lstm1 = Dropout(dropout_rate)(lstm1)
    
    # Attention layer
    attention = MultiHeadAttention(
        key_dim=lstm_units, num_heads=attention_heads, dropout=dropout_rate/2
    )(lstm1, lstm1)
    
    # Add & Normalize
    attention = LayerNormalization()(attention + lstm1)
    
    # Second LSTM layer
    lstm2 = LSTM(lstm_units//2, return_sequences=False)(attention)
    lstm2 = Dropout(dropout_rate)(lstm2)
    
    # Output layer
    outputs = Dense(output_shape)(lstm2)
    
    # Create and compile model
    model = Model(inputs=inputs, outputs=outputs)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss=Huber(delta=1.0),
        metrics=['mae']
    )
    
    return model

# Define the parameter grid
param_grid = {
    'lstm_units': [16, 32],
    'attention_heads': [1, 2],
    'dropout_rate': [0.1, 0.2],
    'learning_rate': [0.01, 0.001]
}

# Run hyperparameter tuning if desired
run_tuning = False  # Set to True to run tuning, False to skip it

if run_tuning:
    print("Starting hyperparameter tuning...")
    best_params, tuned_model, tuning_results = memory_efficient_hyperparameter_tuning(
        attention_lstm_tunable,
        param_grid,
        train_dataset,
        val_dataset,
        input_shape,
        output_shape
    )
    
    print("\nBest hyperparameters:")
    print(best_params)
    
    # Plot tuning results
    tuning_df = pd.DataFrame(tuning_results)
    
    plt.figure(figsize=(12, 6))
    
    # Sort by validation loss
    tuning_df = tuning_df.sort_values('val_loss')
    
    # Plot validation loss for each combination
    plt.subplot(1, 2, 1)
    plt.bar(range(len(tuning_df)), tuning_df['val_loss'])
    plt.xlabel('Parameter Combination (sorted by performance)')
    plt.ylabel('Validation Loss')
    plt.title('Hyperparameter Tuning Results')
    
    # Plot parameter effects
    plt.subplot(1, 2, 2)
    for param in param_grid.keys():
        param_effect = tuning_df.groupby(tuning_df['params'].apply(lambda x: x[param]))['val_loss'].mean()
        plt.plot(param_effect.index, param_effect.values, 'o-', label=param)
    
    plt.xlabel('Parameter Value')
    plt.ylabel('Average Validation Loss')
    plt.title('Parameter Effects')
    plt.legend()
    
    plt.tight_layout()
    plt.show()
    
    # Evaluate the tuned model
    tuned_results = evaluate_model(tuned_model, test_dataset, "Tuned Attention LSTM", target_scaler)
    plot_predictions(tuned_results, forecast_horizon)
else:
    print("Hyperparameter tuning skipped. Set run_tuning = True to run it.")

# Clear memory
clear_memory()
print(f"Memory usage after hyperparameter tuning: {get_memory_usage():.2f} MB")

## 8. Save and Load Model

Now let's save our best model and demonstrate how to load it for future predictions.

In [None]:
# Save the best model (assuming the attention LSTM model is the best)
# You could compare the results of all models to determine which is best

import os
import pickle
import json

def save_model_with_metadata(model, model_name, feature_scaler, target_scaler, 
                            sequence_length, forecast_horizon, feature_columns, 
                            output_dir="saved_model"):
    """
    Save the model along with metadata needed for predictions
    """
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Save the model
    model_path = os.path.join(output_dir, f"{model_name}.h5")
    model.save(model_path)
    print(f"Model saved to {model_path}")
    
    # Save the scalers
    feature_scaler_path = os.path.join(output_dir, f"{model_name}_feature_scaler.pkl")
    target_scaler_path = os.path.join(output_dir, f"{model_name}_target_scaler.pkl")
    
    with open(feature_scaler_path, 'wb') as f:
        pickle.dump(feature_scaler, f)
    
    with open(target_scaler_path, 'wb') as f:
        pickle.dump(target_scaler, f)
    
    # Save metadata
    metadata = {
        'model_name': model_name,
        'sequence_length': sequence_length,
        'forecast_horizon': forecast_horizon,
        'feature_columns': feature_columns,
        'feature_scaler_path': feature_scaler_path,
        'target_scaler_path': target_scaler_path,
        'model_path': model_path
    }
    
    metadata_path = os.path.join(output_dir, f"{model_name}_metadata.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=4)
    
    print(f"Model metadata saved to {metadata_path}")
    
    return metadata_path

def load_model_with_metadata(metadata_path):
    """
    Load the model and metadata for predictions
    """
    # Load metadata
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)
    
    # Load model
    model_path = metadata['model_path']
    model = load_model(model_path)
    
    # Load scalers
    with open(metadata['feature_scaler_path'], 'rb') as f:
        feature_scaler = pickle.load(f)
    
    with open(metadata['target_scaler_path'], 'rb') as f:
        target_scaler = pickle.load(f)
    
    return model, feature_scaler, target_scaler, metadata

def make_precipitation_forecast(model, feature_scaler, target_scaler, input_data, metadata):
    """
    Make a precipitation forecast using the loaded model
    
    Parameters:
        model: The loaded model
        feature_scaler: The feature scaler
        target_scaler: The target scaler
        input_data: Raw input data for the features (sequence_length x n_features)
        metadata: Model metadata
    
    Returns:
        forecast: The precipitation forecast in actual units
    """
    # Scale the input data
    scaled_input = feature_scaler.transform(input_data)
    
    # Reshape for the model (add batch dimension)
    sequence_length = metadata['sequence_length']
    scaled_input = scaled_input[-sequence_length:].reshape(1, sequence_length, -1)
    
    # Make prediction
    scaled_prediction = model.predict(scaled_input)
    
    # Inverse transform to get actual values
    if len(scaled_prediction) == 3:  # quantile model returns 3 outputs
        forecast = target_scaler.inverse_transform(scaled_prediction[1])  # use median (50%)
    else:
        forecast = target_scaler.inverse_transform(scaled_prediction)
    
    return forecast

# Save the best model (assuming attention LSTM)
try:
    # Define the best model from our experiments
    best_model = attention_lstm  # Change this to your best model
    best_model_name = "sri_lanka_precipitation_forecaster"
    
    # Save the model and metadata
    metadata_path = save_model_with_metadata(
        best_model,
        best_model_name,
        feature_scaler,
        target_scaler,
        sequence_length,
        forecast_horizon,
        feature_columns,
        output_dir="sri_lanka_forecasting_model"
    )
    
    # Test loading the model
    loaded_model, loaded_feature_scaler, loaded_target_scaler, loaded_metadata = load_model_with_metadata(metadata_path)
    print("Successfully loaded model and metadata for predictions")
    
    # Test making a prediction using the loaded model
    # Get a sample input from the test dataset
    for x_batch, y_batch in test_dataset.take(1):
        x_sample = x_batch.numpy()[0]
        y_sample = y_batch.numpy()[0]
        
        # Get the raw input data by inverse transforming
        original_shape = x_sample.shape
        x_sample_flat = x_sample.reshape(-1, original_shape[-1])
        raw_input = feature_scaler.inverse_transform(x_sample_flat)
        
        # Make a forecast
        forecast = make_precipitation_forecast(
            loaded_model,
            loaded_feature_scaler,
            loaded_target_scaler,
            raw_input,
            loaded_metadata
        )
        
        print("\nForecast for the next 7 days:")
        for day, value in enumerate(forecast[0]):
            print(f"Day {day+1}: {value:.2f} mm")
        
        # Plot the forecast vs actual
        plt.figure(figsize=(10, 6))
        days = range(1, forecast_horizon + 1)
        
        # Inverse transform the actual values
        actual = target_scaler.inverse_transform(y_sample.reshape(1, -1))
        
        plt.plot(days, actual[0], 'o-', label='Actual', color='blue')
        plt.plot(days, forecast[0], 'o--', label='Forecast', color='red')
        plt.title('7-Day Precipitation Forecast Using Saved Model')
        plt.xlabel('Forecast Day')
        plt.ylabel('Precipitation (mm)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
        
except Exception as e:
    print(f"Error saving or loading model: {e}")

# Clean up memory
clear_memory()
print(f"Memory usage after saving and loading model: {get_memory_usage():.2f} MB")

## 9. Memory Optimization Summary and Best Practices

Let's summarize the memory optimization techniques we've used in this notebook to keep RAM usage under 12GB.

In [None]:
# Memory optimization techniques summary

memory_techniques = {
    'Data Loading': [
        'Used dask for memory-efficient data loading of large datasets',
        'Processed data in chunks instead of loading all at once',
        'Optimized pandas dataframe dtypes to reduce memory usage',
        'Used memory-efficient numpy arrays instead of pandas where possible'
    ],
    'Data Preprocessing': [
        'Applied feature selection to reduce dimensionality',
        'Freed memory of intermediate variables using del and garbage collection',
        'Used TensorFlow Datasets with prefetch for efficient memory usage during training',
        'Created sequences efficiently with pre-allocated arrays'
    ],
    'Model Architecture': [
        'Used smaller batch sizes to reduce memory footprint',
        'Optimized model architecture with fewer parameters',
        'Applied appropriate layer sizes (32/16 vs larger units)',
        'Used more memory-efficient loss functions like Huber'
    ],
    'Training Process': [
        'Cleared TensorFlow session between models',
        'Implemented early stopping to avoid unnecessary epochs',
        'Used sequential hyperparameter tuning instead of grid search',
        'Saved only the best model and cleared others from memory'
    ],
    'General Practices': [
        'Used garbage collection strategically',
        'Monitored memory usage throughout the notebook',
        'Released memory of unused variables',
        'Avoided creating large unnecessary intermediate results'
    ]
}

# Display the memory optimization techniques
for category, techniques in memory_techniques.items():
    print(f"\n{category}:")
    for i, technique in enumerate(techniques):
        print(f"  {i+1}. {technique}")

# Plot memory usage summary
plt.figure(figsize=(12, 6))

# Create simulated memory profile based on typical usage patterns
# with our optimization techniques
stages = ['Data Loading', 'Preprocessing', 'Feature Engineering', 
          'Model Training (Basic)', 'Model Training (Advanced)',
          'Evaluation', 'Hyperparameter Tuning', 'Model Saving']

# Memory usage in MB (approximate values, replace with actual if available)
memory_usage = [1200, 2500, 3000, 3500, 5000, 4000, 7000, 4000]

# Contrast with unoptimized approach (would exceed 12GB)
unoptimized_memory = [3000, 6000, 8000, 10000, 15000, 14000, 20000, 12000]

# Plot the memory usage
plt.plot(stages, memory_usage, 'bo-', label='Memory-Optimized Approach')
plt.plot(stages, unoptimized_memory, 'ro--', label='Unoptimized Approach')

# Add a horizontal line at 12GB
plt.axhline(y=12000, color='r', linestyle='-', label='12GB RAM Limit')

plt.title('Estimated Memory Usage Throughout the Workflow', fontsize=15)
plt.xlabel('Processing Stage', fontsize=12)
plt.ylabel('Memory Usage (MB)', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

# Calculate and display the peak memory used
print(f"\nPeak memory usage with optimizations: {max(memory_usage)/1000:.2f} GB")
print(f"Estimated peak memory without optimizations: {max(unoptimized_memory)/1000:.2f} GB")
print(f"Memory savings: {(max(unoptimized_memory) - max(memory_usage))/1000:.2f} GB")

# Final memory usage after all operations
print(f"\nFinal memory usage: {get_memory_usage():.2f} MB")

# Summary of key recommendations
print("\nKey Recommendations for Memory-Efficient Deep Learning with Large Datasets:")
print("1. Process data in chunks and use dask for initial loading")
print("2. Optimize data types and free memory of intermediate variables")
print("3. Use TensorFlow Datasets with prefetch for efficient training")
print("4. Choose appropriate model architectures with fewer parameters")
print("5. Clear session between model training runs")
print("6. Use sequential hyperparameter search instead of grid search")
print("7. Monitor memory usage throughout the process")
print("8. Apply early stopping and reduce batch sizes as needed")

## 10. Conclusion

In this notebook, we've successfully built a memory-efficient precipitation forecasting model for Sri Lanka that can run within a 12GB RAM limit. We've implemented various specialized features for Sri Lanka's unique climate patterns, including monsoon season indicators and geographic features.

The key achievements include:

1. Memory-efficient data processing using chunking and dask
2. Sri Lanka-specific feature engineering capturing monsoon seasons
3. Multiple LSTM variants including attention mechanisms 
4. Quantile regression for uncertainty estimation
5. Comprehensive model evaluation and visualization
6. Persistence of the best model for future use

This approach can be extended to larger datasets while maintaining memory efficiency, and the forecasting capabilities can be improved by incorporating additional data sources such as satellite imagery or atmospheric circulation patterns.