## Hybrid classifier - baseline + expert system

In [1]:
import torch
import joblib
from pathlib import Path
import numpy as np
import pandas as pd
# from rocks_evaluation import SimpleCNN1D, UncertaintyAwareCNN1D, IntegratedRockClassifier, UncertaintyIntegratedRockClassifier, RockType, MineralGroups, save_analysis, plot_mineral_analysis
from rocks_evaluation import SimpleCNN1D, HierarchicalRockClassifier, save_analysis

In [2]:
# Load label encoder
label_encoder = joblib.load('../mineral_label_encoder.joblib')
print(label_encoder.classes_)

['Albite' 'Almandine' 'Andalusite' 'Annite' 'Anorthite' 'Calcite'
 'Dolomite' 'Epidote' 'Glaucophane' 'Jadeite' 'Kyanite' 'Muscovite'
 'Omphacite' 'Orthoclase' 'Phlogopite' 'Pyrite' 'Pyrope' 'Quartz'
 'Rhodochrosite' 'Rutile' 'Sanidine' 'Staurolite' 'Tourmaline']


In [3]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

baseline_model = SimpleCNN1D()  
# uncertainty_model = UncertaintyAwareCNN1D()

# Model configuration and loading
best_repeat = 2
best_fold = 1
best_overall_accuracy = 0.9836
best_model_path = f'../weights/model_repeat{best_repeat}_fold{best_fold}_acc{best_overall_accuracy:.4f}.pth'

# For standard classification
# Initialize model and load weights
best_model = baseline_model.to(device)  # Make sure to import your model class
best_checkpoint = torch.load(best_model_path, weights_only=False)
best_model.load_state_dict(best_checkpoint['model_state_dict'])
best_model.eval()

# Load label encoder
label_encoder = joblib.load('../mineral_label_encoder.joblib')
print(label_encoder)

# Initialize rock analyzer with HierarchicalRockClassifier instead
rock_analyzer = HierarchicalRockClassifier(
    model=best_model,
    label_encoder=label_encoder,
    device=device,
)



LabelEncoder()


In [4]:
# Load and process data
base_path = Path("../validation_rocks")
results = []

for rock_num in range(1, 31):
    rock_analyzer.analysis_history = []  # Reset analysis history

    # Process rock
    final_classification = None
    final_accuracy = None

    print(f"\nProcessing Rock {rock_num}")
    rock_folder = f"rock_{rock_num:02d}"  # This will give "rock_01", "rock_02", etc.
    rock_path = base_path / rock_folder
    wavelengths = np.load(f"{rock_path}_wavelengths.npy")
    intensities = np.load(f"{rock_path}_intensities.npy")
    
    # Load true composition
    true_compositions = []
    composition_file = rock_path.with_name(f"{rock_path.stem}_composition.txt")
    with open(composition_file, 'r') as f:
        next(f)  # Skip header
        true_compositions = [line.strip().split('\t')[1] for line in f]
    
    # Process spectra and collect results
    for i, spectrum in enumerate(intensities):
        # Prepare spectrum tensor
        spectrum_tensor = torch.from_numpy(spectrum).float()
        
        # Process spectrum
        result = rock_analyzer.process_spectrum(spectrum_tensor, true_compositions[i])
        
        # Print analysis if we have a full window
        if 'accuracy_rule' in result['rock_analysis']:
            final_classification = result['rock_analysis']['classification']
            final_accuracy = result['rock_analysis']['accuracy_rule']['accuracy']

            print(f"\nAnalysis for measurements 1-10:")
            print(f"Classification: {result['rock_analysis']['classification']}")
            print(f"Accuracy: {result['rock_analysis']['accuracy_rule']['accuracy']:.1%}")
            print("Mineral Assemblage:")
            for rule, satisfied in result['rock_analysis']['assemblage_rules']['details'].items():
                print(f"- {rule}: {'✓' if satisfied else '✗'}")
        
    # Append only once per rock
    if final_classification and final_accuracy:
        results.append({
            'rock_num': rock_num,
            'classification': final_classification, 
            'accuracy': final_accuracy
        })
            
    
    save_analysis(result, 'rock'+str(rock_num)+'_analysis_results.txt')  # Saves to 'rock_analysis_results.txt'


    # Create final visualization
    rock_analyzer.plot_analysis(rock_num, 'rock_analysis'+str(rock_num)+'.png')

# Save results after each rock
pd.DataFrame(results).to_csv('rock_classifications.csv', index=False)



Processing Rock 1

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Processing Rock 2

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

An

In [5]:
from analyze_results import analyze_results

analyze_results()


Confusion Matrix:
[[7 0 1 0]
 [0 6 2 0]
 [3 1 1 0]
 [9 0 0 0]]

Classification Report:
              precision    recall  f1-score   support

     granite       0.37      0.88      0.52         8
   limestone       0.86      0.75      0.80         8
       other       0.25      0.20      0.22         5
   sandstone       0.00      0.00      0.00         9

    accuracy                           0.47        30
   macro avg       0.37      0.46      0.39        30
weighted avg       0.37      0.47      0.39        30



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
