# HUST Bearing Fault Diagnosis - High Accuracy Classification
## Using Data Augmentation + EfficientNet for 99% Accuracy

**Author:** Muhammad Umar  
**Dataset:** HUST Bearing (99 samples → augmented to 1000+ samples)  
**Classes:** 4 (Normal, Ball, Inner Race, Outer Race)  
**Target:** 99% Accuracy with minimal training time

## 1. Import Libraries

In [None]:
# Core libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.io import loadmat
from scipy import signal
import pywt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report, accuracy_score
import warnings
warnings.filterwarnings('ignore')

# Deep Learning
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow version: {tf.__version__}")
print(f"GPU Available: {tf.config.list_physical_devices('GPU')}")

## 2. Configuration

In [None]:
# Paths
DATA_PATH = r'F:\NeuTech\HUST bearing\HUST bearing dataset'
OUTPUT_PATH = r'F:\NeuTech\Results\HUST_Classification'

# Create output directory
os.makedirs(OUTPUT_PATH, exist_ok=True)

# Signal parameters
SAMPLING_FREQ = 51200  # Hz
SEGMENT_LENGTH = 2048  # Samples per segment
OVERLAP = 0.5  # 50% overlap for augmentation

# CWT parameters for image generation
SCALES = np.arange(1, 128)
WAVELET = 'morl'
IMAGE_SIZE = (224, 224)  # EfficientNet input size

# Training parameters
BATCH_SIZE = 32
EPOCHS = 50
LEARNING_RATE = 0.001

# Class mapping
CLASS_NAMES = ['Normal', 'Ball', 'Inner', 'Outer']

print("Configuration set!")

## 3. Data Loading and Exploration

In [None]:
def load_hust_data(data_path):
    """
    Load HUST bearing dataset and organize by fault type
    """
    print("Loading HUST bearing dataset...")
    
    # Get all .mat files
    all_files = [f for f in os.listdir(data_path) if f.endswith('.mat')]
    
    # Organize by class
    data_dict = {
        'Normal': [],
        'Ball': [],
        'Inner': [],
        'Outer': []
    }
    
    for file in all_files:
        file_upper = file.upper()
        
        if file_upper.startswith('N'):
            data_dict['Normal'].append(os.path.join(data_path, file))
        elif file_upper.startswith('B') and not file_upper.startswith('BA'):
            data_dict['Ball'].append(os.path.join(data_path, file))
        elif file_upper.startswith('I'):
            data_dict['Inner'].append(os.path.join(data_path, file))
        elif file_upper.startswith('O'):
            data_dict['Outer'].append(os.path.join(data_path, file))
    
    # Print statistics
    print("\n" + "="*50)
    print("Dataset Statistics:")
    print("="*50)
    total = 0
    for class_name, files in data_dict.items():
        count = len(files)
        total += count
        print(f"{class_name:15s}: {count:3d} files")
    print(f"{'Total':15s}: {total:3d} files")
    print("="*50 + "\n")
    
    return data_dict

# Load data
data_dict = load_hust_data(DATA_PATH)

## 4. Signal Processing and Data Augmentation

In [None]:
def load_signal_from_mat(filepath):
    """
    Load vibration signal from HUST .mat file
    """
    try:
        data = loadmat(filepath)
        
        # Try common field names
        for key in data.keys():
            if not key.startswith('__'):
                signal_data = data[key]
                if isinstance(signal_data, np.ndarray) and signal_data.size > 100:
                    return signal_data.flatten()
        return None
    except:
        return None

def segment_signal(signal, segment_length, overlap=0.5):
    """
    Segment signal with overlap for data augmentation
    """
    step = int(segment_length * (1 - overlap))
    segments = []
    
    for start in range(0, len(signal) - segment_length + 1, step):
        segment = signal[start:start + segment_length]
        segments.append(segment)
    
    return segments

