## Hybrid classifier - baseline + expert system

In [9]:
import torch
import joblib
from pathlib import Path
import numpy as np
import pandas as pd
from rocks_evaluation import SimpleCNN1D, UncertaintyAwareCNN1D, IntegratedRockClassifier, UncertaintyIntegratedRockClassifier, RockType, MineralGroups, save_analysis, plot_mineral_analysis


In [5]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

baseline_model = SimpleCNN1D()  
# uncertainty_model = UncertaintyAwareCNN1D()

# For standard classification


# Model configuration and loading
best_repeat = 0
best_fold = 0
best_overall_accuracy = 0.9927
best_model_path = f'./weights/model_repeat{best_repeat}_fold{best_fold}_acc{best_overall_accuracy:.4f}.pth'

# Initialize model and load weights
best_model = baseline_model.to(device)  # Make sure to import your model class
best_checkpoint = torch.load(best_model_path, weights_only=False)
best_model.load_state_dict(best_checkpoint['model_state_dict'])
best_model.eval()

# Load label encoder
label_encoder = joblib.load('mineral_label_encoder.joblib')

# Initialize rock analyzer
rock_analyzer = IntegratedRockClassifier(
    model=best_model,
    label_encoder=label_encoder,
    device=device,
)



In [10]:
# Load and process data
base_path = Path("validation_rocks")
results = []

for rock_num in range(1, 31):
    rock_analyzer.analysis_history = []  # Reset analysis history

    # Process rock
    final_classification = None
    final_accuracy = None

    print(f"\nProcessing Rock {rock_num}")
    rock_folder = f"rock_{rock_num:02d}"  # This will give "rock_01", "rock_02", etc.
    rock_path = base_path / rock_folder
    wavelengths = np.load(f"{rock_path}_wavelengths.npy")
    intensities = np.load(f"{rock_path}_intensities.npy")
    
    # Load true composition
    true_compositions = []
    composition_file = rock_path.with_name(f"{rock_path.stem}_composition.txt")
    with open(composition_file, 'r') as f:
        next(f)  # Skip header
        true_compositions = [line.strip().split('\t')[1] for line in f]
    
    # Process spectra and collect results
    for i, spectrum in enumerate(intensities):
        # Prepare spectrum tensor
        spectrum_tensor = torch.from_numpy(spectrum).float()
        
        # Process spectrum
        result = rock_analyzer.process_spectrum(spectrum_tensor, true_compositions[i])
        
        # Print analysis if we have a full window
        if 'accuracy_rule' in result['rock_analysis']:
            final_classification = result['rock_analysis']['classification']
            final_accuracy = result['rock_analysis']['accuracy_rule']['accuracy']

            print(f"\nAnalysis for measurements 1-10:")
            print(f"Classification: {result['rock_analysis']['classification']}")
            print(f"Accuracy: {result['rock_analysis']['accuracy_rule']['accuracy']:.1%}")
            print("Mineral Assemblage:")
            for rule, satisfied in result['rock_analysis']['assemblage_rules']['details'].items():
                print(f"- {rule}: {'✓' if satisfied else '✗'}")
        
    # Append only once per rock
    if final_classification and final_accuracy:
        results.append({
            'rock_num': rock_num,
            'classification': final_classification, 
            'accuracy': final_accuracy
        })
            
    
    save_analysis(result, 'rock'+str(rock_num)+'_analysis_results.txt')  # Saves to 'rock_analysis_results.txt'


    # Create final visualization
    rock_analyzer.plot_analysis(rock_num, 'rock_analysis'+str(rock_num)+'.png')

# Save results after each rock
pd.DataFrame(results).to_csv('rock_classifications.csv', index=False)



Processing Rock 1

Analysis for measurements 1-10:
Classification: limestone
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: limestone
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: other
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: other
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurements 1-10:
Classification: granite
Accuracy: 100.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓

Analysis for measurem

In [7]:


