# Bayes Classifier Training

Train and evaluate a Naive Bayes classifier with optimized feature engineering:

### Base Features (52 total):
- **Color**: Mean & Std HSV (6 features)
- **Texture**: LBP histogram + Haralick (19 features)
- **Shape**: Aspect ratio + Hu moments (4 features)

### Material-Specific Features (7 total - optimized):
- **Specular (Metal)**: Highlight contrast, gradient concentration (2)
- **Metal**: Reflection directionality (1)
- **Glass**: Brightness gradient smoothness, high-freq FFT energy, saturation uniformity (3)
- **Trash**: Texture chaos (1)

### Texture Features (16):
- **Single-scale LBP**: Radius=2 only (16 bins)

### Optimization Pipeline:
1. **Correlation Removal**: 52 → ~43 features (threshold=0.85)
2. **Scaling**: StandardScaler normalization
3. **PCA**: ~43 → 15 components (optimized)

In [1]:
import sys
sys.path.append('..')

from src.config import load_config
from src.load_data import load_data
from src.models.bayes import BayesClassifier
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from pathlib import Path


In [None]:
# Load configuration
config = load_config()

print("Configuration:")
print(f"  Dataset: {config['data']['dataset_name']}")
print(f"  Classes: {', '.join(config['data']['classes'])}")
print(f"  Image size: {config['data']['image_size']}x{config['data']['image_size']}")
print(f"\nFeature extraction:")
print(f"  LBP bins: {config['bayes']['lbp_bins']}")
print(f"  Use balanced priors: {config['bayes']['use_balanced_priors']}")
print(f"  Use OVR ensemble: {config['bayes']['use_ovr']}")
print(f"\nFeature optimization:")
print(f"  Remove correlated: {config['bayes']['remove_correlated']}")
if config['bayes']['remove_correlated']:
    print(f"  Correlation threshold: {config['bayes']['correlation_threshold']}")
print(f"  Use PCA: {config['bayes']['apply_pca']}")
if config['bayes']['apply_pca']:
    print(f"  PCA components: {config['bayes']['pca_components']}")
print(f"\n  Optimized feature breakdown:")
print(f"    Color (HSV mean + std): 6")
print(f"    Texture (LBP + Haralick): 19")
print(f"    Shape (aspect + Hu): 4")
print(f"    Specular (optimized): 2")
print(f"    Metal (optimized): 1")
print(f"    Glass (optimized): 3")
print(f"    Trash (optimized): 1")
print(f"    Single-scale LBP (r=2): 16")
print("    -----------------------------")
print(f"    Base total: 52 features")
if config['bayes']['remove_correlated']:
    print(f"    After correlation removal: ~43 features")
if config['bayes']['apply_pca']:
    print(f"    After PCA: {config['bayes']['pca_components']} features")


In [None]:
print("Loading datasets...")
train_dataset, val_dataset, test_dataset = load_data(split_data=True)

print(f"\nDataset sizes:")
print(f"  Training:   {len(train_dataset)} images")
print(f"  Validation: {len(val_dataset)} images")
print(f"  Test:       {len(test_dataset)} images")


## Feature Extraction Visualization Sample

In [None]:
from src.models.bayes import BayesFeatureExtractor
import cv2

feature_extractor = BayesFeatureExtractor(config)

sample = train_dataset[0]
sample_image = sample['image']
sample_label = sample['label']
class_name = config['data']['classes'][sample_label]

sample_image_resized = sample_image.resize((config['data']['image_size'], config['data']['image_size']))
sample_array = np.array(sample_image_resized)

color_feat = feature_extractor.extract_color_features(sample_image_resized)
texture_feat = feature_extractor.extract_texture_features(sample_image_resized)
shape_feat = feature_extractor.extract_shape_features(sample_image_resized)
specular_feat = feature_extractor.extract_specular_features(sample_image_resized)
metal_feat = feature_extractor.extract_metal_features(sample_image_resized)
glass_feat = feature_extractor.extract_glass_features(sample_image_resized)
trash_feat = feature_extractor.extract_trash_features(sample_image_resized)
all_features = feature_extractor.extract_features(sample_image_resized)