def augment_signal(signal, num_augmentations=3):
    """
    Apply data augmentation techniques to vibration signals
    """
    augmented = [signal]  # Original signal
    
    # 1. Add noise
    noise_level = 0.01 * np.std(signal)
    noisy = signal + np.random.normal(0, noise_level, signal.shape)
    augmented.append(noisy)
    
    # 2. Scale amplitude
    scale_factors = [0.9, 1.1]
    for scale in scale_factors[:num_augmentations-1]:
        scaled = signal * scale
        augmented.append(scaled)
    
    # 3. Time shift
    shift = int(len(signal) * 0.1)
    shifted = np.roll(signal, shift)
    augmented.append(shifted)
    
    return augmented[:num_augmentations + 1]

def signal_to_cwt_image(signal, scales, wavelet='morl'):
    """
    Convert 1D signal to 2D CWT image
    """
    # Normalize signal
    signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-8)
    
    # Compute CWT
    coefficients, _ = pywt.cwt(signal, scales, wavelet)
    
    # Convert to image (normalize to 0-255)
    cwt_abs = np.abs(coefficients)
    cwt_normalized = (cwt_abs - cwt_abs.min()) / (cwt_abs.max() - cwt_abs.min() + 1e-8)
    cwt_image = (cwt_normalized * 255).astype(np.uint8)
    
    return cwt_image

print("Signal processing functions defined!")

## 5. Generate CWT Images with Augmentation

In [None]:
def generate_dataset(data_dict, segment_length=2048, overlap=0.5, augment=True):
    """
    Generate CWT images from raw signals with augmentation
    """
    X_images = []
    y_labels = []
    
    print("\nGenerating CWT images with augmentation...")
    print("="*60)
    
    for class_idx, (class_name, file_paths) in enumerate(data_dict.items()):
        print(f"\nProcessing {class_name}...")
        class_images = 0
        
        for file_path in file_paths:
            # Load signal
            signal_data = load_signal_from_mat(file_path)
            if signal_data is None:
                continue
            
            # Segment signal
            segments = segment_signal(signal_data, segment_length, overlap)
            
            for segment in segments:
                # Apply augmentation
                if augment:
                    augmented_signals = augment_signal(segment, num_augmentations=2)
                else:
                    augmented_signals = [segment]
                
                for aug_signal in augmented_signals:
                    # Generate CWT image
                    cwt_img = signal_to_cwt_image(aug_signal, SCALES, WAVELET)
                    
                    # Resize to target size
                    from scipy.ndimage import zoom
                    zoom_factors = (IMAGE_SIZE[0] / cwt_img.shape[0], 
                                   IMAGE_SIZE[1] / cwt_img.shape[1])
                    cwt_resized = zoom(cwt_img, zoom_factors, order=1)
                    
                    # Convert to RGB (3 channels)
                    cwt_rgb = np.stack([cwt_resized] * 3, axis=-1)
                    
                    X_images.append(cwt_rgb)
                    y_labels.append(class_idx)
                    class_images += 1
        
        print(f"  Generated {class_images} images")
    
    X_images = np.array(X_images)
    y_labels = np.array(y_labels)
    
    print("\n" + "="*60)
    print(f"Total images generated: {len(X_images)}")
    print(f"Image shape: {X_images[0].shape}")
    print("="*60 + "\n")
    
    return X_images, y_labels

# Generate dataset
X_data, y_data = generate_dataset(data_dict, SEGMENT_LENGTH, OVERLAP, augment=True)

# Display class distribution
unique, counts = np.unique(y_data, return_counts=True)
print("\nAugmented Dataset Distribution:")
for i, count in zip(unique, counts):
    print(f"  {CLASS_NAMES[i]:15s}: {count:4d} samples")

## 6. Visualize Sample CWT Images

In [None]:
# Visualize samples from each class
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

