# Truck Failure Prediction ML Pipeline
# =================================
# This notebook demonstrates an end-to-end ML pipeline for predicting truck failures using Snowflake's ML capabilities. We'll predict which trucks are likely to fail in the next 12 hours based on sensor data.

# --------------------------------
# Config - Set all parameters here
# --------------------------------

In [None]:
DATABASE_NAME = "SUMMIT_25"
SCHEMA_NAME = "ASSET_HEALTH"
WAREHOUSE_NAME = "MEDIUM"
MODEL_NAME = "ASSET_HEALTH_12HOUR_FAILURE_PREDICTION"
MODEL_VERSION = "XGB_V1"
RAW_DATA_TABLE = "TURBO_HISTORY_DATA"
PRODUCTION_DATA_TABLE = "TURBO_DATA_PRODUCTION"
PREDICTION_OUTPUT_TABLE = "TURBO_DATA_PREDICTIONS_NEW"
TRAIN_TEST_SPLIT_DATE = "2025-03-20 00:00:00"

# Import required packages
import streamlit as st
import pandas as pd
import numpy as np
import shap
from snowflake.snowpark.functions import col, lag, avg, stddev, min as sf_min, max as sf_max, hour
from snowflake.snowpark.window import Window
from snowflake.ml.feature_store import FeatureStore, FeatureView, Entity, CreationMode
from snowflake.ml.modeling.xgboost import XGBClassifier
from snowflake.ml.modeling.pipeline import Pipeline
from snowflake.ml.modeling.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score
from snowflake.ml.registry import Registry
from snowflake.snowpark import Session, DataFrame, Window, WindowSpec
import snowflake.snowpark.functions as F

# Get active session
from snowflake.snowpark.context import get_active_session
session = get_active_session()


# --------------------
# 1. Data Exploration
# --------------------

In [None]:
# First, let's examine our raw sensor data to understand what we're working with
print("Exploring raw sensor data...")

# Load and display raw data
df = session.table(RAW_DATA_TABLE)
df = df.with_column("timestamp", col("timestamp").cast("timestamp"))
df = df.with_column("truck_id", col("truck_id").cast("integer"))
print("Raw data sample:")
df.show(5)

print("Data columns:")
for column in df.columns:
    print(f"- {column}")

# ---------------------------
# 2. Feature Engineering
# ---------------------------

In [None]:
# We'll create features from sensor data to help predict failures
print("\nEngineering features from sensor data...")

# Helper function for feature engineering
def create_lag_features(df, columns, window):
    """Create lag features for the specified columns"""
    result_df = df
    for column in columns:
        result_df = result_df.with_column(f"{column}_lag1", lag(col(column)).over(window))
    return result_df

def create_rolling_stats(df, columns, window, stat_type="avg"):
    """Create rolling statistics for the specified columns"""
    result_df = df
    for column in columns:
        if stat_type == "avg":
            result_df = result_df.with_column(f"{column}_avg_1hr", avg(col(column)).over(window))
        elif stat_type == "min":
            result_df = result_df.with_column(f"{column}_min_1hr", sf_min(col(column)).over(window))
        elif stat_type == "max":
            result_df = result_df.with_column(f"{column}_max_1hr", sf_max(col(column)).over(window))
    return result_df

def create_delta_features(df, columns):
    """Create delta features (current - previous) for the specified columns"""
    result_df = df
    for column in columns:
        result_df = result_df.with_column(f"delta_{column}", col(column) - col(f"{column}_lag1"))
    return result_df

# Define window for time-based operations
w = Window.partition_by("truck_id").order_by("timestamp")
rolling_window = w.rows_between(-12, 0)  # 12 hour rolling window

# Sensor columns to use for feature engineering
sensor_columns = ["exhaust_gas_temp", "oil_pressure", "boost_pressure", 
                 "oil_contamination", "engine_boost_ratio"]

# 1. Create lag features (previous values)
df = create_lag_features(df, sensor_columns, w)

# 2. Create rolling average features
df = create_rolling_stats(df, sensor_columns, rolling_window, "avg")

# 3. Create min/max features for specific sensors
temp_pressure_columns = ["exhaust_gas_temp", "oil_pressure"]
df = create_rolling_stats(df, temp_pressure_columns, rolling_window, "min")
df = create_rolling_stats(df, temp_pressure_columns, rolling_window, "max")

# 4. Create rate of change features
df = create_delta_features(df, sensor_columns)