hsv = cv2.cvtColor(sample_array, cv2.COLOR_RGB2HSV)

fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(2, 4, hspace=0.35, wspace=0.3)

ax1 = fig.add_subplot(gs[0, 0])
ax1.imshow(sample_image_resized)
ax1.set_title(f'Original Image\nClass: {class_name}')
ax1.axis('off')

ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(hsv[:, :, 0], cmap='hsv')
ax2.set_title('Hue Channel')
ax2.axis('off')

ax3 = fig.add_subplot(gs[0, 2])
gray = cv2.cvtColor(sample_array, cv2.COLOR_RGB2GRAY)
grad_x = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
grad_y = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
grad_magnitude = np.sqrt(grad_x**2 + grad_y**2)
ax3.imshow(grad_magnitude, cmap='hot')
ax3.set_title('Gradient Magnitude\n(Metal Detection)')
ax3.axis('off')

ax4 = fig.add_subplot(gs[0, 3])
f_transform = np.fft.fft2(gray)
f_shift = np.fft.fftshift(f_transform)
magnitude = np.log(np.abs(f_shift) + 1)
ax4.imshow(magnitude, cmap='viridis')
ax4.set_title('FFT Spectrum\n(Glass Detection)')
ax4.axis('off')

ax5 = fig.add_subplot(gs[1, 0])
ax5.bar(range(6), color_feat, color=['red', 'green', 'blue', 'red', 'green', 'blue'], alpha=0.7)
ax5.set_xticks(range(6))
ax5.set_xticklabels(['H_mu', 'S_mu', 'V_mu', 'H_sigma', 'S_sigma', 'V_sigma'], fontsize=7)
ax5.set_title('Color (6)')
ax5.set_ylabel('Value')
ax5.grid(alpha=0.3)

ax6 = fig.add_subplot(gs[1, 1])
metal_combined = np.concatenate([specular_feat, metal_feat])
ax6.bar(range(3), metal_combined, color='gold', alpha=0.7)
ax6.set_xticks(range(3))
ax6.set_xticklabels(['HighContrast', 'GradConc', 'RefDir'], fontsize=7)
ax6.set_title('Metal Features (3)\nOptimized from 8')
ax6.set_ylabel('Value')
ax6.grid(alpha=0.3)

ax7 = fig.add_subplot(gs[1, 2])
ax7.bar(range(3), glass_feat, color='lightblue', alpha=0.7)
ax7.set_xticks(range(3))
ax7.set_xticklabels(['BrGrad', 'HighFreq', 'SatUnif'], fontsize=7)
ax7.set_title('Glass Features (3)\nOptimized from 4')
ax7.set_ylabel('Value')
ax7.grid(alpha=0.3)

ax8 = fig.add_subplot(gs[1, 3])
ax8.bar([0], trash_feat, color='brown', alpha=0.7, width=0.5)
ax8.set_xticks([0])
ax8.set_xticklabels(['TexChaos'], fontsize=8)
ax8.set_title('Trash Feature (1)\nOptimized from 4')
ax8.set_ylabel('Value')
ax8.set_xlim(-0.5, 0.5)
ax8.grid(alpha=0.3)

plt.suptitle('Optimized Feature Extraction (52 features -> ~43 -> 15 via PCA)', 
             fontsize=14, fontweight='bold')
plt.show()

