## PLS as a fault detection model

In [1]:
from fermfaultdetect.data.utils import load_batchset, dataloader
from fermfaultdetect.utils import get_simulation_dir
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import seaborn as sns
from fermfaultdetect.utils import get_models_dir
from fermfaultdetect.visualizations import visualize
from fermfaultdetect.fault_detect_models.ml_models import pls_fdm
from fermfaultdetect import model_evaluation as eval
from fermfaultdetect.model_evaluation import plot_example_set
from datetime import datetime
import pickle
import json

### Load training and test data

In [2]:
seed = 42 # set seeding

sim_dir = get_simulation_dir() # get directory of simulation data

##############################################
model_name = "FILL_IN_MODEL_NAME" # set the name of model (e.g. date or specific name)
train_set_name = "FILL_IN_TRAINING_SET_NAME"
val_set_name = "FILL_IN_VALIDATION_SET_NAME"
##############################################

train_path = os.path.join(sim_dir, train_set_name)
val_path = os.path.join(sim_dir, val_set_name)

# set directory to save model and metrics
model_dir = os.path.join(get_models_dir(), model_name)
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

train_set = load_batchset(train_path)
val_set = load_batchset(val_path)

# Load train data into dataloader and standardize
target_cols = ['defect_steambarrier', 'steam_in_feed', 'blocked_spargers', 'airflow_OOC', 'OUR_OOC', 'no_fault'] # set target columns
train_dl = dataloader(batchset = train_set[:], seed=seed)
train_dl.shuffle_batches()
train_dl.standardize_data(exclude_cols=target_cols)

# Load test data into dataloader and standardize
val_dl = dataloader(batchset = val_set[:], seed=seed)
val_dl.import_standardization(train_dl)
val_dl.standardize_data(exclude_cols=target_cols)

# Retrieve data from dataloader with separate and fused target columns
train_X, train_Y = train_dl.get_data(split_batches=False, target_cols=target_cols, separate_target_matrix=True, fuse_target_cols=True)
val_X, val_Y = val_dl.get_data(split_batches=False, target_cols=target_cols, separate_target_matrix=True, fuse_target_cols=True)
_, val_Y_unfused = val_dl.get_data(split_batches=False, target_cols=target_cols, separate_target_matrix=True, fuse_target_cols=False)

# Cut target column to 1D-array
train_Y = train_Y[["fault"]]
val_Y = val_Y[["fault"]]

## Optimize components, detection threshold tau and moving time window through gridsearch

### Run gridsearch

In [None]:
# Define the range of values for hyperparameters
th_values = np.round(np.linspace(0.2, 0.8, 13), decimals=2)
mw_values = np.linspace(1, 30, 6, dtype=int)
pc_values = np.arange(1,8)

# Prepare to collect results
results = []
best_model = None

# Loop through all possible combinations of 'th' and 'mw'
for pc in pc_values:
    for mw in mw_values:
        for th in th_values:
            pls_model = pls_fdm(n_components=pc, threshold=th, mw=mw)
            pls_model.train(train_X, train_Y)
            accuracy = pls_model.prediction_accuracy(val_X, val_Y)
            results.append({
                'th': th,
                'mw': mw,
                'pc': pc,
                'accuracy': accuracy
            })
            if best_model is None or accuracy > best_model['accuracy']:
                best_model = {
                    'th': th,
                    'mw': mw,
                    'pc': pc,
                    'accuracy': accuracy
                }
    print(f"Finished with {pc} components")


# Convert results to a DataFrame
results_df = pd.DataFrame(results)

# Choose results with optimal n_components
best_pc_results = results_df[results_df['pc'] == best_model['pc']]

# Pivot the DataFrame for heatmap plotting
pivot_table = best_pc_results.pivot(index='th', columns='mw', values='accuracy')

# Save the pivot table to a CSV file
heatmap_path = os.path.join(model_dir, "pls_gridsearch_heatmap_"+model_name+".csv") # can be passed to metrics_table
pivot_table.to_csv(heatmap_path)

### Visualize gridsearch

In [None]:
# Print optimal hyperparameters
best_row = results_df.loc[results_df['accuracy'].idxmax()]
print(f"Optimal parameters: accuracy = {best_row['accuracy']:.3f}, threshold = {best_row['th']:.3f}, moving time window = {best_row['mw']}")

# Plotting the results using seaborn heatmap
plt.figure(figsize=(10, 8))
visualize.set_plot_params(high_res=True)
ax = sns.heatmap(pivot_table, annot=False, cmap=visualize.get_hotcold_colormap(), fmt=".3f")
#plt.title('PLS Model Accuracy Heatmap')
ax.set_yticklabels(ax.get_yticklabels(), rotation=0)
plt.xlabel(r'Moving time window $\it{n}$ [-]')
plt.ylabel('Threshold Ï„ [-]')
# Save the heatmap
plt.savefig(os.path.join(model_dir, "pls_gridsearch_heatmap_"+model_name+".png"), dpi=300)
plt.show()

### Analyse and save optimal model

In [None]:
pls_model_best = pls_fdm(n_components=best_row['pc'].astype(int), threshold=best_row['th'], mw=best_row['mw'].astype(int))
pls_model_best.train(train_X, train_Y)
predictions_best = pls_model_best.predict(val_X)

metrics_path = os.path.join(model_dir, "pls_metrics_opt_"+model_name+".csv")
metrics = eval.metrics_table_oneclass(val_Y_unfused, predictions_best["fault"], save_path=metrics_path)
eval.visualize_metrics(metrics)

In [7]:
# Save PLS model
filename = 'model.pkl' # set model name
save_path = os.path.join(model_dir, filename)
pls_model_best.clear_large_attributes() # clear large attributes before saving
with open(save_path, 'wb') as file:
    pickle.dump(pls_model_best, file)


# Create and save config
config_model = {
    "model": "PLS-DA",
    "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
    "train_set": train_set_name,
    "name": model_name,
    "n_components": best_row['pc'],
    "threshold": best_row['th'],
    "moving time window": best_row['mw']
}

# Save the model config as a json file
config_name = "config.json"
config_path = os.path.join(model_dir, config_name)
with open(config_path, 'w') as json_file:
    json.dump(config_model, json_file, indent=4)

### Show performance with exemplatory validation set

In [None]:
example_setname = "FILL_IN_EXAMPLE_SET_NAME"
plot_example_set(model=pls_model_best, dataset_name=example_setname, parameter_plotted="weight", combined_figure=False)