In [None]:
import os
import h5py
import numpy as np
import onnxruntime

# Define the data loading and preprocessing functions
def load_data_from_h5(file_path):
    with h5py.File(file_path, "r") as hf:
        vibration_data = hf["vibration_data"][:]
    return vibration_data

def normalize_data(data):
    min_value, max_value = -2500.0, 2500.0
    normalized_data = 2 * ((data - min_value) / (max_value - min_value)) - 1
    return normalized_data

# Load data
good_data = load_data_from_h5("../data/M01/OP07/good/M01_Aug_2019_OP07_000.h5")
bad_data = load_data_from_h5("../data/M01/OP07/bad/M01_Aug_2019_OP07_000.h5")

# Normalize data
good_data_normalized = normalize_data(good_data)
bad_data_normalized = normalize_data(bad_data)

# Load the ONNX model
onnx_file_path = "onnx_models/97e1a403699d408c89c536d7e4a7c09b.onnx"
ort_session = onnxruntime.InferenceSession(onnx_file_path)
input_name = ort_session.get_inputs()[0].name

def infer_and_compute_loss(data_normalized):
    losses = []
    seq_length = 4000
    num_chunks = len(data_normalized) // seq_length

    for i in range(num_chunks):
        chunk = data_normalized[i * seq_length: (i + 1) * seq_length]
        chunk = chunk.reshape(1, seq_length, 3).astype(np.float32)
        ort_inputs = {input_name: chunk}
        ort_outs = ort_session.run(None, ort_inputs)
        reconstruction = np.array(ort_outs).squeeze()
        loss = np.mean((reconstruction - chunk) ** 2)
        losses.append(loss)
    
    return np.mean(losses)

root_dir = "../data"
machine_names = ["M01", "M02", "M03"]
process_name = "OP07"
labels = ["good", "bad"]

loss_stats_machine = {machine: [] for machine in machine_names}
loss_stats_machine_label = {machine: {label: [] for label in labels} for machine in machine_names}
loss_stats_label = {label: [] for label in labels}

for machine_name in machine_names:
    for label in labels:
        data_dir = os.path.join(root_dir, machine_name, process_name, label)
        file_list = [file for file in os.listdir(data_dir) if file.endswith(".h5")]

        for file_name in file_list:
            file_path = os.path.join(data_dir, file_name)
            data = load_data_from_h5(file_path)
            data_normalized = normalize_data(data)

            mse_loss = infer_and_compute_loss(data_normalized)

            print(f"MSE Loss for {file_name} in machine {machine_name} with label {label}: {mse_loss}")

            loss_stats_machine[machine_name].append(mse_loss)
            loss_stats_machine_label[machine_name][label].append(mse_loss)
            loss_stats_label[label].append(mse_loss)


In [None]:
print("\nStatistics:")

print("\nAverage, Min, Max Loss Per Machine:")
for machine, losses in loss_stats_machine.items():
    print(f"{machine}: Avg: {sum(losses)/len(losses):.5f}, Min: {min(losses):.5f}, Max: {max(losses):.5f}")

print("\nAverage, Min, Max Loss Per Machine Per Label:")
for machine, label_losses in loss_stats_machine_label.items():
    for label, losses in label_losses.items():
        print(f"{machine} - {label}: Avg: {sum(losses)/len(losses):.5f}, Min: {min(losses):.5f}, Max: {max(losses):.5f}")

print("\nAverage, Min, Max Loss Per Label:")
for label, losses in loss_stats_label.items():
    print(f"{label}: Avg: {sum(losses)/len(losses):.5f}, Min: {min(losses):.5f}, Max: {max(losses):.5f}")

In [None]:
# Determine the threshold for classification
threshold = min(loss_stats_label['bad']) * 0.95

# Initialize counters
true_positive = 0
true_negative = 0
false_positive = 0
false_negative = 0

# Iterate over 'good' labeled data and classify
for loss in loss_stats_label['good']:
    if loss > threshold:
        false_positive += 1
    else:
        true_negative += 1

# Iterate over 'bad' labeled data and classify
for loss in loss_stats_label['bad']:
    if loss > threshold:
        true_positive += 1
    else:
        false_negative += 1

print(f"\nClassification Statistics based on threshold of {threshold:.5f}:")
print(f"True Positives (TP): {true_positive}")
print(f"True Negatives (TN): {true_negative}")
print(f"False Positives (FP): {false_positive}")
print(f"False Negatives (FN): {false_negative}")

print(f"\nTotals:")
print(f"Total 'good' labeled data: {len(loss_stats_label['good'])}")
print(f"Total 'bad' labeled data: {len(loss_stats_label['bad'])}")