## Hybrid classifier - baseline + expert system

In [1]:
import torch
import joblib
from pathlib import Path
import numpy as np
import pandas as pd
# from rocks_evaluation import SimpleCNN1D, UncertaintyAwareCNN1D, IntegratedRockClassifier, UncertaintyIntegratedRockClassifier, RockType, MineralGroups, save_analysis, plot_mineral_analysis
from rocks_evaluation import SimpleCNN1D, HierarchicalRockClassifier, save_analysis

In [2]:
# Load label encoder
label_encoder = joblib.load('../mineral_label_encoder.joblib')
print(label_encoder.classes_)

['Albite' 'Almandine' 'Andalusite' 'Annite' 'Anorthite' 'Calcite'
 'Dolomite' 'Epidote' 'Glaucophane' 'Jadeite' 'Kyanite' 'Muscovite'
 'Omphacite' 'Orthoclase' 'Phlogopite' 'Pyrite' 'Pyrope' 'Quartz'
 'Rhodochrosite' 'Rutile' 'Sanidine' 'Staurolite' 'Tourmaline']


In [3]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

baseline_model = SimpleCNN1D()  
# uncertainty_model = UncertaintyAwareCNN1D()

# Model configuration and loading
best_repeat = 2
best_fold = 1
best_overall_accuracy = 0.9836
best_model_path = f'../weights/model_repeat{best_repeat}_fold{best_fold}_acc{best_overall_accuracy:.4f}.pth'

# For standard classification
# Initialize model and load weights
best_model = baseline_model.to(device)  # Make sure to import your model class
best_checkpoint = torch.load(best_model_path, weights_only=False)
best_model.load_state_dict(best_checkpoint['model_state_dict'])
best_model.eval()

# Load label encoder
label_encoder = joblib.load('../mineral_label_encoder.joblib')

# Initialize rock analyzer with HierarchicalRockClassifier instead
rock_analyzer = HierarchicalRockClassifier(
    model=best_model,
    label_encoder=label_encoder,
    device=device,
)



In [4]:
# Load and process data
base_path = Path("../validation_rocks")
results = []

for rock_num in range(1, 31):
    rock_analyzer.analysis_history = []  # Reset analysis history at start of each rock
    rock_analyzer.prediction_history = []  # Reset prediction history
    rock_analyzer.ground_truth_history = []  # Reset ground truth history

    print(f"\nProcessing Rock {rock_num}")
    rock_folder = f"rock_{rock_num:02d}"
    rock_path = base_path / rock_folder
    wavelengths = np.load(f"{rock_path}_wavelengths.npy")
    intensities = np.load(f"{rock_path}_intensities.npy")
    
    # Load true composition
    true_compositions = []
    composition_file = rock_path.with_name(f"{rock_path.stem}_composition.txt")
    with open(composition_file, 'r') as f:
        next(f)  # Skip header
        true_compositions = [line.strip().split('\t')[1] for line in f]
    
    final_result = None
    
    # Process spectra and collect results
    for i, spectrum in enumerate(intensities):
        spectrum_tensor = torch.from_numpy(spectrum).float()
        result = rock_analyzer.process_spectrum(spectrum_tensor, true_compositions[i])
        final_result = result
        
        # Print analysis if we have a full window
        if i >= 9:  # After 10 measurements
            print(f"\nAnalysis Results:")
            print(f"Classification: {result['rock_analysis']['classification']}")
            if 'weights' in result['rock_analysis']:
                print("Rock Type Weights:")
                for rock_type, weight in result['rock_analysis']['weights'].items():
                    print(f"- {rock_type.capitalize()}: {weight:.3f}")
    
    # Save results after all spectra are processed
    if final_result:
        weights = final_result['rock_analysis'].get('weights', {})
        # Get the highest weight as accuracy
        max_weight = max(weights.values()) if weights else 0.0
        
        results.append({
            'rock_num': rock_num,
            'classification': final_result['rock_analysis']['classification'],
            'accuracy': max_weight,  # Using highest weight as accuracy
            'weights': weights
        })
        
        # Update the analysis results before saving
        final_result['rock_analysis']['accuracy'] = max_weight
        
        # Save analysis results
        save_analysis(final_result, f'rock{rock_num}_analysis_results.txt')
        
        # Create visualization
        rock_analyzer.plot_analysis(rock_num, f'rock_analysis{rock_num}.png')

# Save final results with more detailed information
df = pd.DataFrame(results)
df.to_csv('rock_classifications.csv', index=False)

# Print summary statistics
print("\nSummary Statistics:")
print(f"Total rocks processed: {len(df)}")
print("\nClassification Distribution:")
print(df['classification'].value_counts())
print("\nAccuracy Statistics:")
print(f"Mean accuracy: {df['accuracy'].mean():.2%}")
print(f"Median accuracy: {df['accuracy'].median():.2%}")


Processing Rock 1

DEBUG: Predictions received: ['Albite', 'Anorthite', 'Quartz', 'Quartz', 'Annite', 'Muscovite', 'Quartz', 'Albite', 'Annite', 'Orthoclase']

DEBUG: Mineral counts:
quartz: 3
feldspars: 4
micas: 3
calcite: 0
pyrite: 0
rutile: 0
tourmaline: 0

DEBUG: Mineral ratios:
Quartz ratio: 0.30
Feldspar ratio: 0.40
Mica ratio: 0.30
Calcite ratio: 0.00

DEBUG: Checking granite criteria:
0.2 <= quartz_ratio <= 0.4: True
0.45 <= feldspar_ratio <= 0.8: False
0.0 <= mica_ratio <= 0.15: False

Analysis Results:
Classification: granite
Rock Type Weights:
- Granite: 1.000
- Sandstone: 0.200
- Limestone: 0.000

Processing Rock 2

DEBUG: Predictions received: ['Albite', 'Quartz', 'Annite', 'Annite', 'Muscovite', 'Orthoclase', 'Sanidine', 'Quartz', 'Quartz', 'Anorthite']

DEBUG: Mineral counts:
quartz: 3
feldspars: 3
micas: 3
calcite: 0
pyrite: 0
rutile: 0
tourmaline: 0

DEBUG: Mineral ratios:
Quartz ratio: 0.30
Feldspar ratio: 0.30
Mica ratio: 0.30
Calcite ratio: 0.00

DEBUG: Checking gr

In [5]:
from analyze_results import analyze_results

analyze_results()


Confusion Matrix:
[[5 0 0 0]
 [2 5 0 0]
 [6 3 0 2]
 [6 0 0 1]]

Classification Report:
              precision    recall  f1-score   support

     granite       0.26      1.00      0.42         5
   limestone       0.62      0.71      0.67         7
       other       0.00      0.00      0.00        11
   sandstone       0.33      0.14      0.20         7

    accuracy                           0.37        30
   macro avg       0.31      0.46      0.32        30
weighted avg       0.27      0.37      0.27        30



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