print(f"\n{'='*70}")
print("OPTIMIZED Feature Breakdown:")
print(f"{'='*70}")
print(f"  Color (HSV): {len(color_feat)} features")
print(f"  Texture (LBP + Haralick): {len(texture_feat)} features")
print(f"  Shape (aspect + Hu): {len(shape_feat)} features")
print(f"  Specular (optimized): {len(specular_feat)} features (was 6)")
print(f"  Metal (optimized): {len(metal_feat)} features (was 2)")
print(f"  Glass (optimized): {len(glass_feat)} features (was 4)")
print(f"  Trash (optimized): {len(trash_feat)} features (was 4)")
print("  ------------------------------")
print(f"  TOTAL EXTRACTED: {len(all_features)} features")
print(f"  After PCA: {config['bayes']['pca_components']} features")
print(f"{'='*70}")
print(f"\nKey optimized feature values for '{class_name}':")
print(f"  Metal - Highlight contrast: {specular_feat[0]:.4f}")
print(f"  Metal - Gradient concentration: {specular_feat[1]:.4f}")
print(f"  Metal - Reflection directionality: {metal_feat[0]:.4f}")
print(f"  Glass - Brightness gradient: {glass_feat[0]:.4f}")
print(f"  Glass - High-freq energy: {glass_feat[1]:.4f}")
print(f"  Glass - Saturation uniformity: {glass_feat[2]:.4f}")
print(f"  Trash - Texture chaos: {trash_feat[0]:.4f}")


## Train Bayes Classifier

In [None]:
use_balanced = config['bayes']['use_balanced_priors']
use_augmentation = config['augmentation']['enabled']
augment_factor = 2

print(f"Initializing Bayes Classifier...")
print(f"  Balanced priors: {use_balanced}")
print(f"  Data augmentation: {use_augmentation}")
if use_augmentation:
    print(f"  Augmentation factor: {augment_factor}")
    print(f"  This will triple the training data: {len(train_dataset)} -> {len(train_dataset) * (1 + augment_factor)}\n")

classifier = BayesClassifier(config, use_balanced_priors=use_balanced)

classifier.fit(train_dataset, verbose=True, use_augmentation=use_augmentation, augment_factor=augment_factor)


Initializing Bayes Classifier...
  Balanced priors: True
  Data augmentation: False
Extracting features from training data...


Training features:   0%|          | 0/3537 [00:00<?, ?it/s]

Training features: 100%|██████████| 3537/3537 [09:26<00:00,  6.24it/s]

Feature shape: (3537, 52)
Removing correlated features (threshold=0.85)...
  Removing zero-variance features before correlation analysis...
  Removed 16 correlated features
  Features: 52 → 29 (removed 23 total)
Standardizing features...
Applying PCA (reducing to 15 components)...
  Variance explained: 93.41%
Training GaussianNB classifier...
  Final feature shape: (3537, 15)
Training complete!
Training accuracy: 0.5824





<src.models.bayes.BayesClassifier at 0x196035dda10>

## Evaluate on Validation Set

In [None]:
from pathlib import Path

log_dir = Path(config["paths"]["logs_dir"])
log_dir.mkdir(parents=True, exist_ok=True)
val_log_file = log_dir / "validation_predictions.txt"

print(f"\nEvaluating on validation set...")
print(f"Logging predictions to {val_log_file}")
val_results = classifier.evaluate(val_dataset, verbose=True, log_predictions=True, log_file=val_log_file)


## Evaluate on Test Set

In [None]:
test_log_file = log_dir / "test_predictions.txt"

print(f"\nEvaluating on test set...")
print(f"Logging predictions to {test_log_file}")
test_results = classifier.evaluate(test_dataset, verbose=True, log_predictions=True, log_file=test_log_file)

print(f"\n All prediction logs saved to {log_dir}")
print(f"  - {val_log_file.name}")
print(f"  - {test_log_file.name}")


# Confusion Matrix

In [None]:
cm_test = test_results['confusion_matrix']

plt.figure(figsize=(10, 8))
sns.heatmap(
    cm_test,
    annot=True,
    fmt='d',
    cmap='Greens',
    xticklabels=config['data']['classes'],
    yticklabels=config['data']['classes']
)
plt.title('Confusion Matrix - Test Set', fontsize=14, fontweight='bold')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.show()

print("\nPer-class accuracy:")
for i, class_name in enumerate(config['data']['classes']):
    class_acc = cm_test[i, i] / cm_test[i].sum() if cm_test[i].sum() > 0 else 0
    print(f"  {class_name:<12}: {class_acc:.4f}")


## Compare Validation and Test Performance

In [None]:
metrics = ['accuracy', 'precision', 'recall', 'f1']

fig, ax = plt.subplots(figsize=(10, 6))

