## Hybrid classifier - uncertainty + expert system

In [1]:
import torch
import joblib
from pathlib import Path
import numpy as np
import pandas as pd
# from rocks_evaluation import SimpleCNN1D, UncertaintyAwareCNN1D, IntegratedRockClassifier, UncertaintyIntegratedRockClassifier, RockType, MineralGroups, save_analysis, plot_mineral_analysis
from rocks_evaluation import UncertaintyAwareCNN1D, UncertaintyHierarchicalRockClassifier, save_analysis

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model configuration and loading
best_model_path = f'../weights/best_model.pth'

# For uncertainty-aware classification
# Initialize model and load weights
uncertainty_model = UncertaintyAwareCNN1D()
best_model = uncertainty_model.to(device)  # Make sure to import your model class
best_checkpoint = torch.load(best_model_path, weights_only=False)
best_model.load_state_dict(best_checkpoint['model_state_dict'])
best_model.eval()

# Load label encoder
label_encoder = joblib.load('../mineral_label_encoder.joblib')


# Initialize rock analyzer
rock_analyzer = UncertaintyHierarchicalRockClassifier(
    model=best_model,
    label_encoder=label_encoder,
    device=device,
)

In [3]:
import matplotlib.pyplot as plt
from typing import Dict, List

# Create helper function for uncertainty visualization
def plot_uncertainty_analysis(uncertainties, weights, rock_num, filename):
    """
    Create a visualization of uncertainty analysis for a rock sample
    
    Args:
        uncertainties: List of uncertainty values for each measurement
        weights: Dictionary of final rock type weights
        rock_num: Rock sample number
        filename: Output filename for the plot
    """
    plt.figure(figsize=(12, 6))
    
    # Plot 1: Uncertainty over measurements
    plt.subplot(1, 2, 1)
    plt.plot(uncertainties, marker='o')
    plt.axhline(y=rock_analyzer.uncertainty_threshold, color='r', linestyle='--', 
                label='Threshold')
    plt.title(f'Measurement Uncertainties\nRock {rock_num}')
    plt.xlabel('Measurement Number')
    plt.ylabel('Uncertainty')
    plt.legend()
    
    # Plot 2: Final rock type weights with uncertainty
    plt.subplot(1, 2, 2)
    rock_types = list(weights.keys())
    weight_values = list(weights.values())
    colors = ['g' if w > rock_analyzer.uncertainty_threshold else 'r' 
              for w in weight_values]
    
    plt.bar(rock_types, weight_values, color=colors)
    plt.axhline(y=rock_analyzer.uncertainty_threshold, color='r', linestyle='--', 
                label='Uncertainty Threshold')
    plt.title('Rock Type Weights vs. Uncertainty Threshold')
    plt.xticks(rotation=45)
    plt.ylabel('Weight')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()


# Create helper function for uncertainty visualization
def plot_uncertainty_analysis(uncertainties, weights, rock_num, filename):
    """
    Create a visualization of uncertainty analysis for a rock sample
    
    Args:
        uncertainties: List of uncertainty values for each measurement
        weights: Dictionary of final rock type weights
        rock_num: Rock sample number
        filename: Output filename for the plot
    """
    plt.figure(figsize=(12, 6))
    
    # Plot 1: Uncertainty over measurements
    plt.subplot(1, 2, 1)
    plt.plot(uncertainties, marker='o')
    plt.axhline(y=rock_analyzer.uncertainty_threshold, color='r', linestyle='--', 
                label='Threshold')
    plt.title(f'Measurement Uncertainties\nRock {rock_num}')
    plt.xlabel('Measurement Number')
    plt.ylabel('Uncertainty')
    plt.legend()
    
    # Plot 2: Final rock type weights with uncertainty
    plt.subplot(1, 2, 2)
    rock_types = list(weights.keys())
    weight_values = list(weights.values())
    colors = ['g' if w > rock_analyzer.uncertainty_threshold else 'r' 
              for w in weight_values]
    
    plt.bar(rock_types, weight_values, color=colors)
    plt.axhline(y=rock_analyzer.uncertainty_threshold, color='r', linestyle='--', 
                label='Uncertainty Threshold')
    plt.title('Rock Type Weights vs. Uncertainty Threshold')
    plt.xticks(rotation=45)
    plt.ylabel('Weight')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()

