# HUST Bearing Fault Diagnosis - Memory-Efficient High Accuracy Classification
## Using Smart Augmentation + EfficientNet for 99% Accuracy

**Author:** Muhammad Umar  
**Dataset:** HUST Bearing (99 samples → ~400 balanced samples)  
**Classes:** 4 (Normal, Ball, Inner Race, Outer Race)  
**Target:** 99% Accuracy with minimal memory usage

## 1. Import Libraries

In [None]:
# Core libraries
import os
import gc
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from scipy import signal as scipy_signal
import pywt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Deep Learning
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# Memory management
import psutil

# Set random seeds
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {len(tf.config.list_physical_devices('GPU')) > 0}")
print(f"Available RAM: {psutil.virtual_memory().available / (1024**3):.2f} GB")

## 2. Configuration

In [None]:
# Paths
DATA_PATH = r'F:\NeuTech\HUST bearing\HUST bearing dataset'
OUTPUT_PATH = r'F:\NeuTech\Results\HUST_Classification'
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Signal parameters - OPTIMIZED FOR MEMORY
SAMPLING_FREQ = 51200
SEGMENT_LENGTH = 2048  # Segment size
SAMPLES_PER_FILE = 4   # Only 4 segments per file (controlled augmentation)

# CWT parameters - REDUCED SIZE
SCALES = np.arange(1, 64)  # Reduced from 128 to 64
WAVELET = 'morl'
IMAGE_SIZE = (128, 128)  # Reduced from 224 to 128 for memory

# Training parameters
BATCH_SIZE = 16  # Reduced batch size
EPOCHS = 100
LEARNING_RATE = 0.001

# Class mapping
CLASS_NAMES = ['Normal', 'Ball', 'Inner', 'Outer']

print("✓ Configuration set!")

## 3. Load and Organize Dataset

In [None]:
def load_hust_data(data_path):
    """Load and organize HUST dataset by fault type"""
    print("Loading HUST bearing dataset...\n")
    
    all_files = [f for f in os.listdir(data_path) if f.endswith('.mat')]
    
    data_dict = {'Normal': [], 'Ball': [], 'Inner': [], 'Outer': []}
    
    for file in all_files:
        file_upper = file.upper()
        if file_upper.startswith('N'):
            data_dict['Normal'].append(os.path.join(data_path, file))
        elif file_upper.startswith('B') and not file_upper.startswith('BA'):
            data_dict['Ball'].append(os.path.join(data_path, file))
        elif file_upper.startswith('I'):
            data_dict['Inner'].append(os.path.join(data_path, file))
        elif file_upper.startswith('O'):
            data_dict['Outer'].append(os.path.join(data_path, file))
    
    print("Dataset Statistics:")
    print("="*50)
    for name, files in data_dict.items():
        print(f"{name:15s}: {len(files):3d} files")
    print("="*50 + "\n")
    
    return data_dict

data_dict = load_hust_data(DATA_PATH)

## 4. Memory-Efficient Signal Processing

In [None]:
def load_signal_from_mat(filepath):
    """Load vibration signal from .mat file"""
    try:
        data = loadmat(filepath)
        for key in data.keys():
            if not key.startswith('__'):
                signal_data = data[key]
                if isinstance(signal_data, np.ndarray) and signal_data.size > 100:
                    return signal_data.flatten()
    except:
        pass
    return None

def extract_segments(signal, segment_length, num_segments=4):
    """Extract fixed number of random segments from signal"""
    segments = []
    signal_length = len(signal)
    
    if signal_length < segment_length:
        return []
    
    # Random starting positions
    max_start = signal_length - segment_length
    starts = np.random.choice(max_start, size=min(num_segments, max_start), replace=False)
    
    for start in starts:
        segment = signal[start:start + segment_length]
        segments.append(segment)
    
    return segments

