LSTMIX: Multivariate multistep RNN predictive framework for multidimensional mixing performance metrics.
This Python repository provides a comprehensive framework for preprocessing and augmenting multivariate and multidimensional mixing performance metrics and statistics to train Recurrent Neural Networks with different cell types (LSTM and GRU) and network architectures (Fully-Connected and Encoder-Decoder) using PyTorch. The framework includes scripts for data generation (read and preprocess), model training, hyperparameter tuning, and prediction generation via sequence rollout, and model uncertainty quantification via ensemble-based perturbations.
The repository is organized into several key components:
This script is responsible for generating and preprocessing datasets for training, validation, and testing. It utilizes two additional functionalities:
input.py
: Handles the import, organization, scaling, and smoothing of raw data from DNS simulations stored in CSV files.- Windowing class from
modeltrain_LSTM.py
: Used for augmenting data by windowing and packaging it into corresponding pickle (pkl) files with labels and relevant data for later use.
This script contains essential classes and functionalities for model training. It includes:
- Windowing class: Implements data windowing for augmentation.
- RNN abstract classes and cell-specific child classes: Provides implementations for both fully connected and encoder-decoder architectures using either GRU or LSTM cells.
- Model training logic: Conducts the main training process, saving model states, trained datasets, and the hyperparameters used during training.
This script leverages Ray Tune to perform hyperparameter tuning on the RNN architectures defined in modeltrain_LSTM.py
. It aims to optimize the model's performance by systematically exploring hyperparameter combinations.
This script facilitates the evaluation of the trained model. It includes functionalities to:
- Plot trained and validated datasets.
- Execute a rollout operation to predict values from the test set and compare them against the ground truth.
- Plot various metrics such as a y=x plot, Wasserstein and K-L divergence plots, and more.
LSTMIX/
│
├── Clean_CSV.py
│
├── config/
│ ├── config_paths.ini
│ ├── config_sm.ini
│ └── config_sv.ini
│
├── data_gen.py
│
├── figs/
│ ├── input_data/
│ ├── performance_logs/
│ ├── perturbations/
│ ├── rollouts/
│ ├── split_data/
│ ├── temporal_dist/
│ ├── temporal_EMD/
│ └── windowed/
│
├── hyperparam_tuning.py
│
├── input_data/
│ └── inputdata.pkl
│
├── input.py
├── Load_Clean_DF.py
├── modeltrain_LSTM.py
├── perturbation.py
├── rollout_prediction.py
├── README.md
├── requirements.txt
├── tools_modeltraining.py
│
├── trained_models/
│ ├── data_sets_GRU_ED/
│ ├── data_sets_GRU_FC/
│ ├── data_sets_LSTM_ED/
│ ├── data_sets_LSTM_FC/
│ ├── GRU_ED_logs/
│ ├── GRU_ED_trained_model.pt
│ ├── GRU_FC_logs/
│ ├── GRU_FC_trained_model.pt
│ ├── hyperparams_GRU_ED.txt
│ ├── hyperparams_GRU_FC.txt
│ ├── hyperparams_LSTM_ED.txt
│ ├── hyperparams_LSTM_FC.txt
│ ├── LSTM_ED_logs/
│ ├── LSTM_ED_trained_model.pt
│ ├── LSTM_FC_logs/
│ └── LSTM_FC_trained_model.pt
│
├── tuning/
│ ├── best_models/
│ ├── GRU_ED/
│ ├── GRU_FC/
│ ├── LSTM_ED/
│ └── LSTM_FC/
│
└── RawData/ # Not part of the repository, user data
To use this framework, follow these steps:
- Clone the repository:
git clone https://github.com/your-username/your-repository.git
- Install the required dependencies:
pip install -r requirements.txt
- Execute the data generation script:
python datagen.py
- Train the LSTM model:
python modeltrain_LSTM.py
- Tune hyperparameters (optional):
python hyperparam_tuning.py
- Evaluate model performance:
python rollout_prediction.py
- PyTorch
- Ray Tune
- NumPy
- Matplotlib
- Other dependencies specified in
requirements.txt
- The developers and contributors of PyTorch, Ray Tune, and other open-source libraries used in this framework.