# Usage
Ideally the only thing that needs to be changed is: 
- The start/end date of the data
- The data ingestion portion

The rest should be taken care of, assuming no bugs

# Data parameters

In [None]:
# define bounding box
lat_bottom, lat_top = 33.9, 34.2
lon_bottom, lon_top = -118.4, -118.0
extent = (lon_bottom, lon_top, lat_bottom, lat_top)

# input data shape
dim = 200
frames_per_sample = 5

# date range of data
start_date, end_date = "2024-12-01-00", "2024-12-31-23"

# Data ingestion and preprocessing

In [None]:
# python nonsense that allows you to import from sibling directories
import sys
sys.path.append("..")

import numpy as np
# split data
def train_test_split(X, train_size=0.75):
    split_idx = int(X.shape[0] * train_size)
    X_train, X_test = X[:split_idx], X[split_idx:]
    
    return X_train, X_test

# scale training data, then scale test data based on training data stats
from sklearn.preprocessing import StandardScaler
def std_scale(X_train, X_test):
    scaler = StandardScaler()
    scaled_train = scaler.fit_transform(X_train.reshape(-1, 1)).reshape(X_train.shape)
    scaled_test = scaler.transform(X_test.reshape(-1, 1)).reshape(X_test.shape)

    return scaled_train, scaled_test

In [None]:
# ingest data here; maiac, airnow, hrrr, etc.
X_1 = ... # replace X_1 with dataset; e.g. X_hrrr = HD.data

# train-test split
X_1_train, X_1_test = train_test_split(X_1, train_size=0.75)
X_airnow_train, X_airnow_test = train_test_split(X_airnow, train_size=0.75)
y_train, y_test = train_test_split(Y, train_size=0.75)

# scale dataset
X_1_train, X_1_test = std_scale(X_1_train, X_1_test)
...

# merge datasets into a 5D tensor
X_train = np.concatenate([...], axis=-1)
X_test = np.concatenate([...], axis=-1)

print(X_train.shape, X_test.shape)
print(y_train.shape, y_test.shape)

# Data visualization

In [None]:
import matplotlib.pyplot as plt

# Construct a figure on which we will visualize the images.
n_channels = X_train.shape[4]
fig, axes = plt.subplots(n_channels, 5, figsize=(10, 12))

# plot channels of a random data sample
np.random.seed(42)
rand_sample = np.random.choice(range(len(X_train)), size=1)[0]
for c in range(n_channels):
    for idx, ax in enumerate(axes[c]):
        ax.imshow(np.squeeze(X_train[rand_sample, idx, :, :, c]))
        ax.set_title(f"Frame {idx + 1}")
        ax.axis("off")

# Print information and display the figure.
print(f"Displaying frames for example {rand_sample}.")
plt.show()

print("Target: ", y_train[rand_sample])

# Model

In [None]:
import tensorflow as tf
import keras
from keras.models import Sequential
from keras.models import Model
from keras.layers import Conv3D
from keras.layers import ConvLSTM2D
from keras.layers import BatchNormalization
from keras.layers import Convolution2D, MaxPooling3D, Flatten, Reshape
from keras.layers import TimeDistributed
from keras.layers import Dropout
from keras.layers import Dense
from keras.layers import InputLayer

tf.keras.backend.set_image_data_format('channels_last')

seq = Sequential()

seq.add(
    InputLayer(shape=(5, 200, 200, 6))
)

seq.add(
    ConvLSTM2D(
            filters=15, 
            kernel_size=(3, 3),
            padding='same', 
            return_sequences=True
    )
)

seq.add(
    ConvLSTM2D(
        filters=30, 
        kernel_size=(3, 3),
        padding='same', 
        return_sequences=True
    )
)

seq.add(
    Conv3D(
        filters=15, 
        kernel_size=(3, 3, 3),
        activation='relu',
        padding='same'    
    )
)

seq.add(
    Conv3D(
        filters=1, 
        kernel_size=(3, 3, 3),
        activation='relu',
        padding='same'
    )
)

seq.add(Flatten())

seq.add(Dense(3,activation='relu'))

seq.compile(loss='mean_absolute_error', optimizer='adam')
seq.summary()

In [None]:
seq.fit(X_train, y_train, batch_size=4, epochs=150)

# Evaluate

In [None]:
y_pred = seq.predict(X_test, verbose=0)

In [None]:
from libs.plotting import (
    plot_prediction_comparison,
    plot_scatter_comparison,
    plot_error_by_sensor,
    plot_time_series_comparison,
    plot_input_frames,
    print_metrics
)

sensor_names = ["North Holywood", "Los Angeles - N. Main Street", "Compton"]

print("\n1. Plotting prediction comparison...")
plot_prediction_comparison(y_pred, y_test, sensor_names, sample_idx=12)

print("\n2. Plotting scatter comparison...")
plot_scatter_comparison(y_pred, y_test)

print("\n3. Plotting error by sensor...")
plot_error_by_sensor(y_pred, y_test, sensor_names)

print("\n4. Plotting time series comparison...")
plot_time_series_comparison(y_pred, y_test, sensor_names)
    
print("\n5. Plotting time series with shifted predictions...")
plot_time_series_comparison(y_pred, y_test, sensor_names, shift_pred=1)

print("\n6. Printing metrics...")
print_metrics(y_pred, y_test, sensor_names)