def save_uncertainty_analysis(result: Dict, filename: str):
    """
    Save analysis results including uncertainty metrics to a text file
    
    Args:
        result: Dictionary containing analysis results
        filename: Output filename
    """
    with open(filename, 'w') as f:
        # Write mineral prediction
        f.write("Mineral Analysis Results:\n")
        f.write(f"Predicted Mineral: {result['mineral_prediction']}\n")
        
        # Write probabilities
        f.write("\nMineral Probabilities:\n")
        for mineral, prob in result['mineral_probabilities'].items():
            f.write(f"{mineral}: {prob:.4f}\n")
        
        # Write uncertainty information
        f.write("\nUncertainty Information:\n")
        f.write(f"Measurement Uncertainty: {result['uncertainty']:.4f}\n")
        
        # Write rock analysis results
        f.write("\nRock Analysis Results:\n")
        f.write(f"Classification: {result['rock_analysis']['classification']}\n")
        
        if 'weights' in result['rock_analysis']:
            f.write("\nRock Type Weights:\n")
            for rock_type, weight in result['rock_analysis']['weights'].items():
                f.write(f"{rock_type.capitalize()}: {weight:.4f}\n")
        
        if 'confidence_score' in result['rock_analysis']:
            f.write(f"\nConfidence Score: {result['rock_analysis']['confidence_score']:.4f}\n")
        
        if 'mean_uncertainty' in result['rock_analysis']:
            f.write(f"Mean Uncertainty: {result['rock_analysis']['mean_uncertainty']:.4f}\n")
        
        if 'is_confident' in result['rock_analysis']:
            f.write(f"Is Confident: {result['rock_analysis']['is_confident']}\n")
        
        if 'mineral_counts' in result['rock_analysis']:
            f.write("\nMineral Counts:\n")
            for mineral, count in result['rock_analysis']['mineral_counts'].items():
                f.write(f"{mineral.capitalize()}: {count:.2f}\n")
        
        # Write uncertainty history if available
        if 'uncertainty_history' in result['rock_analysis']:
            f.write("\nUncertainty History:\n")
            for i, uncertainty in enumerate(result['rock_analysis']['uncertainty_history']):
                f.write(f"Measurement {i+1}: {uncertainty:.4f}\n")

In [4]:
# Load and process data
base_path = Path("../validation_rocks")
results = []

for rock_num in range(1, 31):
    # Reset histories at start of each rock
    rock_analyzer.analysis_history = []
    rock_analyzer.prediction_history = []
    rock_analyzer.uncertainty_history = []  # Added for uncertainty tracking
    rock_analyzer.ground_truth_history = []
    
    print(f"\nProcessing Rock {rock_num}")
    rock_folder = f"rock_{rock_num:02d}"
    rock_path = base_path / rock_folder
    wavelengths = np.load(f"{rock_path}_wavelengths.npy")
    intensities = np.load(f"{rock_path}_intensities.npy")
    
    # Load true composition
    true_compositions = []
    composition_file = rock_path.with_name(f"{rock_path.stem}_composition.txt")
    with open(composition_file, 'r') as f:
        next(f)  # Skip header
        true_compositions = [line.strip().split('\t')[1] for line in f]
    
    final_result = None
    measurement_uncertainties = []  # Track uncertainties for each measurement
    
    # Process spectra and collect results
    for i, spectrum in enumerate(intensities):
        spectrum_tensor = torch.from_numpy(spectrum).float()
        result = rock_analyzer.process_spectrum(spectrum_tensor, true_compositions[i])
        final_result = result
        
        # Track uncertainty for this measurement
        measurement_uncertainties.append(result['uncertainty'])
        
        # Print analysis if we have a full window
        if i >= 9:  # After 10 measurements
            print(f"\nAnalysis Results:")
            print(f"Classification: {result['rock_analysis']['classification']}")
            print(f"Mean Uncertainty: {result['rock_analysis'].get('mean_uncertainty', 'N/A'):.3f}")
            
            if 'weights' in result['rock_analysis']:
                print("\nRock Type Weights:")
                for rock_type, weight in result['rock_analysis']['weights'].items():
                    print(f"- {rock_type.capitalize()}: {weight:.3f}")
                
                print("\nConfidence Metrics:")
                print(f"Confidence Threshold: {result['rock_analysis'].get('confidence_threshold', 'N/A'):.3f}")
                print(f"Is Confident: {result['rock_analysis']['is_confident']}")
    
    # Save results after all spectra are processed
    if final_result:
        weights = final_result['rock_analysis'].get('weights', {})
        # Get the highest weight as accuracy, adjusted by uncertainty
        max_weight = max(weights.values()) if weights else 0.0
        mean_uncertainty = final_result['rock_analysis'].get('mean_uncertainty', 0.0)
        
        results.append({
            'rock_num': rock_num,
            'classification': final_result['rock_analysis']['classification'],
            'accuracy': max_weight,
            'uncertainty': mean_uncertainty,
            'confidence_score': max_weight * (1 - mean_uncertainty),  # Adjust accuracy by uncertainty
            'weights': weights,
            'is_confident': final_result['rock_analysis'].get('is_confident', False)
        })
        
        # Update the analysis results before saving
        final_result['rock_analysis'].update({
            'accuracy': max_weight,
            'confidence_score': max_weight * (1 - mean_uncertainty),
            'uncertainty_history': measurement_uncertainties
        })
        
        # Save detailed analysis results
        save_uncertainty_analysis(final_result, f'rock{rock_num}_uncertainty_analysis.txt')
        
        # Create visualizations
        rock_analyzer.plot_analysis(rock_num, f'rock_analysis{rock_num}.png')
        
        # Additional uncertainty visualization
        plot_uncertainty_analysis(
            measurement_uncertainties,
            final_result['rock_analysis']['weights'],
            rock_num,
            f'uncertainty_analysis{rock_num}.png'
        )