for i in range(4):
    # Get random sample from each class
    class_indices = np.where(y_data == i)[0]
    sample_idx = np.random.choice(class_indices, 2, replace=False)
    
    for j, idx in enumerate(sample_idx):
        ax = axes[i*2 + j]
        ax.imshow(X_data[idx], cmap='jet', aspect='auto')
        ax.set_title(f'{CLASS_NAMES[i]} - Sample {j+1}', fontweight='bold')
        ax.axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'sample_cwt_images.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Sample CWT images displayed!")

## 7. Prepare Train/Test Split

In [None]:
# Split dataset: 80% train, 20% test
X_train, X_test, y_train, y_test = train_test_split(
    X_data, y_data, 
    test_size=0.2, 
    random_state=42, 
    stratify=y_data
)

# Normalize to [0, 1]
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

# Convert labels to categorical
y_train_cat = keras.utils.to_categorical(y_train, num_classes=4)
y_test_cat = keras.utils.to_categorical(y_test, num_classes=4)

print("Dataset Split:")
print(f"  Training samples:   {len(X_train)}")
print(f"  Testing samples:    {len(X_test)}")
print(f"\nTraining set distribution:")
unique, counts = np.unique(y_train, return_counts=True)
for i, count in zip(unique, counts):
    print(f"  {CLASS_NAMES[i]:15s}: {count:4d} samples")
print(f"\nTest set distribution:")
unique, counts = np.unique(y_test, return_counts=True)
for i, count in zip(unique, counts):
    print(f"  {CLASS_NAMES[i]:15s}: {count:4d} samples")

## 8. Build EfficientNet Model

In [None]:
def build_model(input_shape=(224, 224, 3), num_classes=4):
    """
    Build EfficientNetB0 model for bearing fault classification
    Fast training with high accuracy
    """
    # Load pre-trained EfficientNetB0
    base_model = EfficientNetB0(
        include_top=False,
        weights='imagenet',
        input_shape=input_shape
    )
    
    # Freeze early layers, fine-tune later layers
    for layer in base_model.layers[:-30]:
        layer.trainable = False
    
    # Build model
    model = models.Sequential([
        base_model,
        layers.GlobalAveragePooling2D(),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(256, activation='relu'),
        layers.BatchNormalization(),
        layers.Dropout(0.3),
        layers.Dense(128, activation='relu'),
        layers.Dropout(0.2),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    # Compile model
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=LEARNING_RATE),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# Build model
model = build_model()

print("\nModel Architecture:")
model.summary()

print(f"\nTotal parameters: {model.count_params():,}")
print(f"Trainable parameters: {sum([tf.size(w).numpy() for w in model.trainable_weights]):,}")

## 9. Setup Callbacks

In [None]:
# Define callbacks
callbacks = [
    EarlyStopping(
        monitor='val_accuracy',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    ModelCheckpoint(
        os.path.join(OUTPUT_PATH, 'best_model.h5'),
        monitor='val_accuracy',
        save_best_only=True,
        verbose=1
    )
]

print("Callbacks configured!")

## 10. Train Model

In [None]:
print("\n" + "="*60)
print("Starting Training...")
print("="*60 + "\n")

history = model.fit(
    X_train, y_train_cat,
    batch_size=BATCH_SIZE,
    epochs=EPOCHS,
    validation_split=0.2,
    callbacks=callbacks,
    verbose=1
)

print("\n" + "="*60)
print("Training Complete!")
print("="*60)

## 11. Plot Training History

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Accuracy plot
ax1.plot(history.history['accuracy'], label='Train Accuracy', linewidth=2)
ax1.plot(history.history['val_accuracy'], label='Val Accuracy', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax1.set_ylabel('Accuracy', fontsize=12, fontweight='bold')
ax1.set_title('Model Accuracy', fontsize=14, fontweight='bold')
ax1.legend(fontsize=10)
ax1.grid(True, alpha=0.3)

# Loss plot
ax2.plot(history.history['loss'], label='Train Loss', linewidth=2)
ax2.plot(history.history['val_loss'], label='Val Loss', linewidth=2)
ax2.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax2.set_ylabel('Loss', fontsize=12, fontweight='bold')
ax2.set_title('Model Loss', fontsize=14, fontweight='bold')
ax2.legend(fontsize=10)
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'training_history.png'), dpi=300, bbox_inches='tight')
plt.show()

# Print best metrics
best_val_acc = max(history.history['val_accuracy'])
best_epoch = history.history['val_accuracy'].index(best_val_acc) + 1

print(f"\nBest Validation Accuracy: {best_val_acc*100:.2f}% (Epoch {best_epoch})")

## 12. Evaluate on Test Set

In [None]:
# Load best model
model = keras.models.load_model(os.path.join(OUTPUT_PATH, 'best_model.h5'))

# Evaluate on test set
test_loss, test_accuracy = model.evaluate(X_test, y_test_cat, verbose=0)

print("\n" + "="*60)
print("TEST SET PERFORMANCE")
print("="*60)
print(f"Test Accuracy:  {test_accuracy*100:.2f}%")
print(f"Test Loss:      {test_loss:.4f}")
print("="*60 + "\n")

## 13. Generate Predictions and Confusion Matrix

In [None]:
# Make predictions
y_pred_proba = model.predict(X_test, verbose=0)
y_pred = np.argmax(y_pred_proba, axis=1)

# Compute confusion matrix
cm = confusion_matrix(y_test, y_pred)

# Plot confusion matrix
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES,
            cbar_kws={'label': 'Count'},
            annot_kws={'size': 14, 'weight': 'bold'})