# 5. Add time-based features
df = df.with_column("hour_of_day", hour(col("timestamp")))

# 6. Remove label columns from feature set
df = df.drop(["PART_FAILURE_NEXT_12HR", "ACTUAL_FAILURE_EVENT"])

print("Engineered features sample:")
df.show(5)

# -----------------------------------
# 3. Feature Store Creation & Registration
# -----------------------------------

In [None]:
# We'll register our features in Snowflake's Feature Store for reuse
print("\nRegistering features in Feature Store...")

# Create feature store
fs = FeatureStore(
    session=session,
    database=DATABASE_NAME,
    name=SCHEMA_NAME,
    default_warehouse=WAREHOUSE_NAME,
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST
)

# Create or retrieve entity definition
try:
    # Try to retrieve existing entity
    truck_entity = fs.get_entity('Sensor_Data')
    print('Retrieved existing entity')
except:
    # Define new entity if it doesn't exist
    truck_entity = Entity(
        name="Sensor_Data",
        join_keys=["truck_id"]
    )
    
    fs.register_entity(truck_entity)
    print("Registered new entity")

# Create feature view
truck_feature_view = FeatureView(
    name="truck_sensors_feature_view_train",
    entities=[truck_entity],
    feature_df=df,
    timestamp_col="timestamp",
    refresh_freq="5 minutes",
    desc="Lagged sensor features for truck failure predictions"
)

# Register feature view
registered_fv = fs.register_feature_view(
    feature_view=truck_feature_view,
    version="v1",
    block=True,
    overwrite=True
)

print(f"Registered feature view: {registered_fv.name} (v{registered_fv.version})")

# -----------------------------------
# 4. Generate Training Dataset
# -----------------------------------

In [None]:
# Create dataset for training by joining features with labels
print("\nGenerating training dataset...")

# Get entity data with labels
entity_df = session.table(RAW_DATA_TABLE).select("truck_id", "timestamp", "PART_FAILURE_NEXT_12HR")
entity_df = entity_df.with_column("timestamp", col("timestamp").cast("timestamp"))

# Generate training set from feature store
training_df = fs.generate_training_set(
    features=[registered_fv],
    spine_df=entity_df,
    spine_timestamp_col="timestamp",
    spine_label_cols=["PART_FAILURE_NEXT_12HR"]
)

print("Training dataset sample:")
training_df.show(5)

# ----------------------------------
# 5. Create Test/Train Split
# ----------------------------------

In [None]:
print(f"\nSplitting data into train/test sets at {TRAIN_TEST_SPLIT_DATE}...")

# Split based on timestamp
train_df = training_df.filter(col("timestamp") < TRAIN_TEST_SPLIT_DATE) \
           .with_column("PART_FAILURE_NEXT_12HR", col("PART_FAILURE_NEXT_12HR").cast("integer")) \
           .dropna()
test_df = training_df.filter(col("timestamp") >= TRAIN_TEST_SPLIT_DATE).dropna()

print(f"Train set size: {train_df.count()} rows")
print(f"Test set size: {test_df.count()} rows")

# Define feature and label columns
feature_cols = [
    "EXHAUST_GAS_TEMP", "OIL_PRESSURE", "BOOST_PRESSURE", "OIL_CONTAMINATION", 
    "ENGINE_BOOST_RATIO", "AMBIENT_TEMP", "RPM", "SPEED", "GPS_LAT", "GPS_LON",
    "EXHAUST_GAS_TEMP_LAG1", "OIL_PRESSURE_LAG1", "BOOST_PRESSURE_LAG1", 
    "OIL_CONTAMINATION_LAG1", "ENGINE_BOOST_RATIO_LAG1",
    "EXHAUST_GAS_TEMP_AVG_1HR", "OIL_PRESSURE_AVG_1HR", "BOOST_PRESSURE_AVG_1HR", 
    "OIL_CONTAMINATION_AVG_1HR", "ENGINE_BOOST_RATIO_AVG_1HR",
    "EXHAUST_GAS_TEMP_MIN_1HR", "EXHAUST_GAS_TEMP_MAX_1HR", "OIL_PRESSURE_MIN_1HR", 
    "OIL_PRESSURE_MAX_1HR",
    "DELTA_EXHAUST_GAS_TEMP", "DELTA_OIL_PRESSURE", "DELTA_BOOST_PRESSURE", 
    "DELTA_OIL_CONTAMINATION", "DELTA_ENGINE_BOOST_RATIO",
    "HOUR_OF_DAY"
]