# Save final results with uncertainty information
df = pd.DataFrame(results)
df.to_csv('rock_classifications_with_uncertainty.csv', index=False)

# Print summary statistics
print("\nSummary Statistics:")
print(f"Total rocks processed: {len(df)}")

print("\nClassification Distribution:")
print(df['classification'].value_counts())

print("\nAccuracy and Uncertainty Statistics:")
print(f"Mean accuracy: {df['accuracy'].mean():.2%}")
print(f"Median accuracy: {df['accuracy'].median():.2%}")
print(f"Mean uncertainty: {df['uncertainty'].mean():.2%}")
print(f"Median uncertainty: {df['uncertainty'].median():.2%}")
print(f"Mean confidence score: {df['confidence_score'].mean():.2%}")

print("\nConfidence Distribution:")
print(f"Confident classifications: {df['is_confident'].sum()}")
print(f"Uncertain classifications: {len(df) - df['is_confident'].sum()}")



Processing Rock 1

Analysis Results:
Classification: granite
Mean Uncertainty: 0.040

Rock Type Weights:
- Granite: 1.000
- Sandstone: 0.200
- Limestone: 0.000

Confidence Metrics:
Confidence Threshold: 0.872
Is Confident: True

Processing Rock 2

Analysis Results:
Classification: granite
Mean Uncertainty: 0.173

Rock Type Weights:
- Granite: 1.000
- Sandstone: 0.333
- Limestone: 0.000

Confidence Metrics:
Confidence Threshold: 0.779
Is Confident: True

Processing Rock 3

Analysis Results:
Classification: sandstone
Mean Uncertainty: 0.128

Rock Type Weights:
- Granite: 0.400
- Sandstone: 1.000
- Limestone: 0.000

Confidence Metrics:
Confidence Threshold: 0.811
Is Confident: True

Processing Rock 4

Analysis Results:
Classification: granite
Mean Uncertainty: 0.063

Rock Type Weights:
- Granite: 1.000
- Sandstone: 0.500
- Limestone: 0.000

Confidence Metrics:
Confidence Threshold: 0.856
Is Confident: True

Processing Rock 5

Analysis Results:
Classification: sandstone
Mean Uncertainty: 

In [2]:
from analyze_results import analyze_results

analyze_results('./uncertainty-classifier/rock_classifications_with_uncertainty.csv')


Confusion Matrix:
[[4 0 1 0]
 [2 4 1 0]
 [5 2 2 2]
 [6 0 1 0]]