plt.xlabel('Predicted Label', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=14, fontweight='bold')
plt.title(f'Confusion Matrix\nTest Accuracy: {test_accuracy*100:.2f}%', 
          fontsize=16, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
plt.show()

# Print classification report
print("\nClassification Report:")
print("="*60)
print(classification_report(y_test, y_pred, target_names=CLASS_NAMES, digits=4))
print("="*60)

## 14. Per-Class Performance

In [None]:
# Calculate per-class metrics
from sklearn.metrics import precision_recall_fscore_support

precision, recall, f1, support = precision_recall_fscore_support(
    y_test, y_pred, average=None
)

# Create performance DataFrame
performance_df = pd.DataFrame({
    'Class': CLASS_NAMES,
    'Precision': precision * 100,
    'Recall': recall * 100,
    'F1-Score': f1 * 100,
    'Support': support
})

print("\nPer-Class Performance:")
print("="*70)
print(performance_df.to_string(index=False))
print("="*70)

# Save to CSV
performance_df.to_csv(os.path.join(OUTPUT_PATH, 'per_class_performance.csv'), index=False)

# Visualize per-class performance
fig, ax = plt.subplots(figsize=(12, 6))
x = np.arange(len(CLASS_NAMES))
width = 0.25

ax.bar(x - width, precision * 100, width, label='Precision', alpha=0.8)
ax.bar(x, recall * 100, width, label='Recall', alpha=0.8)
ax.bar(x + width, f1 * 100, width, label='F1-Score', alpha=0.8)

ax.set_xlabel('Class', fontsize=12, fontweight='bold')
ax.set_ylabel('Score (%)', fontsize=12, fontweight='bold')
ax.set_title('Per-Class Performance Metrics', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(CLASS_NAMES, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3, axis='y')
ax.set_ylim([95, 101])

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'per_class_performance.png'), dpi=300, bbox_inches='tight')
plt.show()

## 15. Save Model and Results

In [None]:
# Save final model
model.save(os.path.join(OUTPUT_PATH, 'hust_bearing_model_final.h5'))
print(f"✓ Model saved to: {os.path.join(OUTPUT_PATH, 'hust_bearing_model_final.h5')}")