label_col = "PART_FAILURE_NEXT_12HR"


# Fix for Decimal Type Conversion Warnings
def explicitly_cast_decimal_columns(df):
    """
    Explicitly cast Decimal columns to Double to avoid automatic conversion warnings
    """
    from snowflake.snowpark.types import DoubleType
    
    # Get the schema to identify Decimal columns
    schema = df.schema
    
    # Loop through all columns and explicitly cast Decimal types to Double
    for field in schema.fields:
        if "DECIMAL" in str(field.datatype).upper():
            df = df.withColumn(field.name, df[field.name].cast(DoubleType()))
    
    return df

# Apply explicit casting to avoid conversion warnings
train_df = explicitly_cast_decimal_columns(train_df)
test_df = explicitly_cast_decimal_columns(test_df)

# ----------------------------------
# 6. Train Model
# ----------------------------------

In [None]:
print("\nTraining XGBoost classifier model...")

# Import warnings to suppress specific warnings
import warnings

# Create and train XGBoost model
model = XGBClassifier(
    input_cols=feature_cols,
    label_cols=[label_col],
    output_cols=["PREDICTION"]
)

# Suppress only the specific truncation warning during fitting
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message=".*truncation happened before inferring signature.*")
    model.fit(train_df)
    print("Model training complete on full dataset!")

# ----------------------------------
# 7. Evaluate Model
# ----------------------------------

In [None]:
print("\nEvaluating model on test data...")

# Make predictions on test data
pred_df = model.predict(test_df)
pred_df.show(5)

# Calculate metrics
accuracy_score_test = accuracy_score(
    df=pred_df, 
    y_true_col_names=[label_col], 
    y_pred_col_names=["PREDICTION"]
)
f1_score_test = f1_score(
    df=pred_df, 
    y_true_col_names=[label_col], 
    y_pred_col_names=["PREDICTION"]
)
recall_score_test = recall_score(
    df=pred_df, 
    y_true_col_names=[label_col], 
    y_pred_col_names=["PREDICTION"]
)
precision_score_test = precision_score(
    df=pred_df, 
    y_true_col_names=[label_col], 
    y_pred_col_names=["PREDICTION"]
)

print(f"Test Metrics:")
print(f"- Accuracy: {accuracy_score_test:.4f}")
print(f"- F1 Score: {f1_score_test:.4f}")
print(f"- Recall: {recall_score_test:.4f}")
print(f"- Precision: {precision_score_test:.4f}")

# ----------------------------------
# 8. Register Model in Model Registry
# ----------------------------------

In [None]:
print("\nRegistering model in Snowflake Model Registry...")

# Import the Registry class
from snowflake.ml.registry import Registry

# Create the registry object
registry = Registry(session=session)

# Create a small sample for registration
registration_sample = train_df.drop(["TRUCK_ID", "TIMESTAMP", label_col]).limit(100)

# Import warnings to suppress specific warnings
import warnings

# Generate a timestamp-based version to ensure uniqueness
import datetime
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
new_version = f"XGB_V1_{timestamp}"

# Log model to registry with metrics and new version name
print(f"Logging model: {MODEL_NAME} version: {new_version}")
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message=".*truncation happened before inferring signature.*")
    warnings.filterwarnings("ignore", message=".*Providing model signature for Snowpark ML Modeling model is not required.*")
    warnings.filterwarnings("ignore", message=".*`relax_version` is not set and therefore defaulted to True.*")    
    mv_base = registry.log_model(
        model_name=MODEL_NAME,
        model=model,
        version_name=new_version,  # Use new unique version name
        sample_input_data=registration_sample,
        comment="XGBoost Forecasting Model"
    )

# ----------------------------------
# 9. Analyze Feature Importance
# ----------------------------------

In [None]:
print("\nAnalyzing feature importance with SHAP values...")

# Create sample for SHAP analysis 
test_pd = test_df.to_pandas()
test_pd_sample = test_pd.sample(n=min(2500, len(test_pd)), random_state=100).reset_index(drop=True)

# Compute Shapley values
base_shap_pd = mv_base.run(test_pd_sample, function_name="explain")

# Convert to proper format for visualization
shap_values_np = np.array(base_shap_pd.astype(float))
feature_names = test_pd_sample.drop(columns=[label_col, "TIMESTAMP", "TRUCK_ID"]).columns