def main():
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Model configuration and loading
    best_repeat = 0
    best_fold = 0
    best_overall_accuracy = 0.9927
    best_model_path = f'./models/model_repeat{best_repeat}_fold{best_fold}_acc{best_overall_accuracy:.4f}.pth'
    
    # Initialize model and load weights
    best_model = SimpleCNN1D().to(device)  # Make sure to import your model class
    best_checkpoint = torch.load(best_model_path)
    best_model.load_state_dict(best_checkpoint['model_state_dict'])
    best_model.eval()
    
    # Load label encoder
    label_encoder = joblib.load('mineral_label_encoder.joblib')
    
    # Initialize rock analyzer
    rock_analyzer = IntegratedRockClassifier(
        model=best_model,
        label_encoder=label_encoder,
        device=device,
    )
    
    # Load and process data
    # rock_num = 1
    base_path = Path("/home/iyeszin/Desktop/my git/minerals-rocks-relationship/all about rruff/validation_rocks")
    results = []
    
    for rock_num in range(1, 31):
        rock_analyzer.analysis_history = []  # Reset analysis history

        # Process rock
        final_classification = None
        final_accuracy = None

        print(f"\nProcessing Rock {rock_num}")
        rock_folder = f"rock_{rock_num:02d}"  # This will give "rock_01", "rock_02", etc.
        rock_path = base_path / rock_folder
        wavelengths = np.load(f"{rock_path}_wavelengths.npy")
        intensities = np.load(f"{rock_path}_intensities.npy")
        
        # Load true composition
        true_compositions = []
        composition_file = rock_path.with_name(f"{rock_path.stem}_composition.txt")
        with open(composition_file, 'r') as f:
            next(f)  # Skip header
            true_compositions = [line.strip().split('\t')[1] for line in f]
        
        # Process spectra and collect results
        for i, spectrum in enumerate(intensities):
            # Prepare spectrum tensor
            spectrum_tensor = torch.from_numpy(spectrum).float()
            
            # Process spectrum
            result = rock_analyzer.process_spectrum(spectrum_tensor, true_compositions[i])
            
            # Print analysis if we have a full window
            if 'accuracy_rule' in result['rock_analysis']:
                final_classification = result['rock_analysis']['classification']
                final_accuracy = result['rock_analysis']['accuracy_rule']['accuracy']

                print(f"\nAnalysis for measurements 1-10:")
                print(f"Classification: {result['rock_analysis']['classification']}")
                print(f"Accuracy: {result['rock_analysis']['accuracy_rule']['accuracy']:.1%}")
                print("Mineral Assemblage:")
                for rule, satisfied in result['rock_analysis']['assemblage_rules']['details'].items():
                    print(f"- {rule}: {'✓' if satisfied else '✗'}")
            
        # Append only once per rock
        if final_classification and final_accuracy:
            results.append({
                'rock_num': rock_num,
                'classification': final_classification, 
                'accuracy': final_accuracy
            })
                
        
        save_analysis(result, 'rock'+str(rock_num)+'_analysis_results.txt')  # Saves to 'rock_analysis_results.txt'


        # Create final visualization
        rock_analyzer.plot_analysis(rock_num, 'rock_analysis'+str(rock_num)+'.png')

    # Save results after each rock
    pd.DataFrame(results).to_csv('rock_classifications.csv', index=False)

if __name__ == "__main__":
    main()

  best_checkpoint = torch.load(best_model_path)



Processing Rock 1

Analysis for measurements 1-10:
Classification: granite
Accuracy: 80.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓
ground-truth ['Albite', 'Anorthite', 'Quartz', 'Quartz', 'Annite', 'Muscovite', 'Quartz', 'Albite', 'Annite', 'Orthoclase']
predictions ['Albite', 'Anorthite', 'Quartz', 'Quartz', 'Annite', 'Muscovite', 'Quartz', 'Anorthite', 'Annite', 'Sanidine']
  agg_filter: a filter function, which takes a (m, n, 3) float array and a dpi value, and returns a (m, n, 3) array and two offsets from the bottom left corner of the image
  alpha: scalar or None
  animated: bool
  antialiased: bool
  backgroundcolor: color
  bbox: dict with properties for `.patches.FancyBboxPatch`
  clip_box: `~matplotlib.transforms.BboxBase` or None
  clip_on: bool
  clip_path: Patch or (Path, Transform) or None
  color or c: color
  figure: `~matplotlib.figure.Figure`
  fontfamily or family or fontname: {FONTNAME, 'serif', 'sans-serif', 'cursive', 'fantasy', 'monospace'

In [8]:
def analyze_results():
    import pandas as pd
    from sklearn.metrics import confusion_matrix, classification_report
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Load saved classifications
    results_df = pd.read_csv('rock_classifications.csv')
    
    # Ground truth list
    # Borderline is consider the rock type
    # Not the rock type is consider other
    ground_truths = [
        'granite', 'granite', 'granite', 'granite', 'granite', 
        'granite', 'granite', 'granite', 'other', 'other',
        'sandstone', 'sandstone', 'sandstone', 'sandstone', 'sandstone',
        'sandstone', 'other', 'sandstone', 'sandstone', 'sandstone',
        'limestone', 'other', 'limestone', 'other', 'limestone',
        'limestone', 'limestone', 'limestone', 'limestone', 'limestone'
    ]
    
    # Create ground truth dataframe
    ground_truth_df = pd.DataFrame({
        'rock_num': range(1, 31),
        'ground_truth': ground_truths
    })
    
    # Merge and analyze
    final_df = pd.merge(results_df, ground_truth_df, on='rock_num')
    
    cm = confusion_matrix(final_df['ground_truth'], final_df['classification'])
    labels = ['Granite', 'Limestone', 'Sandstone', 'Other']
    print("\nConfusion Matrix:")
    print(cm)

    # Plot confusion matrix
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=labels, 
            yticklabels=labels)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True Label')
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    print("\nClassification Report:")
    print(classification_report(final_df['ground_truth'], final_df['classification']))


    final_df.to_csv('rock_classification_analysis.csv', index=False)

if __name__ == "__main__":
    analyze_results()


Confusion Matrix:
[[8 0 0 0]
 [0 6 2 0]
 [3 1 1 0]
 [9 0 0 0]]

Classification Report:
              precision    recall  f1-score   support

     granite       0.40      1.00      0.57         8
   limestone       0.86      0.75      0.80         8
       other       0.33      0.20      0.25         5
   sandstone       0.00      0.00      0.00         9

    accuracy                           0.50        30
   macro avg       0.40      0.49      0.41        30
weighted avg       0.39      0.50      0.41        30



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