# Save results summary
results_summary = {
    'Dataset': 'HUST Bearing',
    'Total Original Samples': 99,
    'Total Augmented Samples': len(X_data),
    'Number of Classes': 4,
    'Train Samples': len(X_train),
    'Test Samples': len(X_test),
    'Model': 'EfficientNetB0',
    'Image Size': IMAGE_SIZE,
    'Batch Size': BATCH_SIZE,
    'Epochs Trained': len(history.history['accuracy']),
    'Best Val Accuracy': f"{best_val_acc*100:.2f}%",
    'Test Accuracy': f"{test_accuracy*100:.2f}%",
    'Test Loss': f"{test_loss:.4f}"
}

# Save as JSON
import json
with open(os.path.join(OUTPUT_PATH, 'results_summary.json'), 'w') as f:
    json.dump(results_summary, f, indent=4)

print("\n" + "="*60)
print("FINAL RESULTS SUMMARY")
print("="*60)
for key, value in results_summary.items():
    print(f"{key:25s}: {value}")
print("="*60)

print(f"\n✓ All results saved to: {OUTPUT_PATH}")

## 16. Sample Predictions Visualization

In [None]:
# Visualize some predictions
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.ravel()

# Get random test samples
random_indices = np.random.choice(len(X_test), 8, replace=False)

for i, idx in enumerate(random_indices):
    ax = axes[i]
    
    # Display image
    ax.imshow(X_test[idx], cmap='jet')
    
    # Get prediction
    true_label = CLASS_NAMES[y_test[idx]]
    pred_label = CLASS_NAMES[y_pred[idx]]
    confidence = y_pred_proba[idx][y_pred[idx]] * 100
    
    # Set title with color based on correctness
    color = 'green' if y_test[idx] == y_pred[idx] else 'red'
    ax.set_title(f'True: {true_label}\nPred: {pred_label} ({confidence:.1f}%)',
                fontweight='bold', color=color, fontsize=10)
    ax.axis('off')

plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_PATH, 'sample_predictions.png'), dpi=300, bbox_inches='tight')
plt.show()

print("Sample predictions visualized!")

## 17. Final Summary for Research Paper

In [None]:
print("\n" + "="*70)
print("RESEARCH PAPER SUMMARY")
print("="*70)
print("\nMETHODOLOGY:")
print("-" * 70)
print("1. Dataset: HUST Bearing Vibration Dataset")
print(f"   - Original samples: 99 (Normal: 15, Ball: 12, Inner: 42, Outer: 30)")
print(f"   - Augmented samples: {len(X_data)}")
print(f"   - Sampling frequency: {SAMPLING_FREQUENCY} Hz")
print()
print("2. Signal Processing:")
print(f"   - Segmentation: {SEGMENT_LENGTH} samples with {OVERLAP*100}% overlap")
print(f"   - CWT transform using {WAVELET} wavelet")
print(f"   - Image size: {IMAGE_SIZE}")
print()
print("3. Data Augmentation:")
print("   - Gaussian noise addition")
print("   - Amplitude scaling")
print("   - Time shifting")
print()
print("4. Deep Learning Model:")
print("   - Architecture: EfficientNetB0 (transfer learning)")
print(f"   - Total parameters: {model.count_params():,}")
print(f"   - Batch size: {BATCH_SIZE}")
print(f"   - Optimizer: Adam (lr={LEARNING_RATE})")
print()
print("="*70)
print("RESULTS:")
print("="*70)
print(f"Overall Test Accuracy:     {test_accuracy*100:.2f}%")
print(f"Overall Test Loss:         {test_loss:.4f}")
print()
print("Per-Class Performance:")
for i, class_name in enumerate(CLASS_NAMES):
    print(f"  {class_name:15s} - Precision: {precision[i]*100:.2f}%, "
          f"Recall: {recall[i]*100:.2f}%, F1: {f1[i]*100:.2f}%")
print("="*70)
print("\n✓ All results and figures saved to:", OUTPUT_PATH)
print("="*70)