In [18]:
import torch
import torch.nn as nn
import numpy as np

from collections import deque
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple
import joblib

## Mineral classifier 

In [19]:

n_class = 14

class UncertaintyAwareCNN1D(nn.Module):
    def __init__(self, n_class=n_class, dropout_rate=0.3):
        super(UncertaintyAwareCNN1D, self).__init__()
        
        # Core layers - similar to your original model
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        
        # Add dropout layers for uncertainty estimation
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.dropout3 = nn.Dropout(dropout_rate)
        
        self.fc1 = nn.Linear(32 * (1000 // 4), 100)
        self.fc2 = nn.Linear(100, n_class)
        
    def forward(self, x, enable_dropout=False):
        if enable_dropout:
            self.train()  # Enable dropout even during evaluation
        else:
            self.eval()
            
        x = x.unsqueeze(1)
        
        x = self.dropout1(torch.relu(self.conv1(x)))
        x = torch.max_pool1d(x, kernel_size=2)
        
        x = self.dropout2(torch.relu(self.conv2(x)))
        x = torch.max_pool1d(x, kernel_size=2)
        
        x = x.view(x.size(0), -1)
        x = self.dropout3(torch.relu(self.fc1(x)))
        x = self.fc2(x)
        
        return x
    


## Rock Expert system 3
- Adapt into new inference rules

In [20]:
import torch
import numpy as np
import joblib
import matplotlib.pyplot as plt
from pathlib import Path
import pandas as pd
from typing import Dict, List, Tuple
from collections import deque
from enum import Enum
from dataclasses import dataclass

class RockType(Enum):
    GRANITE = 'granite'
    LIMESTONE = 'limestone'
    SANDSTONE = 'sandstone'
    OTHER = 'other'


@dataclass
class MineralGroups:
    """Mineral groups definition following Γ notation"""
    def __init__(self):
        self.feldspars = {'albite', 'anorthite', 'orthoclase', 'sanidine'}  # Γ_f
        self.quartz = {'quartz'}  # Γ_q
        self.micas = {'annite', 'eastonite', 'margarite', 'muscovite', 'phlogopite'}  # Γ_m
        self.calcite = {'calcite'}  # Γ_c
        self.pyrite = {'pyrite'}  # Γ_p
        self.rutile = {'rutile'}  # Γ_r
        self.tourmaline = {'tourmaline'}  # Γ_t


class IntegratedRockClassifier:
    def __init__(self, model, label_encoder, device, 
                 entropy_threshold: float =1.3,  # Add appropriate default
                 variance_threshold: float = 0.07,  # Add appropriate default
                 window_size: int = 10):
        """
        Initialize integrated classifier with uncertainty-aware mineral classification
        
        Args:
            model: Trained UncertaintyAwareCNN1D model
            label_encoder: Fitted LabelEncoder for mineral labels
            device: PyTorch device (cuda/cpu)
            entropy_threshold: Threshold for entropy-based uncertainty
            variance_threshold: Threshold for variance-based uncertainty
            window_size: Number of measurements to consider (fixed at 10 per rules)
        """
        self.model = model
        self.label_encoder = label_encoder
        self.device = device
        self.window_size = 10  # Fixed per rules
        self.entropy_threshold = entropy_threshold
        self.variance_threshold = variance_threshold
        self.mineral_groups = MineralGroups()
        
        # History tracking
        self.measurements = []
        self.prediction_history = []
        self.uncertainty_history = []
        self.ground_truth_history = []
        self.analysis_history = []


    def predict_mineral(self, spectrum: torch.Tensor) -> Tuple[str, np.ndarray, float, float]:
        """
        Predict mineral from spectrum using uncertainty-aware model
        
        Returns:
            Tuple containing:
            - predicted_mineral: str
            - probabilities: np.ndarray
            - entropy: float
            - variance: float
        """
        self.model.eval()
        predictions = []
        
        # Perform MC Dropout sampling
        for _ in range(20):  # reduced from 30
            with torch.no_grad():
                if len(spectrum.shape) == 1:
                    spectrum = spectrum.unsqueeze(0)
                outputs = self.model(spectrum.to(self.device), enable_dropout=True)
                probs = torch.softmax(outputs, dim=1)
                predictions.append(probs)
        
        predictions = torch.stack(predictions)
        mean_probs = predictions.mean(dim=0)
        variance = predictions.var(dim=0).mean(dim=1)
        entropy = -torch.sum(mean_probs * torch.log(mean_probs + 1e-10), dim=1)

        # Debug print
        print(f"Entropy: {entropy.item():.4f}, Variance: {variance.item():.4f}")
        print(f"Max probability: {mean_probs.max().item():.4f}")
        
        # Determine if prediction is unknown
        # Only mark as unknown if BOTH entropy and variance are high
        is_unknown = (entropy > self.entropy_threshold) & (variance > self.variance_threshold)
        # OR if the max probability is very low
        is_unknown = is_unknown | (mean_probs.max() < 0.30)
        
        if is_unknown.any():
            return "unknown", mean_probs.cpu().numpy()[0], entropy.item(), variance.item()
        else:
            predicted_idx = mean_probs.argmax(dim=1).item()
            predicted_label = self.label_encoder.inverse_transform([predicted_idx])[0]
            return predicted_label, mean_probs.cpu().numpy()[0], entropy.item(), variance.item()

    def check_accuracy_rule(self, predictions: List[str], ground_truth: List[str]) -> Dict:
        """
        Modified accuracy rule to handle unknown predictions
        """
        if len(predictions) != 10 or len(ground_truth) != 10:
            return {
                'satisfied': False,
                'accuracy': 0.0,
                'correct_predictions': 0,
                'unknown_count': 0
            }
        
        unknown_count = sum(1 for pred in predictions if pred == "unknown")
        valid_predictions = [(p, t) for p, t in zip(predictions, ground_truth) if p != "unknown"]
        
        if not valid_predictions:
            return {
                'satisfied': False,
                'accuracy': 0.0,
                'correct_predictions': 0,
                'unknown_count': unknown_count
            }
        
        correct_predictions = sum(
            1 for pred, truth in valid_predictions
            if pred.lower() == truth.lower()
        )
        
        accuracy = correct_predictions / (len(valid_predictions) if valid_predictions else 1)
        
        return {
            'satisfied': accuracy >= 0.6 and unknown_count <= 3,  # Allow some unknowns
            'accuracy': accuracy,
            'correct_predictions': correct_predictions,
            'unknown_count': unknown_count
        }

    def check_mineral_assemblage_rules(self, predictions: List[str]) -> Dict:
        """
        Modified mineral assemblage rules to handle unknown predictions
        """
        # Filter out unknown predictions
        valid_predictions = [p for p in predictions if p != "unknown"]
        
        if len(valid_predictions) < 7:  # Require at least 7 valid predictions
            return {
                'satisfied': False,
                'rock_types': {
                    'granite': False,
                    'limestone': False,
                    'sandstone': False
                },
                'valid_predictions': len(valid_predictions)
            }
        
        predictions_lower = {pred.lower() for pred in valid_predictions}
        
        granite_rules = {
            'feldspar_present': bool(predictions_lower & self.mineral_groups.feldspars),
            'quartz_present': bool(predictions_lower & self.mineral_groups.quartz),
            'mica_present': bool(predictions_lower & self.mineral_groups.micas)
        }
        
        limestone_rules = {
            'calcite_present': bool(predictions_lower & self.mineral_groups.calcite),
            'low_silicates': len(predictions_lower & (self.mineral_groups.quartz | 
                                                    self.mineral_groups.feldspars)) <= 1
        }
        
        sandstone_rules = {
            'quartz_present': bool(predictions_lower & self.mineral_groups.quartz),
            'feldspar_present': bool(predictions_lower & self.mineral_groups.feldspars),
            'accessory_present': bool(predictions_lower & (self.mineral_groups.micas | 
                                                         self.mineral_groups.rutile | 
                                                         self.mineral_groups.tourmaline))
        }
        
        mineral_counts = {
            'feldspars': sum(1 for p in predictions if p.lower() in self.mineral_groups.feldspars),
            'quartz': sum(1 for p in predictions if p.lower() in self.mineral_groups.quartz),
            'micas': sum(1 for p in predictions if p.lower() in self.mineral_groups.micas),
            'calcite': sum(1 for p in predictions if p.lower() in self.mineral_groups.calcite),
            'pyrite': sum(1 for p in predictions if p.lower() in self.mineral_groups.pyrite),
            'rutile': sum(1 for p in predictions if p.lower() in self.mineral_groups.rutile),
            'tourmaline': sum(1 for p in predictions if p.lower() in self.mineral_groups.tourmaline)
        }

        is_granite = all(granite_rules.values())
        is_limestone = all(limestone_rules.values())
        is_sandstone = all(sandstone_rules.values())
        
        return {
            'satisfied': any([is_granite, is_limestone, is_sandstone]),
            'details': {
                'granite': granite_rules,
                'limestone': limestone_rules,
                'sandstone': sandstone_rules
            },
            'counts': mineral_counts,
            'rock_types': {
                'granite': is_granite,
                'limestone': is_limestone,
                'sandstone': is_sandstone
            },
            'valid_predictions': len(valid_predictions)
        }

    
    def process_spectrum(self, spectrum: torch.Tensor, true_mineral: str = None) -> Dict:
        """
        Process spectrum with uncertainty-aware classification
        
        Args:
            spectrum: Input spectrum tensor. Can be:
                     - 1D: [sequence_length]
                     - 2D: [batch_size, sequence_length]
                     - 3D: [batch_size, sequence_length, channels]
            true_mineral: Optional ground truth mineral label
        """
        # # Now spectrum should be in shape [batch, channels, length]
        # print(f"Spectrum shape after processing: {spectrum.shape}")
        
        # Get mineral prediction with uncertainty
        predicted_mineral, probabilities, entropy, variance = self.predict_mineral(spectrum)
        
        # Store prediction and uncertainty metrics
        self.prediction_history.append(predicted_mineral)
        self.uncertainty_history.append({'entropy': entropy, 'variance': variance})
        if true_mineral is not None:
            self.ground_truth_history.append(true_mineral)
        
        # Keep only the last 10 measurements
        if len(self.prediction_history) > 10:
            self.prediction_history.pop(0)
            self.uncertainty_history.pop(0)
            if self.ground_truth_history:
                self.ground_truth_history.pop(0)
        
        # Perform rock analysis
        if len(self.prediction_history) == 10:
            accuracy_analysis = self.check_accuracy_rule(
                self.prediction_history, 
                self.ground_truth_history if self.ground_truth_history else self.prediction_history
            )
            assemblage_analysis = self.check_mineral_assemblage_rules(
                self.prediction_history
            )
            
            # Determine final classification
            if accuracy_analysis['satisfied'] and assemblage_analysis['satisfied']:
                if assemblage_analysis['rock_types']['granite']:
                    classification = RockType.GRANITE
                elif assemblage_analysis['rock_types']['limestone']:
                    classification = RockType.LIMESTONE
                elif assemblage_analysis['rock_types']['sandstone']:
                    classification = RockType.SANDSTONE
                else:
                    classification = RockType.OTHER
            else:
                classification = RockType.OTHER
            
            rock_analysis = {
                'classification': classification.value,
                'accuracy_rule': accuracy_analysis,
                'assemblage_rules': assemblage_analysis,
                'uncertainty_metrics': {
                    'mean_entropy': np.mean([u['entropy'] for u in self.uncertainty_history]),
                    'mean_variance': np.mean([u['variance'] for u in self.uncertainty_history]),
                    'unknown_predictions': accuracy_analysis['unknown_count']
                }
            }
        else:
            rock_analysis = {
                'classification': RockType.OTHER.value,
                'status': 'Insufficient measurements',
                'current_count': len(self.prediction_history)
            }
        
        # Store analysis
        self.analysis_history.append({
            'mineral_prediction': predicted_mineral,
            'true_mineral': true_mineral,
            'uncertainty': {'entropy': entropy, 'variance': variance},
            'rock_analysis': rock_analysis,
            'measurement_number': len(self.prediction_history)
        })
        
        return {
            'mineral_prediction': predicted_mineral,
            'mineral_probabilities': probabilities,
            'uncertainty': {
                'entropy': entropy,
                'variance': variance
            },
            'rock_analysis': rock_analysis
        }
    
    
    def plot_analysis(self, rock_num, save_path: str = None):
        """
        Create comprehensive visualization of the analysis results
        """
        if not self.analysis_history:
            print("No data to plot")
            return
            
        plt.style.use('ggplot')
        # fig = plt.figure(figsize=(15, 12))
        fig = plt.figure(figsize=(10, 6), dpi=300)
        gs = plt.GridSpec(3, 1, height_ratios=[2, 1, 1])
        
        # Plot 1: Mineral Predictions vs Ground Truth
        ax1 = fig.add_subplot(gs[0])
        measurements = range(1, len(self.analysis_history) + 1)
        
        # Plot ground truth
        true_minerals = [a['true_mineral'] for a in self.analysis_history]
        predicted_minerals = [a['mineral_prediction'] for a in self.analysis_history]

        print("ground-truth",true_minerals)
        print("predictions",predicted_minerals)
        
        ax1.scatter(measurements, true_minerals, label='Ground Truth', 
                marker='o', s=100, alpha=0.6)
        ax1.scatter(measurements, predicted_minerals, label='Predicted',
                marker='x', s=100, alpha=0.8)
        
        ax1.set_title(f'Mineral Predictions vs Ground Truth of Test Sample {rock_num}', fontsize=14, fontweight='bold', pad=20)
        ax1.set_xlabel('Measurement Number', fontsize=12, fontweight='bold')
        ax1.set_ylabel('Mineral', fontsize=12, fontweight='bold')
        ax1.legend(loc='upper left')
        plt.setp(ax1.get_xticklabels())
        
        # # Plot 2: Classification Results per Window
        # # this need to modify to adapt for more than one rock type
        # ax2 = fig.add_subplot(gs[1])
        # window_results = []
        # window_start = 0
        
        # for i in range(len(self.analysis_history)):
        #     if i >= 9:  # We have a full window of 10 measurements
        #         window_true = true_minerals[window_start:i+1]
        #         window_pred = predicted_minerals[window_start:i+1]
                
        #         # Calculate accuracy
        #         correct_count = sum(1 for t, p in zip(window_true, window_pred) if t == p)
        #         accuracy = correct_count / 10
                
        #         # Check mineral assemblage
        #         assemblage = self.check_mineral_assemblage_rules(window_pred)
                
        #         # Store window results
        #         window_results.append({
        #             'window': f'{window_start+1}-{i+1}',
        #             'accuracy': accuracy,
        #             'accuracy_satisfied': accuracy >= 0.6,
        #             'assemblage_satisfied': assemblage['satisfied'],
        #             'is_granite': accuracy >= 0.6 and assemblage['satisfied']
        #         })
        #         window_start += 1
        
        # if window_results:
        #     # Create classification result visualization
        #     windows = [w['window'] for w in window_results]
        #     x = np.arange(len(windows))
        #     bar_width = 0.35
            
        #     # Plot accuracy bars
        #     accuracy_bars = ax2.bar(x - bar_width/2, 
        #                         [w['accuracy'] for w in window_results],
        #                         bar_width, label='Accuracy', color='lightblue')
            
        #     # Add accuracy threshold line
        #     ax2.axhline(y=0.6, color='red', linestyle='--', label='Required Accuracy (60%)')
            
        #     # Add granite classification markers
        #     for i, result in enumerate(window_results):
        #         if result['is_granite']:
        #             ax2.text(i, 1.05, '✓ Granite', ha='center', va='bottom', color='green')
        #         elif not result['is_granite']:
        #             ax2.text(i, 1.05, '✗ Granite', ha='center', va='bottom', color='red')
        #         if result['is_granite']:
        #             ax2.text(i, 1.05, '✓ Granite', ha='center', va='bottom', color='green')
        #         elif not result['is_granite']:
        #             ax2.text(i, 1.05, '✗ Granite', ha='center', va='bottom', color='red')

        #         else:
        #             ax2.text(i, 1.05, 'Other', ha='center', va='bottom', color='black')
            
        #     # Add accuracy labels
        #     for i, rect in enumerate(accuracy_bars):
        #         height = rect.get_height()
        #         ax2.text(rect.get_x() + rect.get_width()/2., height,
        #                 f'{height:.0%}',
        #                 ha='center', va='bottom')
            
        #     ax2.set_title('Classification Results per Window')
        #     ax2.set_xlabel('Measurement Windows')
        #     ax2.set_ylabel('Accuracy')
        #     ax2.set_xticks(x)
        #     ax2.set_xticklabels(windows, rotation=45)
        #     ax2.set_ylim(0, 1.2)  # Make room for granite labels
        #     ax2.legend()
        
        # # Plot 3: Mineral Assemblage Analysis (Bar Chart)
        # ax3 = fig.add_subplot(gs[2])
        
        # # Get mineral counts for each window
        # window_data = []
        # window_start = 0
        
        # for i in range(len(self.analysis_history)):
        #     if i >= 9:  # We have a full window
        #         window_pred = predicted_minerals[window_start:i+1]
        #         assemblage = self.check_mineral_assemblage_rules(window_pred)
        #         counts = assemblage['counts']
        #         window_data.append({
        #             'window': f'{window_start+1}-{i+1}',
        #             'counts': counts
        #         })
        #         window_start += 1
        
        # if window_data:
        #     # Prepare data for stacked bar chart
        #     windows = [d['window'] for d in window_data]
        #     feldspar_counts = [d['counts']['feldspars'] for d in window_data]
        #     quartz_counts = [d['counts']['quartz'] for d in window_data]
        #     mica_counts = [d['counts']['micas'] for d in window_data]
            
        #     # Create stacked bar chart
        #     bar_width = 0.8
            
        #     # Plot bars with different colors
        #     bars1 = ax3.bar(range(len(windows)), feldspar_counts, bar_width,
        #                 label='Feldspars', color='#1f77b4', alpha=0.7)
        #     bars2 = ax3.bar(range(len(windows)), quartz_counts, bar_width,
        #                 bottom=feldspar_counts, label='Quartz', color='#2ca02c', alpha=0.7)
            
        #     # Calculate the total height for the mica bars
        #     total_height = [f + q for f, q in zip(feldspar_counts, quartz_counts)]
        #     bars3 = ax3.bar(range(len(windows)), mica_counts, bar_width,
        #                 bottom=total_height, label='Micas', color='#ff7f0e', alpha=0.7)
            
        #     # Customize the plot
        #     ax3.set_title('Mineral Group Counts per Window')
        #     ax3.set_xlabel('Measurement Windows')
        #     ax3.set_ylabel('Count')
        #     ax3.set_xticks(range(len(windows)))
        #     ax3.set_xticklabels(windows, rotation=45)
        #     ax3.legend()
            
        #     # Add count labels on the bars
        #     def add_labels(bars, bottom=None):
        #         for bar in bars:
        #             height = bar.get_height()
        #             if height > 0:  # Only add label if there's a count
        #                 y_pos = bar.get_y() + height/2.
        #                 ax3.text(bar.get_x() + bar.get_width()/2., y_pos,
        #                         f'{int(height)}',
        #                         ha='center', va='center')
            
        #     add_labels(bars1)
        #     add_labels(bars2)
        #     add_labels(bars3)
            
        #     # Add horizontal line for minimum requirements
        #     ax3.axhline(y=1, color='r', linestyle='--', alpha=0.5, 
        #                 label='Minimum Required (1)')
            
        # plt.tight_layout()
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
        else:
            plt.show()
        
        plt.close()


In [21]:
def save_analysis(result, filename="rock_analysis_results.txt"):
    """
    Save rock analysis results to a text file
    
    Args:
        result (dict): Analysis result dictionary
        filename (str): Name of output file
    """
    with open(filename, 'w') as f:
        # Print analysis if we have a full window
        if 'accuracy_rule' in result['rock_analysis']:
            f.write(f"\nAnalysis Results:\n")
            f.write(f"Classification: {result['rock_analysis']['classification']}\n")
            f.write(f"Accuracy: {result['rock_analysis']['accuracy_rule']['accuracy']:.1%}\n")
            f.write("\nMineral Assemblage:\n")
            for rule, satisfied in result['rock_analysis']['assemblage_rules']['details'].items():
                f.write(f"- {rule}: {'✓' if satisfied else '✗'}\n")


In [22]:


def main():
    # Device configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Model configuration and loading
    best_model_path = f'./models/best_model.pth'
    
    # Initialize model and load weights
    best_model = UncertaintyAwareCNN1D().to(device)  # Make sure to import your model class
    best_checkpoint = torch.load(best_model_path)
    best_model.load_state_dict(best_checkpoint['model_state_dict'])
    best_model.eval()
    
    # Load label encoder
    label_encoder = joblib.load('mineral_label_encoder.joblib')
    
    
    # Initialize rock analyzer
    rock_analyzer = IntegratedRockClassifier(
        model=best_model,
        label_encoder=label_encoder,
        device=device,
    )
    
    # Load and process data
    base_path = Path("/home/iyeszin/Desktop/my git/minerals-rocks-relationship/all about rruff/validation_rocks")
    results = []
    
    for rock_num in range(1, 31):
        rock_analyzer.analysis_history = []  # Reset analysis history

        # Process rock
        final_classification = None
        final_accuracy = None

        print(f"\nProcessing Rock {rock_num}")
        rock_folder = f"rock_{rock_num:02d}"  # This will give "rock_01", "rock_02", etc.
        rock_path = base_path / rock_folder
        wavelengths = np.load(f"{rock_path}_wavelengths.npy")
        intensities = np.load(f"{rock_path}_intensities.npy")
        
        # Load true composition
        true_compositions = []
        composition_file = rock_path.with_name(f"{rock_path.stem}_composition.txt")
        with open(composition_file, 'r') as f:
            next(f)  # Skip header
            true_compositions = [line.strip().split('\t')[1] for line in f]
        
        # Process spectra and collect results
        for i, spectrum in enumerate(intensities):
            # Prepare spectrum tensor
            spectrum_tensor = torch.from_numpy(spectrum).float()
            
            # Process spectrum
            result = rock_analyzer.process_spectrum(spectrum_tensor, true_compositions[i])
            
            # Print analysis if we have a full window
            if 'accuracy_rule' in result['rock_analysis']:
                final_classification = result['rock_analysis']['classification']
                final_accuracy = result['rock_analysis']['accuracy_rule']['accuracy']

                print(f"\nAnalysis for measurements 1-10:")
                print(f"Classification: {result['rock_analysis']['classification']}")
                print(f"Accuracy: {result['rock_analysis']['accuracy_rule']['accuracy']:.1%}")
                print("Mineral Assemblage:")
                for rule, satisfied in result['rock_analysis']['assemblage_rules']['details'].items():
                    print(f"- {rule}: {'✓' if satisfied else '✗'}")
            
        # Append only once per rock
        if final_classification and final_accuracy:
            results.append({
                'rock_num': rock_num,
                'classification': final_classification, 
                'accuracy': final_accuracy
            })
                
        
        save_analysis(result, 'rock'+str(rock_num)+'_analysis_results.txt')  # Saves to 'rock_analysis_results.txt'


        # Create final visualization
        rock_analyzer.plot_analysis(rock_num, 'rock_analysis'+str(rock_num)+'.png')

    # Save results after each rock
    pd.DataFrame(results).to_csv('rock_classifications.csv', index=False)

if __name__ == "__main__":
    main()

  best_checkpoint = torch.load(best_model_path)



Processing Rock 1
Entropy: 0.9809, Variance: 0.0393
Max probability: 0.6504
Entropy: 0.5416, Variance: 0.0168
Max probability: 0.8652
Entropy: 0.3764, Variance: 0.0123
Max probability: 0.9061
Entropy: 0.1985, Variance: 0.0071
Max probability: 0.9500
Entropy: 1.5855, Variance: 0.0318
Max probability: 0.4479
Entropy: 1.1943, Variance: 0.0316
Max probability: 0.6454
Entropy: 0.0000, Variance: 0.0000
Max probability: 1.0000
Entropy: 1.5295, Variance: 0.0560
Max probability: 0.3468
Entropy: 1.5610, Variance: 0.0337
Max probability: 0.4802
Entropy: 1.5275, Variance: 0.0508
Max probability: 0.4592

Analysis for measurements 1-10:
Classification: granite
Accuracy: 90.0%
Mineral Assemblage:
- granite: ✓
- limestone: ✓
- sandstone: ✓
ground-truth ['Albite', 'Anorthite', 'Quartz', 'Quartz', 'Annite', 'Muscovite', 'Quartz', 'Albite', 'Annite', 'Orthoclase']
predictions ['Albite', 'Anorthite', 'Quartz', 'Quartz', 'Annite', 'Muscovite', 'Quartz', 'Pyrite', 'Annite', 'Orthoclase']
  agg_filter: a fi

In [23]:
def analyze_results():
    import pandas as pd
    from sklearn.metrics import confusion_matrix, classification_report
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    # Load saved classifications
    results_df = pd.read_csv('rock_classifications.csv')
    
    # Ground truth list
    # Borderline is consider the rock type
    # Not the rock type is consider other
    ground_truths = [
        'granite', 'granite', 'granite', 'granite', 'granite', 
        'granite', 'granite', 'granite', 'other', 'other',
        'sandstone', 'sandstone', 'sandstone', 'sandstone', 'sandstone',
        'sandstone', 'other', 'sandstone', 'sandstone', 'sandstone',
        'limestone', 'other', 'limestone', 'other', 'limestone',
        'limestone', 'limestone', 'limestone', 'limestone', 'limestone'
    ]
    
    # Create ground truth dataframe
    ground_truth_df = pd.DataFrame({
        'rock_num': range(1, 31),
        'ground_truth': ground_truths
    })
    
    # Merge and analyze
    final_df = pd.merge(results_df, ground_truth_df, on='rock_num')
    
    cm = confusion_matrix(final_df['ground_truth'], final_df['classification'])
    labels = ['Granite', 'Limestone', 'Sandstone', 'Other']
    print("\nConfusion Matrix:")
    print(cm)

    # Plot confusion matrix
    plt.figure(figsize=(10,8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=labels, 
            yticklabels=labels)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True Label')
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    print("\nClassification Report:")
    print(classification_report(final_df['ground_truth'], final_df['classification']))


    final_df.to_csv('rock_classification_analysis.csv', index=False)

if __name__ == "__main__":
    analyze_results()


Confusion Matrix:
[[8 0 0 0]
 [0 5 3 0]
 [3 1 1 0]
 [8 0 0 1]]

Classification Report:
              precision    recall  f1-score   support

     granite       0.42      1.00      0.59         8
   limestone       0.83      0.62      0.71         8
       other       0.25      0.20      0.22         5
   sandstone       1.00      0.11      0.20         9

    accuracy                           0.50        30
   macro avg       0.63      0.48      0.43        30
weighted avg       0.68      0.50      0.45        30