# Create SHAP values object
shap_values_obj = shap._explanation.Explanation(
    values=shap_values_np,
    feature_names=feature_names,
    data=test_pd_sample[feature_names].values
)

print("Top 5 most important features:")
importance_vals = np.abs(shap_values_np).mean(axis=0)
importance_df = pd.DataFrame({
    'feature': feature_names,
    'importance': importance_vals
})
importance_df = importance_df.sort_values('importance', ascending=False)
for i, row in importance_df.head(5).iterrows():
    print(f"- {row['feature']}: {row['importance']:.4f}")

# ----------------------------------
# 10. Production Pipeline Setup
# ----------------------------------

In [None]:
print("\nSetting up production pipeline...")

# Create production feature store
production_fs = FeatureStore(
    session=session,
    database=DATABASE_NAME,
    default_warehouse=WAREHOUSE_NAME,
    name=SCHEMA_NAME,
    creation_mode=CreationMode.CREATE_IF_NOT_EXIST
)

# Get feature view for production
experiment_fv = production_fs.get_feature_view(
    name="truck_sensors_feature_view_train",
    version="v1"
)

# Extract SQL logic for feature engineering
sql_logic = experiment_fv.feature_df.queries['queries'][0]
print("Extracted feature engineering SQL logic")

# Simplified production SQL that targets production data table
production_sql = f"""
SELECT 
    "EXHAUST_GAS_TEMP", "OIL_PRESSURE", "BOOST_PRESSURE", "OIL_CONTAMINATION", 
    "ENGINE_BOOST_RATIO", "AMBIENT_TEMP", "RPM", "SPEED", "GPS_LAT", "GPS_LON", 
    "TIMESTAMP", "TRUCK_ID", 
    
    -- Lag features
    "EXHAUST_GAS_TEMP_LAG1", "OIL_PRESSURE_LAG1", "BOOST_PRESSURE_LAG1", 
    "OIL_CONTAMINATION_LAG1", "ENGINE_BOOST_RATIO_LAG1", 
    
    -- Rolling averages
    "EXHAUST_GAS_TEMP_AVG_1HR", "OIL_PRESSURE_AVG_1HR", "BOOST_PRESSURE_AVG_1HR", 
    "OIL_CONTAMINATION_AVG_1HR", "ENGINE_BOOST_RATIO_AVG_1HR", 
    
    -- Min/Max
    "EXHAUST_GAS_TEMP_MIN_1HR", "EXHAUST_GAS_TEMP_MAX_1HR", 
    "OIL_PRESSURE_MIN_1HR", "OIL_PRESSURE_MAX_1HR", 
    
    -- Delta calculations
    ("EXHAUST_GAS_TEMP" - "EXHAUST_GAS_TEMP_LAG1") AS "DELTA_EXHAUST_GAS_TEMP", 
    ("OIL_PRESSURE" - "OIL_PRESSURE_LAG1") AS "DELTA_OIL_PRESSURE", 
    ("BOOST_PRESSURE" - "BOOST_PRESSURE_LAG1") AS "DELTA_BOOST_PRESSURE", 
    ("OIL_CONTAMINATION" - "OIL_CONTAMINATION_LAG1") AS "DELTA_OIL_CONTAMINATION", 
    ("ENGINE_BOOST_RATIO" - "ENGINE_BOOST_RATIO_LAG1") AS "DELTA_ENGINE_BOOST_RATIO", 
    
    -- Time features
    hour("TIMESTAMP") AS "HOUR_OF_DAY" 
FROM (
    SELECT 
        "EXHAUST_GAS_TEMP", "OIL_PRESSURE", "BOOST_PRESSURE", "OIL_CONTAMINATION", 
        "ENGINE_BOOST_RATIO", "AMBIENT_TEMP", "RPM", "SPEED", "GPS_LAT", "GPS_LON", 
        "TIMESTAMP", "TRUCK_ID", 
        
        -- Lag calculations
        LAG("EXHAUST_GAS_TEMP", 1, NULL) OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP") AS "EXHAUST_GAS_TEMP_LAG1", 
        LAG("OIL_PRESSURE", 1, NULL) OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP") AS "OIL_PRESSURE_LAG1", 
        LAG("BOOST_PRESSURE", 1, NULL) OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP") AS "BOOST_PRESSURE_LAG1", 
        LAG("OIL_CONTAMINATION", 1, NULL) OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP") AS "OIL_CONTAMINATION_LAG1", 
        LAG("ENGINE_BOOST_RATIO", 1, NULL) OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP") AS "ENGINE_BOOST_RATIO_LAG1", 
        
        -- Rolling averages (12 rows = ~1 hour of data)
        AVG("EXHAUST_GAS_TEMP") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "EXHAUST_GAS_TEMP_AVG_1HR", 
        AVG("OIL_PRESSURE") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "OIL_PRESSURE_AVG_1HR", 
        AVG("BOOST_PRESSURE") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "BOOST_PRESSURE_AVG_1HR", 
        AVG("OIL_CONTAMINATION") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "OIL_CONTAMINATION_AVG_1HR", 
        AVG("ENGINE_BOOST_RATIO") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "ENGINE_BOOST_RATIO_AVG_1HR", 
        
        -- Min/Max values
        MIN("EXHAUST_GAS_TEMP") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "EXHAUST_GAS_TEMP_MIN_1HR", 
        MAX("EXHAUST_GAS_TEMP") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "EXHAUST_GAS_TEMP_MAX_1HR", 
        MIN("OIL_PRESSURE") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "OIL_PRESSURE_MIN_1HR", 
        MAX("OIL_PRESSURE") OVER (PARTITION BY "TRUCK_ID" ORDER BY "TIMESTAMP" ROWS BETWEEN 12 PRECEDING AND CURRENT ROW) AS "OIL_PRESSURE_MAX_1HR" 
    FROM (
        SELECT 
            "EXHAUST_GAS_TEMP", "OIL_PRESSURE", "BOOST_PRESSURE", "OIL_CONTAMINATION", 
            "ENGINE_BOOST_RATIO", "AMBIENT_TEMP", "RPM", "SPEED", "GPS_LAT", "GPS_LON", 
            CAST("TIMESTAMP" AS TIMESTAMP) AS "TIMESTAMP", 
            CAST("TRUCK_ID" AS INT) AS "TRUCK_ID" 
        FROM {PRODUCTION_DATA_TABLE}
    )
)
"""

