In [41]:
using Statistics
using MAT

In [42]:
# Load data
function load_data()
    println("Loading data...")
    data = matread("train_data_play.mat")
    U = data["U"]
    X = data["X"]
    measurement_series = data["measurement_series"]
    dt = data["dt"]
    println("Data loaded: $(size(U)) inputs, $(size(X)) states")
    return U, X, measurement_series, dt
end

load_data (generic function with 1 method)

In [43]:
include("utils.jl")
# Data preparation
function prepare_sequences(U_raw, X_raw, measurement_series, series_indices; max_length=8000)
    sequences = []

    #Extraction
    for series_idx in series_indices
        u_series = Float32.(extract(U_raw, measurement_series, series_idx))
        x_series = Float32.(extract(X_raw, measurement_series, series_idx))

        # Limiting length of sequences
        T = min(size(x_series, 1), max_length)

        sequence = Dict(
            "U" => u_series[1:T, :],     # [T × 21] inputs
            "X" => x_series[1:T, :],     # [T × 2] ground truth
            "x0" => x_series[1, :],      # [2] initial state
            "T" => T,
            "series_idx" => series_idx
        )

        push!(sequences, sequence)
        println("Series $series_idx: $T timesteps")
    end

    return sequences
end


prepare_sequences (generic function with 1 method)

In [None]:
function clean_data(U, X)
    println("Starting data cleaning:")
    println("Original data shapes: U=$(size(U)), X=$(size(X))")
    
    U_clean = copy(U)
    X_clean = copy(X)
    
    # Replace NaNs with interpolated values
    nan_mask_U = any(isnan.(U_clean), dims=2)[:]
    nan_mask_X = any(isnan.(X_clean), dims=2)[:]
    combined_nan_mask = nan_mask_U .| nan_mask_X
    
    if sum(combined_nan_mask) > 0
        println("Found $(sum(combined_nan_mask)) rows with NaN values")
         U_clean = interpolate_missing(U_clean)
         X_clean = interpolate_missing(X_clean)
    end
    
    # Remove outliers using IQR method
    #U_clean, X_clean = remove_outliers_iqr(U_clean, X_clean)
    
    # Smooth noisy signals
    X_clean = apply_smoothing(X_clean, 3)  
    
    # Handle physically unrealistic values
    U_clean, X_clean = fix_unrealistic_values(U_clean, X_clean)
    
    # Remove duplicate consecutive measurements
   # U_clean, X_clean = remove_duplicates(U_clean, X_clean)
    
    # Check for data consistency
    check_data_consistency(U_clean, X_clean)
    
    println("Data cleaning completed.")
    println("Cleaned data shapes: U=$(size(U_clean)), X=$(size(X_clean))")
    
    return U_clean, X_clean
end


clean_data (generic function with 1 method)

In [None]:

function remove_outliers_iqr(U, X, factor=3)
    """Remove outliers using Interquartile Range method"""
    
    # Calculate outliers for each feature
    outlier_mask = falses(size(U, 1))
    
    # Check U features
    for i in 1:size(U, 2)
        q1 = quantile(U[:, i], 0.25)
        q3 = quantile(U[:, i], 0.75)
        iqr = q3 - q1
        lower_bound = q1 - factor * iqr
        upper_bound = q3 + factor * iqr
        
        feature_outliers = (U[:, i] .< lower_bound) .| (U[:, i] .> upper_bound)
        outlier_mask .|= feature_outliers
    end
    
    # Check X features (temperatures)
    for i in 1:size(X, 2)
        q1 = quantile(X[:, i], 0.25)
        q3 = quantile(X[:, i], 0.75)
        iqr = q3 - q1
        lower_bound = q1 - factor * iqr
        upper_bound = q3 + factor * iqr
        
        feature_outliers = (X[:, i] .< lower_bound) .| (X[:, i] .> upper_bound)
        outlier_mask .|= feature_outliers
    end
    
    outliers_removed = sum(outlier_mask)
    if outliers_removed > 0
        println("Removed $outliers_removed outlier points")
        return U[.!outlier_mask, :], X[.!outlier_mask, :]
    end
    
    return U, X