x = np.arange(len(metrics))
width = 0.35

val_scores = [val_results[m] for m in metrics]
test_scores = [test_results[m] for m in metrics]

bars1 = ax.bar(x - width/2, val_scores, width, label='Validation', color='steelblue', alpha=0.8)
bars2 = ax.bar(x + width/2, test_scores, width, label='Test', color='seagreen', alpha=0.8)

ax.set_ylabel('Score', fontsize=12)
ax.set_title('Validation vs Test Performance', fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels([m.capitalize() for m in metrics])
ax.legend()
ax.grid(axis='y', alpha=0.3)
ax.set_ylim([0, 1])

for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
                f'{height:.3f}',
                ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

print(f"\n{'='*60}")
print("FINAL RESULTS SUMMARY")
print(f"{'='*60}")
print(f"\n{'Metric':<12} | {'Validation':>12} | {'Test':>12}")
print(f"{'-'*12}-+-{'-'*12}-+-{'-'*12}")
for metric in metrics:
    print(f"{metric.capitalize():<12} | {val_results[metric]:>12.4f} | {test_results[metric]:>12.4f}")
print(f"{'='*60}")


## Save Model

In [None]:
models_dir = Path(config['paths']['models_dir'])
models_dir.mkdir(parents=True, exist_ok=True)

model_path = models_dir / 'bayes_classifier.pkl'
classifier.save(model_path)

print(f"\nModel saved successfully!")
print(f"Path: {model_path}")


## Feature Distribution Analysis

In [None]:
from src.models.bayes import BayesFeatureExtractor

feature_extractor = BayesFeatureExtractor(config)

sample_size_per_class = 100
sampled_features = []
sampled_labels = []

print("Extracting optimized features from sample data...")
for class_idx, class_name in enumerate(config["data"]["classes"]):
    count = 0
    for item in train_dataset:
        label = item["label"]
        if label == class_idx and count < sample_size_per_class:
            image = item["image"].resize((config["data"]["image_size"], config["data"]["image_size"]))
            features = feature_extractor.extract_features(image)
            sampled_features.append(features)
            sampled_labels.append(label)
            count += 1
        if count >= sample_size_per_class:
            break
    print(f"  {class_name}: {count} samples")

sampled_features = np.array(sampled_features)
sampled_labels = np.array(sampled_labels)

print(f"\nExtracted feature shape: {sampled_features.shape}")

feature_groups = {
    "Color (HSV)": list(range(6)),
    "Specular (Metal) - Optimized": [29, 30],
    "Glass - Optimized": [32, 33, 34],
    "Trash - Optimized": [35]
}

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

for ax_idx, (group_name, feature_indices) in enumerate(feature_groups.items()):
    ax = axes[ax_idx]
    
    for class_idx, class_name in enumerate(config["data"]["classes"]):
        class_mask = sampled_labels == class_idx
        class_features = sampled_features[class_mask][:, feature_indices]
        mean_values = np.mean(class_features, axis=1)
        ax.hist(mean_values, bins=30, alpha=0.5, label=class_name, density=True)
    
    ax.set_title(f"{group_name} Distribution", fontsize=12, fontweight="bold")
    ax.set_xlabel("Mean Feature Value")
    ax.set_ylabel("Density")
    ax.legend(loc="best", fontsize=9)
    ax.grid(alpha=0.3)

plt.suptitle("Optimized Feature Distributions Across Waste Classes", fontsize=14, fontweight="bold")
plt.tight_layout()
plt.show()

print(f"\n{'='*70}")
print("Per-Class Feature Statistics (Mean +/- Std) - OPTIMIZED FEATURES")
print(f"{'='*70}")
for feat_name, feat_indices in feature_groups.items():
    print(f"\n{feat_name}:")
    for class_idx, class_name in enumerate(config["data"]["classes"]):
        class_mask = sampled_labels == class_idx
        class_features = sampled_features[class_mask][:, feat_indices]
        mean_val = np.mean(class_features)
        std_val = np.std(class_features)
        print(f"  {class_name:<12}: {mean_val:>8.4f} +/- {std_val:.4f}")