def signal_to_cwt_image(signal, scales, wavelet='morl', target_size=(128, 128)):
    """Convert signal to CWT image - memory efficient"""
    # Normalize
    signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-8)
    
    # Compute CWT - use float32 to save memory
    coefficients, _ = pywt.cwt(signal.astype(np.float32), scales, wavelet)
    
    # Get magnitude
    cwt_abs = np.abs(coefficients)
    
    # Normalize to 0-1
    cwt_norm = (cwt_abs - cwt_abs.min()) / (cwt_abs.max() - cwt_abs.min() + 1e-8)
    
    # Resize using simple method (faster, less memory)
    from skimage.transform import resize
    cwt_resized = resize(cwt_norm, target_size, mode='reflect', anti_aliasing=True)
    
    # Convert to RGB
    cwt_rgb = np.stack([cwt_resized] * 3, axis=-1)
    
    return (cwt_rgb * 255).astype(np.uint8)

print("✓ Signal processing functions defined!")

## 5. Generate Balanced Dataset with Memory Management

In [None]:
def generate_balanced_dataset(data_dict, samples_per_file=4):
    """Generate balanced dataset with memory management"""
    X_images = []
    y_labels = []
    
    print("\nGenerating CWT images (balanced sampling)...")
    print("="*60)
    
    for class_idx, (class_name, file_paths) in enumerate(data_dict.items()):
        print(f"\nProcessing {class_name}...")
        class_count = 0
        
        for file_idx, file_path in enumerate(file_paths):
            # Load signal
            signal_data = load_signal_from_mat(file_path)
            if signal_data is None:
                continue
            
            # Extract fixed number of segments
            segments = extract_segments(signal_data, SEGMENT_LENGTH, samples_per_file)
            
            for segment in segments:
                # Generate CWT image
                cwt_img = signal_to_cwt_image(segment, SCALES, WAVELET, IMAGE_SIZE)
                
                X_images.append(cwt_img)
                y_labels.append(class_idx)
                class_count += 1
            
            # Clear memory periodically
            if (file_idx + 1) % 10 == 0:
                gc.collect()
        
        print(f"  Generated {class_count} images")
    
    X_images = np.array(X_images, dtype=np.uint8)
    y_labels = np.array(y_labels, dtype=np.int32)
    
    print("\n" + "="*60)
    print(f"Total images: {len(X_images)}")
    print(f"Image shape: {X_images[0].shape}")
    print(f"Memory usage: {X_images.nbytes / (1024**2):.2f} MB")
    print("="*60 + "\n")
    
    return X_images, y_labels

# Generate dataset
X_data, y_data = generate_balanced_dataset(data_dict, SAMPLES_PER_FILE)

# Show distribution
unique, counts = np.unique(y_data, return_counts=True)
print("Dataset Distribution:")
for i, count in zip(unique, counts):
    print(f"  {CLASS_NAMES[i]:15s}: {count:4d} samples ({count/len(y_data)*100:.1f}%)")

## 6. Visualize Samples

In [None]:
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

for i in range(4):
    class_indices = np.where(y_data == i)[0]
    samples = np.random.choice(class_indices, 2, replace=False)
    
    for j, idx in enumerate(samples):
        axes[i*2 + j].imshow(X_data[idx], cmap='jet')
        axes[i*2 + j].set_title(f'{CLASS_NAMES[i]} - Sample {j+1}', 
                               fontweight='bold', fontsize=12)
        axes[i*2 + j].axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'sample_cwt_images.png'), dpi=200, bbox_inches='tight')
plt.show()

## 7. Train/Test Split

In [None]:
# Split: 80% train, 20% test
X_train, X_test, y_train, y_test = train_test_split(
    X_data, y_data, test_size=0.2, random_state=42, stratify=y_data
)

# Normalize
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

# One-hot encode
y_train_cat = keras.utils.to_categorical(y_train, 4)
y_test_cat = keras.utils.to_categorical(y_test, 4)

print("Train/Test Split:")
print(f"  Train: {len(X_train)} samples")
print(f"  Test:  {len(X_test)} samples")

# Clear original data from memory
del X_data, y_data
gc.collect()

print("\n✓ Data ready for training!")

## 8. Build Optimized Model