end


In [None]:
function apply_smoothing(X, window_size=3)
    """Apply moving average smoothing to temperature data"""
    X_smooth = copy(X)
    
    for i in 1:size(X, 2)
        for j in (window_size÷2 + 1):(size(X, 1) - window_size÷2) 
            start_idx = j - window_size÷2
            end_idx = j + window_size÷2
            X_smooth[j, i] = mean(X[start_idx:end_idx, i])
        end
    end
    
    println("Applied smoothing with window size $window_size")
    return X_smooth
end

apply_smoothing (generic function with 2 methods)

In [None]:

function fix_unrealistic_values(U, X)
    """Fix physically unrealistic values based on domain knowledge"""
    U_clean = copy(U)
    X_clean = copy(X)
    
    # Temperature bounds
    min_temp = -40.0  
    max_temp = 200.0  
    
    # Clamp temperatures to realistic ranges
    X_clean[:, 1] = clamp.(X_clean[:, 1], min_temp, max_temp)  # Stator temp
    X_clean[:, 2] = clamp.(X_clean[:, 2], min_temp, max_temp)  # Rotor temp
    
    # Fix RPM (should be non-negative, reasonable upper bound)
    U_clean[:, 1] = clamp.(U_clean[:, 1], 0.0, 20000.0)  # RPM
    
    # Fix voltage values (reasonable bounds)
    U_clean[:, 8] = clamp.(U_clean[:, 8], -600.0, 600.0)   # u_d voltage
    U_clean[:, 10] = clamp.(U_clean[:, 10], -600.0, 600.0) # u_q voltage
    U_clean[:, 13] = clamp.(U_clean[:, 13], 0.0, 1000.0)     # DC link voltage
    
    # Fix current values (reasonable bounds)
    for current_idx in [4, 5, 7]  # i_d, i_q, RMS current
        U_clean[:, current_idx] = clamp.(U_clean[:, current_idx], -500.0, 500.0)
    end
    
    # Fix oil flow rates (should be non-negative)
    U_clean[:, 16] = clamp.(U_clean[:, 16], 0.0, 100.0)  # Rotor oil flow
    U_clean[:, 17] = clamp.(U_clean[:, 17], 0.0, 100.0)  # Stator oil flow
    
    # Fix oil temperatures
    for oil_temp_idx in [18, 19, 20, 21]
        U_clean[:, oil_temp_idx] = clamp.(U_clean[:, oil_temp_idx], -20.0, 150.0)
    end
    
    println("Applied realistic value constraints")
    return U_clean, X_clean
end


fix_unrealistic_values (generic function with 1 method)

In [None]:

function remove_duplicates(U, X, tolerance=1e-6)
    """Remove consecutive duplicate measurements"""
    keep_mask = trues(size(U, 1))
    
    for i in 2:size(U, 1)
        # Check if current row is essentially identical to previous row
        u_diff = maximum(abs.(U[i, :] - U[i-1, :]))
        x_diff = maximum(abs.(X[i, :] - X[i-1, :]))
        
        if u_diff < tolerance && x_diff < tolerance
            keep_mask[i] = false
        end
    end
    
    duplicates_removed = sum(.!keep_mask)
    if duplicates_removed > 0
        println("Removed $duplicates_removed duplicate consecutive measurements")
        return U[keep_mask, :], X[keep_mask, :]
    end
    
    return U, X
end


In [47]:

function check_data_consistency(U, X)
    """Check for data consistency and report potential issues"""
    println("\n=== Data Consistency Check ===")
    
    # Check for extreme jumps in temperature (might indicate sensor issues)
    for i in 2:size(X, 1)
        temp_jump_stator = abs(X[i, 1] - X[i-1, 1])
        temp_jump_rotor = abs(X[i, 2] - X[i-1, 2])
        
        if temp_jump_stator > 20.0
            println("Warning: Large stator temperature jump at index $i: $(temp_jump_stator)°C")
        end
        if temp_jump_rotor > 20.0
            println("Warning: Large rotor temperature jump at index $i: $(temp_jump_rotor)°C")
        end
    end
    
    # Check correlation between related variables
    rpm_torque_corr = cor(U[:, 1], U[:, 2])  # RPM vs Torque
    println("RPM-Torque correlation: $(round(rpm_torque_corr, digits=3))")
    
    stator_rotor_corr = cor(X[:, 1], X[:, 2])  # Stator vs Rotor temp
    println("Stator-Rotor temperature correlation: $(round(stator_rotor_corr, digits=3))")
    
    # Check for zero-variance features
    for i in 1:size(U, 2)
        if var(U[:, i]) < 1e-10
            println("Warning: Feature $i (U) has very low variance")
        end
    end
    
    println("=== End Consistency Check ===\n")