# Create dataframe with production features
feature_df = session.sql(production_sql)
print("Created production feature dataframe")

# Register all entities from experiment feature view
for entity in experiment_fv.entities:
    production_fs.register_entity(entity)

# Create production feature view
production_fv = FeatureView(
    name="PRODUCTION_FEATURE_VIEW",
    entities=experiment_fv.entities,
    feature_df=feature_df,
    timestamp_col="TIMESTAMP",
    refresh_freq="5 minutes",
    desc="Production feature view for truck telemetry"
)

# Register production feature view
production_fs.register_feature_view(
    feature_view=production_fv,
    version="1",
    block=True,
    overwrite=True
)

print("Registered production feature view")

# ----------------------------------
# 11. Apply Model for Predictions
# ----------------------------------

In [None]:
print("\nApplying model to production data...")

# Import warnings to suppress specific warnings
import warnings

# Load model from registry
registry = Registry(session=session)
model = registry.get_model(MODEL_NAME)

# Get the specific version of the model you registered
model_version_name = "XGB_V1_20250515_155142"  # Use your actual version name from registration
model_version = model.version(model_version_name)

# Create inference function
def predict_failure_probability(feature_vector, model):
    """Run inference with the specified model"""
    return model.run(feature_vector, function_name="predict_proba")

# Read from production feature view
feature_view = production_fs.get_feature_view("PRODUCTION_FEATURE_VIEW", version="1")
inference_input_sdf = production_fs.read_feature_view(feature_view)

# Apply explicit decimal type conversion to inference data
print("Explicitly casting Decimal columns in inference data to Double...")
inference_input_sdf = explicitly_cast_decimal_columns(inference_input_sdf)

# Make predictions with warning suppression
print("Generating predictions...")
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", message=".*Type DecimalType.*is being automatically converted to DOUBLE.*")
    inference_result_sdf = predict_failure_probability(inference_input_sdf, model_version)

print("Generated predictions on production data")

# Show sample predictions
inference_result_sdf.sort(F.col('TRUCK_ID').desc(), F.col('TIMESTAMP')).show(5)

# ----------------------------------
# 12. Save Predictions to Table
# ----------------------------------

In [None]:
print(f"\nSaving predictions to {PREDICTION_OUTPUT_TABLE}...")
inference_result_sdf.write.mode("overwrite").save_as_table(PREDICTION_OUTPUT_TABLE)
print("Predictions saved successfully!")