In [None]:
def build_efficientnet_model(input_shape=(128, 128, 3), num_classes=4):
    """Build EfficientNet model optimized for small input"""
    
    base_model = EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=input_shape
    )
    
    # Freeze base model
    base_model.trainable = False
    
    # Build model
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dropout(0.5),
        layers.Dense(256, activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(128, activation='relu', kernel_regularizer=keras.regularizers.l2(0.001)),
        layers.Dropout(0.2),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    # Compile
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

model = build_efficientnet_model()
print("\n✓ Model built!")
print(f"Total parameters: {model.count_params():,}")

## 9. Data Augmentation During Training

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Data augmentation generator
datagen = ImageDataGenerator(
    rotation_range=10,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.1,
    horizontal_flip=False,
    fill_mode='nearest'
)

print("✓ Data augmentation configured!")

## 10. Train Model

In [None]:
# Callbacks
callbacks = [
    EarlyStopping(monitor='val_accuracy', patience=15, restore_best_weights=True, verbose=1),
    ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=5, min_lr=1e-7, verbose=1),
    ModelCheckpoint(os.path.join(OUTPUT_PATH, 'best_model.h5'), 
                   monitor='val_accuracy', save_best_only=True, verbose=1)
]

print("\n" + "="*60)
print("Starting Training...")
print("="*60 + "\n")

history = model.fit(
    datagen.flow(X_train, y_train_cat, batch_size=BATCH_SIZE),
    epochs=EPOCHS,
    validation_data=(X_test, y_test_cat),
    callbacks=callbacks,
    verbose=1
)

print("\n✓ Training complete!")

## 11. Plot Training History

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Accuracy
ax1.plot(history.history['accuracy'], 'b-', label='Train', linewidth=2)
ax1.plot(history.history['val_accuracy'], 'r-', label='Validation', linewidth=2)
ax1.set_xlabel('Epoch', fontweight='bold', fontsize=12)
ax1.set_ylabel('Accuracy', fontweight='bold', fontsize=12)
ax1.set_title('Model Accuracy', fontweight='bold', fontsize=14)
ax1.legend()
ax1.grid(True, alpha=0.3)

# Loss
ax2.plot(history.history['loss'], 'b-', label='Train', linewidth=2)
ax2.plot(history.history['val_loss'], 'r-', label='Validation', linewidth=2)
ax2.set_xlabel('Epoch', fontweight='bold', fontsize=12)
ax2.set_ylabel('Loss', fontweight='bold', fontsize=12)
ax2.set_title('Model Loss', fontweight='bold', fontsize=14)
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.show()

best_acc = max(history.history['val_accuracy'])
print(f"\nBest Validation Accuracy: {best_acc*100:.2f}%")

## 12. Evaluate and Generate Confusion Matrix

In [None]:
# Load best model
model = keras.models.load_model(os.path.join(OUTPUT_PATH, 'best_model.h5'))

# Evaluate
test_loss, test_acc = model.evaluate(X_test, y_test_cat, verbose=0)

# Predictions
y_pred_proba = model.predict(X_test, verbose=0)
y_pred = np.argmax(y_pred_proba, axis=1)

# Confusion Matrix
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            annot_kws={'size': 16, 'weight': 'bold'})
plt.xlabel('Predicted', fontweight='bold', fontsize=14)
plt.ylabel('True', fontweight='bold', fontsize=14)
plt.title(f'Confusion Matrix - Test Accuracy: {test_acc*100:.2f}%',
         fontweight='bold', fontsize=16)
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

print("\n" + "="*60)
print("TEST RESULTS")
print("="*60)
print(f"Test Accuracy: {test_acc*100:.2f}%")
print(f"Test Loss: {test_loss:.4f}")
print("="*60)

print("\n" + classification_report(y_test, y_pred, target_names=CLASS_NAMES, digits=4))

## 13. Save Results

In [None]:
# Save model
model.save(os.path.join(OUTPUT_PATH, 'hust_model_final.h5'))

# Save summary
summary = {
    'Dataset': 'HUST Bearing',
    'Total Samples': len(X_train) + len(X_test),
    'Train': len(X_train),
    'Test': len(X_test),
    'Test Accuracy': f"{test_acc*100:.2f}%",
    'Image Size': IMAGE_SIZE,
    'Model': 'EfficientNetB0'
}

import json
with open(os.path.join(OUTPUT_PATH, 'results.json'), 'w') as f:
    json.dump(summary, f, indent=4)

print("\n✓ All results saved to:", OUTPUT_PATH)
print("\n" + "="*60)
print("SUMMARY")
print("="*60)
for k, v in summary.items():
    print(f"{k:20s}: {v}")
print("="*60)