end


check_data_consistency (generic function with 1 method)

In [48]:

function interpolate_missing(data)
    """Simple linear interpolation for missing values"""
    data_clean = copy(data)
    
    for col in 1:size(data, 2)
        nan_indices = findall(isnan, data[:, col])
        
        if !isempty(nan_indices)
            valid_indices = findall(.!isnan.(data[:, col]))
            
            if length(valid_indices) >= 2
                # Linear interpolation
                for nan_idx in nan_indices
                    # Find nearest valid values
                    before_idx = findlast(x -> x < nan_idx, valid_indices)
                    after_idx = findfirst(x -> x > nan_idx, valid_indices)
                    
                    if !isnothing(before_idx) && !isnothing(after_idx)
                        before_pos = valid_indices[before_idx]
                        after_pos = valid_indices[after_idx]
                        
                        # Linear interpolation
                        weight = (nan_idx - before_pos) / (after_pos - before_pos)
                        data_clean[nan_idx, col] = (1 - weight) * data[before_pos, col] + 
                                                   weight * data[after_pos, col]
                    elseif !isnothing(before_idx)
                        # Use last valid value
                        data_clean[nan_idx, col] = data[valid_indices[before_idx], col]
                    elseif !isnothing(after_idx)
                        # Use next valid value
                        data_clean[nan_idx, col] = data[valid_indices[after_idx], col]
                    end
                end
            end
        end
    end
    
    return data_clean
end

interpolate_missing (generic function with 1 method)

In [55]:
function main()
    # Load data
    U, X, measurement_series, dt = load_data()
    
    println("Temperature statistics before cleaning:")
    check_stats(U,X)

    # Clean the data
    U_clean, X_clean = clean_data(U, X)
    
    println("Temperature statistics after cleaning:")
    println("Stator temp - Min: $(minimum(X_clean[:, 1])), Max: $(maximum(X_clean[:, 1])), Mean: $(round(mean(X_clean[:, 1]), digits=2))")
    println("Rotor temp - Min: $(minimum(X_clean[:, 2])), Max: $(maximum(X_clean[:, 2])), Mean: $(round(mean(X_clean[:, 2]), digits=2))")
    
    return U_clean, X_clean, measurement_series, dt
end

main (generic function with 1 method)

In [56]:
U_clean, X_clean, measurement_series, dt = main()

Loading data...
Data loaded: (594885, 21) inputs, (594885, 2) states
Temperature statistics before cleaning:
=== DATA STATISTICS ===
Dataset size: 594885 samples

--- TEMPERATURE STATISTICS ---
Stator temp - Min: 32.27°C, Max: 187.66°C, Mean: 90.88°C
Rotor temp  - Min: 32.56°C, Max: 133.19°C, Mean: 80.53°C

--- KEY INPUT STATISTICS ---
RPM         - Min: 191.0, Max: 10513.0, Mean: 3513.0
Torque      - Min: -512.28, Max: 513.99, Mean: 0.06
u_d voltage - Min: -404.83V, Max: 395.32V, Mean: -8.41V
u_q voltage - Min: -56.99V, Max: 434.18V, Mean: 142.29V
DC link V   - Min: 396.04V, Max: 755.49V, Mean: 584.44V
i_d current - Min: -594.03A, Max: 1.5A, Mean: -74.44A
i_q current - Min: -422.53A, Max: 420.83A, Mean: 1.16A

--- COOLING SYSTEM STATISTICS ---
Rotor oil flow   - Min: 0.83, Max: 5.9, Mean: 2.8
Stator oil flow  - Min: 4.93, Max: 5.12, Mean: 5.0
Oil temp entry (rotor)  - Min: 29.39°C, Max: 84.23°C
Oil temp entry (stator) - Min: 29.39°C, Max: 82.29°C
Oil temp exit A (rotor) - Min: 27.53°C

([500.010009765625 0.38555998653173446 … 59.459999084472656 61.18000030517578; 500.0199890136719 0.3923849925398827 … 59.47999954223633 61.15999984741211; … ; 500.0199890136719 0.37611001431941987 … 76.72000122070312 79.12999725341797; 499.989990234375 0.3973199978470803 … 76.72000122070312 79.12999725341797], [65.40201568603516 67.427001953125; 65.37940724690755 67.42171986897786; … ; 81.68876139322917 86.22265625; 81.68314361572266 86.21672821044922], [1; 1; … ; 10; 10;;], 0.5)

In [52]:
function check_stats(U, X; show_all_features=true)
    """
    Check statistics for input and state variables
    
    Parameters:
    - U: Input matrix (N × 21)
    - X: State matrix (N × 2) 
    - show_all_features: if true, shows stats for all 21 input features
    """
    println("=== DATA STATISTICS ===")
    println("Dataset size: $(size(U, 1)) samples")
    
    # Temperature statistics (always show)
    println("\n--- TEMPERATURE STATISTICS ---")
    println("Stator temp - Min: $(round(minimum(X[:, 1]), digits=2))°C, Max: $(round(maximum(X[:, 1]), digits=2))°C, Mean: $(round(mean(X[:, 1]), digits=2))°C")
    println("Rotor temp  - Min: $(round(minimum(X[:, 2]), digits=2))°C, Max: $(round(maximum(X[:, 2]), digits=2))°C, Mean: $(round(mean(X[:, 2]), digits=2))°C")
    
    # Key input features (always show)
    println("\n--- KEY INPUT STATISTICS ---")
    println("RPM         - Min: $(round(minimum(U[:, 1]), digits=0)), Max: $(round(maximum(U[:, 1]), digits=0)), Mean: $(round(mean(U[:, 1]), digits=0))")
    println("Torque      - Min: $(round(minimum(U[:, 2]), digits=2)), Max: $(round(maximum(U[:, 2]), digits=2)), Mean: $(round(mean(U[:, 2]), digits=2))")
    println("u_d voltage - Min: $(round(minimum(U[:, 8]), digits=2))V, Max: $(round(maximum(U[:, 8]), digits=2))V, Mean: $(round(mean(U[:, 8]), digits=2))V")
    println("u_q voltage - Min: $(round(minimum(U[:, 10]), digits=2))V, Max: $(round(maximum(U[:, 10]), digits=2))V, Mean: $(round(mean(U[:, 10]), digits=2))V")
    println("DC link V   - Min: $(round(minimum(U[:, 13]), digits=2))V, Max: $(round(maximum(U[:, 13]), digits=2))V, Mean: $(round(mean(U[:, 13]), digits=2))V")
    println("i_d current - Min: $(round(minimum(U[:, 4]), digits=2))A, Max: $(round(maximum(U[:, 4]), digits=2))A, Mean: $(round(mean(U[:, 4]), digits=2))A")
    println("i_q current - Min: $(round(minimum(U[:, 5]), digits=2))A, Max: $(round(maximum(U[:, 5]), digits=2))A, Mean: $(round(mean(U[:, 5]), digits=2))A")
    
    # Oil system statistics
    println("\n--- COOLING SYSTEM STATISTICS ---")
    println("Rotor oil flow   - Min: $(round(minimum(U[:, 16]), digits=2)), Max: $(round(maximum(U[:, 16]), digits=2)), Mean: $(round(mean(U[:, 16]), digits=2))")
    println("Stator oil flow  - Min: $(round(minimum(U[:, 17]), digits=2)), Max: $(round(maximum(U[:, 17]), digits=2)), Mean: $(round(mean(U[:, 17]), digits=2))")
    println("Oil temp entry (rotor)  - Min: $(round(minimum(U[:, 18]), digits=2))°C, Max: $(round(maximum(U[:, 18]), digits=2))°C")
    println("Oil temp entry (stator) - Min: $(round(minimum(U[:, 19]), digits=2))°C, Max: $(round(maximum(U[:, 19]), digits=2))°C")
    println("Oil temp exit A (rotor) - Min: $(round(minimum(U[:, 20]), digits=2))°C, Max: $(round(maximum(U[:, 20]), digits=2))°C")
    println("Oil temp exit B (rotor) - Min: $(round(minimum(U[:, 21]), digits=2))°C, Max: $(round(maximum(U[:, 21]), digits=2))°C")
    
    # Data quality checks
    println("\n--- DATA QUALITY CHECKS ---")
    nan_count_U = sum(any(isnan.(U), dims=2))
    nan_count_X = sum(any(isnan.(X), dims=2))
    println("Rows with NaN in U: $nan_count_U")
    println("Rows with NaN in X: $nan_count_X")
    
    # Check for potential issues
    println("\n--- POTENTIAL ISSUES ---")
    if maximum(U[:, 1]) > 15000
        println("⚠️  RPM exceeds typical automotive range (>15,000)")
    end
    if maximum(X[:, 1]) > 150 || maximum(X[:, 2]) > 150
        println("⚠️  High temperature detected (>150°C)")
    end
    if minimum(X[:, 1]) < 0 || minimum(X[:, 2]) < 0
        println("⚠️  Negative temperature detected")
    end
    if abs(maximum(U[:, 8])) > 500 || abs(maximum(U[:, 10])) > 500
        println("⚠️  Voltage exceeds typical automotive range (>±500V)")
    end
    
    # Correlations
    println("\n--- KEY CORRELATIONS ---")
    println("RPM-Torque correlation: $(round(cor(U[:, 1], U[:, 2]), digits=3))")
    println("Stator-Rotor temp correlation: $(round(cor(X[:, 1], X[:, 2]), digits=3))")
    println("RPM-Power correlation: $(round(cor(U[:, 1], U[:, 15]), digits=3))")
    
    # All features (optional)
    if show_all_features
        println("\n--- ALL INPUT FEATURES ---")
        feature_names = ["RPM", "Torque", "|Torque|", "i_d", "i_q", "|i_q|", "RMS_current", 
                        "u_d", "|u_d|", "u_q", "neutral_V", "switch_freq", "DC_link_V", 
                        "mod_index", "power", "oil_flow_rotor", "oil_flow_stator", 
                        "oil_temp_entry_rotor", "oil_temp_entry_stator", 
                        "oil_temp_exit_A", "oil_temp_exit_B"]
        
        for i in 1:size(U, 2)
            println("$(feature_names[i]) - Min: $(round(minimum(U[:, i]), digits=2)), Max: $(round(maximum(U[:, i]), digits=2)), Mean: $(round(mean(U[:, i]), digits=2))")
        end
    end
    
    println("\n=== END STATISTICS ===")
end

# Convenience function for before/after cleaning comparison
function check_stats_comparison(U_before, X_before, U_after, X_after)
    """Compare statistics before and after cleaning"""
    println("BEFORE CLEANING:")
    println("================")
    check_stats(U_before, X_before)
    
    println("\n\nAFTER CLEANING:")
    println("===============")
    check_stats(U_after, X_after)
    
    println("\n\nCHANGES SUMMARY:")
    println("================")
    println("Data points: $(size(U_before, 1)) → $(size(U_after, 1)) ($(size(U_before, 1) - size(U_after, 1)) removed)")
    
    temp_change_stator = abs(mean(X_after[:, 1]) - mean(X_before[:, 1]))
    temp_change_rotor = abs(mean(X_after[:, 2]) - mean(X_before[:, 2]))
    
    println("Mean temperature change:")
    println("  Stator: $(round(temp_change_stator, digits=3))°C")
    println("  Rotor:  $(round(temp_change_rotor, digits=3))°C")
end

check_stats_comparison (generic function with 1 method)