# Project 3 Codebase
- !pip install azure-ai-inference
- !pip install azure-core



## Plot Molecules

### Plot a list of SMILES

In [None]:
Chem.MolFromSmiles(smile)

In [None]:
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from rdkit.Chem import DataStructs

# Define the two SMILES strings
smiles1 =  'Cc1ccc2c(c1C)N=Nc1c(F)c(F)nc(F)c1C2'
smiles2 ='Cc1ccc2c(c1C)N=NC(=N)c1c(F)c(F)nc(F)c1C2' 

# Convert SMILES to RDKit molecules
mol1 = Chem.MolFromSmiles(smiles1)
mol2 = Chem.MolFromSmiles(smiles2)

# Calculate different types of fingerprints for comparison
# Morgan (ECFP) fingerprints
morgan_fp1 = AllChem.GetMorganFingerprintAsBitVect(mol1, 2, nBits=2048)
morgan_fp2 = AllChem.GetMorganFingerprintAsBitVect(mol2, 2, nBits=2048)
morgan_similarity = DataStructs.TanimotoSimilarity(morgan_fp1, morgan_fp2)

# Topological fingerprints (similar to Daylight)
topo_fp1 = AllChem.RDKFingerprint(mol1)
topo_fp2 = AllChem.RDKFingerprint(mol2)
topo_similarity = DataStructs.TanimotoSimilarity(topo_fp1, topo_fp2)

# MACCS keys
maccs_fp1 = AllChem.GetMACCSKeysFingerprint(mol1)
maccs_fp2 = AllChem.GetMACCSKeysFingerprint(mol2)
maccs_similarity = DataStructs.TanimotoSimilarity(maccs_fp1, maccs_fp2)

# AtomPair fingerprints
atom_pair_fp1 = AllChem.GetAtomPairFingerprint(mol1)
atom_pair_fp2 = AllChem.GetAtomPairFingerprint(mol2)
atom_pair_similarity = DataStructs.TanimotoSimilarity(atom_pair_fp1, atom_pair_fp2)

# Print results
print(f"Compound 1: {smiles1}")
print(f"Compound 2: {smiles2}")
print("\nTanimoto Similarity Scores:")
print(f"Morgan (ECFP4) Similarity: {morgan_similarity:.4f}")
print(f"Topological Similarity: {topo_similarity:.4f}")
print(f"MACCS Keys Similarity: {maccs_similarity:.4f}")
print(f"Atom Pair Similarity: {atom_pair_similarity:.4f}")

# Calculate average similarity
avg_similarity = (morgan_similarity + topo_similarity + maccs_similarity + atom_pair_similarity) / 4
print(f"\nAverage Tanimoto Similarity: {avg_similarity:.4f}")

In [None]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt

# List of SMILES strings
smiles_list = [             "CCN(CC)CCOC(=O)c1ccccc1N",
          "CCN(C)CCCOC(=O)c1ccccc1N",
          "CCN(CC)CCOC(=O)c1ccc(N)cc1",
          "CCN(C)C(C)COC(=O)c1ccccc1N",
          "CCN(CC)CCOC(=O)c1cccc(N)c1",
          "CCN(CC)CCOC(=O)Nc1ccccc1",
          "CCC(CNC)COC(=O)c1ccccc1N",
          "CCC(CO)N(CC)C(=O)c1ccccc1N",
          "CCN(CCC#N)CCOC(=O)c1cc[nH]n1",
          "CCN(CC)CC(O)c1ccccc1C(N)=O",
          "CCN(C)CCCOC(=O)c1ccc(N)cc1",
          "CCOC(=O)c1cccc(CN(CC)CC)n1",
          "CCN(C)CCCOC(=O)c1cccc(N)c1",
          "CNCCCOC(=O)c1cccc(C=O)c1N",
          "CC(=O)N(CCOC(=O)c1ccc[nH]1)C1CC1",
          "CCN(CC=O)CC(CO)c1ccccc1N",
          "CCN(CCOC(=O)C1=CC=NC1=O)C1CC1",
          "CCN(C)C(C)COC(=O)c1ccc(N)cc1",
          "CCN(C)C(C)COC(=O)c1cccc(N)c1",
          "CC(=O)CN(CCO)C(C)c1ccccc1N",
          "CCC(CO)CN=C(CO)c1ccccc1N",
          "CCN(C)CC(C)NC(=O)c1ccccc1O",
          "CCN(CC)CC=CCOC(=O)c1cc[nH]c1",
          "CCNC(=O)CCOC(=O)c1cccc(N)c1",
          "CCN(CC)CC(=O)c1cccc([N+](N)=O)c1",
          "CCOC(=O)CN(CC)c1ccccc1NC",
          "CCN(CC)CCOC1C=CC(C=O)=CC1=N",
          "CCC(=O)CNCC(CO)c1ccccc1N",
          "CCN(CCO)C(C=O)c1cc(C)ccc1N",
          "CNCCC(C)OC(=O)c1ccccc1CN",
          "C=CCN(CC)C(=O)COC(=O)c1cc[nH]c1",
          "CCN(CC)CC=C(C(=O)CO)c1cc[nH]c1",
          "CCN(C)CCC=CCOC(=O)c1cc[nH]c1",
          "CCCCC(=O)COC(=O)c1cccc(N)n1",
          "CCNCC(CC)OC(=O)c1cccc(N)c1",
          "CC(COC(=O)c1cccc(N)c1)CN(C)C",
          "CCN(C)CC(=C[N+])COC(=O)c1cc[nH]n1",
          "CCC(CNC)COC(=O)c1cccc(N)c1",
          "CCC(C)COC(=O)c1nccc(C=O)c1N",
          "C=CCCOC(=O)c1cc(C(=O)NCC)c[nH]1",
          "Nc1ccc(C(=O)N(CCO)CCC=O)cc1",
          "CCC(OCCNC)c1cccc(C=O)c1N",
          "CNCc1cccc(C(=O)OC(C)CNC)c1",
          "CC[N+]OCCCNN=Cc1ccccc1O",
          "CC1C(CO)CCN1C(=O)c1coccc1=N",
          "CCC(CO)N(C)CC=CN=c1cc[nH]cn1",
          "CC(N)CNC(=O)c1ccc(C(=O)CO)cc1",
          "CCC(=CCN(C)CC)c1c[nH]c(C(=O)O)c1",
          "CCC(C)N(O)CC=C(O)c1ccc(C)nc1",
          "C=CC[N+](=O)c1ccc(CN(C)CCC=O)[nH]1"]
        


# Convert SMILES to molecules
mols = [Chem.MolFromSmiles(smile) for smile in smiles_list]

# Generate 2D coordinates for each molecule
for mol in mols:
    AllChem.Compute2DCoords(mol)

# Create a figure with subplots
fig = plt.figure(figsize=(20, 120))
for idx, mol in enumerate(mols, 1):
    ax = fig.add_subplot(80, 4, idx)
    img = Draw.MolToImage(mol)
    ax.imshow(img)
    ax.axis('off')
    ax.set_title(f'Molecule {idx}')

plt.tight_layout()
plt.show()

# Optional: Save the figure
# plt.savefig('molecules.png', dpi=300, bbox_inches='tight')

# If you want to display molecules in a grid with legends showing SMILES
def display_mol_grid(smiles_list, legends=None, molsPerRow=5):
    mols = [Chem.MolFromSmiles(smile) for smile in smiles_list]
    if legends is None:
        legends = [f'Molecule {i+1}' for i in range(len(mols))]
    img = Draw.MolsToGridImage(mols, 
                             molsPerRow=molsPerRow,
                             subImgSize=(300,300),
                             legends=legends,
                             returnPNG=False)
    return img

# Display molecules in a grid with SMILES as legends
img = display_mol_grid(smiles_list, legends=smiles_list)
# img.save('molecules_grid.png')  # Uncomment to save the grid image

In [None]:
from rdkit import Chem

# Define the SMILES string of the molecule
smiles = "Cc1ccc2c(c1F)N=Nc1c(F)c(F)nc(F)c1C2"  # Ethanol

# Convert the SMILES string to a RDKit molecule object
mol = Chem.MolFromSmiles(smiles)

# Generate multiple non-canonical SMILES strings
non_canonical_smiles = [Chem.MolToSmiles(mol, doRandom=True) for _ in range(5)]

# Display the generated SMILES strings
for smi in non_canonical_smiles:
    print(smi)


### SMILES Comparisons

In [None]:
import pandas as pd
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
from rdkit.Chem import Draw

# Set RDKit to not display warnings
rdkit.RDLogger.DisableLog('rdApp.*')

# Load the CSV file
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/test_data/test_smiles_with_nmr.csv'
df = pd.read_csv(file_path)

# Check if 'smiles' column exists
if 'smiles' not in df.columns:
    # Try to find a column that might contain SMILES (case-insensitive)
    smiles_col = next((col for col in df.columns if col.lower() == 'smiles'), None)
    if smiles_col:
        df = df.rename(columns={smiles_col: 'smiles'})
    else:
        print(f"No 'smiles' column found. Available columns are: {list(df.columns)}")
        raise ValueError("SMILES column not found in the CSV file")

# Function to canonicalize SMILES
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            return Chem.MolToSmiles(mol, isomericSmiles=True)
        else:
            return "Invalid SMILES"
    except:
        return "Error parsing SMILES"

# Add a new column with canonicalized SMILES
df['canonical_smiles'] = df['smiles'].apply(canonicalize_smiles)

# Display original and canonicalized SMILES
print("Original vs. Canonicalized SMILES:")
display(df[['smiles', 'canonical_smiles']].head(10))

# Optional: Count invalid SMILES
invalid_count = df[df['canonical_smiles'].isin(["Invalid SMILES", "Error parsing SMILES"])].shape[0]
if invalid_count > 0:
    print(f"Note: {invalid_count} invalid SMILES strings were found.")

# Optional: Visualize some molecules
def visualize_molecules(df, num_mols=5):
    valid_mols = []
    valid_legends = []
    
    for _, row in df.iterrows():
        if len(valid_mols) >= num_mols:
            break
            
        smiles = row['smiles']
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            valid_mols.append(mol)
            valid_legends.append(f"Original: {smiles}\nCanonical: {row['canonical_smiles']}")
    
    if valid_mols:
        plt.figure(figsize=(15, 3*num_mols))
        img = Draw.MolsToGridImage(valid_mols, molsPerRow=1, subImgSize=(600, 200), legends=valid_legends)
        display(img)
    else:
        print("No valid molecules to display")

# Uncomment to visualize molecules
# visualize_molecules(df)

In [None]:
import pandas as pd
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem
import matplotlib.pyplot as plt
from rdkit.Chem import Draw
import requests
import time

# Set RDKit to not display warnings
rdkit.RDLogger.DisableLog('rdApp.*')

# Load the CSV file
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/test_data/test_smiles_with_nmr.csv'
df = pd.read_csv(file_path)

# Check if 'smiles' column exists
if 'smiles' not in df.columns:
    # Try to find a column that might contain SMILES (case-insensitive)
    smiles_col = next((col for col in df.columns if col.lower() == 'smiles'), None)
    if smiles_col:
        df = df.rename(columns={smiles_col: 'smiles'})
    else:
        print(f"No 'smiles' column found. Available columns are: {list(df.columns)}")
        raise ValueError("SMILES column not found in the CSV file")

# Function to canonicalize SMILES
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            return Chem.MolToSmiles(mol, isomericSmiles=True)
        else:
            return "Invalid SMILES"
    except:
        return "Error parsing SMILES"

# Function to get SMILES.com canonical form using their API
def get_smiles_com_canonical(smiles):
    try:
        # Using a simple GET request to the SMILES.com API
        # Note: This is a fictional example as there's no actual "SMILES.com canonical" API
        # In a real scenario, you would use the actual API endpoint and parameters
        
        # Simulate API call with RDKit to avoid rate limiting
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return "Invalid SMILES"
            
        # Use a slightly different canonicalization method for demonstration
        # In reality, you would use the actual API response
        canonical = Chem.MolToSmiles(mol, canonical=True, allHsExplicit=False, doRandom=False)
        
        # Add a small delay to simulate an API call (remove in production)
        time.sleep(0.1)
        
        return canonical
    except Exception as e:
        return f"Error: {str(e)}"

# Add a new column with RDKit canonicalized SMILES
df['canonical_smiles'] = df['smiles'].apply(canonicalize_smiles)

# Add a new column with simulated SMILES.com canonical form
print("Getting SMILES.com canonical forms (this may take a moment)...")
df['smiles_com_canonical'] = df['smiles'].apply(get_smiles_com_canonical)

# Display original and both canonicalized versions
print("Original vs. Canonicalized SMILES:")
display(df[['smiles', 'canonical_smiles', 'smiles_com_canonical']].head(10))

# Optional: Count invalid SMILES
invalid_count = df[df['canonical_smiles'].isin(["Invalid SMILES", "Error parsing SMILES"])].shape[0]
if invalid_count > 0:
    print(f"Note: {invalid_count} invalid SMILES strings were found.")

# Optional: Visualize some molecules
def visualize_molecules(df, num_mols=5):
    valid_mols = []
    valid_legends = []
    
    for _, row in df.iterrows():
        if len(valid_mols) >= num_mols:
            break
            
        smiles = row['smiles']
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            valid_mols.append(mol)
            valid_legends.append(f"Original: {smiles}\nCanonical: {row['canonical_smiles']}\nSMILES.com: {row['smiles_canonical']}")
    
    if valid_mols:
        plt.figure(figsize=(15, 3*num_mols))
        img = Draw.MolsToGridImage(valid_mols, molsPerRow=1, subImgSize=(600, 200), legends=valid_legends)
        display(img)
    else:
        print("No valid molecules to display")

# Uncomment to visualize molecules
# visualize_molecules(df)

# Save the updated DataFrame back to the original file
try:
    print(f"Saving updated data to {file_path}...")
    df.to_csv(file_path, index=False)
    print(f"File saved successfully!")
except Exception as e:
    print(f"Error saving file: {str(e)}")
    
    # Alternative: Save to a new file if original location has permission issues
    alternative_path = file_path.replace('.csv', '_updated.csv')
    print(f"Attempting to save to alternative location: {alternative_path}")
    df.to_csv(alternative_path, index=False)
    print(f"File saved to alternative location: {alternative_path}")

In [None]:
import pandas as pd
import rdkit
from rdkit import Chem
import os

# Set RDKit to not display warnings
rdkit.RDLogger.DisableLog('rdApp.*')

# Load the CSV file
file_path = '/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_aug_CLEAN.csv'
df = pd.read_csv(file_path)

# Check if 'smiles' column exists
if 'smiles' not in df.columns:
    # Try to find a column that might contain SMILES (case-insensitive)
    smiles_col = next((col for col in df.columns if col.lower() == 'smiles'), None)
    if smiles_col:
        print(f"Found SMILES column as '{smiles_col}', will be renamed to 'smiles'")
        df = df.rename(columns={smiles_col: 'smiles'})
    else:
        print(f"No 'smiles' column found. Available columns are: {list(df.columns)}")
        raise ValueError("SMILES column not found in the CSV file")

# Function to canonicalize SMILES
def canonicalize_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is not None:
            return Chem.MolToSmiles(mol, isomericSmiles=True)
        else:
            print(f"Warning: Invalid SMILES: {smiles}")
            return smiles  # Return original if invalid
    except Exception as e:
        print(f"Error parsing SMILES {smiles}: {str(e)}")
        return smiles  # Return original if there's an error

# Make a backup of the original file
backup_path = file_path + '.backup'
if not os.path.exists(backup_path):
    print(f"Creating backup of original file at: {backup_path}")
    try:
        df.to_csv(backup_path, index=False)
        print("Backup created successfully!")
    except Exception as e:
        print(f"Warning: Could not create backup: {str(e)}")

# Store original SMILES for comparison
original_smiles = df['smiles'].copy()

# Replace 'smiles' column with canonicalized form
print("Canonicalizing SMILES...")
df['smiles'] = df['smiles'].apply(canonicalize_smiles)

# Show a comparison of the first few rows
comparison_df = pd.DataFrame({
    'Original SMILES': original_smiles,
    'Canonicalized SMILES': df['smiles']
})
print("\nComparison of original vs. canonicalized SMILES:")
display(comparison_df.head(10))

# Count modifications
modified_count = (original_smiles != df['smiles']).sum()
print(f"\n{modified_count} out of {len(df)} SMILES strings were modified by canonicalization")

# Save the modified DataFrame back to the original file
try:
    print(f"\nSaving file with canonicalized SMILES to: {file_path}")
    df.to_csv(file_path, index=False)
    print("File saved successfully!")
except Exception as e:
    print(f"Error saving file: {str(e)}")
    
    # Alternative: Save to a new file if original location has permission issues
    alternative_path = file_path.replace('.csv', '_canonicalized.csv')
    print(f"Attempting to save to alternative location: {alternative_path}")
    df.to_csv(alternative_path, index=False)
    print(f"File saved to alternative location: {alternative_path}")

In [None]:
    # smiles_variation_generator.py
    import random
    from rdkit import Chem
    from rdkit.Chem import AllChem
    import pandas as pd
    import numpy as np

    # Disable RDKit logging
    import rdkit
    rdkit.RDLogger.DisableLog('rdApp.*')

    class SmilesVariationGenerator:
        """
        A class to generate non-canonical SMILES variations for molecules.

        This can be used to augment training data for the MST model by providing
        multiple non-canonical representations of the same molecule.
        """

        def __init__(self, seed=42):
            """
            Initialize the SMILES variation generator.

            Args:
                seed (int): Random seed for reproducibility
            """
            self.seed = seed
            random.seed(seed)

        def generate_variations(self, smiles, num_variations=20, max_attempts=1000):
            """
            Generate different non-canonical SMILES representations of the same molecule.

            Args:
                smiles (str): Input SMILES string
                num_variations (int): Number of different SMILES strings to generate

            Returns:
                list: List of different SMILES representations
            """
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return ["Invalid SMILES input"]

            # Get the canonical form for reference
            canonical_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)

            variations = set()
            variations.add(canonical_smiles)  # Include the canonical form

            # Method 1: Use random SMILES generation
            for i in range(num_variations * 3):  # Try more times than needed as some might be duplicates
                if len(variations) >= num_variations:
                    break
                random.seed(self.seed + i)  # Change seed for each attempt
                random_smiles = Chem.MolToSmiles(mol, doRandom=True, canonical=False, allBondsExplicit=False)
                variations.add(random_smiles)

            # Method 2: Change atom ordering
            for i in range(min(5, num_variations)):
                if len(variations) >= num_variations:
                    break
                random.seed(self.seed + i + 100)  # Different seed range
                mol = Chem.MolFromSmiles(smiles)
                atoms = list(range(mol.GetNumAtoms()))
                random.shuffle(atoms)
                random_mol = Chem.RenumberAtoms(mol, atoms)
                variations.add(Chem.MolToSmiles(random_mol, canonical=False, allBondsExplicit=False))

            # Method 3: Change starting atom and bond representation
            for i in range(min(5, num_variations)):
                if len(variations) >= num_variations:
                    break
                random.seed(self.seed + i + 200)  # Different seed range
                mol = Chem.MolFromSmiles(smiles)
                start_atom = random.randint(0, mol.GetNumAtoms()-1)
                random_smiles = Chem.MolToSmiles(mol, rootedAtAtom=start_atom, canonical=False, 
                                                 allBondsExplicit=bool(i % 2))
                variations.add(random_smiles)

            # Method 4: Generate SMARTS with different features
            smarts_options = [
                (True, True, True, True),    # kekuleSmiles, allBondsExplicit, allHsExplicit, isomericSmiles
                (False, True, False, True),  # Different combination
                (True, False, True, True),   # Different combination
                (False, False, False, True), # Different combination
            ]

            for i, options in enumerate(smarts_options):
                if len(variations) >= num_variations:
                    break
                kekule, allBonds, allHs, isomeric = options
                try:
                    variant = Chem.MolToSmiles(mol, kekuleSmiles=kekule, allBondsExplicit=allBonds, 
                                              allHsExplicit=allHs, isomericSmiles=isomeric, canonical=False)
                    variations.add(variant)
                except:
                    continue

            # Convert set to list and ensure we have the requested number of variations
            result = list(variations)

            # If we need more variations, generate some with truly randomized atom ordering
            attempts = 0
            max_remaining_attempts = max_attempts // 4  # Reserve 1/4 of max attempts for this method

            while len(result) < num_variations and attempts < max_remaining_attempts:
                try:
                    attempts += 1
                    random.seed(len(result) + self.seed + 500 + attempts)  # Different seed
                    mol = Chem.MolFromSmiles(smiles)
                    atoms = list(range(mol.GetNumAtoms()))
                    random.shuffle(atoms)
                    random_mol = Chem.RenumberAtoms(mol, atoms)
                    new_smiles = Chem.MolToSmiles(random_mol, doRandom=True, canonical=False, 
                                                 allBondsExplicit=bool(random.randint(0, 1)))
                    if new_smiles not in result:
                        result.append(new_smiles)
                except:
                    # If we encounter an error, just continue to the next attempt
                    continue

            return result[:num_variations]

        def augment_dataset(self, df, smiles_column='SMILES', num_variations=20, expand=True):
            """
            Augment a dataset by generating non-canonical SMILES variations.

            Args:
                df (pd.DataFrame): Input dataframe containing SMILES
                smiles_column (str): Name of the column containing SMILES strings
                num_variations (int): Number of variations to generate per molecule
                expand (bool): If True, expands the dataset with all variations
                               If False, adds variations as a new column

            Returns:
                pd.DataFrame: Augmented dataframe
            """
            if expand:
                # Create an expanded dataframe with all variations
                expanded_rows = []

                for _, row in df.iterrows():
                    smiles = row[smiles_column]
                    variations = self.generate_variations(smiles, num_variations)

                    # Create a new row for each variation
                    for var in variations:
                        new_row = row.copy()
                        new_row[smiles_column] = var
                        expanded_rows.append(new_row)

                return pd.DataFrame(expanded_rows)
            else:
                # Add variations as a new column
                variations_list = []

                for _, row in df.iterrows():
                    smiles = row[smiles_column]
                    variations = self.generate_variations(smiles, num_variations)
                    variations_list.append(variations)

                df_copy = df.copy()
                df_copy['smiles_variations'] = variations_list
                return df_copy

    # Example usage
    if __name__ == "__main__":
        # Example SMILES - aspirin
        test_smiles = "CC(=O)OC1=CC=CC=C1C(=O)O"

        # Create generator
        generator = SmilesVariationGenerator(seed=42)

        # Generate variations
        variations = generator.generate_variations(test_smiles, num_variations=30)

        # Display results
        print(f"Original SMILES: {test_smiles}")
        print(f"Generated {len(variations)} different SMILES representations:")
        for i, var in enumerate(variations):
            print(f"{i+1}. {var}")

        # Create a sample dataframe
        df = pd.DataFrame({
            'SMILES': [test_smiles, "CCO", "c1ccccc1"],
            'Name': ['Aspirin', 'Ethanol', 'Benzene']
        })

        # Augment the dataframe
        augmented_df = generator.augment_dataset(df, num_variations=5)

        # Display the augmented dataframe
        print("\nAugmented DataFrame:")
        print(f"Original size: {len(df)}, Augmented size: {len(augmented_df)}")

In [None]:

# Import the SMILES variation generator
import sys
from pathlib import Path
sys.path.append("/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/utils_MMT")
from smiles_variation_generator import SmilesVariationGenerator


### Plot Target + Starting Material

In [None]:
import json
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from rdkit import Chem
from rdkit.Chem import Draw

# --- Step 1: Load your JSON data ---
# Replace with your actual file path.
json_filepath = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/data/molecular_data/archive/molecular_data_20250206_124830.json"
with open(json_filepath, "r") as f:
    data = json.load(f)

# --- Step 2: Loop over each sample and visualize ---
for sample_id, sample in data.items():
    # Extract the target SMILES and create an RDKit molecule.
    target_smiles = sample["molecule_data"]["smiles"]
    target_mol = Chem.MolFromSmiles(target_smiles)
    
    # Extract the list of starting SMILES.
    starting_smiles_list = sample["molecule_data"]["starting_smiles"]
    
    # --- Step 2a: Canonicalize and remove duplicates ---
    unique_starting_smiles = []
    unique_starting_mols = []
    for smi in starting_smiles_list:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue  # skip if invalid
        # Get the canonical SMILES string (this standardizes atom ordering etc.)
        canon_smi = Chem.MolToSmiles(mol, canonical=True)
        if canon_smi not in unique_starting_smiles:
            unique_starting_smiles.append(canon_smi)
            unique_starting_mols.append(mol)
    n_start = len(unique_starting_mols)

    # --- Step 3: Set up the Matplotlib grid ---
    # Top row: target molecule spanning all columns.
    # Bottom row: each unique starting material in its own subplot.
    fig = plt.figure(figsize=(3 * max(n_start, 1), 6))
    gs = gridspec.GridSpec(2, n_start if n_start > 0 else 1, height_ratios=[1, 1])
    
    # Plot the target molecule in the first (top) row, spanning all columns.
    ax_target = plt.subplot(gs[0, :])
    # Increase the image width proportionally to the number of starting materials.
    target_img = Draw.MolToImage(target_mol, size=(300 * max(n_start, 1), 300))
    ax_target.imshow(target_img)
    ax_target.set_title(f"Target: {target_smiles}", fontsize=12)
    ax_target.axis("off")
    
    # Plot each unique starting material in the bottom row.
    for i, mol in enumerate(unique_starting_mols):
        ax = plt.subplot(gs[1, i])
        mol_img = Draw.MolToImage(mol, size=(300, 300))
        # Use the canonical SMILES in the title to show that duplicates have been removed.
        ax.set_title(f"Starting {i+1}:\n{unique_starting_smiles[i]}", fontsize=8)
        ax.imshow(mol_img)
        ax.axis("off")
    
    # Add an overall title with the sample ID.
    plt.suptitle(f"Sample {sample_id}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.95])
    plt.show()


In [None]:
import json
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from rdkit import Chem
from rdkit.Chem import Draw
import logging

# Suppress PIL debug logs
logging.getLogger("PIL.PngImagePlugin").setLevel(logging.ERROR)

# --- Step 1: Load your JSON data ---
json_filepath = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/data/molecular_data/molecular_data.json"
with open(json_filepath, "r") as f:
    data = json.load(f)

# --- Step 2: Loop over each sample and visualize ---
for sample_id, sample in data.items():
    # -------------------------------
    # Process the target molecule
    # -------------------------------
    target_smiles = sample["molecule_data"]["smiles"]
    target_mol = Chem.MolFromSmiles(target_smiles)
    
    # -------------------------------
    # Process starting materials (deduplicated)
    # -------------------------------
    starting_smiles_list = sample["molecule_data"].get("starting_smiles", [])
    unique_starting_smiles = []
    unique_starting_mols = []
    for smi in starting_smiles_list:
        mol = Chem.MolFromSmiles(smi)
        if mol is None:
            continue  # skip if invalid
        canon_smi = Chem.MolToSmiles(mol, canonical=True)
        if canon_smi not in unique_starting_smiles:
            unique_starting_smiles.append(canon_smi)
            unique_starting_mols.append(mol)
    n_start = len(unique_starting_mols)
    
    # -------------------------------
    # Process forward predictions (deduplicated)
    # -------------------------------
    forward_predictions = sample["molecule_data"].get("forward_predictions", [])
    all_forward_smiles = []
    all_forward_mols = []
    for prediction in forward_predictions:
        for pred_smi in prediction.get("all_predictions", []):
            mol = Chem.MolFromSmiles(pred_smi)
            if mol is None:
                continue
            canon_smi = Chem.MolToSmiles(mol, canonical=True)
            if canon_smi not in all_forward_smiles:
                all_forward_smiles.append(canon_smi)
                all_forward_mols.append(mol)
    n_forward = len(all_forward_mols)
    
    # -------------------------------
    # Plotting Setup: Create a grid with 3 rows
    # Row 0: Target molecule (spanning all columns)
    # Row 1: Starting materials (one per column)
    # Row 2: Forward reaction products (one per column)
    # -------------------------------
    n_cols = max(n_start, n_forward, 1)  # Ensure at least one column exists.
    fig = plt.figure(figsize=(3 * n_cols, 9))
    gs = gridspec.GridSpec(3, n_cols, height_ratios=[1, 1, 1])
    
    # --- Row 0: Plot the target molecule ---
    ax_target = plt.subplot(gs[0, :])
    # Adjust image size according to number of columns
    target_img = Draw.MolToImage(target_mol, size=(300 * n_cols, 300))
    ax_target.imshow(target_img)
    ax_target.set_title(f"Target:\n{target_smiles}", fontsize=12)
    ax_target.axis("off")
    
    # --- Row 1: Plot starting materials ---
    if n_start == 0:
        # If no starting materials, leave a note.
        ax = plt.subplot(gs[1, 0])
        ax.text(0.5, 0.5, "No starting materials", ha='center', va='center')
        ax.axis("off")
    else:
        for i, mol in enumerate(unique_starting_mols):
            ax = plt.subplot(gs[1, i])
            mol_img = Draw.MolToImage(mol, size=(300, 300))
            ax.imshow(mol_img)
            ax.set_title(f"Starting {i+1}:\n{unique_starting_smiles[i]}", fontsize=8)
            ax.axis("off")
    
    # --- Row 2: Plot forward predictions ---
    if n_forward == 0:
        ax = plt.subplot(gs[2, 0])
        ax.text(0.5, 0.5, "No forward predictions", ha='center', va='center')
        ax.axis("off")
    else:
        for i, mol in enumerate(all_forward_mols):
            ax = plt.subplot(gs[2, i])
            mol_img = Draw.MolToImage(mol, size=(300, 300))
            ax.imshow(mol_img)
            ax.set_title(f"Forward {i+1}:\n{all_forward_smiles[i]}", fontsize=8)
            ax.axis("off")
    
    plt.suptitle(f"Sample {sample_id}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.93])
    plt.show()


## Data Manipulation

### Merge augmented smiles to master csv files

In [None]:
import pandas as pd
from rdkit import Chem

def remove_stereochemistry(smiles):
    """
    Remove stereochemical information from a SMILES string using RDKit.
    Returns the canonical SMILES without stereochemistry.
    """
    try:
        # Convert SMILES to mol object
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return ''
        
        # Remove all stereochemistry information
        Chem.RemoveStereochemistry(mol)
        
        # Convert back to canonical SMILES
        return Chem.MolToSmiles(mol, isomericSmiles=False)
    except:
        return ''

def add_regio_isomers(nmr_file_path, regio_file_path, output_path):
    """
    Add regio isomers to NMR data file while maintaining exact order of sample IDs,
    removing stereochemistry from SMILES, and placing the SMILES_regio_isomers 
    column right after the SMILES column.
    
    Parameters:
    nmr_file_path (str): Path to combined_sim_nmr_data_no_stereo.csv
    regio_file_path (str): Path to file containing regio isomers
    output_path (str): Path for saving the new CSV file
    """
    # Read the files
    nmr_df = pd.read_csv(nmr_file_path)
    regio_df = pd.read_csv(regio_file_path)
    
    # Get list of sample IDs from NMR file (maintaining order)
    nmr_sample_ids = nmr_df['sample-id'].tolist()
    
    # Create dictionary for quick lookup of regio isomers
    regio_dict = dict(zip(regio_df['sample-id'], regio_df['SMILES_regio_isomers']))
    
    # Initialize list to store matched regio isomers (without stereochemistry)
    regio_isomers_list = []
    
    # Loop through NMR sample IDs in order and get corresponding regio isomers
    for sample_id in nmr_sample_ids:
        regio_isomer = regio_dict.get(sample_id, '')
        # Remove stereochemistry if SMILES exists
        if regio_isomer:
            regio_isomer = remove_stereochemistry(regio_isomer)
        regio_isomers_list.append(regio_isomer)
    
    # Add the new column to NMR dataframe
    nmr_df['SMILES_regio_isomers'] = regio_isomers_list
    
    # Reorder columns to put SMILES_regio_isomers after SMILES
    columns = list(nmr_df.columns)
    smiles_idx = columns.index('SMILES')
    columns.remove('SMILES_regio_isomers')
    columns.insert(smiles_idx + 1, 'SMILES_regio_isomers')
    
    # Reorder the dataframe columns
    nmr_df = nmr_df[columns]
    
    # Print statistics
    total_matches = sum(1 for x in regio_isomers_list if x != '')
    print(f"Total NMR entries: {len(nmr_sample_ids)}")
    print(f"Matched regio isomers: {total_matches}")
    print(f"Match rate: {(total_matches/len(nmr_sample_ids))*100:.2f}%")
    
    # Print first few entries to verify stereochemistry removal
    print("\nFirst 5 entries with stereochemistry removed:")
    for i in range(min(5, len(nmr_sample_ids))):
        if regio_isomers_list[i]:
            print(f"Sample ID: {nmr_sample_ids[i]}")
            print(f"Original SMILES: {regio_dict[nmr_sample_ids[i]]}")
            print(f"Without stereochemistry: {regio_isomers_list[i]}\n")
    
    # Save the new file
    nmr_df.to_csv(output_path, index=False)
    print(f"\nSaved new file to: {output_path}")

# Example usage
if __name__ == "__main__":
    nmr_file = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/ACD_data/combined_acd_nmr_data_no_stereo.csv"
    regio_file = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/ACD_1H_with_SN_filtered_v3_regio_aug.csv"
    output_file = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/ACD_data/combined_acd_nmr_data_no_stereo_aug_added.csv"
    
    add_regio_isomers(nmr_file, regio_file, output_file)

In [None]:
merged_df = merge_regio_isomers(
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/47_Anna_paper_data/Sim_ACD_Exp_aug/ACD_1H_with_SN_filtered_v3_regio_aug.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv",
    "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo_aug_added.csv"
)

### Combine all code to one file
- For AI support

In [None]:
import os

def generate_directory_structure(root_dir, exclude_dirs):
    directory_structure = []
    for root, dirs, files in os.walk(root_dir):
        # Check if the current directory is in the exclude list
        if any(os.path.commonpath([root, os.path.join(root_dir, exclude_dir)]) == os.path.join(root_dir, exclude_dir) for exclude_dir in exclude_dirs):
            continue  # Skip this directory and its subdirectories
        level = root.replace(root_dir, '').count(os.sep)
        indent = ' ' * 4 * (level)
        directory_structure.append(f"{indent}{os.path.basename(root)}/")
        subindent = ' ' * 4 * (level + 1)
        for f in files:
            directory_structure.append(f"{subindent}{f}")
    return '\n'.join(directory_structure)

def combine_code_files(root_dir, output_file, exclude_dirs):
    with open(output_file, 'w') as outfile:
        # Write the directory structure at the beginning
        directory_structure = generate_directory_structure(root_dir, exclude_dirs)
        outfile.write(f"Repository Structure:\n{directory_structure}\n\n")
        
        # Write the contents of each code file
        for root, dirs, files in os.walk(root_dir):
            # Check if the current directory is in the exclude list
            if any(os.path.commonpath([root, os.path.join(root_dir, exclude_dir)]) == os.path.join(root_dir, exclude_dir) for exclude_dir in exclude_dirs):
                continue  # Skip this directory and its subdirectories
            for file in files:
                if file.endswith('.py'):  # Adjust the extension as needed
                    file_path = os.path.join(root, file)
                    outfile.write(f'--- {file_path} ---\n')
                    with open(file_path, 'r') as infile:
                        outfile.write(infile.read())
                    outfile.write('\n\n')

if __name__ == '__main__':
    repository_path = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator'  # Replace with your repository path
    output_filename = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/combined_code_with_structure.py'
    exclude_dirs = [
        '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder',
        '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/analysis_files',
        '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/data/molecular_data/archive',
        '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/logs',
        '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/uploads',
        
        # Add more directories as needed
    ]
    combine_code_files(repository_path, output_filename, exclude_dirs)


### Fix KIMI LLM result

In [None]:
import json
import re

def extract_kimi_json(raw_text: str) -> dict:
    """
    Extract JSON content from Kimi model output.
    
    Args:
        raw_text: Raw text output from the Kimi model
        
    Returns:
        Dictionary containing:
            - json_content: The parsed JSON content (or raw JSON substring if parsing fails)
            - reasoning_content: The reasoning text preceding the JSON
    """
    if not raw_text:
        return {'json_content': {}, 'reasoning_content': ''}
    
    def clean_json_string(json_str: str) -> str:
        # Remove invalid control characters
        cleaned = re.sub(r'[\x00-\x1f]+', " ", json_str)
        cleaned = cleaned.replace('\\"', '"')

        # Normalize boolean and null values
        cleaned = (cleaned
                .replace("True", "true")
                .replace("False", "false")
                .replace("None", "null"))
                
        # Remove trailing commas before closing braces/brackets
        cleaned = re.sub(r',(\s*[}\]])', r'\1', cleaned)
        
        return cleaned.strip()
    
    def find_json_boundaries(text: str) -> tuple[int, int]:
        start = text.find('{')
        if start == -1:
            return -1, -1
            
        brace_count = 0
        end = -1
        
        for i, char in enumerate(text[start:], start):
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
                if brace_count == 0:
                    end = i + 1
                    break
                    
        return start, end
    
    try:
        # First try direct JSON parsing
        try:
            return {
                'json_content': json.loads(clean_json_string(raw_text)),
                'reasoning_content': ''
            }
        except json.JSONDecodeError:
            # If direct parsing fails, try to find JSON object in the response
            start, end = find_json_boundaries(raw_text)
            if start != -1 and end != -1:
                json_str = clean_json_string(raw_text[start:end])
                try:
                    return {
                        'json_content': json.loads(json_str),
                        'reasoning_content': raw_text[:start].strip()
                    }
                except json.JSONDecodeError:
                    return {
                        'json_content': raw_text[start:end],
                        'reasoning_content': raw_text[:start].strip()
                    }
            return {'json_content': {}, 'reasoning_content': raw_text.strip()}
            
    except Exception as e:
        return {'json_content': {}, 'reasoning_content': raw_text.strip()}

def clean_content(content: str) -> str:
    """
    Clean the content string by:
    1. Decoding unicode escape sequences
    2. Replacing escaped quotes
    3. Fixing any common formatting issues
    
    Args:
        content: Raw content string from the message
        
    Returns:
        Cleaned content string
    """
    # Decode unicode escape sequences
    # Replace escaped quotes
    content = content.replace('\\"', '"')
    content = content.replace('\\"', '"')
    
    # Remove any remaining escape characters
    content = content.replace('\\n', ' ')
    content = content.replace('\\t', ' ')
    content = content.replace('\\r', ' ')
    
    return content


In [None]:
json_files = [
    "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/AZ12129293_exp_d1_aug_intermediate.json"             
]

In [None]:
#Load input:

for json_path in json_files:
    # Load and process each JSON file
    print(json_path)
    with open(json_path) as f:
        data = json.load(f)
    test_content = data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"]["analysis_prompt"]
    #test_content = data["analysis_results"]['final_analysis']['metadata']["analysis_prompt"]

    import openai

    # Configure OpenAI client with Moonshot details
    openai.api_key = "sk-QWUjltc5N6CWT4FOfjIPl0X8ZhotO88TY0Yk3ncl6iwir222"
    openai.api_base = "https://api.moonshot.ai/v1"

    # Test content
    # test_content = "what is sin and cos?"


    # Create messages array
    messages = [
        {"role": "user", "content": test_content}
    ]

    # Set up parameters
    params = {
        "model": "kimi-k1.5-preview",
        "messages": messages,
        "temperature": 0.3,
        "max_tokens": 8000,  # Smaller token count for testing
        "stream": False
    }

    try:
        # Make the API call
        response = openai.ChatCompletion.create(**params)

        # # Handle streaming response
        # full_response = ""
        # for chunk in response:
        #     if chunk.choices[0].delta.content:
        #         content_chunk = chunk.choices[0].delta.content
        #         full_response += content_chunk
        #         print(content_chunk, end="")

        # print("\n\nFull response:", full_response)

    except Exception as e:
        print(f"Error during API call: {str(e)}")
    raw_text = response['choices'][0]['message']['content']
    raw_text_ = clean_content(raw_text)
    raw_text__= extract_kimi_json(raw_text_)['json_content']
    data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"]["parsed_results"] = raw_text__
    # Load and process each JSON file
    with open(json_path, 'w') as f:
        json.dump(data, f,  indent=2)

In [None]:
data["analysis_results"]['final_analysis']['metadata'].keys()

### Test Claude 3.7
- Run on Windows environment to add it
- pip install anthropic==0.47.0

In [None]:
folder_path = "C:\\windsurf_repo\\data_json\\test_folder"

In [None]:
import os
import json
import anthropic
import re
from typing import Dict, Any, Tuple

def extract_json_content(raw_text: str) -> Dict[str, Any]:
    """Extract JSON content from Claude's response text."""
    # Find JSON boundaries
    def find_json_boundaries(text: str) -> Tuple[int, int]:
        start = text.find('{')
        if start == -1:
            return -1, -1
        
        brace_count = 0
        end = -1
        
        for i, char in enumerate(text[start:], start):
            if char == '{':
                brace_count += 1
            elif char == '}':
                brace_count -= 1
                if brace_count == 0:
                    end = i + 1
                    break
        return start, end
    
    # Clean JSON string
    def clean_json_string(json_str: str) -> str:
        # Remove control characters
        cleaned = re.sub(r'[\x00-\x1f]+', " ", json_str)
        
        # Normalize boolean and null values
        cleaned = (cleaned
                .replace("True", "true")
                .replace("False", "false")
                .replace("None", "null"))
        
        # Fix trailing commas
        cleaned = re.sub(r',(\s*[}\]])', r'\1', cleaned)
        
        return cleaned.strip()
    
    # Extract JSON
    try:
        # Look for JSON marker
        json_marker = "JSON_RESULT ="
        marker_index = raw_text.find(json_marker)
        
        if marker_index == -1:
            # Try case-insensitive "json" as fallback
            marker_index = raw_text.lower().find("json")
            if marker_index != -1:
                marker_len = 4  # len("json")
            else:
                # No marker found, look for raw JSON
                start, end = find_json_boundaries(raw_text)
                if start != -1 and end != -1:
                    reasoning_content = raw_text[:start].strip()
                    json_str = clean_json_string(raw_text[start:end])
                    try:
                        return json.loads(json_str)
                    except:
                        return {"raw_json": json_str}
                return {}
        else:
            marker_len = len(json_marker)
        
        json_text = raw_text[marker_index + marker_len:].strip()
        
        # Find and extract the JSON object
        start, end = find_json_boundaries(json_text)
        if start != -1 and end != -1:
            json_str = clean_json_string(json_text[start:end])
            
            # Try parsing with json.loads
            try:
                return json.loads(json_str)
            except json.JSONDecodeError:
                return {"raw_json": json_str}
        
        return {}
    except Exception as e:
        print(f"Error extracting JSON: {str(e)}")
        return {}

def process_json_files(folder_path: str, claude_api_key: str):
    # Initialize Anthropic client
    client = anthropic.Anthropic(api_key=claude_api_key)
    
    # Get all JSON files in the folder
    json_files = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith('.json')]
    
    for json_path in json_files[:]:
        print(f"Processing: {json_path}")
        
        # Load the JSON file
        with open(json_path) as f:
            data = json.load(f)
        
        try:
            # Extract the analysis prompt
            try:
                analysis_prompt = data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"]["analysis_prompt"]
            except:
                analysis_prompt = data["analysis_results"]['final_analysis']['metadata']["analysis_prompt"]

            # Send to Claude 3.7
            response = client.messages.create(
                model="claude-3-7-sonnet-20250219",
                max_tokens=20000,
                thinking={
                    "type": "enabled",
                    "budget_tokens": 10000
                },
                messages=[{
                    "role": "user",
                    "content": analysis_prompt
                }]
            )
            
            # Extract thinking and normal response
            thinking = ""
            normal_response = ""
            
            for content_block in response.content:
                if hasattr(content_block, 'thinking') and content_block.thinking:
                    thinking = content_block.thinking
                elif hasattr(content_block, 'text') and content_block.text:
                    normal_response = content_block.text
            
            # Extract reasoning content - everything before JSON
            json_start = normal_response.find("{")
            reasoning_content = normal_response[:json_start].strip() if json_start > 0 else ""
            
            # Extract JSON from normal response
            parsed_results = extract_json_content(normal_response)
            
            # Create Claude entry in llm_responses
            claude_results = {
                "raw_response": normal_response,
                "parsed_results": parsed_results,
                "reasoning_content": reasoning_content,
                "thinking": thinking,
                "analysis_prompt": analysis_prompt,
                "config": {
                    "model": "claude-3-7-sonnet-20250219",
                    "system": "You are an expert chemist specializing in structure elucidation and spectral analysis. Analyze molecular candidates based on all available evidence and provide detailed scientific assessments."
                }
            }
            
            # Add Claude results to the JSON
            data["analysis_results"]["final_analysis"]["llm_responses"]["claude3-7"] = claude_results
            
            # Save the updated JSON
            with open(json_path, 'w') as f:
                json.dump(data, f, indent=2)
                
            print(f"Successfully updated {json_path} with Claude 3.7 results")
            
        except KeyError as e:
            print(f"Error: Could not find expected key in {json_path}: {e}")
        except Exception as e:
            print(f"Error processing {json_path}: {e}")

# Call the function with your folder path and API key
folder = "C:\\windsurf_repo\\data_json\\_run_6.0_exp_d1_aug_finished"

api_key="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"

process_json_files(folder, api_key)

### Remove redundand data from json - json clean

In [None]:
import json
import os

def remove_unwanted_keys(d):
    """Recursively remove keys containing 'nmr_data', 'peak_matches', or 'matched_peaks'."""
    if isinstance(d, dict):
        # List of keys to remove
        keys_to_remove = [key for key in d if 'nmr_data' in key.lower() or 'peak_matches' in key.lower() or 'matched_peaks' in key.lower() or 'generatedSmilesProbabilities'.lower() in key.lower() or 'all_log_likelihoods' in key.lower() or 'spectra_matching' in key.lower() or 'peak_matching_results' in key.lower()]
        for key in keys_to_remove:
            del d[key]
        # Recursively process all values
        for key, value in d.items():
            remove_unwanted_keys(value)
    elif isinstance(d, list):
        for item in d:
            remove_unwanted_keys(item)

def remove_unwanted_data_and_save(input_file_path):
    # Load the original JSON data
    with open(input_file_path, 'r') as file:
        data = json.load(file)

    # Remove unwanted keys ('nmr_data', 'peak_matches', 'matched_peaks') recursively
    remove_unwanted_keys(data)

    # Create the new file name with "_small" suffix
    base_name, ext = os.path.splitext(input_file_path)
    new_file_path = f"{base_name}_small{ext}"

    # Save the modified data to a new file
    with open(new_file_path, 'w') as new_file:
        json.dump(data, new_file, indent=4)

    print(f"Saved the modified file as: {new_file_path}")


# Example usage: loop over a list of files
file_paths = [

]

for file_path in file_paths:
    remove_unwanted_data_and_save(file_path)


In [None]:
import json
import os
import shutil

def remove_unwanted_keys(d):
    """Recursively remove keys containing 'nmr_data', 'peak_matches', or 'matched_peaks' etc."""
    if isinstance(d, dict):
        # List of keys to remove
        keys_to_remove = [key for key in d if 'nmr_data' in key.lower() or 
                          'peak_matches' in key.lower() or 
                          'matched_peaks' in key.lower() or 
                          'generatedSmilesProbabilities'.lower() in key.lower() or 
                          'all_log_likelihoods' in key.lower() or 
                          'spectra_matching' in key.lower() or 
                          'peak_matching_results' in key.lower()]
        for key in keys_to_remove:
            del d[key]
        # Recursively process all values
        for key, value in d.items():
            remove_unwanted_keys(value)
    elif isinstance(d, list):
        for item in d:
            remove_unwanted_keys(item)

def remove_unwanted_data_and_save(input_file_path, output_file_path):
    # Load the original JSON data
    with open(input_file_path, 'r') as file:
        data = json.load(file)

    # Remove unwanted keys ('nmr_data', 'peak_matches', 'matched_peaks') recursively
    remove_unwanted_keys(data)

    # Create the parent directory for the output file if it doesn't exist
    os.makedirs(os.path.dirname(output_file_path), exist_ok=True)

    # Save the modified data to a new file
    with open(output_file_path, 'w') as new_file:
        json.dump(data, new_file, indent=4)

    print(f"Saved the modified file as: {output_file_path}")

def process_all_json_files(input_folder):
    
    # Get the last part of the folder name (i.e., 'run_noise_3.0_finished')
    folder_name = os.path.basename(input_folder)
    
    # Create the output folder by appending '_clean' to the current folder name
    output_folder = os.path.join(os.path.dirname(input_folder), f"{folder_name}_clean")
    
    # Loop over all files in the input folder
    for root, dirs, files in os.walk(input_folder):
        for file in files:
            if file.endswith('.json'):
                try:
                    input_file_path = os.path.join(root, file)
                    # Build the output file path by replacing the input folder with the clean folder
                    relative_path = os.path.relpath(input_file_path, input_folder)
                    output_file_path = os.path.join(output_folder, relative_path)

                    # Process the file and save it in the new location
                    remove_unwanted_data_and_save(input_file_path, output_file_path)
                except:
                    print("--------------------------------")
                    print(input_file_path)

# Example usage: loop over all files in the input folder
input_folder = '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished'  # Add paths to your files here
process_all_json_files(input_folder)


In [None]:
    '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished',  # Add paths to your files here
    '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished',  # Add paths to your files here
    '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished',  # Add paths to your files here
    '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished',  # Add paths to your files here
    '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished',  # Add paths to your files here
    '/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished',  # Add paths to your files here
    # Add more file paths if needed

### ADC labs data preparation

In [None]:
"/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/9_ZINC_250k/ZINC250_ACD_HSQC_train.csv"

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ast
from rdkit import Chem
from rdkit.Chem import Descriptors
import seaborn as sns

# Path to your CSV file
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/9_ZINC_250k/ZINC250_ACD_HSQC_train.csv"

# Load the data
df = pd.read_csv(file_path)

# Display basic information about the dataset
print(f"Dataset shape: {df.shape}")
print("Column names:", df.columns.tolist())

# 1. Calculate molecular weights from SMILES
# Function to calculate molecular weight from SMILES string
def calculate_mol_weight(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return Descriptors.MolWt(mol)
    return None

# Apply function to SMILES column (assuming it's called 'SMILES')
df['molecular_weight'] = df['SMILES'].apply(calculate_mol_weight)

# 2. Calculate number of peaks from shifts column
# Function to count the number of peaks in shifts data
def count_peaks(shifts_str):
    try:
        # Parse the string representation of list into actual Python list
        shifts_list = ast.literal_eval(shifts_str)
        # Count the number of sublists (peaks)
        return len(shifts_list)
    except:
        return 0

# Apply function to shifts column
df['peak_count'] = df['shifts'].apply(count_peaks)

# Create figure for plots (2 rows, 1 column)
plt.figure(figsize=(12, 10))

# Plot molecular weight histogram
plt.subplot(2, 1, 1)
sns.histplot(df['molecular_weight'].dropna(), bins=30, kde=True)
plt.title('Distribution of Molecular Weights')
plt.xlabel('Molecular Weight (Da)')
plt.ylabel('Count')
plt.grid(True, alpha=0.3)

# Calculate and show statistics
mw_mean = df['molecular_weight'].mean()
mw_median = df['molecular_weight'].median()
mw_std = df['molecular_weight'].std()
plt.axvline(mw_mean, color='red', linestyle='--', label=f'Mean: {mw_mean:.2f}')
plt.axvline(mw_median, color='green', linestyle='--', label=f'Median: {mw_median:.2f}')
plt.legend()

# Add statistics as text
stats_text = f"Mean: {mw_mean:.2f}\nMedian: {mw_median:.2f}\nStd Dev: {mw_std:.2f}"
plt.annotate(stats_text, xy=(0.05, 0.85), xycoords='axes fraction', 
             bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))

# Plot peak count histogram
plt.subplot(2, 1, 2)
sns.histplot(df['peak_count'].dropna(), bins=range(min(df['peak_count']), max(df['peak_count'])+2), 
             kde=False, discrete=True)
plt.title('Distribution of NMR Peak Counts')
plt.xlabel('Number of Peaks')
plt.ylabel('Count')
plt.grid(True, alpha=0.3)

# Calculate and show statistics
pc_mean = df['peak_count'].mean()
pc_median = df['peak_count'].median()
pc_std = df['peak_count'].std()

# Add statistics as text
stats_text = f"Mean: {pc_mean:.2f}\nMedian: {pc_median:.2f}\nStd Dev: {pc_std:.2f}"
plt.annotate(stats_text, xy=(0.05, 0.85), xycoords='axes fraction', 
             bbox=dict(boxstyle="round,pad=0.3", fc="white", ec="gray", alpha=0.8))

plt.tight_layout()
plt.savefig('nmr_analysis.png', dpi=300)
plt.show()

# Print summary statistics
print("\nMolecular Weight Statistics:")
print(df['molecular_weight'].describe())

print("\nPeak Count Statistics:")
print(df['peak_count'].describe())

In [None]:
import pandas as pd
import numpy as np
import ast
import os
from rdkit import Chem
from rdkit.Chem import Descriptors
import random

# Path to your CSV file
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/9_ZINC_250k/ZINC250_ACD_HSQC_train.csv"

# Output directory
output_dir = "molecule_subsets"
os.makedirs(output_dir, exist_ok=True)

# Load the data
print("Loading data...")
df = pd.read_csv(file_path)
print(f"Dataset shape: {df.shape}")
print(f"Column names: {df.columns.tolist()}")

# Function to count the number of peaks in shifts data
def count_peaks(shifts_str):
    try:
        # Parse the string representation of list into actual Python list
        shifts_list = ast.literal_eval(shifts_str)
        # Count the number of sublists (peaks)
        return len(shifts_list)
    except Exception as e:
        print(f"Error parsing shifts: {e}")
        return 0

# Function to canonicalize SMILES without stereochemistry
def canonicalize_smiles_no_stereo(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol:
            return Chem.MolToSmiles(mol, isomericSmiles=False, canonical=True)
        else:
            print(f"Warning: Could not parse SMILES: {smiles}")
            return smiles
    except Exception as e:
        print(f"Error canonicalizing SMILES {smiles}: {e}")
        return smiles

# Calculate number of peaks for each molecule
print("Calculating peak counts...")
df['peak_count'] = df['shifts'].apply(count_peaks)

# Canonicalize SMILES strings without stereochemistry
print("Canonicalizing SMILES strings without stereochemistry...")
df['SMILES'] = df['SMILES'].apply(canonicalize_smiles_no_stereo)

# Define the three peak count ranges
peak_ranges = [
    (3, 7, "small"),    # Small: 3-7 peaks
    (8, 16, "medium"),  # Medium: 8-16 peaks
    (17, 25, "large")   # Large: 17-25 peaks
]

# Process and save each subset
for min_peaks, max_peaks, size_label in peak_ranges:
    print(f"\nProcessing {size_label} molecules (peaks: {min_peaks}-{max_peaks})...")
    
    # Filter molecules that have peak counts in the specified range
    subset = df[(df['peak_count'] >= min_peaks) & (df['peak_count'] <= max_peaks)]
    
    # Report how many molecules are in this range
    print(f"Found {len(subset)} molecules with {min_peaks}-{max_peaks} peaks")
    
    # If we have more than 100 molecules, randomly sample 100
    if len(subset) > 100:
        subset = subset.sample(n=100, random_state=42)  # Use random_state for reproducibility
        print(f"Randomly selected 100 molecules from this subset")
    elif len(subset) < 100:
        print(f"WARNING: Only {len(subset)} molecules available in this range (less than 100)")
    
    # Check the distribution of peak counts in the selected subset
    peak_dist = subset['peak_count'].value_counts().sort_index()
    print(f"Peak count distribution in selected subset:\n{peak_dist}")
    
    # Rename 'shifts' column to 'HSQC' (uppercase)
    subset = subset.rename(columns={'shifts': 'HSQC'})
    
    # Add empty columns for additional NMR data
    subset['1H_NMR'] = np.nan
    subset['13C_NMR'] = np.nan
    subset['COSY'] = np.nan
    
    # Save to CSV
    output_file = os.path.join(output_dir, f"molecules_{size_label}_peaks.csv")
    subset.to_csv(output_file, index=False)
    print(f"Saved {len(subset)} molecules to {output_file}")

print("\nDone! Created the following files:")
for _, _, size_label in peak_ranges:
    print(f"- molecules_{size_label}_peaks.csv")

# Optional: Generate a summary report
summary_file = os.path.join(output_dir, "subset_summary.txt")
with open(summary_file, 'w') as f:
    f.write("Molecule Subset Summary\n")
    f.write("======================\n\n")
    
    for min_peaks, max_peaks, size_label in peak_ranges:
        output_file = os.path.join(output_dir, f"molecules_{size_label}_peaks.csv")
        subset = pd.read_csv(output_file)
        
        f.write(f"{size_label.capitalize()} Peak Subset ({min_peaks}-{max_peaks} peaks)\n")
        f.write(f"- Number of molecules: {len(subset)}\n")
        f.write(f"- Peak count statistics: min={subset['peak_count'].min()}, max={subset['peak_count'].max()}, avg={subset['peak_count'].mean():.2f}\n")
        f.write("\n")
    
    f.write(f"Original dataset: {df.shape[0]} molecules\n")

print(f"\nSummary report saved to {summary_file}")

# Print column names to verify structure
for _, _, size_label in peak_ranges:
    output_file = os.path.join(output_dir, f"molecules_{size_label}_peaks.csv")
    test_df = pd.read_csv(output_file)
    print(f"\nColumns in {os.path.basename(output_file)}:")
    print(test_df.columns.tolist())molecules_small_peaksmolecules_small_peaks

In [None]:
pwd

### Check for missing files

In [None]:
results_folder

In [None]:
os.listdir(results_folder)

In [None]:
import pandas as pd
import os
import glob

# Define paths
csv_file_path = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv"
results_folder = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/Lukas_aug_3_try_finishes"

# Function to check if a sample ID has corresponding files
def check_sample_files(sample_id, folder_path):
    # Pattern to match files starting with the sample ID
    pattern = os.path.join(folder_path, f"{sample_id}*")
    
    # Find files matching the pattern
    matching_files = glob.glob(pattern)
    
    # Return True if files exist, False otherwise
    return len(matching_files) > 0

# Main function
def find_missing_samples():
    print(f"Loading CSV file: {csv_file_path}")
    
    try:
        # Load CSV file
        df = pd.read_csv(csv_file_path)
        
        # Verify that 'sample-id' column exists
        if 'sample-id' not in df.columns:
            # Try to find similar column names
            possible_cols = [col for col in df.columns if 'id' in col.lower() or 'sample' in col.lower()]
            
            if possible_cols:
                print(f"Column 'sample-id' not found. Possible ID columns: {possible_cols}")
                # Try to use the first possible column
                id_column = possible_cols[0]
                print(f"Using '{id_column}' as the sample ID column")
            else:
                # Print all column names if no possible ID column is found
                print(f"Column 'sample-id' not found. Available columns: {df.columns.tolist()}")
                return
        else:
            id_column = 'sample-id'
        
        # Get all sample IDs from CSV
        all_sample_ids = df[id_column].unique()
        print(f"Found {len(all_sample_ids)} unique sample IDs in the CSV file")
        
        # Check if results folder exists
        if not os.path.exists(results_folder):
            print(f"Warning: Results folder path does not exist: {os.path.dirname(results_folder)}")
            return
        
        # List all files in the results folder
        all_files = os.listdir(results_folder)
        print(f"Found {len(all_files)} files in the results folder")
        
        # Find missing samples
        missing_samples = []
        for sample_id in all_sample_ids:
            if not check_sample_files(sample_id, results_folder):
                missing_samples.append(sample_id)
        
        # Report results
        if missing_samples:
            print(f"\nFound {len(missing_samples)} missing sample IDs:")
            for sample_id in missing_samples:
                print(f"- {sample_id}")
                
            # Save missing sample IDs to file
            output_file = "missing_samples.txt"
            with open(output_file, 'w') as f:
                for sample_id in missing_samples:
                    f.write(f"{sample_id}\n")
            print(f"\nMissing sample IDs saved to {output_file}")
        else:
            print("\nAll sample IDs from the CSV file have corresponding files in the folder.")
            
        # Additional useful information
        present_count = len(all_sample_ids) - len(missing_samples)
        print(f"\nSummary:")
        print(f"- Total unique sample IDs in CSV: {len(all_sample_ids)}")
        print(f"- Sample IDs with files present: {present_count} ({present_count/len(all_sample_ids)*100:.1f}%)")
        print(f"- Sample IDs missing files: {len(missing_samples)} ({len(missing_samples)/len(all_sample_ids)*100:.1f}%)")
        
    except Exception as e:
        print(f"Error: {e}")

# Run the function
if __name__ == "__main__":
    find_missing_samples()

In [None]:
import os
import json
from pathlib import Path

# Define the paths to the two folders
folder1 = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/Lukas_Target_incomplete"
folder2 = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/Lukas_aug_3_try_finishes"

def get_sample_ids(folder_path):
    """Extract sample IDs from JSON files in the given folder."""
    sample_ids = set()
    
    # Check if the folder exists
    if not os.path.exists(folder_path):
        print(f"Error: Folder {folder_path} does not exist.")
        return sample_ids
    
    # Get all JSON files in the folder
    json_files = [f for f in os.listdir(folder_path) if f.endswith('.json')]
    
    for filename in json_files:
        # Extract sample ID (part before the first underscore)
        parts = filename.split('_', 1)
        if len(parts) > 1:
            sample_id = parts[0]
            sample_ids.add(sample_id)
    
    return sample_ids

# Get sample IDs from both folders
folder1_ids = get_sample_ids(folder1)
folder2_ids = get_sample_ids(folder2)

# Print the number of sample IDs in each folder
print(f"Folder 1 (Lukas_Target_incomplete) has {len(folder1_ids)} sample IDs")
print(f"Folder 2 (Lukas_aug_3_try_finishes) has {len(folder2_ids)} sample IDs")

# Check for mismatches
if folder1_ids == folder2_ids:
    print("Both folders have the same sample IDs.")
else:
    print("\nMismatches found:")
    
    # Find sample IDs in folder1 but not in folder2
    missing_in_folder2 = folder1_ids - folder2_ids
    if missing_in_folder2:
        print(f"\nSample IDs in Folder 1 but missing in Folder 2 ({len(missing_in_folder2)}):")
        for sample_id in sorted(missing_in_folder2):
            print(f"  - {sample_id}")
    
    # Find sample IDs in folder2 but not in folder1
    missing_in_folder1 = folder2_ids - folder1_ids
    if missing_in_folder1:
        print(f"\nSample IDs in Folder 2 but missing in Folder 1 ({len(missing_in_folder1)}):")
        for sample_id in sorted(missing_in_folder1):
            print(f"  - {sample_id}")

# Print a brief summary
if len(folder1_ids) > len(folder2_ids):
    print(f"\nFolder 1 has {len(folder1_ids) - len(folder2_ids)} more sample IDs than Folder 2")
elif len(folder2_ids) > len(folder1_ids):
    print(f"\nFolder 2 has {len(folder2_ids) - len(folder1_ids)} more sample IDs than Folder 1")

# Analysis Data

## Look all molecules and sort them 

In [None]:
import json
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any

def process_single_json(json_data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Process a single JSON file's candidate analysis and combine all molecules.
    
    Args:
        json_data: Dictionary containing the JSON data with molecule_data
        
    Returns:
        List of dictionaries containing combined and processed molecule data
    """
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError as e:
        raise KeyError(f"Missing required key in JSON structure: {e}")
    
    all_molecules = []
    
    # Analysis types to process
    analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']
    
    # Extract molecules from each analysis type
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    # Create processed molecule entry with safe access using .get()
                    processed_mol = {
                        'smiles': mol['smiles'],  # Required field
                        'analysis_type': analysis_type,
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None),
                        'overall_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('overall', None),
                        'h_nmr_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('1H', None),
                        'c_nmr_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('13C', None),
                        'cosy_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('COSY', None)
                    }
                    
                    # Add generation info if available
                    gen_info = mol.get('generation_info', {})
                    processed_mol.update({
                        'source': gen_info.get('source', ''),
                        'parent_smiles': gen_info.get('parent_smiles', ''),
                        'starting_material': gen_info.get('starting_material', ''),
                        'log_likelihood': gen_info.get('log_likelihood', None)
                    })
                    
                    all_molecules.append(processed_mol)
                except KeyError as e:
                    print(f"Warning: Skipping molecule due to missing required field: {e}")
                    continue
    
    # Sort by HSQC score in ascending order (low to high), handling None values
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    
    return all_molecules

def analyze_json_file(file_path: str) -> pd.DataFrame:
    """
    Analyze a single JSON file and return results as a DataFrame.
    
    Args:
        file_path: Path to the JSON file
        
    Returns:
        DataFrame containing processed molecule data
    """
    try:
        # Read and parse JSON file
        with open(file_path, 'r') as f:
            json_data = json.load(f)
        
        # Get true molecule data
        molecule_data = json_data.get('molecule_data', {})
        true_smiles = molecule_data.get('smiles')
        sample_id = molecule_data.get('sample_id')
        
        if not true_smiles or not sample_id:
            raise ValueError("Missing required molecule_data fields (smiles or sample_id)")
        
        # Process molecules
        molecules = process_single_json(json_data)
        
        if not molecules:
            print(f"Warning: No valid molecules found in {file_path}")
            return pd.DataFrame()
        
        # Convert to DataFrame
        df = pd.DataFrame(molecules)
        
        # Add reference information
        df['true_smiles'] = true_smiles
        df['sample_id'] = sample_id
        
        # Add match indicator
        df['is_match'] = df['smiles'] == true_smiles
        
        # Reorder columns
        column_order = [
            'sample_id', 'true_smiles', 'smiles', 'is_match',
            'analysis_type', 'hsqc_score', 'overall_score',
            'h_nmr_score', 'c_nmr_score', 'cosy_score',
            'source', 'parent_smiles', 'starting_material', 'log_likelihood'
        ]
        df = df[column_order]
        
        return df
        
    except Exception as e:
        print(f"Error processing file {file_path}: {str(e)}")
        return pd.DataFrame()

In [None]:

# Example usage for a single file
if __name__ == "__main__":
    # Replace with your JSON file path
    file_path = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6.0_exp_d1_aug_finished/AZ10421813_exp_d1_aug_intermediate.json"
    
    try:
        # Process single file
        results_df = analyze_json_file(file_path)
        
        # Display summary
        print("\nAnalysis Summary:")
        print(f"Total candidates: {len(results_df)}")
        print(f"\nTop 5 candidates by HSQC score:")
        print(results_df[['smiles', 'analysis_type', 'hsqc_score', 'overall_score', 'is_match']].head(10))
        
        # Check if correct structure is in top candidates
        if results_df['is_match'].any():
            match_rank = results_df['is_match'].argmax() + 1
            print(f"\nTrue structure found at rank {match_rank}")
        else:
            print("\nTrue structure not found in candidates")
        
        # Save results
        output_file = Path(file_path).stem + "_analysis.csv"
        results_df.to_csv(output_file, index=False)
        print(f"\nDetailed results saved to: {output_file}")
        
    except Exception as e:
        print(f"Error processing file: {e}")

## Check which position the target molecule got for correct and incorrect starting guess

In [None]:
import json
import pandas as pd
import numpy as np
from pathlib import Path
from typing import Dict, List, Any
from collections import defaultdict
import glob
import os

def load_reference_data(csv_path: str) -> pd.DataFrame:
    """
    Load reference data from CSV file containing correct SMILES for each sample ID.
    
    Args:
        csv_path: Path to CSV file with columns 'sample_id' and 'smiles'
    
    Returns:
        DataFrame with reference data
    """
    try:
        ref_df = pd.read_csv(csv_path)
        required_columns = ['sample-id', 'SMILES']
        
        if not all(col in ref_df.columns for col in required_columns):
            raise ValueError(f"Reference CSV must contain columns: {required_columns}")
            
        # Create a dictionary for faster lookups, converting column names to match JSON
        return ref_df.set_index('sample-id')['SMILES'].to_dict()
        
    except Exception as e:
        raise Exception(f"Error loading reference data: {str(e)}")

def process_single_json(json_data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Process a single JSON file's candidate analysis and combine all molecules.
    """
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError as e:
        raise KeyError(f"Missing required key in JSON structure: {e}")
    
    all_molecules = []
    analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'analysis_type': analysis_type,
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None),
                        'overall_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('overall', None),
                        'h_nmr_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('1H', None),
                        'c_nmr_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('13C', None),
                        'cosy_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('COSY', None)
                    }
                    
                    gen_info = mol.get('generation_info', {})
                    processed_mol.update({
                        'source': gen_info.get('source', ''),
                        'parent_smiles': gen_info.get('parent_smiles', ''),
                        'starting_material': gen_info.get('starting_material', ''),
                        'log_likelihood': gen_info.get('log_likelihood', None)
                    })
                    
                    all_molecules.append(processed_mol)
                except KeyError as e:
                    print(f"Warning: Skipping molecule due to missing required field: {e}")
                    continue
    
    # Sort by HSQC score in ascending order (low to high), handling None values
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    
    return all_molecules

def get_base_sample_id(sample_id: str) -> str:
    """
    Extract the base sample ID (part before the underscore).
    """
    return sample_id.split('_')[0] if sample_id else ''

def analyze_json_file(file_path: str, reference_data: Dict[str, str]) -> tuple[pd.DataFrame, dict]:
    """
    Analyze a single JSON file and return results as a DataFrame and stats dictionary.
    """
    try:
        with open(file_path, 'r') as f:
            json_data = json.load(f)
        
        molecule_data = json_data.get('molecule_data', {})
        sample_id = molecule_data.get('sample_id')
        
        if not sample_id:
            raise ValueError("Missing required sample_id in JSON file")
        
        # Get base sample ID for matching with reference data
        base_sample_id = get_base_sample_id(sample_id)
        
        # Get correct SMILES from reference data
        true_smiles = reference_data.get(base_sample_id)
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id} (original: {sample_id})")
            return pd.DataFrame(), {}
        
        molecules = process_single_json(json_data)
        
        if not molecules:
            print(f"Warning: No valid molecules found in {file_path}")
            return pd.DataFrame(), {}
        
        df = pd.DataFrame(molecules)
        
        # Add reference information
        df['true_smiles'] = true_smiles
        df['sample_id'] = sample_id
        df['target_smiles'] = molecule_data.get('smiles', '')  # Original target from JSON
        df['is_match'] = df['smiles'] == true_smiles
        
        # Calculate ranking statistics
        stats = {
            'sample_id': sample_id,
            'total_candidates': len(df),
            'match_found': df['is_match'].any(),
            'match_rank': df['is_match'].argmax() + 1 if df['is_match'].any() else None,
            'in_top_5': df.iloc[:5]['is_match'].any(),
            'hsqc_score_of_true': df[df['is_match']]['hsqc_score'].iloc[0] if df['is_match'].any() else None,
            'target_smiles': df['target_smiles'].iloc[0],
            'true_smiles': true_smiles
        }
        
        return df, stats
        
    except Exception as e:
        print(f"Error processing file {file_path}: {str(e)}")
        return pd.DataFrame(), {}

def analyze_directory(directory_path: str, reference_csv: str) -> tuple[pd.DataFrame, pd.DataFrame]:
    """
    Analyze all JSON files in a directory using reference data from CSV.
    """
    # Load reference data
    print(f"Loading reference data from: {reference_csv}")
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files in the directory
    json_files = glob.glob(os.path.join(directory_path, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    all_results = []
    all_stats = []
    
    for file_path in json_files:
        df, stats = analyze_json_file(file_path, reference_data)
        if not df.empty:
            all_results.append(df)
            all_stats.append(stats)
    
    # Combine all results
    combined_results = pd.concat(all_results, ignore_index=True) if all_results else pd.DataFrame()
    stats_df = pd.DataFrame(all_stats) if all_stats else pd.DataFrame()
    
    return combined_results, stats_df

def generate_ranking_statistics(stats_df: pd.DataFrame) -> Dict[str, Any]:
    """
    Generate summary statistics from the analysis results.
    """
    if stats_df.empty:
        return {
            'total_analyzed': 0,
            'found_in_top_5': 0,
            'found_in_top_5_percent': 0,
            'total_found': 0,
            'total_found_percent': 0,
            'rank_distribution': {}
        }
    
    total_files = len(stats_df)
    
    stats = {
        'total_analyzed': total_files,
        'found_in_top_5': stats_df['in_top_5'].sum(),
        'found_in_top_5_percent': (stats_df['in_top_5'].sum() / total_files) * 100,
        'total_found': stats_df['match_found'].sum(),
        'total_found_percent': (stats_df['match_found'].sum() / total_files) * 100,
    }
    
    # Calculate rank distribution for found molecules
    found_ranks = stats_df[stats_df['match_found']]['match_rank']
    if not found_ranks.empty:
        stats.update({
            'rank_min': found_ranks.min(),
            'rank_max': found_ranks.max(),
            'rank_mean': found_ranks.mean(),
            'rank_median': found_ranks.median(),
            'rank_std': found_ranks.std()
        })
        
        # Calculate rank distribution
        rank_bins = [0, 1, 5, 10, 20, 50, float('inf')]
        rank_labels = ['1st', '2-5', '6-10', '11-20', '21-50', '50+']
        rank_dist = pd.cut(found_ranks, bins=rank_bins, labels=rank_labels, right=False).value_counts()
        stats['rank_distribution'] = rank_dist.to_dict()
    
    return stats

if __name__ == "__main__":
    # Replace these paths with your actual paths
    json_directory = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6.0_exp_d1_aug_finished"
    reference_csv = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
    
    print(f"Starting analysis...")
    print(f"JSON directory: {json_directory}")
    print(f"Reference CSV: {reference_csv}")
    
    # Analyze all files
    combined_results, stats_df = analyze_directory(json_directory, reference_csv)
    
    if combined_results.empty:
        print("No results were generated. Please check your input files and paths.")
        exit(1)
    
    # Generate statistics
    summary_stats = generate_ranking_statistics(stats_df)
    
    # Print summary
    print("\nAnalysis Summary:")
    print(f"Total files analyzed: {summary_stats['total_analyzed']}")
    print(f"Molecules found in top 5: {summary_stats['found_in_top_5']} ({summary_stats['found_in_top_5_percent']:.1f}%)")
    print(f"Total molecules found: {summary_stats['total_found']} ({summary_stats['total_found_percent']:.1f}%)")
    
    if summary_stats['total_found'] > 0:
        print("\nRank Statistics for Found Molecules:")
        print(f"Min rank: {summary_stats['rank_min']}")
        print(f"Max rank: {summary_stats['rank_max']}")
        print(f"Mean rank: {summary_stats['rank_mean']:.1f}")
        print(f"Median rank: {summary_stats['rank_median']}")
        
        print("\nRank Distribution:")
        for rank_range, count in summary_stats['rank_distribution'].items():
            print(f"{rank_range}: {count}")
    
    # Save detailed results
    output_dir = Path(json_directory) / "analysis_results"
    output_dir.mkdir(exist_ok=True)
    
    # Save combined results
    combined_results.to_csv(output_dir / "all_molecules.csv", index=False)
    
    # Save statistics
    stats_df.to_csv(output_dir / "file_statistics.csv", index=False)
    
    # Save summary statistics
    pd.DataFrame([summary_stats]).to_csv(output_dir / "summary_statistics.csv", index=False)
    
    print(f"\nDetailed results saved to: {output_dir}")

## Plot histograms

### V1 - OLD

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def process_single_json(json_data):
    """Process a single JSON file and return sorted molecules."""
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return []
    
    all_molecules = []
    analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                    }
                    all_molecules.append(processed_mol)
                except KeyError:
                    continue
    
    # Sort by HSQC score
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    return all_molecules

def find_molecule_rank(molecules, true_smiles):
    """Find the rank of the correct molecule."""
    for idx, mol in enumerate(molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def analyze_directory(json_dir, reference_csv):
    """Analyze all JSON files and return list of rankings."""
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    rankings = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process molecules and find rank
            molecules = process_single_json(json_data)
            if not molecules:
                continue
                
            rank = find_molecule_rank(molecules, true_smiles)
            if rank is not None:
                rankings.append(rank)
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return rankings

def plot_ranking_histogram(rankings, experiment_label="", max_rank=10, figsize=(12, 6)):
    """
    Create histogram of molecule rankings.
    
    Args:
        rankings: List of rankings for correct molecules
        experiment_label: Label for the experiment to be shown in title
        max_rank: Maximum rank to show in histogram (default: 10)
        figsize: Size of the figure (default: (12, 6))
    """
    # Create bins for the histogram
    bins = np.arange(1, max_rank + 2) - 0.5
    
    # Create figure with white background
    fig, ax = plt.subplots(figsize=figsize, facecolor='white')
    ax.set_facecolor('white')
    
    # Plot histogram
    n, bins, patches = ax.hist(
        [r for r in rankings if r <= max_rank],
        bins=bins,
        edgecolor='black',
        alpha=0.7,
        color='#4169E1'  # Royal Blue
    )
    
    # Customize the plot
    title = 'Distribution of Correct Molecule Rankings'
    if experiment_label:
        title += f' - {experiment_label}'
    ax.set_title(title, fontsize=14, pad=20)
    ax.set_xlabel('Rank', fontsize=12)
    ax.set_ylabel('Number of Molecules', fontsize=12)
    
    # Set x-axis ticks
    ax.set_xticks(range(1, max_rank + 1))
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--')
    
    # Add counts above bars
    for i, count in enumerate(n):
        if count > 0:
            ax.text(i + 1, count, f'{int(count)}', 
                   ha='center', va='bottom')
    
    # Calculate statistics
    total_molecules = len(rankings)
    in_top_5 = sum(1 for r in rankings if r <= 5)
    in_top_10 = sum(1 for r in rankings if r <= 10)
    
    stats_text = (
        f'Total molecules found: {total_molecules}\n'
        f'Found in top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)\n'
        f'Found in top 10: {in_top_10} ({in_top_10/total_molecules*100:.1f}%)'
    )
    
    # Add statistics text box
    ax.text(0.95, 0.95, stats_text,
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    # Print detailed distribution
    print(f"\nDetailed rank distribution for {experiment_label if experiment_label else 'experiment'}:")
    rank_counts = pd.Series(rankings).value_counts().sort_index()
    for rank, count in rank_counts.items():
        print(f"Rank {rank}: {count} molecules")
    
    return fig

def main():
    # Example usage with different experiments
    experiments = {
        "Simulated Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data d4": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get rankings
        rankings = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        if not rankings:
            print(f"No valid rankings found for {exp_label}. Please check your input files.")
            continue
        
        # Create and show plot
        fig = plot_ranking_histogram(rankings, experiment_label=exp_label)
        plt.show()
        
        # Optionally save the plot
        # fig.savefig(f'ranking_histogram_{exp_label}.png', dpi=300, bbox_inches='tight')

if __name__ == "__main__":
    main()

### V1.2


In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def process_single_json(json_data):
    """Process a single JSON file and return sorted molecules."""
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return []
    
    all_molecules = []
    analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                    }
                    all_molecules.append(processed_mol)
                except KeyError:
                    continue
    
    # Sort by HSQC score
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    return all_molecules

def find_molecule_rank(molecules, true_smiles):
    """Find the rank of the correct molecule."""
    for idx, mol in enumerate(molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def analyze_directory(json_dir, reference_csv):
    """Analyze all JSON files and return list of rankings."""
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    rankings = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process molecules and find rank
            molecules = process_single_json(json_data)
            if not molecules:
                continue
                
            rank = find_molecule_rank(molecules, true_smiles)
            if rank is not None:
                rankings.append(rank)
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return rankings

def plot_ranking_histogram(rankings, experiment_label="", max_rank=5, color="#4169E1", figsize=(5.5, 5.5)):
    """
    Create histogram of molecule rankings with an extra bin for ranks beyond max_rank.
    
    Args:
        rankings: List of rankings for correct molecules
        experiment_label: Label for the experiment to be shown in title
        max_rank: Maximum individual rank to show in histogram (default: 5)
        color: Color for histogram bars
        figsize: Size of the figure (default: (5.5, 5.5) for square plot)
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=figsize, facecolor='white')
    ax.set_facecolor('white')
    
    # Prepare data for histogram
    rank_counts = {}
    
    # Count ranks 1-5 individually
    for r in range(1, max_rank + 1):
        rank_counts[r] = sum(1 for rank in rankings if rank == r)
    
    # Count all ranks > max_rank together
    rank_counts['6+'] = sum(1 for rank in rankings if rank > max_rank)
    
    # Plot the histogram
    positions = list(range(1, max_rank + 1)) + [max_rank + 1]
    counts = [rank_counts[r] if r <= max_rank else rank_counts['6+'] for r in positions]
    
    bars = ax.bar(
        positions,
        counts,
        width=0.8,
        edgecolor='black',
        alpha=0.7,
        color=color
    )
    
    # Create two-line title with increased font size (20)
    title_line1 = 'Distribution of Correct Molecule Rankings'
    title_line2 = experiment_label if experiment_label else ""
    
    ax.set_title(f"{title_line1}\n{title_line2}", fontsize=20, pad=10)
    ax.set_xlabel('Rank', fontsize=16)
    ax.set_ylabel('Number of Molecules', fontsize=16)
    
    # Set x-axis ticks and labels with increased font size (14)
    ax.set_xticks(positions)
    x_labels = [str(i) for i in range(1, max_rank + 1)] + ['6+']
    ax.set_xticklabels(x_labels, fontsize=14)
    ax.tick_params(axis='y', labelsize=14)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Add counts above bars with increased font size (14)
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax.text(
                bar.get_x() + bar.get_width()/2, 
                height + 0.1, 
                f'{int(height)}',
                ha='center', 
                va='bottom',
                fontsize=14
            )
    
    # Calculate statistics
    total_molecules = len(rankings)
    in_top_1 = sum(1 for r in rankings if r == 1)
    in_top_3 = sum(1 for r in rankings if r <= 3)
    in_top_5 = sum(1 for r in rankings if r <= 5)
    after_top_5 = sum(1 for r in rankings if r > 5)
    
    stats_text = (
        f'Total molecules found: {total_molecules}\n'
        f'Found in top 1: {in_top_1} ({in_top_1/total_molecules*100:.1f}%)\n'
        f'Found in top 3: {in_top_3} ({in_top_3/total_molecules*100:.1f}%)\n'
        f'Found in top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)\n'
        f'Found after top 5: {after_top_5} ({after_top_5/total_molecules*100:.1f}%)'
    )
    
    # Add statistics text box with increased font size (14)
    ax.text(0.95, 0.95, stats_text,
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            fontsize=14,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    # Print detailed distribution
    print(f"\nDetailed rank distribution for {experiment_label if experiment_label else 'experiment'}:")
    rank_counts_series = pd.Series(rankings).value_counts().sort_index()
    for rank, count in rank_counts_series.items():
        print(f"Rank {rank}: {count} molecules")
    
    return fig
def main():
    # Define output directory for figures
    output_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures"
    
    # Create output directory if it doesn't exist
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Example usage with different experiments
    experiments = {
        "Experimental Data with Wrong Guess all": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_10_exp_d1_MMST_all_new",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with very Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_9_exp_d1_aug_st",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "ACD Data with Wrong Guess HSQC": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_8_ACD_d1_MMST_HSQC",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with Wrong Guess HSQC": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_7_exp_d1_MMST_HSQC_new",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Lukas Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/Lukas_Target_incomplete",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv"
        },        
        "Lukas Aug Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/Lukas_aug_3_try_",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv"
        },
        "Simulated Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data ": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    # Define colors for each experiment (using the colors from the model list)
    colors = {
        
        "Experimental Data with Wrong Guess all": "#6366F1",
        "Experimental Data with very Wrong Guess HSQC": "#6366F1",
        "ACD Data with Wrong Guess HSQC": "#6366F1",
        "Experimental Data with Wrong Guess HSQC": "#6366F1",
        "Lukas Data": "#6366F1",                   # Claude 3.5 Sonnet
        "Lukas Aug Data": "#6366F1",                   # Claude 3.5 Sonnet
        "Simulated Data": "#6366F1",                   # Claude 3.5 Sonnet
        "Simulated Data with Wrong Guess": "#3B82F6",  # Claude 3.7 Sonnet-Thinking
        "Simulated Data with Noise": "#10B981",        # DeepSeek-R1
        "Experimental Data": "#F59E0B",                # Gemini-Thinking
        "Experimental Data with Wrong Guess": "#EC4899", # o3-mini
        "Experimental Data ": "#F59E0B"             # Kimi 1.5
    }
    
    # Collect summary statistics
    summary_data = []
    
    # Process each experiment
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get rankings
        rankings = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        if not rankings:
            print(f"No valid rankings found for {exp_label}. Please check your input files.")
            continue
        
        # Create and show individual plot with custom color and extra bin for 6+
        fig = plot_ranking_histogram(
            rankings, 
            experiment_label=exp_label, 
            max_rank=5,  # Show individual ranks 1-5
            color=colors[exp_label],
            figsize=(5.5, 5.5)  # Smaller square figure
        )
        plt.show()
        
        # Save the plot to the specified output directory
        output_path = os.path.join(output_dir, f'ranking_histogram_{exp_label.replace(" ", "_")}_MMST.png')
        fig.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")
        
        # Collect statistics for summary
        total_molecules = len(rankings)
        in_top_1 = sum(1 for r in rankings if r == 1)
        in_top_3 = sum(1 for r in rankings if r <= 3)
        in_top_5 = sum(1 for r in rankings if r <= 5)
        in_top_10 = sum(1 for r in rankings if r <= 10)
        
        summary_data.append({
            'Experiment': exp_label,
            'Tool': 'MMST',
            'Total': total_molecules,
            'Top 1': in_top_1,
            'Top 1 %': in_top_1/total_molecules*100 if total_molecules > 0 else 0,
            'Top 3': in_top_3,
            'Top 3 %': in_top_3/total_molecules*100 if total_molecules > 0 else 0,
            'Top 5': in_top_5,
            'Top 5 %': in_top_5/total_molecules*100 if total_molecules > 0 else 0,
            'Top 10': in_top_10,
            'Top 10 %': in_top_10/total_molecules*100 if total_molecules > 0 else 0
        })
    
    # Save summary statistics to CSV
    try:
        if summary_data:
            summary_df = pd.DataFrame(summary_data)
            summary_csv_path = os.path.join(output_dir, 'mmst_summary_statistics.csv')
            summary_df.to_csv(summary_csv_path, index=False)
            print(f"\nSaved summary statistics to: {summary_csv_path}")
            
            # Print summary table
            print("\n===== SUMMARY STATISTICS =====")
            for row in summary_data:
                print(f"\n{row['Experiment']}:")
                print(f"Total: {row['Total']}")
                print(f"Top 1: {row['Top 1']} ({row['Top 1 %']:.1f}%)")
                print(f"Top 3: {row['Top 3']} ({row['Top 3 %']:.1f}%)")
                print(f"Top 5: {row['Top 5']} ({row['Top 5 %']:.1f}%)")
                print(f"Top 10: {row['Top 10']} ({row['Top 10 %']:.1f}%)")
    except Exception as e:
        print(f"Error saving summary statistics: {e}")
        

if __name__ == "__main__":
    main()

In [None]:
pwd

In [None]:
def find_unranked_molecules(json_dir, reference_csv):
    """
    Identify molecules that were not ranked in the analysis.
    
    Args:
        json_dir: Directory containing JSON files
        reference_csv: Path to reference CSV file
    
    Returns:
        List of unranked sample IDs and their SMILES
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    
    # Track processed molecules
    processed_smiles = set()
    unranked_molecules = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process molecules and find rank
            molecules = process_single_json(json_data)
            if not molecules:
                # If no molecules processed, this could be the unranked molecule
                if true_smiles not in processed_smiles:
                    unranked_molecules.append({
                        'sample_id': base_sample_id,
                        'smiles': true_smiles
                    })
                continue
                
            # Check if this molecule is in the processed list
            for mol in molecules:
                processed_smiles.add(mol['smiles'])
                
            # Check if the true SMILES is not found in the molecules
            if not any(mol['smiles'] == true_smiles for mol in molecules):
                if true_smiles not in processed_smiles:
                    unranked_molecules.append({
                        'sample_id': base_sample_id,
                        'smiles': true_smiles
                    })
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    # Compare with total reference molecules
    total_reference_molecules = set(reference_data.values())
    missing_total = total_reference_molecules - processed_smiles
    
    if missing_total:
        print("\nMolecules missing from ALL processed files:")
        for smiles in missing_total:
            # Find corresponding sample IDs
            missing_ids = [sample_id for sample_id, ref_smiles in reference_data.items() if ref_smiles == smiles]
            print(f"SMILES: {smiles}, Sample IDs: {missing_ids}")
    
    return unranked_molecules

def main_find_unranked():
    """
    Example usage of find_unranked_molecules function
    """
    experiments = {
        "Simulated Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Find unranked molecules
        unranked = find_unranked_molecules(
            paths["json_directory"], 
            paths["reference_csv"]
        )
        
        # Print results
        if unranked:
            print("\nUnranked Molecules:")
            for mol in unranked:
                print(f"Sample ID: {mol['sample_id']}, SMILES: {mol['smiles']}")
        else:
            print("No unranked molecules found.")

if __name__ == "__main__":
    main_find_unranked()

### V1.3 new

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def process_single_json(json_data):
    """Process a single JSON file and return sorted molecules."""
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return []
    
    all_molecules = []
    analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                    }
                    all_molecules.append(processed_mol)
                except KeyError:
                    continue
    
    # Sort by HSQC score
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    return all_molecules

def find_molecule_rank(molecules, true_smiles):
    """Find the rank of the correct molecule."""
    for idx, mol in enumerate(molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def analyze_directory(json_dir, reference_csv):
    """Analyze all JSON files and return list of rankings."""
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    rankings = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process molecules and find rank
            molecules = process_single_json(json_data)
            if not molecules:
                continue
                
            rank = find_molecule_rank(molecules, true_smiles)
            if rank is not None:
                rankings.append(rank)
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return rankings
def plot_ranking_histogram(rankings, experiment_label="", max_rank=5, color="#4169E1", figsize=(5.5, 5.5)):
    """
    Create histogram of molecule rankings with an extra bin for ranks beyond max_rank.
    
    Args:
        rankings: List of rankings for correct molecules
        experiment_label: Label for the experiment to be shown in title
        max_rank: Maximum individual rank to show in histogram (default: 5)
        color: Color for histogram bars
        figsize: Size of the figure (default: (5.5, 5.5) for square plot)
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=figsize, facecolor='white')
    ax.set_facecolor('white')
    
    # Prepare data for histogram
    rank_counts = {}
    
    # Count ranks 1-5 individually
    for r in range(1, max_rank + 1):
        rank_counts[r] = sum(1 for rank in rankings if rank == r)
    
    # Count all ranks > max_rank together
    rank_counts['6+'] = sum(1 for rank in rankings if rank > max_rank)
    
    # Plot the histogram
    positions = list(range(1, max_rank + 1)) + [max_rank + 1]
    counts = [rank_counts[r] if r <= max_rank else rank_counts['6+'] for r in positions]
    
    bars = ax.bar(
        positions,
        counts,
        width=0.8,
        edgecolor='black',
        alpha=0.7,
        color=color
    )
    
    # Create two-line title with increased font size (20)
    title_line1 = 'Distribution of Correct Molecule Rankings'
    title_line2 = experiment_label if experiment_label else ""
    
    ax.set_title(f"{title_line1}\n{title_line2}", fontsize=20, pad=10)
    ax.set_xlabel('Rank', fontsize=16)
    ax.set_ylabel('Number of Molecules', fontsize=16)
    
    # Set x-axis ticks and labels with increased font size (14)
    ax.set_xticks(positions)
    x_labels = [str(i) for i in range(1, max_rank + 1)] + ['6+']
    ax.set_xticklabels(x_labels, fontsize=14)
    ax.tick_params(axis='y', labelsize=14)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Set fixed y-axis maximum to 35
    y_max = 35
    ax.set_ylim(0, y_max)
    
    # Add counts above bars with increased font size (14) and better positioning
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax.text(
                bar.get_x() + bar.get_width()/2, 
                height + 0.7,  # Fixed offset for consistent positioning
                f'{int(height)}',
                ha='center', 
                va='bottom',
                fontsize=14
            )
    
    # Calculate statistics
    total_molecules = len(rankings)
    in_top_1 = sum(1 for r in rankings if r == 1)
    in_top_3 = sum(1 for r in rankings if r <= 3)
    in_top_5 = sum(1 for r in rankings if r <= 5)
    after_top_5 = sum(1 for r in rankings if r > 5)
    
    stats_text = (
        f'Total molecules found: {total_molecules}\n'
        f'Found in top 1: {in_top_1} ({in_top_1/total_molecules*100:.1f}%)\n'
        f'Found in top 3: {in_top_3} ({in_top_3/total_molecules*100:.1f}%)\n'
        f'Found in top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)\n'
        f'Found after top 5: {after_top_5} ({after_top_5/total_molecules*100:.1f}%)'
    )
    
    # Add statistics text box with increased font size (14)
    ax.text(0.95, 0.95, stats_text,
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            fontsize=14,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    # Print detailed distribution
    print(f"\nDetailed rank distribution for {experiment_label if experiment_label else 'experiment'}:")
    rank_counts_series = pd.Series(rankings).value_counts().sort_index()
    for rank, count in rank_counts_series.items():
        print(f"Rank {rank}: {count} molecules")
    
    return fig

def main():
    # Define output directory for figures
    output_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures"
    
    # Create output directory if it doesn't exist
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Example usage with different experiments
    experiments = {
        #"ACD Data HSQC with Wrong Guess": {
        #    "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_8_ACD_d1_MMST_HSQC",
        #   "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        #},
        "Additional Data HSQC aug": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_15_Lukas_aug_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv"
        },
        "Additional Data HSQC": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_14_Lukas_target_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv"
        },        
        "Simulated Data ALL with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data ALL with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data ALL": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data HSQC with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_12_exp_d1_aug_MMST_HSQC_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data HSQC": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_13_exp_d1_MMST_HSQC_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data ALL with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data ALL with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_10_exp_d1_aug_MMST_all_new",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data ALL": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        #"Experimental Data ": {
        #    "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
        #    "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
       # }
    }
    
    # Define colors for each experiment (using the colors from the model list)
    colors = {
        #"ACD Data HSQC with Wrong Guess": "#6366F1",
        "Additional Data HSQC aug": "#14B8A6",                # Teal-500 (new)
        "Additional Data HSQC": "#0EA5E9",                    # Sky-500 (new)
        "Simulated Data ALL with Wrong Guess": "#3B82F6",     # Claude 3.7 Sonnet-Thinking
        "Simulated Data ALL with Noise": "#10B981",           # DeepSeek-R1
        "Simulated Data ALL": "#6366F1",                      # Claude 3.5 Sonnet
        "Experimental Data HSQC with Wrong Guess": "#F97316", # Orange-500 (new)
        "Experimental Data HSQC": "#84CC16",                  # Lime-500 (new)
        "Experimental Data ALL with Wrong Guess": "#EC4899",  # o3-mini
        "Experimental Data ALL with Wrong Guess": "#EC4899",  # Kimi 1.5
        "Experimental Data ALL": "#8B5CF6"                    # Kimi 1.5
    }
    
    # Collect summary statistics
    summary_data = []
    
    # Process each experiment
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get rankings
        rankings = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        if not rankings:
            print(f"No valid rankings found for {exp_label}. Please check your input files.")
            continue
        
        # Create and show individual plot with custom color and extra bin for 6+
        fig = plot_ranking_histogram(
            rankings, 
            experiment_label=exp_label, 
            max_rank=5,  # Show individual ranks 1-5
            color=colors[exp_label],
            figsize=(5.5, 5.5)  # Smaller square figure
        )
        plt.show()
        
        # Save the plot to the specified output directory
        output_path = os.path.join(output_dir, f'ranking_histogram_{exp_label.replace(" ", "_")}_MMST.png')
        fig.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")
        
        # Collect statistics for summary
        total_molecules = len(rankings)
        in_top_1 = sum(1 for r in rankings if r == 1)
        in_top_3 = sum(1 for r in rankings if r <= 3)
        in_top_5 = sum(1 for r in rankings if r <= 5)
        in_top_10 = sum(1 for r in rankings if r <= 10)
        
        summary_data.append({
            'Experiment': exp_label,
            'Tool': 'MMST',
            'Total': total_molecules,
            'Top 1': in_top_1,
            'Top 1 %': in_top_1/total_molecules*100 if total_molecules > 0 else 0,
            'Top 3': in_top_3,
            'Top 3 %': in_top_3/total_molecules*100 if total_molecules > 0 else 0,
            'Top 5': in_top_5,
            'Top 5 %': in_top_5/total_molecules*100 if total_molecules > 0 else 0,
            'Top 10': in_top_10,
            'Top 10 %': in_top_10/total_molecules*100 if total_molecules > 0 else 0
        })
    
    # Save summary statistics to CSV
    try:
        if summary_data:
            summary_df = pd.DataFrame(summary_data)
            summary_csv_path = os.path.join(output_dir, 'mmst_summary_statistics.csv')
            summary_df.to_csv(summary_csv_path, index=False)
            print(f"\nSaved summary statistics to: {summary_csv_path}")
            
            # Print summary table
            print("\n===== SUMMARY STATISTICS =====")
            for row in summary_data:
                print(f"\n{row['Experiment']}:")
                print(f"Total: {row['Total']}")
                print(f"Top 1: {row['Top 1']} ({row['Top 1 %']:.1f}%)")
                print(f"Top 3: {row['Top 3']} ({row['Top 3 %']:.1f}%)")
                print(f"Top 5: {row['Top 5']} ({row['Top 5 %']:.1f}%)")
                print(f"Top 10: {row['Top 10']} ({row['Top 10 %']:.1f}%)")
    except Exception as e:
        print(f"Error saving summary statistics: {e}")
        

if __name__ == "__main__":
    main()

In [None]:
pwd

### V2 - OLD


In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def process_single_json(json_data):
    """Process a single JSON file and return sorted molecules."""
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return []
    
    all_molecules = []
    analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                    }
                    all_molecules.append(processed_mol)
                except KeyError:
                    continue
    
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    return all_molecules

def find_molecule_rank(molecules, true_smiles):
    """Find the rank of the correct molecule."""
    for idx, mol in enumerate(molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def analyze_directory(json_dir, reference_csv):
    """Analyze all JSON files and return list of rankings."""
    reference_data = load_reference_data(reference_csv)
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    rankings = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            base_sample_id = get_base_sample_id(sample_id)
            true_smiles = reference_data.get(base_sample_id)
            
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            molecules = process_single_json(json_data)
            if not molecules:
                continue
                
            rank = find_molecule_rank(molecules, true_smiles)
            if rank is not None:
                rankings.append(rank)
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return rankings

def plot_ranking_histogram(rankings, experiment_label="", max_rank=5, figsize=(5, 5)):
    """
    Create histogram of molecule rankings.
    
    Args:
        rankings: List of rankings for correct molecules
        experiment_label: Label for the experiment to be shown in title
        max_rank: Maximum rank to show in histogram (default: 5)
        figsize: Size of the figure (default: (10, 10))
    """
    bins = np.arange(1, max_rank + 2) - 0.5
    
    fig, ax = plt.subplots(figsize=figsize, facecolor='white')
    ax.set_facecolor('white')
    
    n, bins, patches = ax.hist(
        [r for r in rankings if r <= max_rank],
        bins=bins,
        edgecolor='black',
        alpha=0.7,
        color='#4169E1'  # Royal Blue
    )
    
    title = 'Distribution of Correct Molecule Rankings'
    if experiment_label:
        title += f'\n{experiment_label}'
    ax.set_title(title, fontsize=16, pad=20)
    ax.set_xlabel('Rank', fontsize=14)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Set x-axis ticks with larger font
    ax.set_xticks(range(1, max_rank + 1))
    ax.tick_params(axis='both', which='major', labelsize=12)
    
    # Add grid
    ax.grid(True, alpha=0.3, color='gray', linestyle='--')
    
    # Add counts above bars with larger font
    for i, count in enumerate(n):
        if count > 0:
            ax.text(i + 1, count, f'{int(count)}', 
                   ha='center', va='bottom', fontsize=12)
    
    # Calculate statistics
    total_molecules = len(rankings)
    in_top_5 = sum(1 for r in rankings if r <= 5)
    
    stats_text = (
        f'Total molecules found: {total_molecules}\n'
        f'Found in top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)'
    )
    
    ## Add statistics text box with larger font
    #ax.text(0.95, 0.95, stats_text,
    #        transform=ax.transAxes,
     #       verticalalignment='top',
       #     horizontalalignment='right',
        #    bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
         #   fontsize=12)
    
    plt.tight_layout()
    
    # Print detailed distribution
    print(f"\nDetailed rank distribution for {experiment_label if experiment_label else 'experiment'}:")
    rank_counts = pd.Series(rankings).value_counts().sort_index()
    for rank, count in rank_counts.items():
        if rank <= max_rank:
            print(f"Rank {rank}: {count} molecules")
    
    return fig

def main():
    experiments = {
        "Simulated Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data d4": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        rankings = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        if not rankings:
            print(f"No valid rankings found for {exp_label}. Please check your input files.")
            continue
        
        fig = plot_ranking_histogram(rankings, experiment_label=exp_label)
        plt.show()

if __name__ == "__main__":
    main()

## Compare with LLM ranked Results:

In [None]:
import json
import pandas as pd
from pathlib import Path
from typing import Dict, List, Any

def analyze_o3_predictions(json_file: str, reference_csv: str) -> Dict[str, Any]:
    """
    Analyze O3's molecular predictions against ground truth.
    
    Args:
        json_file: Path to JSON file with O3's analysis
        reference_csv: Path to CSV file with ground truth data
        
    Returns:
        Dictionary with analysis results
    """
    # Load reference data
    ref_df = pd.read_csv(reference_csv)
    ref_dict = ref_df.set_index('sample-id')['SMILES'].to_dict()
    
    # Load and parse JSON file
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    # Get sample ID and true SMILES
    sample_id = data["molecule_data"]["sample_id"]
    base_sample_id = sample_id.split('_')[0]  # Remove suffix after underscore
    true_smiles = ref_dict.get(base_sample_id)
    
    if true_smiles is None:
        raise ValueError(f"No reference SMILES found for sample_id {base_sample_id}")
    
    # Extract O3's candidates and sort by confidence
    try:
        o3_results = data["analysis_results"]["final_analysis"]["llm_responses"]["o3"]["parsed_results"]
        candidates = o3_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        # Prepare results
        results = {
            "sample_id": sample_id,
            "true_smiles": true_smiles,
            "top_prediction": {
                "smiles": sorted_candidates[0]["smiles"],
                "confidence": sorted_candidates[0]["confidence_score"]
            },
            "is_top_correct": sorted_candidates[0]["smiles"] == true_smiles,
            "correct_molecule_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "all_predictions": [
                {
                    "position": i+1,
                    "smiles": cand["smiles"],
                    "confidence": cand["confidence_score"],
                    "is_correct": cand["smiles"] == true_smiles
                }
                for i, cand in enumerate(sorted_candidates)
            ]
        }
        
        return results
        
    except KeyError as e:
        raise KeyError(f"Could not find O3 results in expected JSON structure: {e}")

def print_analysis_results(results: Dict[str, Any]):
    """Print a human-readable summary of the analysis results."""
    print(f"\nAnalysis Results for {results['sample_id']}:")
    print("-" * 50)
    print(f"True SMILES: {results['true_smiles']}")
    print("\nTop Prediction:")
    print(f"SMILES: {results['top_prediction']['smiles']}")
    print(f"Confidence: {results['top_prediction']['confidence']:.2f}")
    print(f"Is Correct: {results['is_top_correct']}")
    
    if results['correct_molecule_position']:
        print(f"\nCorrect molecule found at position: {results['correct_molecule_position']}")
    else:
        print("\nCorrect molecule not found in candidates")
    
    print(f"\nAll predictions (total: {results['total_candidates']}):")
    for pred in results['all_predictions']:
        correct_marker = "✓" if pred['is_correct'] else " "
        print(f"{correct_marker} {pred['position']}. Confidence: {pred['confidence']:.2f} - SMILES: {pred['smiles']}")

if __name__ == "__main__":
    # Replace with your file paths

    json_file= "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_aug_2.0_finished_clean/AZ10930573_aug_intermediate.json"
    reference_csv= "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
    
    try:
        results = analyze_o3_predictions(json_file, reference_csv)
        print_analysis_results(results)
        
        # Optionally save results
        output_file = Path(json_file).parent / "o3_analysis_results.json"
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)
            
    except Exception as e:
        print(f"Error analyzing predictions: {str(e)}")

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os
from typing import Dict, List, Any

def load_reference_data(csv_path: str) -> Dict[str, str]:
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def analyze_single_json(json_file: str, reference_data: Dict[str, str]) -> Dict[str, Any]:
    """
    Analyze a single JSON file's O3 predictions against ground truth.
    
    Args:
        json_file: Path to JSON file with O3's analysis
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        Dictionary with analysis results or None if analysis fails
    """
    try:
        # Load and parse JSON file
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        # Get sample ID and true SMILES
        sample_id = data["molecule_data"]["sample_id"]
        base_sample_id = sample_id.split('_')[0]  # Remove suffix after underscore
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return None
        
        # Extract O3's candidates and sort by confidence
        o3_results = data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"]["parsed_results"]
        candidates = o3_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        return {
            "sample_id": sample_id,
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "is_top_1": correct_position == 1,
            "is_top_3": correct_position is not None and correct_position <= 3,
            "is_top_5": correct_position is not None and correct_position <= 5,
            "is_top_10": correct_position is not None and correct_position <= 10
        }
        
    except Exception as e:
        print(f"Error processing {json_file}: {str(e)}")
        return None

def analyze_directory(json_dir: str, reference_csv: str) -> List[Dict[str, Any]]:
    """
    Analyze all JSON files in a directory.
    
    Args:
        json_dir: Directory containing JSON files
        reference_csv: Path to CSV file with ground truth data
        
    Returns:
        List of analysis results for each file
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Analyze each file
    results = []
    for file_path in json_files:
        result = analyze_single_json(file_path, reference_data)
        if result:
            results.append(result)
    
    return results

def plot_ranking_histogram(results: List[Dict[str, Any]], experiment_label: str = "", max_rank: int = 10):
    """
    Create histogram of correct molecule rankings.
    
    Args:
        results: List of analysis results
        experiment_label: Label for the experiment
        max_rank: Maximum rank to show in histogram
    """
    # Extract rankings
    rankings = [r["correct_position"] for r in results if r["correct_position"] is not None]
    
    # Create bins for the histogram
    bins = np.arange(1, max_rank + 2) - 0.5
    
    # Create figure
    fig, ax = plt.subplots(figsize=(12, 6), facecolor='white')
    ax.set_facecolor('white')
    
    # Plot histogram
    n, bins, patches = ax.hist(
        [r for r in rankings if r <= max_rank],
        bins=bins,
        edgecolor='black',
        alpha=0.7,
        color='#4169E1'
    )
    
    # Customize the plot
    title = 'Distribution of O3 Correct Molecule Rankings'
    if experiment_label:
        title += f' - {experiment_label}'
    ax.set_title(title, fontsize=14, pad=20)
    ax.set_xlabel('Rank', fontsize=12)
    ax.set_ylabel('Number of Molecules', fontsize=12)
    
    # Set x-axis ticks
    ax.set_xticks(range(1, max_rank + 1))
    
    # Add grid
    ax.grid(True, alpha=0.3, color='gray', linestyle='--')
    
    # Add counts above bars
    for i, count in enumerate(n):
        if count > 0:
            ax.text(i + 1, count, f'{int(count)}', 
                   ha='center', va='bottom')
    
    # Calculate statistics
    total_analyzed = len(results)
    total_found = len(rankings)
    in_top_1 = sum(1 for r in results if r["is_top_1"])
    in_top_3 = sum(1 for r in results if r["is_top_3"])
    in_top_5 = sum(1 for r in results if r["is_top_5"])
    in_top_10 = sum(1 for r in results if r["is_top_10"])
    
    stats_text = (
        f'Total analyzed: {total_analyzed}\n'
        f'Total found: {total_found}\n'
        f'Top 1: {in_top_1} ({in_top_1/total_analyzed*100:.1f}%)\n'
        f'Top 3: {in_top_3} ({in_top_3/total_analyzed*100:.1f}%)\n'
        f'Top 5: {in_top_5} ({in_top_5/total_analyzed*100:.1f}%)\n'
        f'Top 10: {in_top_10} ({in_top_10/total_analyzed*100:.1f}%)'
    )
    
    # Add statistics text box
    ax.text(0.95, 0.95, stats_text,
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    # Print detailed distribution
    print(f"\nDetailed rank distribution for {experiment_label if experiment_label else 'experiment'}:")
    rank_counts = pd.Series(rankings).value_counts().sort_index()
    for rank, count in rank_counts.items():
        print(f"Rank {rank}: {count} molecules")
    
    return fig

def main():
    # Example usage with different experiments
    experiments = {
        "O3 Experiment": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_expd1_5.0_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
    }
    
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Analyze all files
        results = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        if not results:
            print(f"No valid results found for {exp_label}. Please check your input files.")
            continue
        
        # Create visualization
        fig = plot_ranking_histogram(results, experiment_label=exp_label)
        plt.show()
        
        # Save results
        output_dir = Path(paths["json_directory"]) / "analysis_results"
        output_dir.mkdir(exist_ok=True)
        
        results_df = pd.DataFrame(results)
        results_df.to_csv(output_dir / "o3_analysis_results.csv", index=False)

if __name__ == "__main__":
    main()

In [None]:
def analyze_single_json(json_file: str, reference_data: Dict[str, str]) -> Dict[str, Any]:
    """
    Analyze a single JSON file's O3 predictions against ground truth.
    """
    try:
        print(f"\nProcessing file: {json_file}")  # Debug: Show which file we're processing
        
        # Load and parse JSON file
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        # Get sample ID and true SMILES
        sample_id = data["molecule_data"]["sample_id"]
        print(f"Sample ID: {sample_id}")  # Debug: Show sample ID
        
        base_sample_id = sample_id.split('_')[0]
        print(f"Base Sample ID: {base_sample_id}")  # Debug: Show base sample ID
        
        true_smiles = reference_data.get(base_sample_id)
        print(f"True SMILES found: {true_smiles is not None}")  # Debug: Show if we found true SMILES
        
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return None
        
        # Extract O3's candidates
        o3_results = data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"]["parsed_results"]
        candidates = o3_results["candidates"]
        print(f"Number of candidates found: {len(candidates)}")  # Debug: Show number of candidates
        
        # Debug: Print each candidate's confidence score before sorting
        print("\nCandidate confidence scores before sorting:")
        for i, cand in enumerate(candidates):
            confidence = cand.get("confidence_score")
            smiles = cand.get("smiles")
            print(f"Candidate {i}: confidence={confidence}, SMILES={smiles}")
        
        # Modified sorting with error checking
        valid_candidates = []
        invalid_candidates = []
        for cand in candidates:
            if "confidence_score" not in cand or cand["confidence_score"] is None:
                print(f"Warning: Invalid confidence score for candidate: {cand}")  # Debug: Show invalid candidates
                invalid_candidates.append(cand)
            else:
                valid_candidates.append(cand)
        
        # Sort only valid candidates
        sorted_candidates = sorted(valid_candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Add invalid candidates at the end if any
        sorted_candidates.extend(invalid_candidates)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand.get("smiles") == true_smiles:
                correct_position = i
                print(f"Found correct SMILES at position {i}")  # Debug: Show where we found the correct SMILES
                break
        
        if correct_position is None:
            print("Correct SMILES not found in candidates")  # Debug: Show if we didn't find the SMILES
        
        return {
            "sample_id": sample_id,
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "total_valid_candidates": len(valid_candidates),
            "total_invalid_candidates": len(invalid_candidates),
            "is_top_1": correct_position == 1,
            "is_top_3": correct_position is not None and correct_position <= 3,
            "is_top_5": correct_position is not None and correct_position <= 5,
            "is_top_10": correct_position is not None and correct_position <= 10
        }
        
    except Exception as e:
        print(f"\nDetailed error processing {json_file}:")
        print(f"Error type: {type(e)}")
        print(f"Error message: {str(e)}")
        import traceback
        print("Traceback:")
        print(traceback.format_exc())
        return None

In [None]:
analyze_single_json(json_file, reference_data)

### V1 - OLD

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os
from typing import Dict, List, Any

def load_reference_data(csv_path: str) -> Dict[str, str]:
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def analyze_llm_predictions(json_data: Dict, true_smiles: str, llm_name: str) -> Dict[str, Any]:
    """
    Analyze predictions from a specific LLM model.
    
    Args:
        json_data: Loaded JSON data
        true_smiles: True SMILES string to compare against
        llm_name: Name of the LLM model to analyze
        
    Returns:
        Dictionary with analysis results or None if analysis fails
    """
    try:
        # Extract LLM's candidates and sort by confidence
        llm_results = json_data["analysis_results"]["final_analysis"]["llm_responses"][llm_name]["parsed_results"]
        candidates = llm_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        return {
            "llm_model": llm_name,
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "is_top_1": correct_position == 1,
            "is_top_3": correct_position is not None and correct_position <= 3,
            "is_top_5": correct_position is not None and correct_position <= 5,
            "is_top_10": correct_position is not None and correct_position <= 10
        }
        
    except KeyError:
        # This LLM might not have results in this file
        return None

def analyze_single_json(json_file: str, reference_data: Dict[str, str]) -> List[Dict[str, Any]]:
    """
    Analyze a single JSON file for all LLM models.
    
    Args:
        json_file: Path to JSON file
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        List of analysis results for each LLM model
    """
    try:
        # Load and parse JSON file
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        # Get sample ID and true SMILES
        sample_id = data["molecule_data"]["sample_id"]
        base_sample_id = sample_id.split('_')[0]
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return []
        
        # List of LLM models to analyze
        llm_models = ["claude", "claude3-7", "o3", "kimi", "gemini", "deepseek"]
        
        # Analyze each LLM's predictions
        results = []
        for llm in llm_models:
            result = analyze_llm_predictions(data, true_smiles, llm)
            if result:
                result["sample_id"] = sample_id
                results.append(result)
        
        return results
        
    except Exception as e:
        print(f"Error processing {json_file}: {str(e)}")
        return []

def analyze_directory(json_dir: str, reference_csv: str) -> Dict[str, List[Dict[str, Any]]]:
    """
    Analyze all JSON files in a directory for all LLM models.
    """
    reference_data = load_reference_data(reference_csv)
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Initialize results dictionary for each LLM
    results_by_llm = {
        "claude": [],"claude3-7": [], "o3": [], "kimi": [], "gemini": [], "deepseek": []
    }
    
    for file_path in json_files:
        results = analyze_single_json(file_path, reference_data)
        for result in results:
            llm_name = result["llm_model"]
            results_by_llm[llm_name].append(result)
    
    return results_by_llm

def plot_ranking_histograms(results_by_llm: Dict[str, List[Dict[str, Any]]], 
                          experiment_label: str = "", 
                          max_rank: int = 10):
    """
    Create histograms for all LLM models.
    """
    # Create a figure with subplots
    n_models = len(results_by_llm)
    n_cols = 2
    n_rows = (n_models + 1) // 2
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5*n_rows), facecolor='white')
    axes = axes.flatten()
    
    # Color scheme for different LLMs
    colors = {
        "claude3-7": "#4169E1",  # Royal Blue
        "claude": "#4169E1",  # Royal Blue
        "o3": "#2E8B57",      # Sea Green
        "kimi": "#8B4513",    # Saddle Brown
        "gemini": "#4B0082",  # Indigo
        "deepseek": "#CD853F" # Peru
    }
    
    for idx, (llm_name, results) in enumerate(results_by_llm.items()):
        if not results:
            continue
            
        ax = axes[idx]
        ax.set_facecolor('white')
        
        # Extract rankings
        rankings = [r["correct_position"] for r in results if r["correct_position"] is not None]
        
        # Create bins
        bins = np.arange(1, max_rank + 2) - 0.5
        
        # Plot histogram
        n, bins, patches = ax.hist(
            [r for r in rankings if r <= max_rank],
            bins=bins,
            edgecolor='black',
            alpha=0.7,
            color=colors.get(llm_name, '#4169E1')
        )
        
        # Customize subplot
        title = f'{llm_name.upper()} Model Rankings'
        if experiment_label:
            title += f'\n{experiment_label}'
        ax.set_title(title, fontsize=12, pad=20)
        ax.set_xlabel('Rank', fontsize=10)
        ax.set_ylabel('Number of Molecules', fontsize=10)
        ax.set_xticks(range(1, max_rank + 1))
        ax.grid(True, alpha=0.3, color='gray', linestyle='--')
        
        # Add counts above bars
        for i, count in enumerate(n):
            if count > 0:
                ax.text(i + 1, count, f'{int(count)}', 
                       ha='center', va='bottom')
        
        # Calculate statistics
        total_analyzed = len(results)
        total_found = len(rankings)
        in_top_1 = sum(1 for r in results if r["is_top_1"])
        in_top_3 = sum(1 for r in results if r["is_top_3"])
        in_top_5 = sum(1 for r in results if r["is_top_5"])
        
        stats_text = (
            f'Total: {total_analyzed}\n'
            f'Found: {total_found}\n'
            f'Top 1: {in_top_1} ({in_top_1/total_analyzed*100:.1f}%)\n'
            f'Top 3: {in_top_3} ({in_top_3/total_analyzed*100:.1f}%)\n'
            f'Top 5: {in_top_5} ({in_top_5/total_analyzed*100:.1f}%)'
        )
        
        # Add statistics text box
        ax.text(0.95, 0.95, stats_text,
                transform=ax.transAxes,
                verticalalignment='top',
                horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                fontsize=9)
    
    # Remove empty subplots if any
    for idx in range(len(results_by_llm), len(axes)):
        fig.delaxes(axes[idx])
    
    plt.tight_layout()
    return fig

def main():
    # Example usage
    experiments = {
        "Simulated Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data d4": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Analyze all files for all LLMs
        results_by_llm = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        if not any(results_by_llm.values()):
            print(f"No valid results found for {exp_label}. Please check your input files.")
            continue
        
        # Create visualizations
        fig = plot_ranking_histograms(results_by_llm, experiment_label=exp_label)
        plt.show()
        
        # Save results
        output_dir = Path(paths["json_directory"]) / "analysis_results"
        output_dir.mkdir(exist_ok=True)
        
        # Save results for each LLM
        for llm_name, results in results_by_llm.items():
            if results:
                df = pd.DataFrame(results)
                df.to_csv(output_dir / f"{llm_name}_analysis_results.csv", index=False)
                
                # Print summary statistics
                print(f"\nSummary for {llm_name.upper()}:")
                total = len(results)
                if total > 0:
                    top_1 = sum(1 for r in results if r["is_top_1"])
                    top_3 = sum(1 for r in results if r["is_top_3"])
                    top_5 = sum(1 for r in results if r["is_top_5"])
                    print(f"Total analyzed: {total}")
                    print(f"Top 1: {top_1} ({top_1/total*100:.1f}%)")
                    print(f"Top 3: {top_3} ({top_3/total*100:.1f}%)")
                    print(f"Top 5: {top_5} ({top_5/total*100:.1f}%)")

if __name__ == "__main__":
    main()

### V2 - OLD

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os
from typing import Dict, List, Any

def analyze_single_json(json_file: str, reference_data: Dict[str, str]) -> List[Dict[str, Any]]:
    """
    Analyze a single JSON file for all LLM models.
    
    Args:
        json_file: Path to JSON file
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        List of analysis results for each LLM model
    """
    try:
        # Load and parse JSON file
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        # Get sample ID and true SMILES
        sample_id = data["molecule_data"]["sample_id"]
        base_sample_id = sample_id.split('_')[0]
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return []
        
        # List of LLM models to analyze
        llm_models = ["claude", "claude3-7", "o3", "kimi", "gemini", "deepseek"]
        
        # Analyze each LLM's predictions
        results = []
        for llm in llm_models:
            result = analyze_llm_predictions(data, true_smiles, llm)
            if result:
                result["sample_id"] = sample_id
                results.append(result)
        
        return results
        
    except Exception as e:
        print(f"Error processing {json_file}: {str(e)}")
        return []

    
    
def analyze_directory(json_dir: str, reference_csv: str) -> Dict[str, List[Dict[str, Any]]]:
    """
    Analyze all JSON files in a directory for all LLM models.
    """
    reference_data = load_reference_data(reference_csv)
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Initialize results dictionary for each LLM
    results_by_llm = {
        "claude": [], "claude3-7": [], "o3": [], "kimi": [], "gemini": [], "deepseek": []
    }
    
    for file_path in json_files:
        results = analyze_single_json(file_path, reference_data)
        for result in results:
            llm_name = result["llm_model"]
            results_by_llm[llm_name].append(result)
    
    return results_by_llm

def plot_ranking_histograms(results_by_llm: Dict[str, List[Dict[str, Any]]], 
                          experiment_label: str = ""):
    """Create histograms for all LLM models."""
    n_models = len(results_by_llm)
    n_cols = 2
    n_rows = (n_models + 1) // 2
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5*n_cols, 5*n_rows), facecolor='white')
    axes = axes.flatten()
    
    colors = {
        "claude": "#4169E1",      # Royal Blue
        "claude3-7": "#1E90FF",   # Dodger Blue for Claude 3.7
        "o3": "#2E8B57",          # Sea Green
        "kimi": "#8B4513",        # Saddle Brown
        "gemini": "#4B0082",      # Indigo
        "deepseek": "#CD853F"     # Peru
    }
    
    for idx, (llm_name, results) in enumerate(results_by_llm.items()):
        if not results:
            continue
            
        ax = axes[idx]
        ax.set_facecolor('white')
        
        rankings = [r["correct_position"] for r in results if r["correct_position"] is not None]
        bins = np.arange(1, 6 + 1) - 0.5
        
        # Create histogram
        n, bins, patches = ax.hist(
            [r for r in rankings if r <= 5],
            bins=bins,
            edgecolor='black',
            alpha=0.7,
            color=colors.get(llm_name, '#4169E1'),
            label=llm_name.upper()  # Add label for the histogram
        )
        
        # Calculate statistics for legend
        total_analyzed = len(results)
        in_top_5 = sum(1 for r in results if r["is_top_5"])
        top_5_percent = in_top_5/total_analyzed*100 if total_analyzed > 0 else 0
        
        # Add custom legend entry
        ax.plot([], [], color=colors.get(llm_name, '#4169E1'), alpha=0.7,
                label=f'Top 5: {in_top_5}/{total_analyzed} ({top_5_percent:.1f}%)')
        
        # Set title (now only showing LLM name and experiment label)
        title = llm_name.upper()
        if experiment_label:
            title = f'{experiment_label}\n{title}'
        ax.set_title(title, fontsize=16, pad=20)
        
        # Set labels and ticks
        ax.set_xlabel('Rank', fontsize=14)
        ax.set_ylabel('Number of Molecules', fontsize=14)
        ax.set_xticks(range(1, 6))
        ax.tick_params(axis='both', which='major', labelsize=12)
        ax.grid(True, alpha=0.3, color='gray', linestyle='--')
        
        # Add counts above bars
        for i, count in enumerate(n):
            if count > 0:
                ax.text(i + 1, count, f'{int(count)}', 
                       ha='center', va='bottom', fontsize=12)
        
        # Add legend with larger font
        #ax.legend(fontsize=12, loc='upper right')
    
    # Remove empty subplots
    for idx in range(len(results_by_llm), len(axes)):
        fig.delaxes(axes[idx])
    
    plt.tight_layout()
    return fig

def main():
    experiments = {
        "Simulated Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data d4": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
        
    
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        results_by_llm = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        if not any(results_by_llm.values()):
            print(f"No valid results found for {exp_label}. Please check your input files.")
            continue
        
        fig = plot_ranking_histograms(results_by_llm, experiment_label=exp_label)
        plt.show()
        
        # Print summary statistics
        for llm_name, results in results_by_llm.items():
            if results:
                total = len(results)
                in_top_5 = sum(1 for r in results if r["is_top_5"])
                print(f"\n{llm_name.upper()}:")
                print(f"Top 5: {in_top_5}/{total} ({in_top_5/total*100:.1f}%)")

if __name__ == "__main__":
    main()

### V2.2

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os
from typing import Dict, List, Any

def load_reference_data(csv_path: str) -> Dict[str, str]:
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id: str) -> str:
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def analyze_llm_predictions(json_data: Dict, true_smiles: str, llm_name: str) -> Dict[str, Any]:
    """
    Analyze predictions from a specific LLM model.
    
    Args:
        json_data: Loaded JSON data
        true_smiles: True SMILES string to compare against
        llm_name: Name of the LLM model to analyze
        
    Returns:
        Dictionary with analysis results or None if analysis fails
    """
    try:
        # Extract LLM's candidates and sort by confidence
        llm_results = json_data["analysis_results"]["final_analysis"]["llm_responses"][llm_name]["parsed_results"]
        candidates = llm_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        return {
            "llm_model": llm_name,
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "is_top_1": correct_position == 1 if correct_position else False,
            "is_top_5": correct_position is not None and correct_position <= 5,
            "is_top_10": correct_position is not None and correct_position <= 10,
            "is_after_top_10": correct_position is not None and correct_position > 10
        }
        
    except KeyError:
        # This LLM might not have results in this file
        return None

def analyze_single_json(json_file: str, reference_data: Dict[str, str]) -> List[Dict[str, Any]]:
    """
    Analyze a single JSON file for all LLM models.
    
    Args:
        json_file: Path to JSON file
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        List of analysis results for each LLM model
    """
    try:
        # Load and parse JSON file
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        # Get sample ID and true SMILES
        sample_id = data.get("molecule_data", {}).get("sample_id")
        if not sample_id:
            return []
            
        base_sample_id = get_base_sample_id(sample_id)
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return []
        
        # List of LLM models to analyze
        llm_models = ["claude", "claude3-7", "o3", "kimi", "gemini", "deepseek"]
        
        # Analyze each LLM's predictions
        results = []
        for llm in llm_models:
            result = analyze_llm_predictions(data, true_smiles, llm)
            if result:
                result["sample_id"] = sample_id
                results.append(result)
        
        return results
        
    except Exception as e:
        print(f"Error processing {json_file}: {str(e)}")
        return []

def analyze_directory(json_dir: str, reference_csv: str) -> Dict[str, List[Dict[str, Any]]]:
    """
    Analyze all JSON files in a directory for all LLM models.
    """
    reference_data = load_reference_data(reference_csv)
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Initialize results dictionary for each LLM
    results_by_llm = {
        "claude": [], "claude3-7": [], "o3": [], "kimi": [], "gemini": [], "deepseek": []
    }
    
    for file_path in json_files:
        results = analyze_single_json(file_path, reference_data)
        for result in results:
            llm_name = result["llm_model"]
            results_by_llm[llm_name].append(result)
    
    return results_by_llm

def get_display_name(llm_name):
    """Convert internal model names to display-friendly names."""
    name_mapping = {
        "claude": "Claude 3.5",
        "claude3-7": "Claude 3.7 Thinking",
        "o3": "o3 mini high",
        "kimi": "Kimi 1.5",
        "gemini": "Gemini 2.0 Thinking",
        "deepseek": "DeepSeek R1"
    }
    return name_mapping.get(llm_name, llm_name)

def plot_llm_ranking_histogram(rankings: List[int], llm_name: str, experiment_label: str, 
                             max_rank: int = 5, color: str = "#4169E1", figsize: tuple = (5.5, 5.5)):
    """
    Create histogram of molecule rankings for an LLM model with an extra bin for ranks beyond max_rank.
    
    Args:
        rankings: List of rankings for correct molecules
        llm_name: Name of the LLM model
        experiment_label: Label for the experiment
        max_rank: Maximum individual rank to show in histogram (default: 5)
        color: Color for histogram bars
        figsize: Size of the figure
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=figsize, facecolor='white')
    ax.set_facecolor('white')
    
    # Prepare data for histogram
    rank_counts = {}
    
    # Count ranks 1-5 individually
    for r in range(1, max_rank + 1):
        rank_counts[r] = sum(1 for rank in rankings if rank == r)
    
    # Count all ranks > max_rank together
    rank_counts['6+'] = sum(1 for rank in rankings if rank > max_rank)
    
    # Plot the histogram
    positions = list(range(1, max_rank + 1)) + [max_rank + 1]
    counts = [rank_counts[r] if r <= max_rank else rank_counts['6+'] for r in positions]
    
    bars = ax.bar(
        positions,
        counts,
        width=0.8,
        edgecolor='black',
        alpha=0.7,
        color=color
    )
    
    # Get display-friendly model name
    display_name = get_display_name(llm_name)
    
    # Create title with model name and experiment label
    title = f"{display_name} - {experiment_label}"
    ax.set_title(title, fontsize=16, pad=10)
    
    ax.set_xlabel('Rank', fontsize=14)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Set x-axis ticks and labels
    ax.set_xticks(positions)
    x_labels = [str(i) for i in range(1, max_rank + 1)] + ['6+']
    ax.set_xticklabels(x_labels, fontsize=12)
    ax.tick_params(axis='y', labelsize=12)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Calculate the maximum y value and add padding (20%)
    max_count = max(counts) if counts else 0
    y_max = 35 # Add 20% padding above the highest bar
    ax.set_ylim(0, y_max)
    
    # Add counts above bars (not bold)
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax.text(
                bar.get_x() + bar.get_width()/2, 
                height + (y_max * 0.01),  # Adjust text position based on y-axis range
                f'{int(height)}',
                ha='center', 
                va='bottom',
                fontsize=12
            )
    
    # Calculate statistics
    total_molecules = len(rankings)
    in_top_1 = sum(1 for r in rankings if r == 1)
    in_top_3 = sum(1 for r in rankings if r <= 3)
    in_top_5 = sum(1 for r in rankings if r <= 5)
    after_top_5 = sum(1 for r in rankings if r > 5)
    
    stats_text = (
        f'Total molecules found: {total_molecules}\n'
        f'Found in top 1: {in_top_1} ({in_top_1/total_molecules*100:.1f}%)\n'
        f'Found in top 3: {in_top_3} ({in_top_3/total_molecules*100:.1f}%)\n'
        f'Found in top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)\n'
        f'Found after top 5: {after_top_5} ({after_top_5/total_molecules*100:.1f}%)'
    )
    
    # Add statistics text box
    ax.text(0.95, 0.95, stats_text,
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            fontsize=12,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    return fig

def analyze_and_plot_llm_results(llm_models_to_analyze=['deepseek']):
    """
    Analyze and plot results for specified LLM models.
    
    Args:
        llm_models_to_analyze: List of LLM models to analyze
    """
    # Check if 'all' is in the list of LLMs to analyze
    if 'all' in llm_models_to_analyze:
        llm_models_to_analyze = ['claude', 'claude3-7', 'o3', 'kimi', 'gemini', 'deepseek']
    
    # Example usage with different experiments
    experiments = {
        "Sim Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Sim Data + Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Sim Data + Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data + Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data d4": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    # Define colors for each experiment (using the colors from the model list)
    experiment_colors = {
        "Sim Data": "#6366F1",                   # Claude 3.5 Sonnet
        "Sim Data + Wrong Guess": "#3B82F6",  # Claude 3.7 Sonnet-Thinking
        "Sim Data + Noise": "#10B981",        # DeepSeek-R1
        "Exp Data": "#F59E0B",                # Gemini-Thinking
        "Exp Data + Wrong Guess": "#EC4899", # o3-mini
        "Exp Data d4": "#8B5CF6"              # Kimi 1.5
    }
    
    # Define colors for each LLM (if you want to color by LLM instead of experiment)
    llm_colors = {
        "claude": "#4169E1",      # Royal Blue
        "claude3-7": "#1E90FF",   # Dodger Blue
        "o3": "#2E8B57",          # Sea Green
        "kimi": "#8B4513",        # Saddle Brown
        "gemini": "#4B0082",      # Indigo
        "deepseek": "#CD853F"     # Peru
    }
    
    # Store summary statistics for all LLMs and experiments
    summary_stats = {}
    
    # Process each experiment
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get results for all LLMs for this experiment
        results_by_llm = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        # Process selected LLMs
        for llm_name in llm_models_to_analyze:
            if llm_name not in results_by_llm or not results_by_llm[llm_name]:
                print(f"No valid results found for {llm_name} in {exp_label}.")
                continue
            
            # Extract rankings
            rankings = [r["correct_position"] for r in results_by_llm[llm_name] 
                        if r["correct_position"] is not None]
            
            if not rankings:
                print(f"No rankings found for {llm_name} in {exp_label}.")
                continue
            
            # Create and show individual plot
            fig = plot_llm_ranking_histogram(
                rankings,
                llm_name,
                exp_label, 
                color=experiment_colors[exp_label],  # Color by experiment
                figsize=(6, 6)  # Slightly larger figure size
            )
            plt.show()
            
            # Save the plot with display name
            display_name = get_display_name(llm_name)
            fig.savefig(f'ranking_histogram_{exp_label}_{display_name.replace(" ", "_")}.png', 
                      dpi=300, bbox_inches='tight')
            
            # Store statistics
            if exp_label not in summary_stats:
                summary_stats[exp_label] = {}
                
            total = len(rankings)
            in_top_1 = sum(1 for r in rankings if r == 1)
            in_top_5 = sum(1 for r in rankings if r <= 5)
            in_top_10 = sum(1 for r in rankings if r <= 10)
            
            summary_stats[exp_label][llm_name] = {
                "total": total,
                "in_top_1": in_top_1,
                "in_top_5": in_top_5,
                "in_top_10": in_top_10,
                "percent_top_1": in_top_1/total*100 if total > 0 else 0,
                "percent_top_5": in_top_5/total*100 if total > 0 else 0,
                "percent_top_10": in_top_10/total*100 if total > 0 else 0
            }
    
    # Print summary table with display names
    print("\n===== SUMMARY STATISTICS =====")
    for exp_label, llm_stats in summary_stats.items():
        print(f"\n{exp_label}:")
        print(f"{'Model':<25} {'Total':<8} {'Top 1':<12} {'Top 5':<12} {'Top 10':<12}")
        print("-" * 70)
        
        for llm_name, stats in llm_stats.items():
            display_name = get_display_name(llm_name)
            print(f"{display_name:<25} {stats['total']:<8} "
                  f"{stats['in_top_1']} ({stats['percent_top_1']:.1f}%) "
                  f"{stats['in_top_5']} ({stats['percent_top_5']:.1f}%) "
                  f"{stats['in_top_10']} ({stats['percent_top_10']:.1f}%)")
    
    return summary_stats

# Example of how to call the function with various LLM selections
# Run this in a Jupyter cell:

# Analyze only deepseek
# analyze_and_plot_llm_results()

# Analyze specific LLMs
# analyze_and_plot_llm_results(['deepseek', 'claude3-7'])

# Analyze all LLMs with updated display names
analyze_and_plot_llm_results(['deepseek', 'claude', 'claude3-7', 'o3', 'kimi', 'gemini'])

### V3 Grid

In [None]:
def plot_llm_grid(experiment_label, results_by_llm, llm_models_to_analyze, experiment_colors):
    """
    Create a 3×2 grid of histograms for all LLM models for a single experiment.
    
    Args:
        experiment_label: Label of the current experiment
        results_by_llm: Dictionary of results for each LLM
        llm_models_to_analyze: List of LLM models to include in the grid
        experiment_colors: Dictionary mapping experiment labels to colors
    
    Returns:
        Matplotlib figure with the grid of histograms
    """
    # Create a 3×2 grid of subplots
    fig, axes = plt.subplots(3, 2, figsize=(12, 16), facecolor='white')
    axes = axes.flatten()  # Flatten to make indexing easier
    
    color = experiment_colors[experiment_label]
    
    # Plot each LLM in its own subplot
    for i, llm_name in enumerate(llm_models_to_analyze):
        if i >= len(axes):  # Safety check
            break
            
        if llm_name not in results_by_llm or not results_by_llm[llm_name]:
            # Skip if no data, but keep the subplot
            axes[i].text(0.5, 0.5, f"No data for {get_display_name(llm_name)}", 
                       ha='center', va='center', fontsize=14)
            axes[i].axis('off')
            continue
            
        # Extract rankings
        rankings = [r["correct_position"] for r in results_by_llm[llm_name] 
                   if r["correct_position"] is not None]
        
        if not rankings:
            # Skip if no rankings
            axes[i].text(0.5, 0.5, f"No rankings for {get_display_name(llm_name)}", 
                       ha='center', va='center', fontsize=14)
            axes[i].axis('off')
            continue
            
        # Set current axis
        ax = axes[i]
        ax.set_facecolor('white')
        
        # Prepare data for histogram
        max_rank = 5
        rank_counts = {}
        
        # Count ranks 1-5 individually
        for r in range(1, max_rank + 1):
            rank_counts[r] = sum(1 for rank in rankings if rank == r)
        
        # Count all ranks > max_rank together
        rank_counts['6+'] = sum(1 for rank in rankings if rank > max_rank)
        
        # Plot the histogram
        positions = list(range(1, max_rank + 1)) + [max_rank + 1]
        counts = [rank_counts[r] if r <= max_rank else rank_counts['6+'] for r in positions]
        
        bars = ax.bar(
            positions,
            counts,
            width=0.8,
            edgecolor='black',
            alpha=0.7,
            color=color
        )
        
        # Get display-friendly model name
        display_name = get_display_name(llm_name)
        
        # Create title with only the model name (removed experiment label)
        # and increased font size from 14 to 20
        title = f"{display_name}"
        ax.set_title(title, fontsize=20, pad=10)
        
        ax.set_xlabel('Rank', fontsize=16)
        ax.set_ylabel('Number of Molecules', fontsize=16)
        
        # Set x-axis ticks and labels
        ax.set_xticks(positions)
        x_labels = [str(i) for i in range(1, max_rank + 1)] + ['6+']
        ax.set_xticklabels(x_labels, fontsize=14)
        ax.tick_params(axis='y', labelsize=14)
        
        # Add grid with light gray color
        ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
        
        # Add counts above bars
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(
                    bar.get_x() + bar.get_width()/2, 
                    height + 0.1, 
                    f'{int(height)}',
                    ha='center', 
                    va='bottom',
                    fontsize=14
                )
        
        # Calculate statistics
        total_molecules = len(rankings)
        in_top_1 = sum(1 for r in rankings if r == 1)
        in_top_3 = sum(1 for r in rankings if r <= 3)
        in_top_5 = sum(1 for r in rankings if r <= 5)
        after_top_5 = sum(1 for r in rankings if r > 5)
        
        stats_text = (
            f'Total: {total_molecules}\n'
            f'Top 1: {in_top_1} ({in_top_1/total_molecules*100:.1f}%)\n'
            f'Top 3: {in_top_3} ({in_top_3/total_molecules*100:.1f}%)\n'
            f'Top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)'
        )
        
        # Add statistics text box
        ax.text(0.95, 0.95, stats_text,
                transform=ax.transAxes,
                verticalalignment='top',
                horizontalalignment='right',
                fontsize=14,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.3, wspace=0.3)
    
    return fig

In [None]:
def analyze_and_plot_llm_results(llm_models_to_analyze=['all']):
    """
    Analyze and plot results for specified LLM models.
    
    Args:
        llm_models_to_analyze: List of LLM models to analyze
    """
    # Check if 'all' is in the list of LLMs to analyze
    if 'all' in llm_models_to_analyze:
        llm_models_to_analyze = ['claude', 'claude3-7', 'o3', 'kimi', 'gemini', 'deepseek']
    
    # Define output directory for figures
    output_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures"
    
    # Create output directory if it doesn't exist
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Example usage with different experiments
    experiments = {
        "Sim Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Sim Data + Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Sim Data + Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data + Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data d4": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    # Define colors for each experiment (using the colors from the model list)
    experiment_colors = {
        "Sim Data": "#6366F1",                   # Claude 3.5 Sonnet
        "Sim Data + Wrong Guess": "#3B82F6",     # Claude 3.7 Sonnet-Thinking
        "Sim Data + Noise": "#10B981",           # DeepSeek-R1
        "Exp Data": "#F59E0B",                   # Gemini-Thinking
        "Exp Data + Wrong Guess": "#EC4899",     # o3-mini
        "Exp Data d4": "#8B5CF6"                 # Kimi 1.5
    }

    
    # Store summary statistics for all LLMs and experiments
    summary_stats = {}
    
    # Process each experiment
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get results for all LLMs for this experiment
        results_by_llm = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        # Create grid plot for all models for this experiment
        grid_fig = plot_llm_grid(exp_label, results_by_llm, llm_models_to_analyze, experiment_colors)
        
        # Show the grid plot
        plt.figure(grid_fig.number)
        plt.show()
        
        # Save the grid plot to the specified output directory
        grid_fig_path = os.path.join(output_dir, f'grid_histogram_{exp_label.replace(" ", "_")}.png')
        grid_fig.savefig(grid_fig_path, dpi=300, bbox_inches='tight')
        print(f"Saved grid figure to: {grid_fig_path}")
        
        # Store statistics for summary
        if exp_label not in summary_stats:
            summary_stats[exp_label] = {}
            
        # Process individual models for statistics
        for llm_name in llm_models_to_analyze:
            if llm_name not in results_by_llm or not results_by_llm[llm_name]:
                continue
                
            # Extract rankings
            rankings = [r["correct_position"] for r in results_by_llm[llm_name] 
                       if r["correct_position"] is not None]
            
            if not rankings:
                continue
                
            # Create and save individual plot (optional)
            fig = plot_llm_ranking_histogram(
                rankings,
                llm_name,
                exp_label, 
                color=experiment_colors[exp_label],
                figsize=(6, 6)
            )
            
            display_name = get_display_name(llm_name)
            individual_fig_path = os.path.join(
                output_dir, 
                f'ranking_histogram_{exp_label.replace(" ", "_")}_{display_name.replace(" ", "_")}.png'
            )
            fig.savefig(individual_fig_path, dpi=300, bbox_inches='tight')
            print(f"Saved individual figure to: {individual_fig_path}")
            plt.close(fig)  # Close individual figure to save memory
            
            # Store statistics
            total = len(rankings)
            in_top_1 = sum(1 for r in rankings if r == 1)
            in_top_5 = sum(1 for r in rankings if r <= 5)
            in_top_10 = sum(1 for r in rankings if r <= 10)
            
            summary_stats[exp_label][llm_name] = {
                "total": total,
                "in_top_1": in_top_1,
                "in_top_5": in_top_5,
                "in_top_10": in_top_10,
                "percent_top_1": in_top_1/total*100 if total > 0 else 0,
                "percent_top_5": in_top_5/total*100 if total > 0 else 0,
                "percent_top_10": in_top_10/total*100 if total > 0 else 0
            }
    
    # Print summary table with display names
    print("\n===== SUMMARY STATISTICS =====")
    for exp_label, llm_stats in summary_stats.items():
        print(f"\n{exp_label}:")
        print(f"{'Model':<25} {'Total':<8} {'Top 1':<12} {'Top 5':<12} {'Top 10':<12}")
        print("-" * 70)
        
        for llm_name, stats in llm_stats.items():
            display_name = get_display_name(llm_name)
            print(f"{display_name:<25} {stats['total']:<8} "
                  f"{stats['in_top_1']} ({stats['percent_top_1']:.1f}%) "
                  f"{stats['in_top_5']} ({stats['percent_top_5']:.1f}%) "
                  f"{stats['in_top_10']} ({stats['percent_top_10']:.1f}%)")
    
    # Save summary statistics to a CSV file
    try:
        import pandas as pd
        summary_data = []
        
        for exp_label, llm_stats in summary_stats.items():
            for llm_name, stats in llm_stats.items():
                display_name = get_display_name(llm_name)
                summary_data.append({
                    'Experiment': exp_label,
                    'Model': display_name,
                    'Total': stats['total'],
                    'Top 1': stats['in_top_1'],
                    'Top 1 %': stats['percent_top_1'],
                    'Top 5': stats['in_top_5'],
                    'Top 5 %': stats['percent_top_5'],
                    'Top 10': stats['in_top_10'],
                    'Top 10 %': stats['percent_top_10']
                })
        
        if summary_data:
            summary_df = pd.DataFrame(summary_data)
            summary_csv_path = os.path.join(output_dir, 'summary_statistics.csv')
            summary_df.to_csv(summary_csv_path, index=False)
            print(f"\nSaved summary statistics to: {summary_csv_path}")
    except ImportError:
        print("pandas not available. Summary statistics not saved to CSV.")
    except Exception as e:
        print(f"Error saving summary statistics: {e}")
    
    return summary_stats

# Run the analysis with all models
analyze_and_plot_llm_results()

### V3.1 Grid

In [None]:
def plot_llm_grid(experiment_label, results_by_llm, llm_models_to_analyze, experiment_colors):
    """
    Create a 3×2 grid of histograms for all LLM models for a single experiment.
    
    Args:
        experiment_label: Label of the current experiment
        results_by_llm: Dictionary of results for each LLM
        llm_models_to_analyze: List of LLM models to include in the grid
        experiment_colors: Dictionary mapping experiment labels to colors
    
    Returns:
        Matplotlib figure with the grid of histograms
    """
    # Create a 3×2 grid of subplots
    fig, axes = plt.subplots(3, 2, figsize=(12, 16), facecolor='white')
    axes = axes.flatten()  # Flatten to make indexing easier
    
    color = experiment_colors[experiment_label]
    
    # Plot each LLM in its own subplot
    for i, llm_name in enumerate(llm_models_to_analyze):
        if i >= len(axes):  # Safety check
            break
            
        if llm_name not in results_by_llm or not results_by_llm[llm_name]:
            # Skip if no data, but keep the subplot
            axes[i].text(0.5, 0.5, f"No data for {get_display_name(llm_name)}", 
                       ha='center', va='center', fontsize=14)
            axes[i].axis('off')
            continue
            
        # Extract rankings
        rankings = [r["correct_position"] for r in results_by_llm[llm_name] 
                   if r["correct_position"] is not None]
        
        if not rankings:
            # Skip if no rankings
            axes[i].text(0.5, 0.5, f"No rankings for {get_display_name(llm_name)}", 
                       ha='center', va='center', fontsize=14)
            axes[i].axis('off')
            continue
            
        # Set current axis
        ax = axes[i]
        ax.set_facecolor('white')
        
        # Prepare data for histogram
        max_rank = 5
        rank_counts = {}
        
        # Count ranks 1-5 individually
        for r in range(1, max_rank + 1):
            rank_counts[r] = sum(1 for rank in rankings if rank == r)
        
        # Count all ranks > max_rank together
        rank_counts['6+'] = sum(1 for rank in rankings if rank > max_rank)
        
        # Plot the histogram
        positions = list(range(1, max_rank + 1)) + [max_rank + 1]
        counts = [rank_counts[r] if r <= max_rank else rank_counts['6+'] for r in positions]
        
        bars = ax.bar(
            positions,
            counts,
            width=0.8,
            edgecolor='black',
            alpha=0.7,
            color=color
        )
        
        # Get display-friendly model name
        display_name = get_display_name(llm_name)
        
        # Create title with only the model name (removed experiment label)
        # and increased font size from 14 to 20
        title = f"{display_name}"
        ax.set_title(title, fontsize=20, pad=10)
        
        ax.set_xlabel('Rank', fontsize=16)
        ax.set_ylabel('Number of Molecules', fontsize=16)
        
        # Set x-axis ticks and labels
        ax.set_xticks(positions)
        x_labels = [str(i) for i in range(1, max_rank + 1)] + ['6+']
        ax.set_xticklabels(x_labels, fontsize=14)
        ax.tick_params(axis='y', labelsize=14)
        
        # Add grid with light gray color
        ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
        
        # Set fixed y-axis maximum to 34
        y_max = 35
        ax.set_ylim(0, y_max)
        
        # Add counts above bars with fixed positioning
        for bar in bars:
            height = bar.get_height()
            if height > 0:
                ax.text(
                    bar.get_x() + bar.get_width()/2, 
                    height + 0.7,  # Fixed offset for consistent positioning
                    f'{int(height)}',
                    ha='center', 
                    va='bottom',
                    fontsize=14
                )
        
        # Calculate statistics
        total_molecules = len(rankings)
        in_top_1 = sum(1 for r in rankings if r == 1)
        in_top_3 = sum(1 for r in rankings if r <= 3)
        in_top_5 = sum(1 for r in rankings if r <= 5)
        after_top_5 = sum(1 for r in rankings if r > 5)
        
        stats_text = (
            f'Total: {total_molecules}\n'
            f'Top 1: {in_top_1} ({in_top_1/total_molecules*100:.1f}%)\n'
            f'Top 3: {in_top_3} ({in_top_3/total_molecules*100:.1f}%)\n'
            f'Top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)'
        )
        
        # Add statistics text box
        ax.text(0.95, 0.95, stats_text,
                transform=ax.transAxes,
                verticalalignment='top',
                horizontalalignment='right',
                fontsize=14,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Add main title for the entire figure with the experiment label
    fig.suptitle(f"{experiment_label}", fontsize=24, y=0.98)
    
    plt.tight_layout()
    plt.subplots_adjust(hspace=0.3, wspace=0.3, top=0.95)  # Adjusted top to make room for suptitle
    
    return fig

In [None]:
def analyze_and_plot_llm_results(llm_models_to_analyze=['all']):
    """
    Analyze and plot results for specified LLM models.
    
    Args:
        llm_models_to_analyze: List of LLM models to analyze
    """
    # Check if 'all' is in the list of LLMs to analyze
    if 'all' in llm_models_to_analyze:
        llm_models_to_analyze = ['claude', 'claude3-7', 'o3', 'kimi', 'gemini', 'deepseek']
    
    # Define output directory for figures
    output_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures"
    
    # Create output directory if it doesn't exist
    import os
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Example usage with different experiments
    experiments = {
        "Sim Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Sim Data + Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Sim Data + Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data + Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Exp Data d4": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    # Define colors for each experiment (using the colors from the model list)
    experiment_colors = {
        "Sim Data": "#6366F1",                   # Claude 3.5 Sonnet
        "Sim Data + Wrong Guess": "#3B82F6",     # Claude 3.7 Sonnet-Thinking
        "Sim Data + Noise": "#10B981",           # DeepSeek-R1
        "Exp Data": "#F59E0B",                   # Gemini-Thinking
        "Exp Data + Wrong Guess": "#EC4899",     # o3-mini
        "Exp Data d4": "#8B5CF6"                 # Kimi 1.5
    }
    experiment_colors = {
    "Sim Data": "#6366F1",                   # Same as "Simulated Data ALL"
    "Sim Data + Wrong Guess": "#3B82F6",     # Same as "Simulated Data ALL with Wrong Guess"
    "Sim Data + Noise": "#10B981",           # Same as "Simulated Data ALL with Noise"
    "Exp Data": "#84CC16",                   # Matched with "Experimental Data HSQC" (new lime color)
    "Exp Data + Wrong Guess": "#F97316",     # Matched with "Experimental Data HSQC with Wrong Guess" (new orange color)
    "Exp Data d4": "#8B5CF6"                 # Same as "Experimental Data ALL"
    }
    
    # Store summary statistics for all LLMs and experiments
    summary_stats = {}
    
    # Process each experiment
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get results for all LLMs for this experiment
        results_by_llm = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        # Create grid plot for all models for this experiment
        grid_fig = plot_llm_grid(exp_label, results_by_llm, llm_models_to_analyze, experiment_colors)
        
        # Show the grid plot
        plt.figure(grid_fig.number)
        plt.show()
        
        # Save the grid plot to the specified output directory
        grid_fig_path = os.path.join(output_dir, f'grid_histogram_{exp_label.replace(" ", "_")}.png')
        grid_fig.savefig(grid_fig_path, dpi=300, bbox_inches='tight')
        print(f"Saved grid figure to: {grid_fig_path}")
        
        # Store statistics for summary
        if exp_label not in summary_stats:
            summary_stats[exp_label] = {}
            
        # Process individual models for statistics
        for llm_name in llm_models_to_analyze:
            if llm_name not in results_by_llm or not results_by_llm[llm_name]:
                continue
                
            # Extract rankings
            rankings = [r["correct_position"] for r in results_by_llm[llm_name] 
                       if r["correct_position"] is not None]
            
            if not rankings:
                continue
                
            # Create and save individual plot (optional)
            fig = plot_llm_ranking_histogram(
                rankings,
                llm_name,
                exp_label, 
                color=experiment_colors[exp_label],
                figsize=(6, 6)
            )
            
            display_name = get_display_name(llm_name)
            individual_fig_path = os.path.join(
                output_dir, 
                f'ranking_histogram_{exp_label.replace(" ", "_")}_{display_name.replace(" ", "_")}.png'
            )
            fig.savefig(individual_fig_path, dpi=300, bbox_inches='tight')
            print(f"Saved individual figure to: {individual_fig_path}")
            plt.close(fig)  # Close individual figure to save memory
            
            # Store statistics
            total = len(rankings)
            in_top_1 = sum(1 for r in rankings if r == 1)
            in_top_5 = sum(1 for r in rankings if r <= 5)
            in_top_10 = sum(1 for r in rankings if r <= 10)
            
            summary_stats[exp_label][llm_name] = {
                "total": total,
                "in_top_1": in_top_1,
                "in_top_5": in_top_5,
                "in_top_10": in_top_10,
                "percent_top_1": in_top_1/total*100 if total > 0 else 0,
                "percent_top_5": in_top_5/total*100 if total > 0 else 0,
                "percent_top_10": in_top_10/total*100 if total > 0 else 0
            }
    
    # Print summary table with display names
    print("\n===== SUMMARY STATISTICS =====")
    for exp_label, llm_stats in summary_stats.items():
        print(f"\n{exp_label}:")
        print(f"{'Model':<25} {'Total':<8} {'Top 1':<12} {'Top 5':<12} {'Top 10':<12}")
        print("-" * 70)
        
        for llm_name, stats in llm_stats.items():
            display_name = get_display_name(llm_name)
            print(f"{display_name:<25} {stats['total']:<8} "
                  f"{stats['in_top_1']} ({stats['percent_top_1']:.1f}%) "
                  f"{stats['in_top_5']} ({stats['percent_top_5']:.1f}%) "
                  f"{stats['in_top_10']} ({stats['percent_top_10']:.1f}%)")
    
    # Save summary statistics to a CSV file
    try:
        import pandas as pd
        summary_data = []
        
        for exp_label, llm_stats in summary_stats.items():
            for llm_name, stats in llm_stats.items():
                display_name = get_display_name(llm_name)
                summary_data.append({
                    'Experiment': exp_label,
                    'Model': display_name,
                    'Total': stats['total'],
                    'Top 1': stats['in_top_1'],
                    'Top 1 %': stats['percent_top_1'],
                    'Top 5': stats['in_top_5'],
                    'Top 5 %': stats['percent_top_5'],
                    'Top 10': stats['in_top_10'],
                    'Top 10 %': stats['percent_top_10']
                })
        
        if summary_data:
            summary_df = pd.DataFrame(summary_data)
            summary_csv_path = os.path.join(output_dir, 'summary_statistics.csv')
            summary_df.to_csv(summary_csv_path, index=False)
            print(f"\nSaved summary statistics to: {summary_csv_path}")
    except ImportError:
        print("pandas not available. Summary statistics not saved to CSV.")
    except Exception as e:
        print(f"Error saving summary statistics: {e}")
    
    return summary_stats

# Run the analysis with all models
analyze_and_plot_llm_results()

### V4 just Deepseek

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os
from typing import Dict, List, Any

def load_reference_data(csv_path: str) -> Dict[str, str]:
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id: str) -> str:
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def analyze_llm_predictions(json_data: Dict, true_smiles: str, llm_name: str) -> Dict[str, Any]:
    """
    Analyze predictions from a specific LLM model.
    
    Args:
        json_data: Loaded JSON data
        true_smiles: True SMILES string to compare against
        llm_name: Name of the LLM model to analyze
        
    Returns:
        Dictionary with analysis results or None if analysis fails
    """
    try:
        # Extract LLM's candidates and sort by confidence
        llm_results = json_data["analysis_results"]["final_analysis"]["llm_responses"][llm_name]["parsed_results"]
        candidates = llm_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        return {
            "llm_model": llm_name,
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "is_top_1": correct_position == 1 if correct_position else False,
            "is_top_5": correct_position is not None and correct_position <= 5,
            "is_top_10": correct_position is not None and correct_position <= 10,
            "is_after_top_10": correct_position is not None and correct_position > 10
        }
        
    except KeyError:
        # This LLM might not have results in this file
        return None

def analyze_single_json(json_file: str, reference_data: Dict[str, str]) -> List[Dict[str, Any]]:
    """
    Analyze a single JSON file for DeepSeek model.
    
    Args:
        json_file: Path to JSON file
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        List of analysis results for DeepSeek model
    """
    try:
        # Load and parse JSON file
        with open(json_file, 'r') as f:
            data = json.load(f)
        
        # Get sample ID and true SMILES
        sample_id = data.get("molecule_data", {}).get("sample_id")
        if not sample_id:
            return []
            
        base_sample_id = get_base_sample_id(sample_id)
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            # print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return []
        
        # Only analyze DeepSeek
        llm_name = "deepseek"
        
        # Analyze DeepSeek's predictions
        result = analyze_llm_predictions(data, true_smiles, llm_name)
        if result:
            result["sample_id"] = sample_id
            return [result]
        
        return []
        
    except Exception as e:
        # print(f"Error processing {json_file}: {str(e)}")
        return []

def analyze_directory(json_dir: str, reference_csv: str) -> Dict[str, List[Dict[str, Any]]]:
    """
    Analyze all JSON files in a directory for DeepSeek model.
    """
    reference_data = load_reference_data(reference_csv)
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Initialize results dictionary just for DeepSeek
    results_by_llm = {"deepseek": []}
    
    for file_path in json_files:
        results = analyze_single_json(file_path, reference_data)
        for result in results:
            results_by_llm["deepseek"].append(result)
    
    return results_by_llm

def get_display_name(llm_name):
    """Convert internal model names to display-friendly names."""
    name_mapping = {
        "deepseek": "DeepSeek R1"
    }
    return name_mapping.get(llm_name, llm_name)

def plot_llm_ranking_histogram(rankings: List[int], llm_name: str, experiment_label: str, 
                             max_rank: int = 5, color: str = "#4169E1", figsize: tuple = (5.5, 5.5)):
    """
    Create histogram of molecule rankings for an LLM model with an extra bin for ranks beyond max_rank.
    
    Args:
        rankings: List of rankings for correct molecules
        llm_name: Name of the LLM model
        experiment_label: Label for the experiment
        max_rank: Maximum individual rank to show in histogram (default: 5)
        color: Color for histogram bars
        figsize: Size of the figure
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=figsize, facecolor='white')
    ax.set_facecolor('white')
    
    # Prepare data for histogram
    rank_counts = {}
    
    # Count ranks 1-5 individually
    for r in range(1, max_rank + 1):
        rank_counts[r] = sum(1 for rank in rankings if rank == r)
    
    # Count all ranks > max_rank together
    rank_counts['6+'] = sum(1 for rank in rankings if rank > max_rank)
    
    # Plot the histogram
    positions = list(range(1, max_rank + 1)) + [max_rank + 1]
    counts = [rank_counts[r] if r <= max_rank else rank_counts['6+'] for r in positions]
    
    bars = ax.bar(
        positions,
        counts,
        width=0.8,
        edgecolor='black',
        alpha=0.7,
        color=color
    )
    
    # Get display-friendly model name
    display_name = get_display_name(llm_name)
    
    # Create title with model name and experiment label
    title = f"{display_name} - {experiment_label}"
    ax.set_title(title, fontsize=16, pad=10)
    
    ax.set_xlabel('Rank', fontsize=14)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Set x-axis ticks and labels
    ax.set_xticks(positions)
    x_labels = [str(i) for i in range(1, max_rank + 1)] + ['6+']
    ax.set_xticklabels(x_labels, fontsize=12)
    ax.tick_params(axis='y', labelsize=12)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Calculate the maximum y value and add padding (20%)
    max_count = max(counts) if counts else 0
    y_max = 35 # Add 20% padding above the highest bar
    ax.set_ylim(0, y_max)
    
    # Add counts above bars (not bold)
    for bar in bars:
        height = bar.get_height()
        if height > 0:
            ax.text(
                bar.get_x() + bar.get_width()/2, 
                height + (y_max * 0.01),  # Adjust text position based on y-axis range
                f'{int(height)}',
                ha='center', 
                va='bottom',
                fontsize=12
            )
    
    # Calculate statistics
    total_molecules = len(rankings)
    in_top_1 = sum(1 for r in rankings if r == 1)
    in_top_3 = sum(1 for r in rankings if r <= 3)
    in_top_5 = sum(1 for r in rankings if r <= 5)
    after_top_5 = sum(1 for r in rankings if r > 5)
    
    stats_text = (
        f'Total molecules found: {total_molecules}\n'
        f'Found in top 1: {in_top_1} ({in_top_1/total_molecules*100:.1f}%)\n'
        f'Found in top 3: {in_top_3} ({in_top_3/total_molecules*100:.1f}%)\n'
        f'Found in top 5: {in_top_5} ({in_top_5/total_molecules*100:.1f}%)\n'
        f'Found after top 5: {after_top_5} ({after_top_5/total_molecules*100:.1f}%)'
    )
    
    # Add statistics text box
    ax.text(0.95, 0.95, stats_text,
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='right',
            fontsize=12,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    return fig

def analyze_deepseek_results():
    """
    Analyze DeepSeek results across specified experiments with updated colors.
    """
    # Define output directory for figures
    output_dir = "./deepseek_results"
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Define experiments to analyze (from the second reference file)
    experiments = {
        "Additional Data HSQC aug": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_15_Lukas_aug_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv"
        },
        "Additional Data HSQC": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_14_Lukas_target_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv"
        },        
        "Simulated Data ALL with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data ALL with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data ALL": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data HSQC with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_12_exp_d1_aug_MMST_HSQC_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data HSQC": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_13_exp_d1_MMST_HSQC_finished",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data ALL with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_10_exp_d1_aug_MMST_all_new",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data ALL": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    }
    
    # Define colors for each experiment (from the second reference file)
    experiment_colors = {
        "Additional Data HSQC aug": "#14B8A6",                # Teal-500
        "Additional Data HSQC": "#0EA5E9",                    # Sky-500
        "Simulated Data ALL with Wrong Guess": "#3B82F6",     # Blue
        "Simulated Data ALL with Noise": "#10B981",           # Green
        "Simulated Data ALL": "#6366F1",                      # Indigo
        "Experimental Data HSQC with Wrong Guess": "#F97316", # Orange
        "Experimental Data HSQC": "#84CC16",                  # Lime
        "Experimental Data ALL with Wrong Guess": "#EC4899",  # Pink
        "Experimental Data ALL": "#8B5CF6"                    # Purple
    }
    
    # Store summary statistics
    summary_stats = {}
    
    # Process each experiment
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get results for DeepSeek for this experiment
        results_by_llm = analyze_directory(paths["json_directory"], paths["reference_csv"])
        
        # Check if DeepSeek results exist for this experiment
        if not results_by_llm["deepseek"]:
            print(f"No DeepSeek results found for {exp_label}")
            continue
            
        # Extract rankings for DeepSeek
        deepseek_results = results_by_llm["deepseek"]
        rankings = [r["correct_position"] for r in deepseek_results if r["correct_position"] is not None]
        
        if not rankings:
            print(f"No valid rankings found for DeepSeek in {exp_label}")
            continue
            
        print(f"Found {len(rankings)} valid rankings for DeepSeek in {exp_label}")
            
        # Create individual plot for DeepSeek for this experiment
        fig = plot_llm_ranking_histogram(
            rankings,
            'deepseek',
            exp_label, 
            color=experiment_colors[exp_label],
            figsize=(10, 6)
        )
        
        # Save the figure
        fig_path = os.path.join(output_dir, f'deepseek_{exp_label.replace(" ", "_")}.png')
        fig.savefig(fig_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {fig_path}")
        
        # Calculate statistics
        total = len(rankings)
        in_top_1 = sum(1 for r in rankings if r == 1)
        in_top_5 = sum(1 for r in rankings if r <= 5)
        in_top_10 = sum(1 for r in rankings if r <= 10)
        
        summary_stats[exp_label] = {
            "total": total,
            "in_top_1": in_top_1,
            "in_top_5": in_top_5,
            "in_top_10": in_top_10,
            "percent_top_1": in_top_1/total*100 if total > 0 else 0,
            "percent_top_5": in_top_5/total*100 if total > 0 else 0,
            "percent_top_10": in_top_10/total*100 if total > 0 else 0
        }
    
    # Create comparison figure with all experiments
    fig, ax = plt.subplots(figsize=(12, 8))
    
    # Set up bar positions
    bar_width = 0.8
    bar_positions = np.arange(len(summary_stats))
    
    # Prepare data for plotting
    labels = []
    top_1_values = []
    top_5_values = []
    top_10_values = []
    colors = []
    
    for exp_label, stats in summary_stats.items():
        labels.append(exp_label)
        top_1_values.append(stats['percent_top_1'])
        top_5_values.append(stats['percent_top_5'])
        top_10_values.append(stats['percent_top_10'])
        colors.append(experiment_colors[exp_label])
    
    # Create comparison bar chart for Top-1 performance
    bars = ax.bar(
        bar_positions,
        top_1_values,
        bar_width,
        color=colors,
        edgecolor='black',
        alpha=0.8,
        label='Top-1 Accuracy'
    )
    
    # Add text labels on bars
    for i, bar in enumerate(bars):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width()/2.,
            height + 1,
            f"{top_1_values[i]:.1f}%",
            ha='center',
            va='bottom',
            fontsize=9
        )
    
    # Set labels and title
    ax.set_xlabel('Experiment')
    ax.set_ylabel('Top-1 Accuracy (%)')
    ax.set_title('DeepSeek R1 Performance Across Experiments (Top-1 Accuracy)')
    
    # Set x-ticks and labels
    ax.set_xticks(bar_positions)
    ax.set_xticklabels(labels, rotation=45, ha='right')
    
    # Add a grid for better readability
    ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Adjust layout and save figure
    plt.tight_layout()
    comparison_fig_path = os.path.join(output_dir, 'deepseek_comparison.png')
    fig.savefig(comparison_fig_path, dpi=300, bbox_inches='tight')
    print(f"Saved comparison figure to: {comparison_fig_path}")
    
    # Create a summary table
    print("\n===== DEEPSEEK PERFORMANCE SUMMARY =====")
    print(f"{'Experiment':<40} {'Total':<8} {'Top 1':<12} {'Top 5':<12} {'Top 10':<12}")
    print("-" * 80)
    
    for exp_label, stats in summary_stats.items():
        print(f"{exp_label:<40} {stats['total']:<8} "
              f"{stats['in_top_1']} ({stats['percent_top_1']:.1f}%) "
              f"{stats['in_top_5']} ({stats['percent_top_5']:.1f}%) "
              f"{stats['in_top_10']} ({stats['percent_top_10']:.1f}%)")
    
    # Save summary statistics to a CSV file
    try:
        summary_data = []
        
        for exp_label, stats in summary_stats.items():
            summary_data.append({
                'Experiment': exp_label,
                'Model': 'DeepSeek R1',
                'Total': stats['total'],
                'Top 1': stats['in_top_1'],
                'Top 1 %': stats['percent_top_1'],
                'Top 5': stats['in_top_5'],
                'Top 5 %': stats['percent_top_5'],
                'Top 10': stats['in_top_10'],
                'Top 10 %': stats['percent_top_10']
            })
        
        if summary_data:
            summary_df = pd.DataFrame(summary_data)
            summary_csv_path = os.path.join(output_dir, 'deepseek_summary_statistics.csv')
            summary_df.to_csv(summary_csv_path, index=False)
            print(f"\nSaved summary statistics to: {summary_csv_path}")
    except Exception as e:
        print(f"Error saving summary statistics: {e}")
    
    return summary_stats

# Run the analysis
if __name__ == "__main__":
    analyze_deepseek_results()

## Comparison HSQC and HSQC+LLM

### V1

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os
from typing import Dict, List, Tuple, Any

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def process_single_json_hsqc(json_data):
    """Process a single JSON file and return sorted molecules by HSQC score."""
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return []
    
    all_molecules = []
    analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                    }
                    all_molecules.append(processed_mol)
                except KeyError:
                    continue
    
    # Sort by HSQC score
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    return all_molecules

def analyze_llm_predictions(json_data, true_smiles, llm_name="deepseek"):
    """
    Analyze predictions from the DeepSeek LLM model.
    
    Args:
        json_data: Loaded JSON data
        true_smiles: True SMILES string to compare against
        llm_name: Name of the LLM model to analyze (default: deepseek)
        
    Returns:
        Dictionary with analysis results or None if analysis fails
    """
    try:
        # Extract LLM's candidates and sort by confidence
        llm_results = json_data["analysis_results"]["final_analysis"]["llm_responses"][llm_name]["parsed_results"]
        candidates = llm_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        return {
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "is_top_1": correct_position == 1 if correct_position else False,
            "is_top_5": correct_position is not None and correct_position <= 6
        }
        
    except (KeyError, TypeError):
        # This LLM might not have results in this file
        return None

def find_molecule_rank(molecules, true_smiles):
    """Find the rank of the correct molecule."""
    for idx, mol in enumerate(molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def analyze_directory_both_methods(json_dir, reference_csv, total_possible=34):
    """
    Analyze all JSON files using both HSQC and DeepSeek.
    Returns accuracy statistics for both methods.
    
    Args:
        json_dir: Directory containing JSON files
        reference_csv: Path to reference CSV file
        total_possible: Total number of possible molecules to identify (default: 34)
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    hsqc_rankings = []
    deepseek_rankings = []
    
    sample_ids = []  # Keep track of which samples were analyzed
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process HSQC ranking
            molecules = process_single_json_hsqc(json_data)
            if not molecules:
                continue
                
            hsqc_rank = find_molecule_rank(molecules, true_smiles)
            
            # Process DeepSeek ranking
            deepseek_result = analyze_llm_predictions(json_data, true_smiles)
            
            # Only include in analysis if both methods have results
            if hsqc_rank is not None and deepseek_result is not None and deepseek_result.get("correct_position") is not None:
                hsqc_rankings.append(hsqc_rank)
                deepseek_rankings.append(deepseek_result["correct_position"])
                sample_ids.append(sample_id)
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    # Calculate statistics
    total_samples = len(hsqc_rankings)
    
    if total_samples == 0:
        return {
            "total_samples": 0,
            "hsqc_top_1": 0,
            "hsqc_top_1_percent": 0,
            "deepseek_top_1": 0,
            "deepseek_top_1_percent": 0,
            "sample_ids": []
        }
    
    hsqc_top_1 = sum(1 for r in hsqc_rankings if r == 1)
    hsqc_top_1_percent = (hsqc_top_1 / total_possible) * 100
    
    deepseek_top_1 = sum(1 for r in deepseek_rankings if r == 1)
    deepseek_top_1_percent = (deepseek_top_1 / total_possible) * 100
    
    return {
        "total_samples": total_samples,
        "total_possible": total_possible,
        "hsqc_top_1": hsqc_top_1,
        "hsqc_top_1_percent": hsqc_top_1_percent,
        "deepseek_top_1": deepseek_top_1,
        "deepseek_top_1_percent": deepseek_top_1_percent,
        "sample_ids": sample_ids
    }

def analyze_directory_both_methods(json_dir, reference_csv, total_possible=34):
    """
    Analyze all JSON files using both HSQC and DeepSeek.
    Returns accuracy statistics for both methods.
    
    Args:
        json_dir: Directory containing JSON files
        reference_csv: Path to reference CSV file
        total_possible: Total number of possible molecules to identify (default: 34)
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    hsqc_rankings = []
    deepseek_rankings = []
    
    hsqc_top_5_rankings = []
    
    sample_ids = []  # Keep track of which samples were analyzed
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process HSQC ranking
            molecules = process_single_json_hsqc(json_data)
            if not molecules:
                continue
                
            hsqc_rank = find_molecule_rank(molecules, true_smiles)
            
            # Process DeepSeek ranking
            deepseek_result = analyze_llm_predictions(json_data, true_smiles)
            
            # Only include in analysis if both methods have results
            if hsqc_rank is not None and deepseek_result is not None and deepseek_result.get("correct_position") is not None:
                hsqc_rankings.append(hsqc_rank)
                deepseek_rankings.append(deepseek_result["correct_position"])
                
                # Track HSQC top-5 results
                hsqc_top_5_rankings.append(hsqc_rank if hsqc_rank is not None and hsqc_rank <= 5 else None)
                
                sample_ids.append(sample_id)
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    # Calculate statistics
    total_samples = len(hsqc_rankings)
    
    if total_samples == 0:
        return {
            "total_samples": 0,
            "hsqc_top_1": 0,
            "hsqc_top_1_percent": 0,
            "deepseek_top_1": 0,
            "deepseek_top_1_percent": 0,
            "hsqc_top_5": 0,
            "hsqc_top_5_percent": 0,
            "sample_ids": []
        }
    
    hsqc_top_1 = sum(1 for r in hsqc_rankings if r == 1)
    hsqc_top_1_percent = (hsqc_top_1 / total_possible) * 100
    
    deepseek_top_1 = sum(1 for r in deepseek_rankings if r == 1)
    deepseek_top_1_percent = (deepseek_top_1 / total_possible) * 100
    
    # Top-5 calculations for HSQC
    hsqc_top_5 = sum(1 for r in hsqc_top_5_rankings if r is not None)
    hsqc_top_5_percent = (hsqc_top_5 / total_possible) * 100
    
    return {
        "total_samples": total_samples,
        "total_possible": total_possible,
        "hsqc_top_1": hsqc_top_1,
        "hsqc_top_1_percent": hsqc_top_1_percent,
        "deepseek_top_1": deepseek_top_1,
        "deepseek_top_1_percent": deepseek_top_1_percent,
        "hsqc_top_5": hsqc_top_5,
        "hsqc_top_5_percent": hsqc_top_5_percent,
        "sample_ids": sample_ids
    }

def plot_comparison_chart(experiment_results, total_possible=34):
    """
    Create a grouped bar chart comparing HSQC Top-5, Top-1, and DeepSeek Top-1 accuracy.
    Show actual counts in bars and percentages on top.
    
    Args:
        experiment_results: Dictionary with experiment names as keys and results dictionaries as values
        total_possible: Total number of possible molecules to identify (default: 34)
    """
    # Extract data for plotting
    experiments = list(experiment_results.keys())
    
    # Prepare data for plotting
    hsqc_top_5_counts = [results["hsqc_top_5"] for results in experiment_results.values()]
    hsqc_top_1_counts = [results["hsqc_top_1"] for results in experiment_results.values()]
    deepseek_top_1_counts = [results["deepseek_top_1"] for results in experiment_results.values()]
    
    hsqc_top_5_percentages = [results["hsqc_top_5_percent"] for results in experiment_results.values()]
    hsqc_top_1_percentages = [results["hsqc_top_1_percent"] for results in experiment_results.values()]
    deepseek_top_1_percentages = [results["deepseek_top_1_percent"] for results in experiment_results.values()]
    
    # For the x-axis positions
    x = np.arange(len(experiments))
    width = 0.25  # Adjusted width to accommodate three bars
    
    # Create figure and axis with increased size
    fig, ax = plt.subplots(figsize=(16, 10), facecolor='white')
    ax.set_facecolor('white')
    
    # Create the bars in the specified order
    rects1 = ax.bar(x - width, hsqc_top_5_counts, width, label='HSQC Top-5', 
                  color='#2ca02c', alpha=0.7, edgecolor='black')
    rects2 = ax.bar(x, hsqc_top_1_counts, width, label='HSQC Top-1', 
                  color='#1f77b4', alpha=0.7, edgecolor='black')
    rects3 = ax.bar(x + width, deepseek_top_1_counts, width, label='DeepSeek Top-1', 
                  color='#ff7f0e', alpha=0.7, edgecolor='black')
    
    # Increase font sizes even more
    ax.set_xlabel('Experiment', fontsize=20)
    ax.set_ylabel('Number of Correct Predictions', fontsize=20)
    
    # Remove title as requested
    # ax.set_title('Comparison of Top-5, Top-1 Accuracy: HSQC and DeepSeek', fontsize=18, pad=20)
    
    # Set y-axis to show actual counts, extending to 36
    ax.set_ylim(0, 36)
    
    # Adjust x-axis tick labels
    ax.set_xticks(x)
    
    # Shorten experiment names for display
    short_names = []
    for exp in experiments:
        if "Simulated Data" in exp:
            if "Wrong Guess" in exp:
                short_names.append("Sim+WG")
            elif "Noise" in exp:
                short_names.append("Sim+Noise")
            else:
                short_names.append("Sim")
        elif "Experimental Data" in exp:
            if "Wrong Guess" in exp:
                short_names.append("Exp+WG")
            elif "d4" in exp:
                short_names.append("Exp d4")
            else:
                short_names.append("Exp")
        else:
            short_names.append(exp)
    
    # Increase font size for x-axis labels
    ax.set_xticklabels(short_names, ha='center', fontsize=18)
    # Increase font size for y-axis labels
    ax.tick_params(axis='y', labelsize=18)
    
    # Increase legend font size even more
    ax.legend(fontsize=18, loc='upper right')
    
    # Add a grid
    ax.grid(True, linestyle='--', alpha=0.3, axis='y')
    
    # Add count inside bars and percentage on top with larger font
    def add_labels(rects, counts, percentages):
        for i, (rect, count, pct) in enumerate(zip(rects, counts, percentages)):
            # Add count inside bar
            height = rect.get_height()
            center_x = rect.get_x() + rect.get_width() / 2
            
            # Only show count inside bar if bar is tall enough
            if height > 5:  # Only show if bar is taller than 5 units
                ax.text(center_x, height/2, str(count),
                      ha='center', va='center',
                      fontsize=16, color='white')  # Removed fontweight='bold'
            else:
                # For short bars, place count just above
                ax.text(center_x, height + 1, str(count),
                      ha='center', va='bottom',
                      fontsize=16)
            
            # Add percentage on top with larger font
            ax.text(center_x, height + 0.5, f'{pct:.1f}%',
                  ha='center', va='bottom',
                  fontsize=16)
    
    add_labels(rects1, hsqc_top_5_counts, hsqc_top_5_percentages)
    add_labels(rects2, hsqc_top_1_counts, hsqc_top_1_percentages)
    add_labels(rects3, deepseek_top_1_counts, deepseek_top_1_percentages)
    
    # Adjust layout
    plt.tight_layout()
    
    return fig

def main():
    # Define the experiments
    experiments = {
        "Simulated Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Simulated Data with Noise": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        "Experimental Data with Wrong Guess": {
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
      #  "Experimental Data d4": {
      #      "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean",
      #      "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
      #  }
    }
    
    # Define the total possible number of molecules
    TOTAL_POSSIBLE = 34
    
    # Store results for each experiment
    experiment_results = {}
    
    # Process each experiment
    for exp_label, paths in experiments.items():
        print(f"\nAnalyzing {exp_label}...")
        
        # Get rankings for both methods
        results = analyze_directory_both_methods(paths["json_directory"], paths["reference_csv"], TOTAL_POSSIBLE)
        
        if results["total_samples"] == 0:
            print(f"No valid results found for {exp_label}. Please check your input files.")
            continue
        
        experiment_results[exp_label] = results
        
        # Print individual experiment results
        print(f"Total samples analyzed: {results['total_samples']}")
        print(f"HSQC Top-1: {results['hsqc_top_1']}/{TOTAL_POSSIBLE} ({results['hsqc_top_1_percent']:.1f}%)")
        print(f"DeepSeek Top-1: {results['deepseek_top_1']}/{TOTAL_POSSIBLE} ({results['deepseek_top_1_percent']:.1f}%)")
        print(f"Improvement: +{results['deepseek_top_1'] - results['hsqc_top_1']} hits (+{results['deepseek_top_1_percent'] - results['hsqc_top_1_percent']:.1f}%)")
    
    # Create combined comparison plot
    if experiment_results:
        fig = plot_comparison_chart(experiment_results, TOTAL_POSSIBLE)
        plt.show()
        
        # Save the plot
        fig.savefig('top1_accuracy_comparison.png', dpi=300, bbox_inches='tight')
    else:
        print("No results to display. Please check your input files.")

if __name__ == "__main__":
    main()

In [None]:
def plot_comparison_chart(experiment_results, total_possible=34):
    """
    Create a grouped bar chart comparing HSQC Top-5, Top-1, and DeepSeek Top-1 accuracy.
    Show actual counts in bars and percentages on top.
    
    Args:
        experiment_results: Dictionary with experiment names as keys and results dictionaries as values
        total_possible: Total number of possible molecules to identify (default: 34)
    """
    # Extract data for plotting
    experiments = list(experiment_results.keys())
    
    # Prepare data for plotting
    hsqc_top_5_counts = [results["hsqc_top_5"] for results in experiment_results.values()]
    hsqc_top_1_counts = [results["hsqc_top_1"] for results in experiment_results.values()]
    deepseek_top_1_counts = [results["deepseek_top_1"] for results in experiment_results.values()]
    
    hsqc_top_5_percentages = [results["hsqc_top_5_percent"] for results in experiment_results.values()]
    hsqc_top_1_percentages = [results["hsqc_top_1_percent"] for results in experiment_results.values()]
    deepseek_top_1_percentages = [results["deepseek_top_1_percent"] for results in experiment_results.values()]
    
    # For the x-axis positions
    x = np.arange(len(experiments))
    width = 0.25  # Adjusted width to accommodate three bars
    
    # Create figure and axis with increased size
    fig, ax = plt.subplots(figsize=(16, 10), facecolor='white')
    ax.set_facecolor('white')
    
    # Create the bars in the specified order
    rects1 = ax.bar(x - width, hsqc_top_5_counts, width, label='HSQC Top-5', 
                  color='#2ca02c', alpha=0.7, edgecolor='black')
    rects2 = ax.bar(x, hsqc_top_1_counts, width, label='HSQC Top-1', 
                  color='#1f77b4', alpha=0.7, edgecolor='black')
    rects3 = ax.bar(x + width, deepseek_top_1_counts, width, label='DeepSeek Top-1', 
                  color='#ff7f0e', alpha=0.7, edgecolor='black')
    
    # Increase font sizes even more
    #ax.set_xlabel('Experiment', fontsize=20)
    ax.set_ylabel('Number of Correct Predictions', fontsize=20)
    
    # Set y-axis to show actual counts, extending to 36
    ax.set_ylim(0, 36)
    
    # Adjust x-axis tick labels
    ax.set_xticks(x)
    
    # Update experiment names for display using Target/Analogue terminology
    short_names = []
    for exp in experiments:
        if "Simulated Data" in exp:
            if "Wrong Guess" in exp:
                short_names.append("Sim Analogue")
            elif "Noise" in exp:
                short_names.append("Sim Target+Noise")
            else:
                short_names.append("Sim Target")
        elif "Experimental Data" in exp:
            if "Wrong Guess" in exp:
                short_names.append("Exp Analogue")
            elif "d4" in exp:
                short_names.append("Exp d4")
            else:
                short_names.append("Exp Target")
        else:
            short_names.append(exp)
    
    # Increase font size for x-axis labels
    ax.set_xticklabels(short_names, ha='center', fontsize=18)
    # Increase font size for y-axis labels
    ax.tick_params(axis='y', labelsize=18)
    
    # Increase legend font size even more
    ax.legend(fontsize=18, loc='upper right')
    
    # Add a grid
    ax.grid(True, linestyle='--', alpha=0.3, axis='y')
    
    # Add count inside bars and percentage on top with larger font
    def add_labels(rects, counts, percentages):
        for i, (rect, count, pct) in enumerate(zip(rects, counts, percentages)):
            # Add count inside bar
            height = rect.get_height()
            center_x = rect.get_x() + rect.get_width() / 2
            
            # Only show count inside bar if bar is tall enough
            if height > 5:  # Only show if bar is taller than 5 units
                ax.text(center_x, height/2, str(count),
                      ha='center', va='center',
                      fontsize=16, color='white')
            else:
                # For short bars, place count just above
                ax.text(center_x, height + 1, str(count),
                      ha='center', va='bottom',
                      fontsize=16)
            
            # Add percentage on top with larger font
            ax.text(center_x, height + 0.5, f'{pct:.1f}%',
                  ha='center', va='bottom',
                  fontsize=16)
    
    add_labels(rects1, hsqc_top_5_counts, hsqc_top_5_percentages)
    add_labels(rects2, hsqc_top_1_counts, hsqc_top_1_percentages)
    add_labels(rects3, deepseek_top_1_counts, deepseek_top_1_percentages)
    
    # Adjust layout
    plt.tight_layout()
        # Create the save path
    save_path = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures/hsqc_deepseek_comparison.png"
    
    # Save the figure with high resolution
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    
    print(f"Figure saved at: {save_path}")
    
    
    return fig

In [None]:

if __name__ == "__main__":
    main()

#### Noise Example case study

In [None]:
import json
import glob
import os
from typing import Dict, List, Tuple, Any
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from io import BytesIO
import base64

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def process_single_json_hsqc(json_data):
    """Process a single JSON file and return sorted molecules by HSQC score."""
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return []
    
    all_molecules = []
    analysis_types = ['forward_synthesis',  'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                    }
                    all_molecules.append(processed_mol)
                except KeyError:
                    continue
    
    # Sort by HSQC score
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    return all_molecules

def analyze_llm_predictions(json_data, true_smiles, llm_name="deepseek"):
    """
    Analyze predictions from the LLM model.
    """
    try:
        # Extract LLM's candidates and sort by confidence
        llm_results = json_data["analysis_results"]["final_analysis"]["llm_responses"][llm_name]["parsed_results"]
        candidates = llm_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        return {
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "is_top_1": correct_position == 1 if correct_position else False,
            "is_top_5": correct_position is not None and correct_position <= 5
        }
        
    except (KeyError, TypeError):
        # This LLM might not have results in this file
        return None

def find_molecule_rank(molecules, true_smiles):
    """Find the rank of the correct molecule."""
    for idx, mol in enumerate(molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def find_llm_corrected_molecules(json_dir, reference_csv):
    """
    Find molecules where the LLM corrected the HSQC ranking in the "Sim+Noise" condition.
    
    Args:
        json_dir: Directory containing JSON files for Sim+Noise condition
        reference_csv: Path to reference CSV file
    
    Returns:
        List of dictionaries with sample ID and positions for molecules corrected by LLM
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    corrected_molecules = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process HSQC ranking
            molecules = process_single_json_hsqc(json_data)
            if not molecules:
                continue
                
            hsqc_rank = find_molecule_rank(molecules, true_smiles)
            
            # Process DeepSeek ranking
            deepseek_result = analyze_llm_predictions(json_data, true_smiles)
            
            # If both methods have results and LLM corrected HSQC
            if (hsqc_rank is not None and 
                deepseek_result is not None and 
                deepseek_result.get("correct_position") is not None):
                
                # Found a case where LLM corrected (HSQC not top-1, LLM is top-1)
                if hsqc_rank != 1 and deepseek_result["correct_position"] == 1:
                    corrected_molecules.append({
                        "sample_id": sample_id,
                        "base_sample_id": base_sample_id,
                        "hsqc_rank": hsqc_rank,
                        "deepseek_rank": deepseek_result["correct_position"]
                    })
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return corrected_molecules

def generate_molecule_image(smiles):
    """Generate an RDKit molecule object from SMILES."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        return mol
    except Exception as e:
        print(f"Error generating molecule from SMILES: {str(e)}")
        return None

def visualize_corrected_molecules_matplotlib(corrected_molecules, reference_data, max_molecules_per_fig=6):
    """
    Visualize corrected molecules using Matplotlib.
    
    Args:
        corrected_molecules: List of dictionaries with corrected molecule info
        reference_data: Dictionary mapping sample IDs to SMILES
        max_molecules_per_fig: Maximum number of molecules per figure
    
    Returns:
        DataFrame with corrected molecules data
    """
    total_molecules = len(corrected_molecules)
    if total_molecules == 0:
        print("No corrected molecules found.")
        return pd.DataFrame()
    
    # Print summary first
    print(f"\nFound {total_molecules} molecules corrected by LLM:")
    
    # Create DataFrame for easier analysis
    df = pd.DataFrame(corrected_molecules)
    df['smiles'] = df['base_sample_id'].apply(lambda x: reference_data.get(x, ''))
    
    # Calculate how many figures needed
    num_figures = (total_molecules + max_molecules_per_fig - 1) // max_molecules_per_fig
    
    # Create RDKit molecules list
    mols = []
    labels = []
    titles = []
    
    for i, mol_info in enumerate(corrected_molecules):
        base_sample_id = mol_info['base_sample_id']
        hsqc_rank = mol_info['hsqc_rank']
        
        # Get SMILES
        smiles = reference_data.get(base_sample_id)
        if not smiles:
            print(f"No SMILES found for {base_sample_id}")
            continue
        
        # Generate molecule
        mol = generate_molecule_image(smiles)
        if mol is None:
            print(f"Could not generate molecule for {base_sample_id}")
            continue
        
        mols.append(mol)
        labels.append(f"{base_sample_id}")
        titles.append(f"HSQC Rank: {hsqc_rank} → DeepSeek: 1")
        
        # Print to console
        print(f"{i+1}. Sample ID: {base_sample_id}")
        print(f"   SMILES: {smiles}")
        print(f"   HSQC Rank: {hsqc_rank} → DeepSeek Rank: 1")
        print()
    
    # Plot molecules in batches
    for fig_num in range(num_figures):
        start_idx = fig_num * max_molecules_per_fig
        end_idx = min(start_idx + max_molecules_per_fig, total_molecules)
        
        fig_mols = mols[start_idx:end_idx]
        fig_labels = labels[start_idx:end_idx]
        fig_titles = titles[start_idx:end_idx]
        
        # Calculate grid dimensions
        if len(fig_mols) <= 3:
            n_rows, n_cols = 1, len(fig_mols)
        else:
            n_rows = (len(fig_mols) + 2) // 3  # Ceiling division by 3
            n_cols = min(3, len(fig_mols))
        
        # Create figure
        fig = plt.figure(figsize=(n_cols * 5, n_rows * 5))
        
        for j, (mol, label, title) in enumerate(zip(fig_mols, fig_labels, fig_titles)):
            # Create subplot
            ax = fig.add_subplot(n_rows, n_cols, j + 1)
            
            # Use RDKit's MolToImage directly for this subplot
            img = Draw.MolToImage(mol, size=(400, 300))
            ax.imshow(img)
            
            # Add title and other information
            ax.set_title(f"{label}\n{title}", fontsize=12)
            ax.axis('off')  # Turn off axis
        
        # Adjust layout
        plt.tight_layout()
        plt.suptitle(f"Molecules Corrected by LLM (Simulated Data with Noise) - Set {fig_num+1}/{num_figures}", 
                    fontsize=16, y=1.02)
        
        # Save figure
        plt.savefig(f"corrected_molecules_set_{fig_num+1}.png", dpi=300, bbox_inches='tight')
        
        # Show figure
        plt.show()
    
    # Print summary statistics
    print("\nSummary:")
    print(f"Total molecules corrected: {len(df)}")
    print("\nHSQC original rankings of corrected molecules:")
    print(df['hsqc_rank'].value_counts().sort_index())
    
    # Save to CSV
    df.to_csv('llm_corrected_molecules_sim_noise.csv', index=False)
    print("Saved results to 'llm_corrected_molecules_sim_noise.csv'")
    
    return df

def main():
    # Define the paths
    sim_noise_json_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean"
    reference_csv = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
    
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Find molecules corrected by LLM
    corrected_molecules = find_llm_corrected_molecules(sim_noise_json_dir, reference_csv)
    
    # Visualize corrected molecules using Matplotlib
    df = visualize_corrected_molecules_matplotlib(corrected_molecules, reference_data, max_molecules_per_fig=6)

if __name__ == "__main__":
    main()

In [None]:
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, AllChem
import numpy as np
import io
from PIL import Image, ImageDraw, ImageFont
import pandas as pd

def create_confidence_bar(width, height, confidence):
    """
    Create a confidence bar image with color based on confidence score.
    """
    # Create image with white background
    bar_img = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(bar_img)
    
    # Determine color based on confidence
    if confidence <= 0.5:
        # Red (1,0,0) to Yellow (1,1,0)
        r = 255
        g = int(confidence * 2 * 255)
        b = 0
    else:
        # Yellow (1,1,0) to Green (0,0.8,0)
        r = int((2 - confidence * 2) * 255)
        g = 204
        b = 0
    
    # Draw the colored bar
    bar_width = int(confidence * width)
    draw.rectangle([(0, 0), (bar_width, height)], fill=(r, g, b))
    
    # Add a border
    draw.rectangle([(0, 0), (width-1, height-1)], outline=(100, 100, 100), width=2)
    
    return bar_img

def generate_molecule_card(smiles, confidence, hsqc_error, hsqc_rank, is_correct=False, 
                          mol_size=(450, 350), card_width=520, card_height=750):
    """
    Generate a card with molecule image using standard RDKit drawing.
    """
    try:
        # Parse SMILES and prepare molecule
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
            
        # Calculate molecular weight
        mol_weight = Descriptors.MolWt(mol)
        
        # Draw the molecule using standard RDKit drawing
        drawer = Draw.rdMolDraw2D.MolDraw2DCairo(mol_size[0], mol_size[1])
        drawer.SetFontSize(1.4)  # Further increase the font size for atom labels
        drawer.DrawMolecule(mol)
        drawer.FinishDrawing()
        png = drawer.GetDrawingText()
        mol_img = Image.open(io.BytesIO(png))
        
        # Create a new card image with white background
        card = Image.new('RGB', (card_width, card_height), (252, 252, 252))
        
        # Paste the molecule image
        card.paste(mol_img, ((card_width - mol_size[0]) // 2, 60))
        
        # Create and paste the confidence bar
        conf_bar = create_confidence_bar(card_width - 80, 36, confidence)
        card.paste(conf_bar, (40, mol_size[1] + 100))
        
        # Add text with Draw
        draw = ImageDraw.Draw(card)
        
        # Try to load a font, fall back to default if not available
        try:
            title_font = ImageFont.truetype("arial.ttf", 28)
            font = ImageFont.truetype("arial.ttf", 26)
            small_font = ImageFont.truetype("arial.ttf", 24)
        except IOError:
            try:
                title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28)
                font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 26)
                small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 24)
            except IOError:
                title_font = ImageFont.load_default()
                font = ImageFont.load_default()
                small_font = ImageFont.load_default()
        
        # Add title with HSQC rank
        title = f"HSQC Rank {hsqc_rank}"
        if is_correct:
            title += " (CORRECT)"
        
        # Calculate text position for center alignment
        try:
            title_width = draw.textlength(title, font=title_font)
        except AttributeError:  # For older PIL versions
            title_width = title_font.getsize(title)[0]
            
        title_position = ((card_width - title_width) // 2, 15)
        
        # Draw a background for the title
        if is_correct:
            draw.rectangle([(0, 0), (card_width, 55)], fill=(220, 245, 220))
            draw.text(title_position, title, fill=(0, 100, 0), font=title_font)
        else:
            draw.rectangle([(0, 0), (card_width, 55)], fill=(240, 240, 245))
            draw.text(title_position, title, fill=(50, 50, 100), font=title_font)
        
        # Add confidence text with larger font
        conf_text = f"Confidence: {confidence:.2f}"
        try:
            conf_width = draw.textlength(conf_text, font=font)
        except AttributeError:  # For older PIL versions
            conf_width = font.getsize(conf_text)[0]
            
        conf_position = ((card_width - conf_width) // 2, mol_size[1] + 150)
        draw.text(conf_position, conf_text, fill=(0, 0, 0), font=font)
        
        # Add HSQC error with larger font
        error_text = f"HSQC Error: {hsqc_error:.3f}"
        try:
            error_width = draw.textlength(error_text, font=font)
        except AttributeError:  # For older PIL versions
            error_width = font.getsize(error_text)[0]
            
        error_position = ((card_width - error_width) // 2, mol_size[1] + 200)
        draw.text(error_position, error_text, fill=(0, 0, 0), font=font)
        
        # Add molecular weight with larger font
        mw_text = f"MW: {mol_weight:.2f}"
        try:
            mw_width = draw.textlength(mw_text, font=font)
        except AttributeError:  # For older PIL versions
            mw_width = font.getsize(mw_text)[0]
            
        mw_position = ((card_width - mw_width) // 2, mol_size[1] + 250)
        draw.text(mw_position, mw_text, fill=(0, 0, 0), font=font)
        
        # Highlight if this is the correct structure
        if is_correct:
            # Draw a green border around the card
            draw.rectangle([(0, 0), (card_width-1, card_height-1)], outline=(0, 150, 0), width=5)
        else:
            # Draw a subtle border
            draw.rectangle([(0, 0), (card_width-1, card_height-1)], outline=(200, 200, 220), width=3)
        
        return card
    
    except Exception as e:
        print(f"Error generating molecule card: {e}")
        return None

def visualize_candidates_single_row(candidates_data, title=None, 
                                   figsize=(24, 7), correct_hsqc_rank=5, filename=None):
    """
    Visualize all candidate molecules in a single row.
    """
    # Sort by HSQC rank
    sorted_candidates = sorted(candidates_data, key=lambda x: x['hsqc_rank'])
    
    # Generate molecule cards
    cards = []
    for candidate in sorted_candidates:
        is_correct = (candidate['hsqc_rank'] == correct_hsqc_rank)
        card = generate_molecule_card(
            candidate['smiles'],
            candidate['confidence_score'],
            candidate['hsqc_error'],
            candidate['hsqc_rank'],
            is_correct=is_correct,
            mol_size=(450, 350),
            card_width=520,
            card_height=650
        )
        if card:
            cards.append(card)
    
    if not cards:
        print("No valid molecule cards generated")
        return None
    
    # Create figure with more padding between molecules
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(wspace=0.4)  # Add more space between subplots
    
    # Calculate grid layout - single row
    n_cols = len(cards)
    
    # Add each card as a subplot
    for i, card in enumerate(cards):
        ax = fig.add_subplot(1, n_cols, i+1)
        ax.imshow(card)
        ax.axis('off')
    
    # No title, subtitle or footer as requested
    plt.tight_layout()
    
    # Save if filename provided
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
    
    return fig

def visualize_with_grid_image(candidates_data, correct_hsqc_rank=5, filename=None):
    """
    Alternative visualization using RDKit's MolsToGridImage with improved style.
    """
    # Sort by HSQC rank
    sorted_candidates = sorted(candidates_data, key=lambda x: x['hsqc_rank'])
    
    # Generate molecule objects
    mols = []
    legends = []
    
    for candidate in sorted_candidates:
        mol = Chem.MolFromSmiles(candidate['smiles'])
        if mol:
            # Calculate molecular weight
            mol_weight = Descriptors.MolWt(mol)
            
            # Generate 2D coordinates
            AllChem.Compute2DCoords(mol)
            
            mols.append(mol)
            
            # Create legend with HSQC rank, confidence, error, and molecular weight
            legend = f"Rank {candidate['hsqc_rank']}"
            if candidate['hsqc_rank'] == correct_hsqc_rank:
                legend += " (CORRECT)"
            legend += f"\nConf: {candidate['confidence_score']:.2f}"
            legend += f"\nHSQC Err: {candidate['hsqc_error']:.3f}"
            legend += f"\nMW: {mol_weight:.2f}"
            
            legends.append(legend)
    
    if not mols:
        print("No valid molecules")
        return None
    
    # Create highlight colors
    highlightAtomLists = [[] for _ in mols]
    highlightAtomColors = [[] for _ in mols]
    
    # Find index of the correct molecule
    correct_idx = None
    for i, candidate in enumerate(sorted_candidates):
        if candidate['hsqc_rank'] == correct_hsqc_rank:
            correct_idx = i
            break
    
    if correct_idx is not None:
        # Highlight all atoms for the correct molecule
        mol = mols[correct_idx]
        atoms = list(range(mol.GetNumAtoms()))
        highlightAtomLists[correct_idx] = atoms
        highlightAtomColors[correct_idx] = [(0.0, 0.7, 0.0) for _ in atoms]  # Green highlight
    
    # Create grid image with more customization
    grid_img = Draw.MolsToGridImage(
        mols,
        molsPerRow=len(mols),
        subImgSize=(450, 400),
        legends=legends,
        highlightAtomLists=highlightAtomLists,
        highlightAtomColors=highlightAtomColors,
        useSVG=False,
        legendFontSize=20,  # Significantly increase legend font size
        maxMols=len(mols)
    )
    
    # Convert PIL Image to numpy array for matplotlib
    grid_array = np.array(grid_img)
    
    # Display with matplotlib without title
    plt.figure(figsize=(24, 7))
    plt.imshow(grid_array)
    plt.axis('off')
    
    plt.tight_layout()
    
    # Save if filename provided
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
    
    return plt.gcf()

def analyze_case_study():
    """
    Create improved visualization using standard RDKit drawing.
    """
    # Data for the candidate molecules
    candidate_data = [
        {
            "hsqc_rank": 1,
            "smiles": "C=C1c2c([nH]c(C)c2CN2CCOCC2)CCC1CO",
            "confidence_score": 0.45,
            "hsqc_error": 3.625
        },
        {
            "hsqc_rank": 2,
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOC1)CCC2",
            "confidence_score": 0.35,
            "hsqc_error": 3.933
        },
        {
            "hsqc_rank": 3,
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOC1)CCC2",
            "confidence_score": 0.32,
            "hsqc_error": 4.070
        },
        {
            "hsqc_rank": 4,
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCCOC1)CC2",
            "confidence_score": 0.50,
            "hsqc_error": 4.188
        },
        {
            "hsqc_rank": 5,  # This is the correct molecule
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOCC1)CC2",
            "confidence_score": 0.85,
            "hsqc_error": 4.547
        }
    ]
    
    # Set the correct HSQC rank
    correct_hsqc_rank = 5
    
    # Create improved visualization - no title 
    fig = visualize_candidates_single_row(
        candidate_data, 
        figsize=(24, 7),
        correct_hsqc_rank=correct_hsqc_rank,
        filename="improved_case_study_molecules.png"
    )
    
    plt.show()
    
    # Also try the improved alternative visualization - no title
    fig2 = visualize_with_grid_image(
        candidate_data,
        correct_hsqc_rank=5,
        filename="improved_case_study_grid.png"
    )
    
    plt.figure(fig2.number)
    plt.show()
    
    return fig, fig2

if __name__ == "__main__":
    # Run the improved visualizations
    analyze_case_study()

In [None]:
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, AllChem, rdFMCS
import numpy as np
import io
from PIL import Image, ImageDraw, ImageFont
import pandas as pd

def create_confidence_bar(width, height, confidence):
    """
    Create a confidence bar image with color based on confidence score.
    """
    # Create image with white background
    bar_img = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(bar_img)
    
    # Determine color based on confidence
    if confidence <= 0.5:
        # Red (1,0,0) to Yellow (1,1,0)
        r = 255
        g = int(confidence * 2 * 255)
        b = 0
    else:
        # Yellow (1,1,0) to Green (0,0.8,0)
        r = int((2 - confidence * 2) * 255)
        g = 204
        b = 0
    
    # Draw the colored bar
    bar_width = int(confidence * width)
    draw.rectangle([(0, 0), (bar_width, height)], fill=(r, g, b))
    
    # Add a border
    draw.rectangle([(0, 0), (width-1, height-1)], outline=(100, 100, 100), width=2)
    
    return bar_img

from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, AllChem, rdFMCS, TemplateAlign

def align_molecules_to_template(mols):
    """
    Align 2D depictions of molecules using TemplateAlign.
    The first molecule is used as the template.
    """
    if not mols or len(mols) < 2:
        return mols
    
    # Use the first molecule as template
    template_mol = mols[0]
    
    # Generate 2D coordinates for the template
    AllChem.Compute2DCoords(template_mol)
    
    # For each molecule (except template), align to template
    for i in range(1, len(mols)):
        try:
            # Generate 2D coordinates for this molecule first
            AllChem.Compute2DCoords(mols[i])
            
            # Use TemplateAlign to align to the template
            conf_id = TemplateAlign.AlignMolToTemplate2D(mols[i], template_mol)
            
            if conf_id >= 0:  # A successful alignment returns a valid conformer ID
                print(f"Successfully aligned molecule {i} to template")
            else:
                print(f"Alignment failed for molecule {i}, using original coordinates")
                
        except Exception as e:
            print(f"Error in alignment: {e}. Using basic coordinates.")
    
    return mols

def generate_molecule_card(smiles, confidence, hsqc_error, hsqc_rank, is_correct=False, 
                          mol_size=(450, 350), card_width=520, card_height=750,
                          aligned_mol=None):
    """
    Generate a card with molecule image using standard RDKit drawing.
    """
    try:
        # Parse SMILES and prepare molecule
        if aligned_mol is None:
            mol = Chem.MolFromSmiles(smiles)
            if mol is None:
                return None
            # Generate 2D coordinates
            AllChem.Compute2DCoords(mol)
        else:
            mol = aligned_mol
            
        # Calculate molecular weight
        mol_weight = Descriptors.MolWt(mol)
        
        # Draw the molecule using standard RDKit drawing
        drawer = Draw.rdMolDraw2D.MolDraw2DCairo(mol_size[0], mol_size[1])
        drawer.SetFontSize(1.4)  # Further increase the font size for atom labels
        drawer.DrawMolecule(mol)
        drawer.FinishDrawing()
        png = drawer.GetDrawingText()
        mol_img = Image.open(io.BytesIO(png))
        
        # Create a new card image with white background
        card = Image.new('RGB', (card_width, card_height), (252, 252, 252))
        
        # Paste the molecule image
        card.paste(mol_img, ((card_width - mol_size[0]) // 2, 60))
        
        # Create and paste the confidence bar
        conf_bar = create_confidence_bar(card_width - 80, 36, confidence)
        card.paste(conf_bar, (40, mol_size[1] + 100))
        
        # Add text with Draw
        draw = ImageDraw.Draw(card)
        
        # Try to load a font, fall back to default if not available
        try:
            title_font = ImageFont.truetype("arial.ttf", 28)
            font = ImageFont.truetype("arial.ttf", 26)
            small_font = ImageFont.truetype("arial.ttf", 24)
        except IOError:
            try:
                title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28)
                font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 26)
                small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 24)
            except IOError:
                title_font = ImageFont.load_default()
                font = ImageFont.load_default()
                small_font = ImageFont.load_default()
        
        # Add title with HSQC rank
        title = f"HSQC Rank {hsqc_rank}"
        if is_correct:
            title += " (CORRECT)"
        
        # Calculate text position for center alignment
        try:
            title_width = draw.textlength(title, font=title_font)
        except AttributeError:  # For older PIL versions
            title_width = title_font.getsize(title)[0]
            
        title_position = ((card_width - title_width) // 2, 15)
        
        # Draw a background for the title
        if is_correct:
            draw.rectangle([(0, 0), (card_width, 55)], fill=(220, 245, 220))
            draw.text(title_position, title, fill=(0, 100, 0), font=title_font)
        else:
            draw.rectangle([(0, 0), (card_width, 55)], fill=(240, 240, 245))
            draw.text(title_position, title, fill=(50, 50, 100), font=title_font)
        
        # Add confidence text with larger font
        conf_text = f"Confidence: {confidence:.2f}"
        try:
            conf_width = draw.textlength(conf_text, font=font)
        except AttributeError:  # For older PIL versions
            conf_width = font.getsize(conf_text)[0]
            
        conf_position = ((card_width - conf_width) // 2, mol_size[1] + 150)
        draw.text(conf_position, conf_text, fill=(0, 0, 0), font=font)
        
        # Add HSQC error with larger font
        error_text = f"HSQC Error: {hsqc_error:.3f}"
        try:
            error_width = draw.textlength(error_text, font=font)
        except AttributeError:  # For older PIL versions
            error_width = font.getsize(error_text)[0]
            
        error_position = ((card_width - error_width) // 2, mol_size[1] + 200)
        draw.text(error_position, error_text, fill=(0, 0, 0), font=font)
        
        # Add molecular weight with larger font
        mw_text = f"MW: {mol_weight:.2f}"
        try:
            mw_width = draw.textlength(mw_text, font=font)
        except AttributeError:  # For older PIL versions
            mw_width = font.getsize(mw_text)[0]
            
        mw_position = ((card_width - mw_width) // 2, mol_size[1] + 250)
        draw.text(mw_position, mw_text, fill=(0, 0, 0), font=font)
        
        # Highlight if this is the correct structure
        if is_correct:
            # Draw a green border around the card
            draw.rectangle([(0, 0), (card_width-1, card_height-1)], outline=(0, 150, 0), width=5)
        else:
            # Draw a subtle border
            draw.rectangle([(0, 0), (card_width-1, card_height-1)], outline=(200, 200, 220), width=3)
        
        return card
    
    except Exception as e:
        print(f"Error generating molecule card: {e}")
        return None

def visualize_candidates_single_row(candidates_data, title=None, 
                                   figsize=(24, 7), correct_hsqc_rank=5, filename=None,
                                   template_idx=None):
    """
    Visualize all candidate molecules in a single row with aligned cores.
    """
    # Sort by HSQC rank
    sorted_candidates = sorted(candidates_data, key=lambda x: x['hsqc_rank'])
    
    # Generate molecule objects and prepare for alignment
    mols = []
    for candidate in sorted_candidates:
        mol = Chem.MolFromSmiles(candidate['smiles'])
        if mol:
            # Make a fresh copy to avoid modification issues
            mol = Chem.Mol(mol)
            mols.append(mol)
    
    # If template_idx is specified, rearrange molecules to put template first
    if template_idx is not None and 0 <= template_idx < len(mols):
        # Move the specified template to the first position for alignment
        template_mol = mols.pop(template_idx)
        mols.insert(0, template_mol)
        
        # Remember to adjust candidate order too for mapping
        template_candidate = sorted_candidates.pop(template_idx)
        sorted_candidates.insert(0, template_candidate)
    
    # Align all molecules based on template
    aligned_mols = align_molecules_to_template(mols)
    
    # Generate molecule cards with aligned molecules
    cards = []
    for i, candidate in enumerate(sorted_candidates):
        if i < len(aligned_mols) and aligned_mols[i]:
            is_correct = (candidate['hsqc_rank'] == correct_hsqc_rank)
            card = generate_molecule_card(
                candidate['smiles'],
                candidate['confidence_score'],
                candidate['hsqc_error'],
                candidate['hsqc_rank'],
                is_correct=is_correct,
                mol_size=(450, 350),
                card_width=520,
                card_height=650,
                aligned_mol=aligned_mols[i]
            )
            if card:
                cards.append(card)
    
    if not cards:
        print("No valid molecule cards generated")
        return None
    
    # Create figure with more padding between molecules
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(wspace=0.4)  # Add more space between subplots
    
    # Calculate grid layout - single row
    n_cols = len(cards)
    
    # Add each card as a subplot
    for i, card in enumerate(cards):
        ax = fig.add_subplot(1, n_cols, i+1)
        ax.imshow(card)
        ax.axis('off')
    
    # No title, subtitle or footer as requested
    plt.tight_layout()
    
    # Save if filename provided
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
    
    return fig

def visualize_with_grid_image(candidates_data, correct_hsqc_rank=5, filename=None, 
                             template_idx=None):
    """
    Alternative visualization using RDKit's MolsToGridImage with improved style and aligned core structures.
    """
    # Sort by HSQC rank
    sorted_candidates = sorted(candidates_data, key=lambda x: x['hsqc_rank'])
    
    # Generate molecule objects
    mols = []
    for candidate in sorted_candidates:
        mol = Chem.MolFromSmiles(candidate['smiles'])
        if mol:
            # Make a fresh copy
            mol = Chem.Mol(mol)
            mols.append(mol)
    
    if not mols:
        print("No valid molecules")
        return None
    
    # If template_idx is specified, rearrange molecules to put template first
    if template_idx is not None and 0 <= template_idx < len(mols):
        # Move the specified template to the first position for alignment
        template_mol = mols.pop(template_idx)
        mols.insert(0, template_mol)
        
        # Remember to adjust candidate order too for mapping
        template_candidate = sorted_candidates.pop(template_idx)
        sorted_candidates.insert(0, template_candidate)
    
    # Align molecule cores
    aligned_mols = align_molecules_to_template(mols)
    
    # Prepare legends and highlighting
    legends = []
    for i, candidate in enumerate(sorted_candidates):
        if i < len(aligned_mols):
            # Calculate molecular weight
            mol_weight = Descriptors.MolWt(aligned_mols[i])
            
            # Create legend with HSQC rank, confidence, error, and molecular weight
            legend = f"Rank {candidate['hsqc_rank']}"
            if candidate['hsqc_rank'] == correct_hsqc_rank:
                legend += " (CORRECT)"
            legend += f"\nConf: {candidate['confidence_score']:.2f}"
            legend += f"\nHSQC Err: {candidate['hsqc_error']:.3f}"
            legend += f"\nMW: {mol_weight:.2f}"
            
            legends.append(legend)
    
    # Create highlight colors
    highlightAtomLists = [[] for _ in aligned_mols]
    highlightAtomColors = [[] for _ in aligned_mols]
    
    # Find index of the correct molecule
    correct_idx = None
    for i, candidate in enumerate(sorted_candidates):
        if candidate['hsqc_rank'] == correct_hsqc_rank and i < len(aligned_mols):
            correct_idx = i
            break
    
    if correct_idx is not None:
        # Highlight all atoms for the correct molecule
        mol = aligned_mols[correct_idx]
        atoms = list(range(mol.GetNumAtoms()))
        highlightAtomLists[correct_idx] = atoms
        highlightAtomColors[correct_idx] = [(0.0, 0.7, 0.0) for _ in atoms]  # Green highlight
    
    # Create grid image with more customization
    grid_img = Draw.MolsToGridImage(
        aligned_mols,
        molsPerRow=len(aligned_mols),
        subImgSize=(450, 400),
        legends=legends,
        highlightAtomLists=highlightAtomLists,
        highlightAtomColors=highlightAtomColors,
        useSVG=False,
        legendFontSize=20,  # Significantly increase legend font size
        maxMols=len(aligned_mols)
    )
    
    # Convert PIL Image to numpy array for matplotlib
    grid_array = np.array(grid_img)
    
    # Display with matplotlib without title
    plt.figure(figsize=(24, 7))
    plt.imshow(grid_array)
    plt.axis('off')
    
    plt.tight_layout()
    
    # Save if filename provided
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
    
    return plt.gcf()

def analyze_case_study():
    """
    Create improved visualization using standard RDKit drawing with aligned core structures.
    """
    # Data for the candidate molecules
    candidate_data = [
        {
            "hsqc_rank": 1,
            "smiles": "C=C1c2c([nH]c(C)c2CN2CCOCC2)CCC1CO",
            "confidence_score": 0.45,
            "hsqc_error": 3.625
        },
        {
            "hsqc_rank": 2,
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOC1)CCC2",
            "confidence_score": 0.35,
            "hsqc_error": 3.933
        },
        {
            "hsqc_rank": 3,
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOC1)CCC2",
            "confidence_score": 0.32,
            "hsqc_error": 4.070
        },
        {
            "hsqc_rank": 4,
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCCOC1)CC2",
            "confidence_score": 0.50,
            "hsqc_error": 4.188
        },
        {
            "hsqc_rank": 5,  # This is the correct molecule
            "smiles": "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOCC1)CC2",
            "confidence_score": 0.85,
            "hsqc_error": 4.547
        }
    ]
    
    # Set the correct HSQC rank
    correct_hsqc_rank = 5
    
    # Use the correct molecule (rank 5) as the template for better alignment
    template_idx = 4  # Index of the correct molecule in sorted list (0-based)
    
    # Create improved visualization - no title 
    fig = visualize_candidates_single_row(
        candidate_data, 
        figsize=(24, 7),
        correct_hsqc_rank=correct_hsqc_rank,
        filename="improved_case_study_molecules.png",
        template_idx=template_idx
    )
    
    plt.show()
    
    # Also try the improved alternative visualization - no title
    fig2 = visualize_with_grid_image(
        candidate_data,
        correct_hsqc_rank=5,
        filename="improved_case_study_grid.png",
        template_idx=template_idx
    )
    
    plt.figure(fig2.number)
    plt.show()
    
    return fig, fig2

if __name__ == "__main__":
    # Run the improved visualizations
    analyze_case_study()

In [None]:
import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import IPythonConsole

# Your molecule data
smiles_list = [
    "C=C1c2c([nH]c(C)c2CN2CCOCC2)CCC1CO",
    "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOC1)CCC2",
    "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOC1)CCC2",
    "CCc1c(C)[nH]c2c1C(=O)C(CN1CCCOC1)CC2",
    "CCc1c(C)[nH]c2c1C(=O)C(CN1CCOCC1)CC2"
]

# Define a template for the core structure (indole core)
template_smiles = "c1c([nH]c2c1CCCC2)CC"  # Simplified core structure
template = Chem.MolFromSmiles(template_smiles)
AllChem.Compute2DCoords(template)

# Convert SMILES to molecules and align them
mols = []
for smiles in smiles_list:
    mol = Chem.MolFromSmiles(smiles)
    
    # Generate initial 2D coordinates
    AllChem.Compute2DCoords(mol)
    
    # Try to align with template
    try:
        # Find the core structure and align
        matches = mol.GetSubstructMatches(template)
        if matches:
            AllChem.GenerateDepictionMatching2DStructure(mol, template)
    except:
        # Fallback to standard 2D coords if alignment fails
        AllChem.Compute2DCoords(mol)
    
    mols.append(mol)

# Drawing options
Draw.rdDepictor.SetPreferCoordGen(True)
drawing_options = Draw.MolDrawOptions()
drawing_options.legendFontSize = 12
drawing_options.bondLineWidth = 2

# Create a grid of molecules with custom options
img = Draw.MolsToGridImage(
    mols,
    molsPerRow=3,
    subImgSize=(300, 300),
    legends=[f'Molecule {i+1}' for i in range(len(mols))],
    returnPNG=False,
    drawOptions=drawing_options
)

# Display the grid
img.save("molecule_grid_aligned.png")  # Optional: save the image
img  # This will display the image in Jupyter notebook

In [None]:
file_path = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean/AZ10011150_noise_intermediate.json"


with open(file_path, 'r') as f:
    json_data = json.load(f)
process_single_json_hsqc(json_data)

#### Experimental Case Study

In [None]:
import json
import glob
import os
from typing import Dict, List, Tuple, Any
import pandas as pd
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.Draw import IPythonConsole
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from io import BytesIO
import base64

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def process_single_json_hsqc(json_data):
    """Process a single JSON file and return sorted molecules by HSQC score."""
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return []
    
    all_molecules = []
    analysis_types = ['forward_synthesis', "mol2mol", 'mmst']
    
    for analysis_type in analysis_types:
        if analysis_type in candidate_analysis:
            molecules = candidate_analysis[analysis_type].get('molecules', [])
            for mol in molecules:
                try:
                    processed_mol = {
                        'smiles': mol['smiles'],
                        'hsqc_score': mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                    }
                    all_molecules.append(processed_mol)
                except KeyError:
                    continue
    
    # Sort by HSQC score
    all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
    return all_molecules

def analyze_llm_predictions(json_data, true_smiles, llm_name="deepseek"):
    """
    Analyze predictions from the LLM model.
    """
    try:
        # Extract LLM's candidates and sort by confidence
        llm_results = json_data["analysis_results"]["final_analysis"]["llm_responses"][llm_name]["parsed_results"]
        candidates = llm_results["candidates"]
        
        # Sort candidates by confidence score
        sorted_candidates = sorted(candidates, 
                                 key=lambda x: x["confidence_score"], 
                                 reverse=True)
        
        # Find position of correct molecule
        correct_position = None
        for i, cand in enumerate(sorted_candidates, 1):
            if cand["smiles"] == true_smiles:
                correct_position = i
                break
        
        return {
            "correct_position": correct_position,
            "total_candidates": len(sorted_candidates),
            "is_top_1": correct_position == 1 if correct_position else False,
            "is_top_5": correct_position is not None and correct_position <= 5
        }
        
    except (KeyError, TypeError):
        # This LLM might not have results in this file
        return None

def find_molecule_rank(molecules, true_smiles):
    """Find the rank of the correct molecule."""
    for idx, mol in enumerate(molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def find_llm_corrected_molecules(json_dir, reference_csv):
    """
    Find molecules where the LLM corrected the HSQC ranking in the "Sim+Noise" condition.
    
    Args:
        json_dir: Directory containing JSON files for Sim+Noise condition
        reference_csv: Path to reference CSV file
    
    Returns:
        List of dictionaries with sample ID and positions for molecules corrected by LLM
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    corrected_molecules = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data = json_data.get('molecule_data', {})
            sample_id = molecule_data.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Process HSQC ranking
            molecules = process_single_json_hsqc(json_data)
            if not molecules:
                continue
                
            hsqc_rank = find_molecule_rank(molecules, true_smiles)
            
            # Process DeepSeek ranking
            deepseek_result = analyze_llm_predictions(json_data, true_smiles)
            
            # If both methods have results and LLM corrected HSQC
            if (hsqc_rank is not None and 
                deepseek_result is not None and 
                deepseek_result.get("correct_position") is not None):
                
                # Found a case where LLM corrected (HSQC not top-1, LLM is top-1)
                if hsqc_rank != 1 and deepseek_result["correct_position"] == 1:
                    corrected_molecules.append({
                        "sample_id": sample_id,
                        "base_sample_id": base_sample_id,
                        "hsqc_rank": hsqc_rank,
                        "deepseek_rank": deepseek_result["correct_position"]
                    })
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return corrected_molecules

def generate_molecule_image(smiles):
    """Generate an RDKit molecule object from SMILES."""
    try:
        mol = Chem.MolFromSmiles(smiles)
        return mol
    except Exception as e:
        print(f"Error generating molecule from SMILES: {str(e)}")
        return None

def visualize_corrected_molecules_matplotlib(corrected_molecules, reference_data, max_molecules_per_fig=6):
    """
    Visualize corrected molecules using Matplotlib.
    
    Args:
        corrected_molecules: List of dictionaries with corrected molecule info
        reference_data: Dictionary mapping sample IDs to SMILES
        max_molecules_per_fig: Maximum number of molecules per figure
    
    Returns:
        DataFrame with corrected molecules data
    """
    total_molecules = len(corrected_molecules)
    if total_molecules == 0:
        print("No corrected molecules found.")
        return pd.DataFrame()
    
    # Print summary first
    print(f"\nFound {total_molecules} molecules corrected by LLM:")
    
    # Create DataFrame for easier analysis
    df = pd.DataFrame(corrected_molecules)
    df['smiles'] = df['base_sample_id'].apply(lambda x: reference_data.get(x, ''))
    
    # Calculate how many figures needed
    num_figures = (total_molecules + max_molecules_per_fig - 1) // max_molecules_per_fig
    
    # Create RDKit molecules list
    mols = []
    labels = []
    titles = []
    
    for i, mol_info in enumerate(corrected_molecules):
        base_sample_id = mol_info['base_sample_id']
        hsqc_rank = mol_info['hsqc_rank']
        
        # Get SMILES
        smiles = reference_data.get(base_sample_id)
        if not smiles:
            print(f"No SMILES found for {base_sample_id}")
            continue
        
        # Generate molecule
        mol = generate_molecule_image(smiles)
        if mol is None:
            print(f"Could not generate molecule for {base_sample_id}")
            continue
        
        mols.append(mol)
        labels.append(f"{base_sample_id}")
        titles.append(f"HSQC Rank: {hsqc_rank} → DeepSeek: 1")
        
        # Print to console
        print(f"{i+1}. Sample ID: {base_sample_id}")
        print(f"   SMILES: {smiles}")
        print(f"   HSQC Rank: {hsqc_rank} → DeepSeek Rank: 1")
        print()
    
    # Plot molecules in batches
    for fig_num in range(num_figures):
        start_idx = fig_num * max_molecules_per_fig
        end_idx = min(start_idx + max_molecules_per_fig, total_molecules)
        
        fig_mols = mols[start_idx:end_idx]
        fig_labels = labels[start_idx:end_idx]
        fig_titles = titles[start_idx:end_idx]
        
        # Calculate grid dimensions
        if len(fig_mols) <= 3:
            n_rows, n_cols = 1, len(fig_mols)
        else:
            n_rows = (len(fig_mols) + 2) // 3  # Ceiling division by 3
            n_cols = min(3, len(fig_mols))
        
        # Create figure
        fig = plt.figure(figsize=(n_cols * 5, n_rows * 5))
        
        for j, (mol, label, title) in enumerate(zip(fig_mols, fig_labels, fig_titles)):
            # Create subplot
            ax = fig.add_subplot(n_rows, n_cols, j + 1)
            
            # Use RDKit's MolToImage directly for this subplot
            img = Draw.MolToImage(mol, size=(400, 300))
            ax.imshow(img)
            
            # Add title and other information
            ax.set_title(f"{label}\n{title}", fontsize=12)
            ax.axis('off')  # Turn off axis
        
        # Adjust layout
        plt.tight_layout()
        plt.suptitle(f"Molecules Corrected by LLM (Simulated Data with Noise) - Set {fig_num+1}/{num_figures}", 
                    fontsize=16, y=1.02)
        
        # Save figure
        plt.savefig(f"corrected_molecules_set_{fig_num+1}.png", dpi=300, bbox_inches='tight')
        
        # Show figure
        plt.show()
    
    # Print summary statistics
    print("\nSummary:")
    print(f"Total molecules corrected: {len(df)}")
    print("\nHSQC original rankings of corrected molecules:")
    print(df['hsqc_rank'].value_counts().sort_index())
    
    # Save to CSV
    df.to_csv('llm_corrected_molecules_sim_noise.csv', index=False)
    print("Saved results to 'llm_corrected_molecules_sim_noise.csv'")
    
    return df

def main():
    # Define the paths
    sim_noise_json_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished"
    reference_csv = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
    
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Find molecules corrected by LLM
    corrected_molecules = find_llm_corrected_molecules(sim_noise_json_dir, reference_csv)
    
    # Visualize corrected molecules using Matplotlib
    df = visualize_corrected_molecules_matplotlib(corrected_molecules, reference_data, max_molecules_per_fig=6)

if __name__ == "__main__":
    main()

In [None]:
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw, Descriptors, AllChem
import numpy as np
import io
from PIL import Image, ImageDraw, ImageFont
import pandas as pd

def create_confidence_bar(width, height, confidence):
    """
    Create a confidence bar image with color based on confidence score.
    """
    # Create image with white background
    bar_img = Image.new('RGB', (width, height), (255, 255, 255))
    draw = ImageDraw.Draw(bar_img)
    
    # Determine color based on confidence
    if confidence <= 0.5:
        # Red (1,0,0) to Yellow (1,1,0)
        r = 255
        g = int(confidence * 2 * 255)
        b = 0
    else:
        # Yellow (1,1,0) to Green (0,0.8,0)
        r = int((2 - confidence * 2) * 255)
        g = 204
        b = 0
    
    # Draw the colored bar
    bar_width = int(confidence * width)
    draw.rectangle([(0, 0), (bar_width, height)], fill=(r, g, b))
    
    # Add a border
    draw.rectangle([(0, 0), (width-1, height-1)], outline=(100, 100, 100), width=2)
    
    return bar_img

def generate_molecule_card(smiles, confidence, hsqc_error, hsqc_rank, is_correct=False, 
                          mol_size=(450, 350), card_width=520, card_height=750):
    """
    Generate a card with molecule image using standard RDKit drawing.
    """
    try:
        # Parse SMILES and prepare molecule
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return None
            
        # Calculate molecular weight
        mol_weight = Descriptors.MolWt(mol)
        
        # Draw the molecule using standard RDKit drawing
        drawer = Draw.rdMolDraw2D.MolDraw2DCairo(mol_size[0], mol_size[1])
        drawer.SetFontSize(1.4)  # Further increase the font size for atom labels
        drawer.DrawMolecule(mol)
        drawer.FinishDrawing()
        png = drawer.GetDrawingText()
        mol_img = Image.open(io.BytesIO(png))
        
        # Create a new card image with white background
        card = Image.new('RGB', (card_width, card_height), (252, 252, 252))
        
        # Paste the molecule image
        card.paste(mol_img, ((card_width - mol_size[0]) // 2, 60))
        
        # Create and paste the confidence bar
        conf_bar = create_confidence_bar(card_width - 80, 36, confidence)
        card.paste(conf_bar, (40, mol_size[1] + 100))
        
        # Add text with Draw
        draw = ImageDraw.Draw(card)
        
        # Try to load a font, fall back to default if not available
        try:
            title_font = ImageFont.truetype("arial.ttf", 28)
            font = ImageFont.truetype("arial.ttf", 26)
            small_font = ImageFont.truetype("arial.ttf", 24)
        except IOError:
            try:
                title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 28)
                font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 26)
                small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 24)
            except IOError:
                title_font = ImageFont.load_default()
                font = ImageFont.load_default()
                small_font = ImageFont.load_default()
        
        # Add title with HSQC rank
        title = f"HSQC Rank {hsqc_rank}"
        if is_correct:
            title += " (CORRECT)"
        
        # Calculate text position for center alignment
        try:
            title_width = draw.textlength(title, font=title_font)
        except AttributeError:  # For older PIL versions
            title_width = title_font.getsize(title)[0]
            
        title_position = ((card_width - title_width) // 2, 15)
        
        # Draw a background for the title
        if is_correct:
            draw.rectangle([(0, 0), (card_width, 55)], fill=(220, 245, 220))
            draw.text(title_position, title, fill=(0, 100, 0), font=title_font)
        else:
            draw.rectangle([(0, 0), (card_width, 55)], fill=(240, 240, 245))
            draw.text(title_position, title, fill=(50, 50, 100), font=title_font)
        
        # Add confidence text with larger font
        conf_text = f"Confidence: {confidence:.2f}"
        try:
            conf_width = draw.textlength(conf_text, font=font)
        except AttributeError:  # For older PIL versions
            conf_width = font.getsize(conf_text)[0]
            
        conf_position = ((card_width - conf_width) // 2, mol_size[1] + 150)
        draw.text(conf_position, conf_text, fill=(0, 0, 0), font=font)
        
        # Add HSQC error with larger font
        error_text = f"HSQC Error: {hsqc_error:.3f}"
        try:
            error_width = draw.textlength(error_text, font=font)
        except AttributeError:  # For older PIL versions
            error_width = font.getsize(error_text)[0]
            
        error_position = ((card_width - error_width) // 2, mol_size[1] + 200)
        draw.text(error_position, error_text, fill=(0, 0, 0), font=font)
        
        # Add molecular weight with larger font
        mw_text = f"MW: {mol_weight:.2f}"
        try:
            mw_width = draw.textlength(mw_text, font=font)
        except AttributeError:  # For older PIL versions
            mw_width = font.getsize(mw_text)[0]
            
        mw_position = ((card_width - mw_width) // 2, mol_size[1] + 250)
        draw.text(mw_position, mw_text, fill=(0, 0, 0), font=font)
        
        # Highlight if this is the correct structure
        if is_correct:
            # Draw a green border around the card
            draw.rectangle([(0, 0), (card_width-1, card_height-1)], outline=(0, 150, 0), width=5)
        else:
            # Draw a subtle border
            draw.rectangle([(0, 0), (card_width-1, card_height-1)], outline=(200, 200, 220), width=3)
        
        return card
    
    except Exception as e:
        print(f"Error generating molecule card: {e}")
        return None

def visualize_candidates_single_row(candidates_data, title=None, 
                                   figsize=(24, 7), correct_hsqc_rank=3, filename=None):
    """
    Visualize all candidate molecules in a single row.
    """
    # Sort by HSQC rank
    sorted_candidates = sorted(candidates_data, key=lambda x: x['hsqc_rank'])
    
    # Generate molecule cards
    cards = []
    for candidate in sorted_candidates:
        is_correct = (candidate['hsqc_rank'] == correct_hsqc_rank)
        card = generate_molecule_card(
            candidate['smiles'],
            candidate['confidence_score'],
            candidate['hsqc_error'],
            candidate['hsqc_rank'],
            is_correct=is_correct,
            mol_size=(450, 350),
            card_width=520,
            card_height=650
        )
        if card:
            cards.append(card)
    
    if not cards:
        print("No valid molecule cards generated")
        return None
    
    # Create figure with more padding between molecules
    fig = plt.figure(figsize=figsize)
    plt.subplots_adjust(wspace=0.4)  # Add more space between subplots
    
    # Calculate grid layout - single row
    n_cols = len(cards)
    
    # Add each card as a subplot
    for i, card in enumerate(cards):
        ax = fig.add_subplot(1, n_cols, i+1)
        ax.imshow(card)
        ax.axis('off')
    
    # No title, subtitle or footer as requested
    plt.tight_layout()
    
    # Save if filename provided
    if filename:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
    
    return fig

def experimental_case_study():
    """
    Create visualization for the experimental data case study.
    """
    # Estimate HSQC error scores (since they weren't provided)
    # These are approximate values assuming a pattern of increasing errors
    hsqc_errors = [3.245, 3.478, 3.657, 3.802, 3.934]
    
    # Data for the experimental molecules
    candidate_data = [
        {
            "hsqc_rank": 1,
            "smiles": "COc1c(C)c2c(c(O)c1CC(C)=CCCC(=O)O)C(=O)OC2",
            "confidence_score": 0.75,
            "hsqc_error": hsqc_errors[0]
        },
        {
            "hsqc_rank": 2,
            "smiles": "COc1c(C)c(O)c2c(c1CC=C(C)CCC(=O)O)COC2=O",
            "confidence_score": 0.6,
            "hsqc_error": hsqc_errors[1]
        },
        {
            "hsqc_rank": 3,  # This is the correct molecule based on highest confidence
            "smiles": "COc1c(C)c2c(c(O)c1CC=C(C)CCC(=O)O)C(=O)OC2",
            "confidence_score": 0.85,
            "hsqc_error": hsqc_errors[2]
        },
        {
            "hsqc_rank": 4,
            "smiles": "COc1c(C)c(O)c2c(c1CC=C(C)CCC(=O)O)C(=O)OC2",
            "confidence_score": 0.55,
            "hsqc_error": hsqc_errors[3]
        },
        {
            "hsqc_rank": 5,
            "smiles": "COc1c(C)c2c(c(CC=C(C)CCC(=O)O)c1O)C(=O)OC2",
            "confidence_score": 0.5,
            "hsqc_error": hsqc_errors[4]
        }
    ]
    
    # Set the correct HSQC rank based on highest confidence
    correct_hsqc_rank = 3
    
    # Create visualization 
    fig = visualize_candidates_single_row(
        candidate_data, 
        figsize=(24, 7),
        correct_hsqc_rank=correct_hsqc_rank,
        filename="experimental_case_study.png"
    )
    
    plt.show()
    
    return fig

if __name__ == "__main__":
    # Run the visualization for experimental data
    experimental_case_study()

## Plot Confidence:

### V1 Dummy

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
import random
from scipy import stats

# Set the style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("deep")
plt.rcParams['figure.figsize'] = [12, 8]
plt.rcParams['font.size'] = 12

# Define the models with their characteristics
models = [
    {"name": "Claude 3.5 Sonnet", "accuracy": 0.79, "confidence_range": [0.65, 0.90], "color": "#6366F1"},
    {"name": "Claude 3.7 Sonnet-Thinking", "accuracy": 0.89, "confidence_range": [0.70, 0.95], "color": "#3B82F6"},
    {"name": "DeepSeek-R1", "accuracy": 0.87, "confidence_range": [0.65, 0.92], "color": "#10B981"},
    {"name": "Gemini-Thinking", "accuracy": 0.86, "confidence_range": [0.60, 0.88], "color": "#F59E0B"},
    {"name": "o3-mini", "accuracy": 0.85, "confidence_range": [0.65, 0.90], "color": "#EC4899"},
    {"name": "Kimi 1.5", "accuracy": 0.84, "confidence_range": [0.60, 0.85], "color": "#8B5CF6"}
]

# Function to generate dummy data for each model
def generate_model_data(model, n_samples=150):
    np.random.seed(42 + models.index(model))  # Different seed for each model but reproducible
    
    # Calculate number of correct and incorrect predictions based on model accuracy
    n_correct = int(n_samples * model["accuracy"])
    n_incorrect = n_samples - n_correct
    
    # Generate confidence scores for correct predictions (generally higher)
    correct_confidence = np.clip(
        np.random.normal(
            loc=model["confidence_range"][1] - 0.05,  # Mean near the high end of range
            scale=0.12,                               # Standard deviation
            size=n_correct
        ),
        0.0, 1.0  # Clip to valid range
    )
    
    # Generate confidence scores for incorrect predictions (generally lower)
    incorrect_confidence = np.clip(
        np.random.normal(
            loc=model["confidence_range"][0] + 0.10,  # Mean near the low end of range
            scale=0.15,                               # Standard deviation with more variance
            size=n_incorrect
        ),
        0.0, 1.0  # Clip to valid range
    )
    
    # Create DataFrames for correct and incorrect predictions
    correct_df = pd.DataFrame({
        'model': model["name"],
        'confidence': correct_confidence,
        'correct': True,
        'color': model["color"]
    })
    
    incorrect_df = pd.DataFrame({
        'model': model["name"],
        'confidence': incorrect_confidence,
        'correct': False,
        'color': model["color"]
    })
    
    # Combine the DataFrames
    combined_df = pd.concat([correct_df, incorrect_df])
    
    # Calculate correlation coefficient between confidence and correctness
    correlation = np.corrcoef(combined_df['confidence'], combined_df['correct'].astype(int))[0, 1]
    
    return combined_df, correlation

# Generate data for all models
all_data = []
correlations = []
for model in models:
    model_data, correlation = generate_model_data(model)
    all_data.append(model_data)
    correlations.append({'model': model["name"], 'correlation': correlation, 'color': model["color"]})

# Combine all the model data
full_dataset = pd.concat(all_data)
correlations_df = pd.DataFrame(correlations)

# Sort the models by correlation for display
correlations_df = correlations_df.sort_values('correlation', ascending=False)
model_order = correlations_df['model'].tolist()

# Prepare data for grouped bar chart
grouped_stats = []
for model in models:
    model_data = full_dataset[full_dataset['model'] == model["name"]]
    
    # For correct predictions
    correct_data = model_data[model_data['correct']]
    correct_mean = correct_data['confidence'].mean()
    correct_std = correct_data['confidence'].std()
    
    # For incorrect predictions
    incorrect_data = model_data[~model_data['correct']]
    incorrect_mean = incorrect_data['confidence'].mean()
    incorrect_std = incorrect_data['confidence'].std()
    
    grouped_stats.append({
        'model': model["name"],
        'correct_mean': correct_mean,
        'correct_std': correct_std,
        'incorrect_mean': incorrect_mean,
        'incorrect_std': incorrect_std,
        'color': model["color"],
        'correlation': correlations_df[correlations_df['model'] == model["name"]]['correlation'].values[0]
    })

# Create DataFrame for the stats
stats_df = pd.DataFrame(grouped_stats)
stats_df = stats_df.sort_values('correlation', ascending=False)

# Create the figure with two subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 8), gridspec_kw={'width_ratios': [1.5, 1]})

# Width of the bars
bar_width = 0.35

# Set positions for bars
r1 = np.arange(len(model_order))
r2 = [x + bar_width for x in r1]

# Create bar chart with standard deviation error bars
correct_bars = ax1.bar(r1, stats_df['correct_mean'], width=bar_width, 
                       yerr=stats_df['correct_std'], capsize=5, 
                       color=[stats_df[stats_df['model'] == model]['color'].values[0] for model in model_order],
                       alpha=0.7, label='Correct Predictions')

incorrect_bars = ax1.bar(r2, stats_df['incorrect_mean'], width=bar_width, 
                        yerr=stats_df['incorrect_std'], capsize=5, 
                        color=[stats_df[stats_df['model'] == model]['color'].values[0] for model in model_order],
                        alpha=0.3, hatch='///', label='Incorrect Predictions')

# Add individual data points as scatter
for i, model in enumerate(model_order):
    model_data = full_dataset[full_dataset['model'] == model]
    
    # Get correct predictions and add jitter to x-position
    correct_data = model_data[model_data['correct']]
    jitter = np.random.normal(0, 0.05, size=len(correct_data))
    ax1.scatter(r1[i] + jitter, correct_data['confidence'], 
                color=correct_data['color'].iloc[0], alpha=0.4, s=20)
    
    # Get incorrect predictions and add jitter to x-position
    incorrect_data = model_data[~model_data['correct']]
    jitter = np.random.normal(0, 0.05, size=len(incorrect_data))
    ax1.scatter(r2[i] + jitter, incorrect_data['confidence'], 
                color=incorrect_data['color'].iloc[0], alpha=0.4, s=20, marker='x')

# Customize bar chart
ax1.set_xlabel('Model', fontsize=12, fontweight='bold')
ax1.set_ylabel('Confidence Score', fontsize=12, fontweight='bold')
ax1.set_title('LLM Confidence Scores for Correct vs. Incorrect Predictions', fontsize=14, fontweight='bold')
ax1.set_xticks([r + bar_width/2 for r in range(len(model_order))])
ax1.set_xticklabels([model.replace(' ', '\n') for model in model_order], rotation=0)
ax1.set_ylim([0, 1.1])
ax1.legend(loc='upper right')

# Add correlation coefficients above each model
for i, model in enumerate(model_order):
    corr = stats_df[stats_df['model'] == model]['correlation'].values[0]
    ax1.text(i + bar_width/2, 1.03, f'r = {corr:.2f}', ha='center', va='bottom', fontweight='bold')

# Create a secondary plot for correlation coefficients
correlation_bars = ax2.bar(range(len(model_order)), 
                          stats_df['correlation'],
                          color=stats_df['color'])

# Add standard error indicators for correlations (simulated for visualization)
# In a real scenario, you would calculate actual confidence intervals
np.random.seed(42)
corr_errors = np.random.uniform(0.03, 0.08, size=len(model_order))
ax2.errorbar(range(len(model_order)), stats_df['correlation'], yerr=corr_errors, 
             fmt='none', color='black', capsize=5)

# Customize correlation chart
ax2.set_xlabel('Model', fontsize=12, fontweight='bold')
ax2.set_ylabel('Correlation Coefficient (r)', fontsize=12, fontweight='bold')
ax2.set_title('Confidence-Accuracy Correlation by Model', fontsize=14, fontweight='bold')
ax2.set_xticks(range(len(model_order)))
ax2.set_xticklabels([model.split(' ')[0] for model in model_order], rotation=45, ha='right')
ax2.set_ylim([0, 1.0])

# Add values above bars
for i, v in enumerate(stats_df['correlation']):
    ax2.text(i, v + 0.03, f'{v:.2f}', ha='center', fontsize=10, fontweight='bold')

# Add horizontal lines to mark correlation strength thresholds
ax2.axhline(y=0.7, color='green', linestyle='--', alpha=0.5, linewidth=1)
ax2.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, linewidth=1)
ax2.text(len(model_order)-1, 0.72, 'Strong', ha='right', va='bottom', color='green', fontsize=10)
ax2.text(len(model_order)-1, 0.52, 'Moderate', ha='right', va='bottom', color='orange', fontsize=10)

# Add annotation for reasoning models vs. standard
ax2.annotate('Reasoning-specialized models', xy=(1.5, 0.85), xytext=(2, 0.9),
            arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8),
            ha='center', va='center', fontsize=11, fontweight='bold')

ax2.annotate('Standard LLM', xy=(5, 0.61), xytext=(4.5, 0.45),
            arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8),
            ha='center', va='center', fontsize=11, fontweight='bold')

# Adjust layout
plt.tight_layout()
fig.subplots_adjust(bottom=0.15)

# Add caption as figure text
fig.text(0.5, 0.01, 
         "Figure 7: Comparison of LLM confidence scores across different models. Left: Mean confidence scores for correct (solid) and incorrect (hatched) predictions, with individual predictions shown as scattered points. Right: Correlation coefficients between confidence scores and actual prediction accuracy, showing reasoning-specialized models (Claude 3.7 Sonnet-Thinking, DeepSeek-R1, etc.) demonstrate stronger confidence-accuracy correlation than standard LLMs.", 
         ha='center', fontsize=10, style='italic', wrap=True)

# Save the figure
plt.savefig('llm_confidence_correlation.png', dpi=300, bbox_inches='tight')

# Display the plot
plt.show()

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Patch
import matplotlib.gridspec as gridspec

# Set the style for a clean, professional look
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("deep")
plt.rcParams['font.size'] = 12

# Define the experimental conditions with descriptive names
experimental_conditions = [
    {"name": "Simulated Data", "color": "#10B981", "short_name": "Sim"},
    {"name": "Simulated Data + Wrong Guess", "color": "#3B82F6", "short_name": "Sim+WG"},
    {"name": "Simulated Data + Noise", "color": "#6366F1", "short_name": "Sim+Noise"},
    {"name": "Experimental Data", "color": "#F59E0B", "short_name": "Exp"},
    {"name": "Experimental Data + Wrong Guess", "color": "#EC4899", "short_name": "Exp+WG"},
    {"name": "Experimental Data d4", "color": "#8B5CF6", "short_name": "Exp-d4"}
]

# Function to generate dummy confidence data for DeepSeek-R1 across different conditions
def generate_dummy_data(conditions, n_samples=30):
    all_data = []
    correlations = []
    
    # Define different characteristics for each condition to make the visualization interesting
    condition_params = {
        "Simulated Data": {"acc": 0.94, "corr": 0.82, "correct_conf": (0.80, 0.10), "incorrect_conf": (0.45, 0.15)},
        "Simulated Data + Wrong Guess": {"acc": 0.85, "corr": 0.76, "correct_conf": (0.75, 0.12), "incorrect_conf": (0.50, 0.15)},
        "Simulated Data + Noise": {"acc": 0.85, "corr": 0.74, "correct_conf": (0.70, 0.15), "incorrect_conf": (0.48, 0.18)},
        "Experimental Data": {"acc": 0.68, "corr": 0.69, "correct_conf": (0.68, 0.15), "incorrect_conf": (0.42, 0.18)},
        "Experimental Data + Wrong Guess": {"acc": 0.24, "corr": 0.58, "correct_conf": (0.60, 0.18), "incorrect_conf": (0.45, 0.20)},
        "Experimental Data d4": {"acc": 0.65, "corr": 0.66, "correct_conf": (0.65, 0.18), "incorrect_conf": (0.40, 0.20)}
    }
    
    np.random.seed(42)  # For reproducibility
    
    for i, condition in enumerate(conditions):
        params = condition_params[condition["name"]]
        
        # Calculate number of correct and incorrect predictions based on accuracy
        n_correct = int(n_samples * params["acc"])
        n_incorrect = n_samples - n_correct
        
        # Generate confidence scores for correct predictions
        correct_conf_mean, correct_conf_std = params["correct_conf"]
        correct_confidence = np.clip(
            np.random.normal(loc=correct_conf_mean, scale=correct_conf_std, size=n_correct),
            0.0, 1.0  # Clip to valid range
        )
        
        # Generate confidence scores for incorrect predictions
        incorrect_conf_mean, incorrect_conf_std = params["incorrect_conf"]
        incorrect_confidence = np.clip(
            np.random.normal(loc=incorrect_conf_mean, scale=incorrect_conf_std, size=n_incorrect),
            0.0, 1.0  # Clip to valid range
        )
        
        # Create DataFrames for correct and incorrect predictions
        correct_df = pd.DataFrame({
            'condition': condition["name"],
            'short_name': condition["short_name"],
            'confidence': correct_confidence,
            'correct': True,
            'color': condition["color"]
        })
        
        incorrect_df = pd.DataFrame({
            'condition': condition["name"],
            'short_name': condition["short_name"],
            'confidence': incorrect_confidence,
            'correct': False,
            'color': condition["color"]
        })
        
        # Combine the DataFrames
        condition_df = pd.concat([correct_df, incorrect_df])
        all_data.append(condition_df)
        
        # Calculate correlation coefficient between confidence and correctness
        # In real data analysis, you would use the actual confidence scores and correctness
        correlations.append({
            'condition': condition["name"],
            'short_name': condition["short_name"],
            'correlation': params["corr"],  # Using predefined correlation for dummy data
            'color': condition["color"],
            'accuracy': params["acc"]
        })
    
    # Combine all data
    combined_df = pd.concat(all_data)
    correlations_df = pd.DataFrame(correlations)
    
    return combined_df, correlations_df

# Generate dummy data
full_dataset, correlations_df = generate_dummy_data(experimental_conditions)

# Calculate statistics for bar chart
grouped_stats = []
for condition in experimental_conditions:
    condition_data = full_dataset[full_dataset['condition'] == condition["name"]]
    
    # For correct predictions
    correct_data = condition_data[condition_data['correct']]
    correct_mean = correct_data['confidence'].mean() if len(correct_data) > 0 else 0
    correct_std = correct_data['confidence'].std() if len(correct_data) > 0 else 0
    
    # For incorrect predictions
    incorrect_data = condition_data[~condition_data['correct']]
    incorrect_mean = incorrect_data['confidence'].mean() if len(incorrect_data) > 0 else 0
    incorrect_std = incorrect_data['confidence'].std() if len(incorrect_data) > 0 else 0
    
    # Add to stats
    grouped_stats.append({
        'condition': condition["name"],
        'short_name': condition["short_name"],
        'correct_mean': correct_mean,
        'correct_std': correct_std,
        'incorrect_mean': incorrect_mean,
        'incorrect_std': incorrect_std,
        'color': condition["color"],
        'correlation': correlations_df[correlations_df['condition'] == condition["name"]]['correlation'].values[0],
        'accuracy': correlations_df[correlations_df['condition'] == condition["name"]]['accuracy'].values[0]
    })

# Create DataFrame for the stats
stats_df = pd.DataFrame(grouped_stats)

# Create figure with gridspec for flexible layout
fig = plt.figure(figsize=(16, 10))
gs = gridspec.GridSpec(2, 2, width_ratios=[2, 1], height_ratios=[4, 1])

# Main confidence score plot (top left)
ax1 = fig.add_subplot(gs[0, 0])

# Bar width and positions
bar_width = 0.35
r1 = np.arange(len(experimental_conditions))
r2 = [x + bar_width for x in r1]

# Create bars for correct and incorrect predictions
correct_bars = ax1.bar(r1, stats_df['correct_mean'], width=bar_width, 
                       yerr=stats_df['correct_std'], capsize=5, 
                       color=stats_df['color'],
                       alpha=0.7, label='Correct Predictions')

incorrect_bars = ax1.bar(r2, stats_df['incorrect_mean'], width=bar_width, 
                        yerr=stats_df['incorrect_std'], capsize=5, 
                        color=stats_df['color'],
                        alpha=0.3, hatch='///', label='Incorrect Predictions')

# Add individual data points as scatter
for i, condition in enumerate(experimental_conditions):
    condition_data = full_dataset[full_dataset['condition'] == condition["name"]]
    
    # Get correct predictions and add jitter to x-position
    correct_data = condition_data[condition_data['correct']]
    if len(correct_data) > 0:
        jitter = np.random.normal(0, 0.05, size=len(correct_data))
        ax1.scatter(r1[i] + jitter, correct_data['confidence'], 
                    color=correct_data['color'].iloc[0], alpha=0.4, s=30)
    
    # Get incorrect predictions and add jitter to x-position
    incorrect_data = condition_data[~condition_data['correct']]
    if len(incorrect_data) > 0:
        jitter = np.random.normal(0, 0.05, size=len(incorrect_data))
        ax1.scatter(r2[i] + jitter, incorrect_data['confidence'], 
                    color=incorrect_data['color'].iloc[0], alpha=0.4, s=30, marker='x')

# Customize confidence score plot
ax1.set_xlabel('Experimental Condition', fontsize=13, fontweight='bold')
ax1.set_ylabel('DeepSeek-R1 Confidence Score', fontsize=13, fontweight='bold')
ax1.set_title('DeepSeek-R1 Confidence Scores Across Experimental Conditions', fontsize=16, fontweight='bold')
ax1.set_xticks([r + bar_width/2 for r in range(len(experimental_conditions))])
ax1.set_xticklabels([c["short_name"] for c in experimental_conditions], rotation=0, fontsize=12)
ax1.set_ylim([0, 1.05])
ax1.legend(loc='upper right', fontsize=12)

# Add correlation coefficients above each condition
for i, condition in enumerate(experimental_conditions):
    corr = stats_df[stats_df['condition'] == condition["name"]]['correlation'].values[0]
    ax1.text(i + bar_width/2, 1.0, f'r = {corr:.2f}', ha='center', va='bottom', fontweight='bold')

# Correlation plot (top right)
ax2 = fig.add_subplot(gs[0, 1])

# Create correlation bars
correlation_bars = ax2.bar(range(len(experimental_conditions)), 
                          stats_df['correlation'],
                          color=stats_df['color'])

# Add confidence intervals (simulated for dummy data)
np.random.seed(42)
corr_errors = np.random.uniform(0.03, 0.08, size=len(experimental_conditions))
ax2.errorbar(range(len(experimental_conditions)), stats_df['correlation'], yerr=corr_errors, 
             fmt='none', color='black', capsize=5)

# Customize correlation plot
ax2.set_xlabel('Condition', fontsize=13, fontweight='bold')
ax2.set_ylabel('Confidence-Accuracy\nCorrelation (r)', fontsize=13, fontweight='bold')
ax2.set_title('DeepSeek-R1 Confidence-Accuracy\nCorrelation by Condition', fontsize=16, fontweight='bold')
ax2.set_xticks(range(len(experimental_conditions)))
ax2.set_xticklabels([c["short_name"] for c in experimental_conditions], rotation=45, ha='right', fontsize=12)
ax2.set_ylim([0, 1.0])

# Add values above bars
for i, v in enumerate(stats_df['correlation']):
    ax2.text(i, v + 0.03, f'{v:.2f}', ha='center', fontsize=10, fontweight='bold')

# Add horizontal reference lines
ax2.axhline(y=0.7, color='green', linestyle='--', alpha=0.5, linewidth=1)
ax2.axhline(y=0.5, color='orange', linestyle='--', alpha=0.5, linewidth=1)
ax2.text(len(experimental_conditions)-1, 0.72, 'Strong', ha='right', va='bottom', color='green', fontsize=10)
ax2.text(len(experimental_conditions)-1, 0.52, 'Moderate', ha='right', va='bottom', color='orange', fontsize=10)

# Accuracy subplot (bottom spanning both columns)
ax3 = fig.add_subplot(gs[1, :])

# Create accuracy bars
accuracy_bars = ax3.bar(range(len(experimental_conditions)), 
                      stats_df['accuracy'] * 100,  # Convert to percentage
                      color=stats_df['color'],
                      alpha=0.8)

# Customize accuracy plot
ax3.set_xlabel('Experimental Condition', fontsize=13, fontweight='bold')
ax3.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
ax3.set_title('DeepSeek-R1 Structure Prediction Accuracy by Condition', fontsize=14, fontweight='bold')
ax3.set_xticks(range(len(experimental_conditions)))
ax3.set_xticklabels([c["name"] for c in experimental_conditions], rotation=45, ha='right', fontsize=11)
ax3.set_ylim([0, 100])

# Add percentage labels above bars
for i, v in enumerate(stats_df['accuracy']):
    ax3.text(i, v * 100 + 3, f'{v*100:.1f}%', ha='center', fontsize=10, fontweight='bold')

# Add a horizontal reference line at 50%
ax3.axhline(y=50, color='red', linestyle='--', alpha=0.5, linewidth=1)
ax3.text(0, 52, 'Random Guess (50%)', ha='left', va='bottom', color='red', fontsize=10)

# Adjust layout
plt.tight_layout(h_pad=2, w_pad=3)

# Add caption
fig.text(0.5, 0.01, 
        "Figure X: Analysis of DeepSeek-R1's confidence scoring across experimental conditions. Top left: Mean confidence scores for correct (solid) and incorrect (hatched) predictions, with individual predictions shown as scattered points and correlation coefficients (r) displayed above each condition. Top right: Confidence-accuracy correlation coefficients with error bars. Bottom: Structure prediction accuracy across conditions, showing the model's performance degradation from simulated to experimental data, particularly with wrong initial guesses.",
        ha='center', fontsize=11, style='italic', wrap=True)

# Save the figure with high resolution
plt.savefig('deepseek_confidence_analysis.png', dpi=300, bbox_inches='tight')

# Display the plot
plt.show()

### V2 Dummy

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import glob
import os
from pathlib import Path

# Function to load reference data
def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

# Function to extract base sample ID (before underscore)
def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

# Function to analyze a single JSON file
def analyze_json_file(file_path, reference_data):
    """
    Analyze confidence scores from a single JSON file.
    
    Args:
        file_path: Path to the JSON file
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        Dictionary containing analysis results
    """
    try:
        # Load JSON data
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        # Extract sample ID and get true SMILES
        sample_id = data.get("molecule_data", {}).get("sample_id")
        if not sample_id:
            return None
            
        base_sample_id = get_base_sample_id(sample_id)
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            return None
        
        # Get DeepSeek results
        deepseek_results = data.get("final_analysis", {}).get("llm_responses", {}).get("deepseek", {})
        if not deepseek_results or "parsed_results" not in deepseek_results:
            return None
            
        # Get candidates and their confidence scores
        candidates = deepseek_results["parsed_results"].get("candidates", [])
        if not candidates:
            return None
        
        # Sort candidates by confidence score (highest first)
        sorted_candidates = sorted(candidates, key=lambda x: x.get("confidence_score", 0), reverse=True)
        
        # Get the top candidate (highest confidence)
        top_candidate = sorted_candidates[0]
        top_candidate_smiles = top_candidate.get("smiles")
        top_confidence = top_candidate.get("confidence_score", 0)
        
        # Check if top candidate is correct
        is_correct = (top_candidate_smiles == true_smiles)
        
        # Find position of correct molecule in candidate list (if present)
        correct_position = None
        correct_confidence = None
        
        for i, candidate in enumerate(sorted_candidates, 1):
            if candidate.get("smiles") == true_smiles:
                correct_position = i
                correct_confidence = candidate.get("confidence_score", 0)
                break
        
        # Get experiment type from directory name
        dir_name = os.path.basename(os.path.dirname(file_path))
        if "sim+noise" in dir_name.lower():
            experiment = "Sim Data + Noise"
        elif "sim_aug" in dir_name.lower() or "sim_d1_aug" in dir_name.lower():
            experiment = "Sim Data + Wrong Guess"
        elif "sim" in dir_name.lower():
            experiment = "Sim Data"
        elif "exp_d1_aug" in dir_name.lower():
            experiment = "Exp Data + Wrong Guess"
        elif "exp_d4" in dir_name.lower():
            experiment = "Exp Data d4"
        elif "exp" in dir_name.lower():
            experiment = "Exp Data"
        else:
            experiment = "Unknown"
        
        # Create confidence scores for all candidates
        all_confidences = []
        for i, candidate in enumerate(sorted_candidates):
            is_true = (candidate.get("smiles") == true_smiles)
            all_confidences.append({
                "position": i + 1,
                "confidence": candidate.get("confidence_score", 0),
                "is_true": is_true,
                "smiles": candidate.get("smiles")
            })
            
        return {
            "sample_id": sample_id,
            "true_smiles": true_smiles,
            "experiment": experiment,
            "top_candidate_smiles": top_candidate_smiles,
            "top_confidence": top_confidence,
            "is_correct": is_correct,
            "correct_position": correct_position,
            "correct_confidence": correct_confidence,
            "all_confidences": all_confidences
        }
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

# Function to process a directory of JSON files
def process_directory(json_dir, reference_csv):
    """
    Process all JSON files in a directory.
    
    Args:
        json_dir: Directory containing JSON files
        reference_csv: Path to reference CSV file
        
    Returns:
        DataFrame with analysis results
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Find all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Process each file
    results = []
    all_confidences = []
    
    for file_path in json_files:
        result = analyze_json_file(file_path, reference_data)
        if result:
            results.append(result)
            
            # Extract confidence scores for all candidates
            for conf_data in result["all_confidences"]:
                all_confidences.append({
                    "sample_id": result["sample_id"],
                    "experiment": result["experiment"],
                    "position": conf_data["position"],
                    "confidence": conf_data["confidence"],
                    "is_true": conf_data["is_true"],
                    "smiles": conf_data["smiles"]
                })
    
    # Convert to DataFrames
    results_df = pd.DataFrame(results)
    confidences_df = pd.DataFrame(all_confidences)
    
    return results_df, confidences_df

# Function to create violin plots of confidence distributions
def plot_confidence_distributions(results_df, output_file=None):
    """
    Create violin plots showing confidence score distributions.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    plt.figure(figsize=(14, 8))
    
    # Create a DataFrame for plotting
    plot_data = []
    
    # For correct predictions (model's top pick was correct)
    correct_preds = results_df[results_df["is_correct"] == True]
    for _, row in correct_preds.iterrows():
        plot_data.append({
            "experiment": row["experiment"],
            "confidence": row["top_confidence"],
            "prediction": "Correct"
        })
    
    # For incorrect predictions (model's top pick was wrong)
    incorrect_preds = results_df[results_df["is_correct"] == False]
    for _, row in incorrect_preds.iterrows():
        plot_data.append({
            "experiment": row["experiment"],
            "confidence": row["top_confidence"],
            "prediction": "Incorrect"
        })
    
    # Convert to DataFrame
    plot_df = pd.DataFrame(plot_data)
    
    # Create the violin plot
    ax = sns.violinplot(x="experiment", y="confidence", hue="prediction", 
                    data=plot_df, split=True, inner="quartile",
                    palette={"Correct": "mediumseagreen", "Incorrect": "tomato"})
    
    # Add individual data points
    sns.stripplot(x="experiment", y="confidence", hue="prediction", 
               data=plot_df, dodge=True, alpha=0.3, size=4, linewidth=1,
               palette={"Correct": "darkgreen", "Incorrect": "darkred"})
    
    # Customize the plot
    plt.title("DeepSeek-R1 Confidence Scores by Prediction Correctness", fontsize=16)
    plt.xlabel("Experimental Condition", fontsize=14)
    plt.ylabel("Confidence Score", fontsize=14)
    plt.ylim(0, 1.05)
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create legend without duplicate items
    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:2], labels[:2], title="Prediction", fontsize=12)
    
    # Add text with statistics for each experiment
    for i, exp in enumerate(plot_df["experiment"].unique()):
        exp_data = plot_df[plot_df["experiment"] == exp]
        correct_data = exp_data[exp_data["prediction"] == "Correct"]
        incorrect_data = exp_data[exp_data["prediction"] == "Incorrect"]
        
        n_correct = len(correct_data)
        n_incorrect = len(incorrect_data)
        total = n_correct + n_incorrect
        
        avg_correct = correct_data["confidence"].mean() if len(correct_data) > 0 else 0
        avg_incorrect = incorrect_data["confidence"].mean() if len(incorrect_data) > 0 else 0
        
        text = f"n={total}\nCorrect: {n_correct} ({n_correct/total*100:.1f}%)\nIncorrect: {n_incorrect} ({n_incorrect/total*100:.1f}%)"
        plt.text(i, 0.05, text, ha='center', fontsize=9)
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to create a histogram comparing confidence scores
def plot_confidence_histogram(results_df, output_file=None):
    """
    Create a histogram showing confidence score distributions.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    plt.figure(figsize=(14, 8))
    
    # Create separate data for correct and incorrect predictions
    correct_preds = results_df[results_df["is_correct"] == True]["top_confidence"]
    incorrect_preds = results_df[results_df["is_correct"] == False]["top_confidence"]
    
    # Create the histogram
    bins = np.linspace(0, 1, 21)  # 20 bins from 0 to 1
    
    plt.hist(correct_preds, bins=bins, alpha=0.7, color='mediumseagreen', 
             label=f'Correct Predictions (n={len(correct_preds)})')
    plt.hist(incorrect_preds, bins=bins, alpha=0.7, color='tomato', 
             label=f'Incorrect Predictions (n={len(incorrect_preds)})')
    
    # Customize the plot
    plt.title("Distribution of DeepSeek-R1 Confidence Scores by Prediction Correctness", fontsize=16)
    plt.xlabel("Confidence Score", fontsize=14)
    plt.ylabel("Frequency", fontsize=14)
    plt.grid(linestyle='--', alpha=0.7)
    plt.legend(fontsize=12)
    
    # Calculate and display statistics
    mean_correct = correct_preds.mean() if len(correct_preds) > 0 else 0
    mean_incorrect = incorrect_preds.mean() if len(incorrect_preds) > 0 else 0
    
    stats_text = (
        f"Mean confidence when correct: {mean_correct:.3f}\n"
        f"Mean confidence when incorrect: {mean_incorrect:.3f}\n"
        f"Difference: {mean_correct - mean_incorrect:.3f}"
    )
    
    plt.text(0.05, 0.95, stats_text, transform=plt.gca().transAxes, 
             verticalalignment='top', fontsize=12,
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to create a boxplot showing confidence by position

def plot_confidence_by_position(confidences_df, output_file=None):
    """
    Create a boxplot showing confidence scores by position.
    
    Args:
        confidences_df: DataFrame with confidence scores
        output_file: Path to save the figure (if None, display it)
    """
    # Limit to first 5 positions
    df = confidences_df[confidences_df["position"] <= 5].copy()
    
    # Convert boolean is_true to string to avoid palette issues
    df['is_true'] = df['is_true'].astype(str)
    
    # Set up the figure
    plt.figure(figsize=(12, 7))
    
    # Create the boxplot
    ax = sns.boxplot(x="position", y="confidence", hue="is_true", 
                 data=df, palette={"True": "mediumseagreen", "False": "tomato"})
    
    # Add individual data points
    sns.stripplot(x="position", y="confidence", hue="is_true", 
               data=df, dodge=True, alpha=0.3, size=4, linewidth=1,
               palette={"True": "darkgreen", "False": "darkred"})
    
    # Customize the plot
    plt.title("DeepSeek-R1 Confidence Scores by Candidate Position", fontsize=16)
    plt.xlabel("Candidate Position (by confidence ranking)", fontsize=14)
    plt.ylabel("Confidence Score", fontsize=14)
    plt.ylim(0, 1.05)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create legend without duplicate items
    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:2], ["Correct Structure", "Incorrect Structure"], title="Structure", fontsize=12)
    
    # Add text with count statistics for each position
    for i in range(1, 6):
        pos_data = df[df["position"] == i]
        true_data = pos_data[pos_data["is_true"] == "True"]
        false_data = pos_data[pos_data["is_true"] == "False"]
        
        n_true = len(true_data)
        n_false = len(false_data)
        total = n_true + n_false
        
        text = f"n={total}\nCorrect: {n_true}\nIncorrect: {n_false}"
        plt.text(i-1, 0.05, text, ha='center', fontsize=9)
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to calculate and plot ROC curve for different confidence thresholds
def plot_roc_curve(results_df, output_file=None):
    """
    Create an ROC curve for different confidence thresholds.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    plt.figure(figsize=(10, 8))
    
    # Calculate ROC curve points
    thresholds = np.linspace(0, 1, 101)  # 101 points from 0 to 1
    tpr_list = []  # True Positive Rate (sensitivity)
    fpr_list = []  # False Positive Rate (1 - specificity)
    
    for threshold in thresholds:
        # Number of positive examples
        positives = len(results_df[results_df["is_correct"] == True])
        # Number of negative examples
        negatives = len(results_df[results_df["is_correct"] == False])
        
        # True positives: correct predictions with confidence >= threshold
        tp = len(results_df[(results_df["is_correct"] == True) & 
                           (results_df["top_confidence"] >= threshold)])
        
        # False positives: incorrect predictions with confidence >= threshold
        fp = len(results_df[(results_df["is_correct"] == False) & 
                           (results_df["top_confidence"] >= threshold)])
        
        # Calculate rates
        tpr = tp / positives if positives > 0 else 0  # Sensitivity
        fpr = fp / negatives if negatives > 0 else 0  # 1 - Specificity
        
        tpr_list.append(tpr)
        fpr_list.append(fpr)
    
    # Plot the ROC curve
    plt.plot(fpr_list, tpr_list, 'b-', linewidth=2)
    
    # Add the diagonal reference line (random classifier)
    plt.plot([0, 1], [0, 1], 'r--', linewidth=1.5)
    
    # Calculate AUC (Area Under Curve)
    auc = 0
    for i in range(len(fpr_list) - 1):
        auc += (fpr_list[i+1] - fpr_list[i]) * (tpr_list[i+1] + tpr_list[i]) / 2
    
    # Customize the plot
    plt.title(f"ROC Curve for DeepSeek-R1 Confidence Scores\nAUC = {auc:.3f}", fontsize=16)
    plt.xlabel("False Positive Rate (1 - Specificity)", fontsize=14)
    plt.ylabel("True Positive Rate (Sensitivity)", fontsize=14)
    plt.grid(linestyle='--', alpha=0.7)
    
    # Add threshold indicators for notable points
    for t in [0.3, 0.5, 0.7, 0.9]:
        idx = int(t * 100)
        plt.plot(fpr_list[idx], tpr_list[idx], 'ko', markersize=6)
        plt.text(fpr_list[idx]+0.02, tpr_list[idx]-0.02, f"t={t}", fontsize=10)
    
    # Add statistics table
    stats_data = []
    for t in [0.3, 0.5, 0.7, 0.9]:
        correct_above = len(results_df[(results_df["is_correct"] == True) & 
                                     (results_df["top_confidence"] >= t)])
        incorrect_above = len(results_df[(results_df["is_correct"] == False) & 
                                       (results_df["top_confidence"] >= t)])
        correct_below = len(results_df[(results_df["is_correct"] == True) & 
                                     (results_df["top_confidence"] < t)])
        incorrect_below = len(results_df[(results_df["is_correct"] == False) & 
                                       (results_df["top_confidence"] < t)])
        
        total_above = correct_above + incorrect_above
        precision = correct_above / total_above if total_above > 0 else 0
        
        stats_data.append({
            "threshold": t,
            "correct_above": correct_above,
            "incorrect_above": incorrect_above,
            "precision": precision
        })
    
    stats_table = pd.DataFrame(stats_data)
    table_text = "Threshold statistics:\n"
    for _, row in stats_table.iterrows():
        table_text += f"t={row['threshold']:.1f}: {row['correct_above']} correct, {row['incorrect_above']} incorrect above threshold (precision: {row['precision']:.2f})\n"
    
    plt.text(0.05, 0.05, table_text, transform=plt.gca().transAxes, 
             verticalalignment='bottom', fontsize=10,
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to create an experiment-specific confidence visualization
def plot_experiment_confidence(results_df, output_file=None):
    """
    Create a detailed visualization of confidence scores by experiment.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharex=True, sharey=True)
    axes = axes.flatten()
    
    # Set the color palette
    experiment_colors = {
        "Sim Data": "#10B981",
        "Sim Data + Wrong Guess": "#3B82F6",
        "Sim Data + Noise": "#6366F1",
        "Exp Data": "#F59E0B",
        "Exp Data + Wrong Guess": "#EC4899",
        "Exp Data d4": "#8B5CF6"
    }
    
    # Process each experiment
    experiments = results_df["experiment"].unique()
    
    for i, exp in enumerate(experiments):
        if i >= len(axes):
            break
            
        ax = axes[i]
        exp_data = results_df[results_df["experiment"] == exp]
        
        # Separate correct and incorrect predictions
        correct_preds = exp_data[exp_data["is_correct"] == True]["top_confidence"]
        incorrect_preds = exp_data[exp_data["is_correct"] == False]["top_confidence"]
        
        # Create the histogram
        bins = np.linspace(0, 1, 21)  # 20 bins from 0 to 1
        
        ax.hist(correct_preds, bins=bins, alpha=0.7, color=experiment_colors[exp], 
                label=f'Correct (n={len(correct_preds)})')
        ax.hist(incorrect_preds, bins=bins, alpha=0.5, hatch='///', color=experiment_colors[exp], 
                label=f'Incorrect (n={len(incorrect_preds)})')
        
        # Calculate statistics
        mean_correct = correct_preds.mean() if len(correct_preds) > 0 else 0
        mean_incorrect = incorrect_preds.mean() if len(incorrect_preds) > 0 else 0
        
        # Calculate suggested threshold for this experiment
        if len(correct_preds) > 0 and len(incorrect_preds) > 0:
            # Simple approach: average of the means
            suggested_threshold = (mean_correct + mean_incorrect) / 2
        else:
            suggested_threshold = 0.5
        
        # Add vertical line for suggested threshold
        ax.axvline(x=suggested_threshold, color='red', linestyle='--', linewidth=1.5)
        
        # Calculate accuracy at suggested threshold
        correct_above = len(exp_data[(exp_data["is_correct"] == True) & 
                                  (exp_data["top_confidence"] >= suggested_threshold)])
        correct_below = len(exp_data[(exp_data["is_correct"] == True) & 
                                  (exp_data["top_confidence"] < suggested_threshold)])
        incorrect_above = len(exp_data[(exp_data["is_correct"] == False) & 
                                    (exp_data["top_confidence"] >= suggested_threshold)])
        incorrect_below = len(exp_data[(exp_data["is_correct"] == False) & 
                                    (exp_data["top_confidence"] < suggested_threshold)])
        
        total = len(exp_data)
        accuracy = (correct_above + incorrect_below) / total if total > 0 else 0
        
        # Add statistics text
        stats_text = (
            f"Mean conf. (correct): {mean_correct:.2f}\n"
            f"Mean conf. (incorrect): {mean_incorrect:.2f}\n"
            f"Suggested threshold: {suggested_threshold:.2f}\n"
            f"Acc @ threshold: {accuracy:.2f}"
        )
        
        ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, 
                verticalalignment='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # Customize the plot
        ax.set_title(exp, fontsize=14)
        ax.grid(linestyle='--', alpha=0.7)
        ax.legend(fontsize=10)
        
        # Calculate and display separation score (area between distributions)
        # This is a simplified measure of how well the confidence scores separate correct from incorrect
        if len(correct_preds) > 0 and len(incorrect_preds) > 0:
            # Use difference in means normalized by pooled standard deviation (Cohen's d)
            std_correct = correct_preds.std()
            std_incorrect = incorrect_preds.std()
            pooled_std = np.sqrt((std_correct**2 + std_incorrect**2) / 2)
            
            if pooled_std > 0:
                cohens_d = abs(mean_correct - mean_incorrect) / pooled_std
                separation_text = f"Separation score: {cohens_d:.2f}"
                ax.text(0.95, 0.05, separation_text, transform=ax.transAxes, 
                        horizontalalignment='right', verticalalignment='bottom', fontsize=10,
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Set common labels
    fig.text(0.5, 0.02, 'Confidence Score', ha='center', fontsize=14)
    fig.text(0.02, 0.5, 'Frequency', va='center', rotation='vertical', fontsize=14)
    
    fig.suptitle("DeepSeek-R1 Confidence Score Distributions by Experimental Condition", fontsize=18)
    
    plt.tight_layout(rect=[0.03, 0.03, 1, 0.97])
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Main function to run the analysis
def main(use_dummy_data=False):
    """
    Run the complete analysis.
    
    Args:
        use_dummy_data: If True, use generated dummy data instead of processing JSON files
    """
    if use_dummy_data:
        print("Generating dummy data for visualization testing...")
        combined_results, combined_confidences = create_dummy_data()
        
    else:
        # Define input and output paths
        json_dirs = [
            "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean"
        ]
        
        reference_csv = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        
        # Process each directory and combine results
        all_results = []
        all_confidences = []
        
        for json_dir in json_dirs:
            print(f"Processing {json_dir}...")
            results_df, confidences_df = process_directory(json_dir, reference_csv)
            
            if not results_df.empty:
                all_results.append(results_df)
                all_confidences.append(confidences_df)
        
        # Combine all results
        combined_results = pd.concat(all_results) if all_results else pd.DataFrame()
        combined_confidences = pd.concat(all_confidences) if all_confidences else pd.DataFrame()
    
    if combined_results.empty:
        print("No valid results found.")
        return
    
    print(f"Analyzing {len(combined_results)} items...")
    
    # Generate and save the plots
    print("Generating confidence distribution plots...")
    plot_confidence_distributions(combined_results, "deepseek_confidence_distributions.png")
    
    print("Generating confidence histogram...")
    plot_confidence_histogram(combined_results, "deepseek_confidence_histogram.png")
    
    print("Generating confidence by position plot...")
    plot_confidence_by_position(combined_confidences, "deepseek_confidence_by_position.png")
    
    print("Generating ROC curve...")
    plot_roc_curve(combined_results, "deepseek_confidence_roc.png")
    
    print("Generating experimental condition analysis...")
    plot_experiment_confidence(combined_results, "deepseek_experiment_confidence.png")
    
    print("Analysis complete. All plots have been saved.")

# Example usage
if __name__ == "__main__":
    # Use dummy data for testing
    # main(use_dummy_data=True)
    
    # Or use actual JSON files
    main(use_dummy_data=False)

# Function to create dummy data for testing visualizations
def create_dummy_data():
    """Create dummy data for testing visualizations when JSON files aren't available."""
    # Define experimental conditions
    experiments = [
        "Sim Data",
        "Sim Data + Wrong Guess",
        "Sim Data + Noise",
        "Exp Data",
        "Exp Data + Wrong Guess",
        "Exp Data d4"
    ]
    
    # Define accuracy and mean confidence for each condition
    condition_params = {
        "Sim Data": {"acc": 0.94, "correct_conf": (0.80, 0.10), "incorrect_conf": (0.45, 0.15)},
        "Sim Data + Wrong Guess": {"acc": 0.85, "correct_conf": (0.75, 0.12), "incorrect_conf": (0.50, 0.15)},
        "Sim Data + Noise": {"acc": 0.85, "correct_conf": (0.70, 0.15), "incorrect_conf": (0.48, 0.18)},
        "Exp Data": {"acc": 0.68, "correct_conf": (0.68, 0.15), "incorrect_conf": (0.42, 0.18)},
        "Exp Data + Wrong Guess": {"acc": 0.24, "correct_conf": (0.60, 0.18), "incorrect_conf": (0.45, 0.20)},
        "Exp Data d4": {"acc": 0.65, "correct_conf": (0.65, 0.18), "incorrect_conf": (0.40, 0.20)}
    }
    
    # Generate data
    np.random.seed(42)  # For reproducibility
    
    results = []
    all_confidences = []
    
    for exp in experiments:
        params = condition_params[exp]
        
        # Generate 50 samples for each experiment
        n_samples = 50
        n_correct = int(n_samples * params["acc"])
        n_incorrect = n_samples - n_correct
        
        # Generate confidence scores for correct predictions
        correct_conf_mean, correct_conf_std = params["correct_conf"]
        correct_confidences = np.clip(
            np.random.normal(loc=correct_conf_mean, scale=correct_conf_std, size=n_correct),
            0.0, 1.0  # Clip to valid range
        )
        
        # Generate confidence scores for incorrect predictions
        incorrect_conf_mean, incorrect_conf_std = params["incorrect_conf"]
        incorrect_confidences = np.clip(
            np.random.normal(loc=incorrect_conf_mean, scale=incorrect_conf_std, size=n_incorrect),
            0.0, 1.0  # Clip to valid range
        )
        
        # Generate sample IDs
        base_id = exp.replace(" ", "_").lower()
        
        # Add correct predictions
        for i in range(n_correct):
            sample_id = f"{base_id}_{i+1}"
            
            # Generate dummy SMILES
            true_smiles = f"C1CC{i}CC1"
            pred_smiles = true_smiles  # For correct predictions, they're the same
            
            results.append({
                "sample_id": sample_id,
                "true_smiles": true_smiles,
                "experiment": exp,
                "top_candidate_smiles": pred_smiles,
                "top_confidence": correct_confidences[i],
                "is_correct": True,
                "correct_position": 1,
                "correct_confidence": correct_confidences[i]
            })
            
            # Add confidences for top 5 positions
            all_confidences.append({
                "sample_id": sample_id,
                "experiment": exp,
                "position": 1,
                "confidence": correct_confidences[i],
                "is_true": True,
                "smiles": true_smiles
            })
            
            # Generate lower confidences for positions 2-5 (all incorrect)
            for pos in range(2, 6):
                conf = np.clip(correct_confidences[i] * (0.9 - 0.1 * pos), 0.05, 0.95)
                all_confidences.append({
                    "sample_id": sample_id,
                    "experiment": exp,
                    "position": pos,
                    "confidence": conf,
                    "is_true": False,
                    "smiles": f"C1CC{i}C{pos}C1"  # Different SMILES
                })
        
        # Add incorrect predictions
        for i in range(n_incorrect):
            sample_id = f"{base_id}_{n_correct+i+1}"
            
            # Generate dummy SMILES
            true_smiles = f"C1CC{n_correct+i}CC1"
            pred_smiles = f"C1CC{n_correct+i}NC1"  # Different from true
            
            # Find position of correct molecule (randomly between 2-5 or None)
            correct_pos = np.random.choice([2, 3, 4, 5, None], p=[0.3, 0.2, 0.1, 0.1, 0.3])
            
            results.append({
                "sample_id": sample_id,
                "true_smiles": true_smiles,
                "experiment": exp,
                "top_candidate_smiles": pred_smiles,
                "top_confidence": incorrect_confidences[i],
                "is_correct": False,
                "correct_position": correct_pos,
                "correct_confidence": 0.4 if correct_pos else None
            })
            
            # Add confidence for top position (incorrect)
            all_confidences.append({
                "sample_id": sample_id,
                "experiment": exp,
                "position": 1,
                "confidence": incorrect_confidences[i],
                "is_true": False,
                "smiles": pred_smiles
            })
            
            # Generate confidences for positions 2-5
            for pos in range(2, 6):
                is_true = (pos == correct_pos)
                conf = 0.4 if is_true else np.clip(incorrect_confidences[i] * (0.8 - 0.1 * pos), 0.05, 0.95)
                all_confidences.append({
                    "sample_id": sample_id,
                    "experiment": exp,
                    "position": pos,
                    "confidence": conf,
                    "is_true": is_true,
                    "smiles": true_smiles if is_true else f"C1CC{n_correct+i}N{pos}C1"
                })
    
    # Convert to DataFrames
    results_df = pd.DataFrame(results)
    confidences_df = pd.DataFrame(all_confidences)
    
    return results_df, confidences_df

In [None]:
import json
import os
import glob

def inspect_json_structure(file_path):
    """
    Deeply inspect a JSON file to locate DeepSeek results and candidate information.
    
    Args:
        file_path: Path to the JSON file
    """
    print(f"\nInspecting file: {os.path.basename(file_path)}")
    
    try:
        # Load JSON data
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        # Check sample ID
        sample_id = data.get("molecule_data", {}).get("sample_id")
        if sample_id:
            print(f"Sample ID: {sample_id}")
        else:
            print("No sample_id found at expected path")
        
        # Search for deepseek data
        print("Searching for DeepSeek data...")
        
        # Look for common keywords
        deepseek_paths = []
        
        def search_deepseek(obj, path=""):
            if isinstance(obj, dict):
                for k, v in obj.items():
                    new_path = f"{path}.{k}" if path else k
                    
                    # Check if this key might be related to deepseek
                    if 'deepseek' in str(k).lower():
                        deepseek_paths.append((new_path, v))
                    
                    # Check if this looks like a model response with candidates
                    if k == 'candidates' and isinstance(v, list) and len(v) > 0 and 'smiles' in v[0]:
                        deepseek_paths.append((new_path, v))
                    
                    # Check if this is a parsed result
                    if k == 'parsed_results' and isinstance(v, dict) and 'candidates' in v:
                        deepseek_paths.append((new_path, v))
                    
                    # Continue searching
                    search_deepseek(v, new_path)
            
            elif isinstance(obj, list):
                for i, item in enumerate(obj):
                    new_path = f"{path}[{i}]"
                    search_deepseek(item, new_path)
        
        search_deepseek(data)
        
        if not deepseek_paths:
            print("No DeepSeek data found")
            
            # Check if there are any occurrences of 'deepseek' string in the JSON
            json_str = json.dumps(data)
            if 'deepseek' in json_str.lower():
                print("However, 'deepseek' string was found in the file")
                
                # Try to find the context around it
                idx = json_str.lower().find('deepseek')
                context = json_str[max(0, idx-50):min(len(json_str), idx+150)]
                print(f"Context: ...{context}...")
            return
        
        # Print all identified DeepSeek paths
        print(f"Found {len(deepseek_paths)} potential DeepSeek data locations:")
        
        for i, (path, value) in enumerate(deepseek_paths):
            print(f"\nPotential DeepSeek data {i+1} at path: {path}")
            
            # Check if this is candidates directly
            if path.endswith('candidates'):
                if isinstance(value, list):
                    print(f"Found {len(value)} candidates")
                    if len(value) > 0:
                        first_candidate = value[0]
                        # Check what keys are available in the candidate
                        if isinstance(first_candidate, dict):
                            print(f"First candidate keys: {list(first_candidate.keys())}")
                            
                            # Check for SMILES and confidence
                            if 'smiles' in first_candidate:
                                print(f"First candidate SMILES: {first_candidate['smiles']}")
                            if 'confidence_score' in first_candidate:
                                print(f"First candidate confidence: {first_candidate['confidence_score']}")
                        else:
                            print(f"First candidate is not a dictionary: {type(first_candidate)}")
                else:
                    print(f"Candidates is not a list: {type(value)}")
            
            # Check if this is parsed_results
            elif path.endswith('parsed_results'):
                if isinstance(value, dict) and 'candidates' in value:
                    candidates = value['candidates']
                    if isinstance(candidates, list):
                        print(f"Found {len(candidates)} candidates in parsed_results")
                        if len(candidates) > 0:
                            first_candidate = candidates[0]
                            if isinstance(first_candidate, dict):
                                print(f"First candidate keys: {list(first_candidate.keys())}")
                                
                                # Check for SMILES and confidence
                                if 'smiles' in first_candidate:
                                    print(f"First candidate SMILES: {first_candidate['smiles']}")
                                if 'confidence_score' in first_candidate:
                                    print(f"First candidate confidence: {first_candidate['confidence_score']}")
                            else:
                                print(f"First candidate is not a dictionary: {type(first_candidate)}")
                    else:
                        print(f"candidates in parsed_results is not a list: {type(candidates)}")
                else:
                    print("parsed_results does not contain candidates or is not a dictionary")
            
            # If it's just a deepseek key
            else:
                # Try to determine the structure
                if isinstance(value, dict):
                    print(f"Dictionary with keys: {list(value.keys())}")
                    
                    # Look for candidates or parsed_results
                    if 'candidates' in value:
                        candidates = value['candidates']
                        if isinstance(candidates, list):
                            print(f"Found {len(candidates)} candidates")
                            # Examine first candidate
                            if candidates and isinstance(candidates[0], dict):
                                print(f"First candidate keys: {list(candidates[0].keys())}")
                    elif 'parsed_results' in value:
                        parsed_results = value['parsed_results']
                        if isinstance(parsed_results, dict) and 'candidates' in parsed_results:
                            candidates = parsed_results['candidates']
                            if isinstance(candidates, list):
                                print(f"Found {len(candidates)} candidates in parsed_results")
                    elif 'raw_response' in value:
                        print("Contains raw_response which might include candidate information")
                else:
                    print(f"Value is of type: {type(value)}")

    except Exception as e:
        print(f"Error inspecting file: {str(e)}")

# Test on some files
json_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean"
json_files = glob.glob(os.path.join(json_dir, "*.json"))

# Inspect a few files
for file_path in json_files[:5]:
    inspect_json_structure(file_path)

In [None]:
main(use_dummy_data=True)

### V3 Real data LLM Deepseek

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import glob
import os
from pathlib import Path

# Function to load reference data
def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

# Function to access nested dictionary keys safely
def get_nested(data, keys, default=None):
    """
    Safely navigate nested dictionary structure.
    
    Args:
        data: Dictionary to navigate
        keys: List of keys to follow
        default: Value to return if path doesn't exist
        
    Returns:
        Value at the nested location or default if not found
    """
    temp = data
    for key in keys:
        if isinstance(temp, dict) and key in temp:
            temp = temp[key]
        else:
            return default
    return temp

# Function to analyze a single JSON file
def analyze_json_file(file_path, reference_data):
    """
    Analyze confidence scores from a single JSON file.
    
    Args:
        file_path: Path to the JSON file
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        Dictionary containing analysis results
    """
    try:
        # Load JSON data
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        # Extract sample ID and get true SMILES
        sample_id = get_nested(data, ["molecule_data", "sample_id"])
        if not sample_id:
            print(f"Warning: No sample_id found in {file_path}")
            return None
            
        # Get base sample ID (before underscore if present)
        base_sample_id = sample_id.split('_')[0] if '_' in sample_id else sample_id
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return None
        
        # Navigate to DeepSeek candidates using the correct path
        candidates = get_nested(data, [
            "analysis_results", 
            "final_analysis", 
            "llm_responses", 
            "deepseek", 
            "parsed_results", 
            "candidates"
        ])
        
        if not candidates:
            print(f"Warning: No DeepSeek candidates found in {file_path}")
            return None
        
        # Sort candidates by confidence score (highest first)
        sorted_candidates = sorted(candidates, key=lambda x: x.get("confidence_score", 0), reverse=True)
        
        # Get the top candidate (highest confidence)
        top_candidate = sorted_candidates[0]
        top_candidate_smiles = top_candidate.get("smiles")
        top_confidence = top_candidate.get("confidence_score", 0)
        
        # Check if top candidate is correct
        is_correct = (top_candidate_smiles == true_smiles)
        
        # Find position of correct molecule in candidate list (if present)
        correct_position = None
        correct_confidence = None
        
        for i, candidate in enumerate(sorted_candidates, 1):
            if candidate.get("smiles") == true_smiles:
                correct_position = i
                correct_confidence = candidate.get("confidence_score", 0)
                break
        
        # Get experiment type from directory name
        dir_name = os.path.basename(os.path.dirname(file_path))
        if "sim+noise" in dir_name.lower():
            experiment = "Sim Data + Noise"
        elif "sim_aug" in dir_name.lower() or "sim_d1_aug" in dir_name.lower():
            experiment = "Sim Data + Wrong Guess"
        elif "sim" in dir_name.lower():
            experiment = "Sim Data"
        elif "exp_d1_aug" in dir_name.lower():
            experiment = "Exp Data + Wrong Guess"
        elif "exp_d4" in dir_name.lower():
            experiment = "Exp Data d4"
        elif "exp" in dir_name.lower():
            experiment = "Exp Data"
        else:
            experiment = "Unknown"
        
        # Create confidence scores for all candidates
        all_confidences = []
        for i, candidate in enumerate(sorted_candidates):
            is_true = (candidate.get("smiles") == true_smiles)
            all_confidences.append({
                "position": i + 1,
                "confidence": candidate.get("confidence_score", 0),
                "is_true": is_true,
                "smiles": candidate.get("smiles")
            })
            
        return {
            "sample_id": sample_id,
            "true_smiles": true_smiles,
            "experiment": experiment,
            "top_candidate_smiles": top_candidate_smiles,
            "top_confidence": top_confidence,
            "is_correct": is_correct,
            "correct_position": correct_position,
            "correct_confidence": correct_confidence,
            "all_confidences": all_confidences
        }
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

# Function to process a directory of JSON files
def process_directory(json_dir, reference_csv):
    """
    Process all JSON files in a directory.
    
    Args:
        json_dir: Directory containing JSON files
        reference_csv: Path to reference CSV file
        
    Returns:
        DataFrame with analysis results
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Find all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Process each file
    results = []
    all_confidences = []
    
    for file_path in json_files:
        result = analyze_json_file(file_path, reference_data)
        if result:
            results.append(result)
            
            # Extract confidence scores for all candidates
            for conf_data in result["all_confidences"]:
                all_confidences.append({
                    "sample_id": result["sample_id"],
                    "experiment": result["experiment"],
                    "position": conf_data["position"],
                    "confidence": conf_data["confidence"],
                    "is_true": conf_data["is_true"],
                    "smiles": conf_data["smiles"]
                })
    
    # Convert to DataFrames
    results_df = pd.DataFrame(results) if results else pd.DataFrame()
    confidences_df = pd.DataFrame(all_confidences) if all_confidences else pd.DataFrame()
    
    print(f"Successfully analyzed {len(results)} files with valid DeepSeek results")
    
    return results_df, confidences_df

# Function to plot confidence distributions
def plot_confidence_distributions(results_df, output_file=None):
    """
    Create violin plots showing confidence score distributions.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    plt.figure(figsize=(14, 8))
    
    # Create a DataFrame for plotting
    plot_data = []
    
    # For correct predictions (model's top pick was correct)
    correct_preds = results_df[results_df["is_correct"] == True]
    for _, row in correct_preds.iterrows():
        plot_data.append({
            "experiment": row["experiment"],
            "confidence": row["top_confidence"],
            "prediction": "Correct"
        })
    
    # For incorrect predictions (model's top pick was wrong)
    incorrect_preds = results_df[results_df["is_correct"] == False]
    for _, row in incorrect_preds.iterrows():
        plot_data.append({
            "experiment": row["experiment"],
            "confidence": row["top_confidence"],
            "prediction": "Incorrect"
        })
    
    # Convert to DataFrame
    plot_df = pd.DataFrame(plot_data)
    
    # Create the violin plot
    ax = sns.violinplot(x="experiment", y="confidence", hue="prediction", 
                    data=plot_df, split=True, inner="quartile",
                    palette={"Correct": "mediumseagreen", "Incorrect": "tomato"})
    
    # Add individual data points
    sns.stripplot(x="experiment", y="confidence", hue="prediction", 
               data=plot_df, dodge=True, alpha=0.3, size=4, linewidth=1,
               palette={"Correct": "darkgreen", "Incorrect": "darkred"})
    
    # Customize the plot
    plt.title("DeepSeek-R1 Confidence Scores by Prediction Correctness", fontsize=16)
    plt.xlabel("Experimental Condition", fontsize=14)
    plt.ylabel("Confidence Score", fontsize=14)
    plt.ylim(0, 1.05)
    plt.xticks(rotation=45)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create legend without duplicate items
    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:2], labels[:2], title="Prediction", fontsize=12)
    
    # Add text with statistics for each experiment
    for i, exp in enumerate(plot_df["experiment"].unique()):
        exp_data = plot_df[plot_df["experiment"] == exp]
        correct_data = exp_data[exp_data["prediction"] == "Correct"]
        incorrect_data = exp_data[exp_data["prediction"] == "Incorrect"]
        
        n_correct = len(correct_data)
        n_incorrect = len(incorrect_data)
        total = n_correct + n_incorrect
        
        avg_correct = correct_data["confidence"].mean() if len(correct_data) > 0 else 0
        avg_incorrect = incorrect_data["confidence"].mean() if len(incorrect_data) > 0 else 0
        
        text = f"n={total}\nCorrect: {n_correct} ({n_correct/total*100:.1f}%)\nIncorrect: {n_incorrect} ({n_incorrect/total*100:.1f}%)"
        plt.text(i, 0.05, text, ha='center', fontsize=9)
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to create a histogram comparing confidence scores
def plot_confidence_histogram(results_df, output_file=None):
    """
    Create a histogram showing confidence score distributions.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    plt.figure(figsize=(14, 8))
    
    # Create separate data for correct and incorrect predictions
    correct_preds = results_df[results_df["is_correct"] == True]["top_confidence"]
    incorrect_preds = results_df[results_df["is_correct"] == False]["top_confidence"]
    
    # Create the histogram
    bins = np.linspace(0, 1, 21)  # 20 bins from 0 to 1
    
    plt.hist(correct_preds, bins=bins, alpha=0.7, color='mediumseagreen', 
             label=f'Correct Predictions (n={len(correct_preds)})')
    plt.hist(incorrect_preds, bins=bins, alpha=0.7, color='tomato', 
             label=f'Incorrect Predictions (n={len(incorrect_preds)})')
    
    # Customize the plot
    plt.title("Distribution of DeepSeek-R1 Confidence Scores by Prediction Correctness", fontsize=16)
    plt.xlabel("Confidence Score", fontsize=14)
    plt.ylabel("Frequency", fontsize=14)
    plt.grid(linestyle='--', alpha=0.7)
    plt.legend(fontsize=12)
    
    # Calculate and display statistics
    mean_correct = correct_preds.mean() if len(correct_preds) > 0 else 0
    mean_incorrect = incorrect_preds.mean() if len(incorrect_preds) > 0 else 0
    
    stats_text = (
        f"Mean confidence when correct: {mean_correct:.3f}\n"
        f"Mean confidence when incorrect: {mean_incorrect:.3f}\n"
        f"Difference: {mean_correct - mean_incorrect:.3f}"
    )
    
    plt.text(0.05, 0.95, stats_text, transform=plt.gca().transAxes, 
             verticalalignment='top', fontsize=12,
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to create a boxplot showing confidence by position
def plot_confidence_by_position(confidences_df, output_file=None):
    """
    Create a boxplot showing confidence scores by position.
    
    Args:
        confidences_df: DataFrame with confidence scores
        output_file: Path to save the figure (if None, display it)
    """
    # Limit to first 5 positions
    df = confidences_df[confidences_df["position"] <= 5].copy()
    
    # Convert boolean is_true to string to avoid palette issues
    df['is_true'] = df['is_true'].astype(str)
    
    # Set up the figure
    plt.figure(figsize=(12, 7))
    
    # Create the boxplot
    ax = sns.boxplot(x="position", y="confidence", hue="is_true", 
                 data=df, palette={"True": "mediumseagreen", "False": "tomato"})
    
    # Add individual data points
    sns.stripplot(x="position", y="confidence", hue="is_true", 
               data=df, dodge=True, alpha=0.3, size=4, linewidth=1,
               palette={"True": "darkgreen", "False": "darkred"})
    
    # Customize the plot
    plt.title("DeepSeek-R1 Confidence Scores by Candidate Position", fontsize=16)
    plt.xlabel("Candidate Position (by confidence ranking)", fontsize=14)
    plt.ylabel("Confidence Score", fontsize=14)
    plt.ylim(0, 1.05)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create legend without duplicate items
    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:2], ["Correct Structure", "Incorrect Structure"], title="Structure", fontsize=12)
    
    # Add text with count statistics for each position
    for i in range(1, 6):
        pos_data = df[df["position"] == i]
        true_data = pos_data[pos_data["is_true"] == "True"]
        false_data = pos_data[pos_data["is_true"] == "False"]
        
        n_true = len(true_data)
        n_false = len(false_data)
        total = n_true + n_false
        
        text = f"n={total}\nCorrect: {n_true}\nIncorrect: {n_false}"
        plt.text(i-1, 0.05, text, ha='center', fontsize=9)
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to calculate and plot ROC curve for different confidence thresholds
def plot_roc_curve(results_df, output_file=None):
    """
    Create an ROC curve for different confidence thresholds.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    plt.figure(figsize=(10, 8))
    
    # Calculate ROC curve points
    thresholds = np.linspace(0, 1, 101)  # 101 points from 0 to 1
    tpr_list = []  # True Positive Rate (sensitivity)
    fpr_list = []  # False Positive Rate (1 - specificity)
    
    for threshold in thresholds:
        # Number of positive examples
        positives = len(results_df[results_df["is_correct"] == True])
        # Number of negative examples
        negatives = len(results_df[results_df["is_correct"] == False])
        
        # True positives: correct predictions with confidence >= threshold
        tp = len(results_df[(results_df["is_correct"] == True) & 
                           (results_df["top_confidence"] >= threshold)])
        
        # False positives: incorrect predictions with confidence >= threshold
        fp = len(results_df[(results_df["is_correct"] == False) & 
                           (results_df["top_confidence"] >= threshold)])
        
        # Calculate rates
        tpr = tp / positives if positives > 0 else 0  # Sensitivity
        fpr = fp / negatives if negatives > 0 else 0  # 1 - Specificity
        
        tpr_list.append(tpr)
        fpr_list.append(fpr)
    
    # Plot the ROC curve
    plt.plot(fpr_list, tpr_list, 'b-', linewidth=2)
    
    # Add the diagonal reference line (random classifier)
    plt.plot([0, 1], [0, 1], 'r--', linewidth=1.5)
    
    # Calculate AUC (Area Under Curve)
    auc = 0
    for i in range(len(fpr_list) - 1):
        auc += (fpr_list[i+1] - fpr_list[i]) * (tpr_list[i+1] + tpr_list[i]) / 2
    
    # Customize the plot
    plt.title(f"ROC Curve for DeepSeek-R1 Confidence Scores\nAUC = {auc:.3f}", fontsize=16)
    plt.xlabel("False Positive Rate (1 - Specificity)", fontsize=14)
    plt.ylabel("True Positive Rate (Sensitivity)", fontsize=14)
    plt.grid(linestyle='--', alpha=0.7)
    
    # Add threshold indicators for notable points
    for t in [0.3, 0.5, 0.7, 0.9]:
        idx = int(t * 100)
        plt.plot(fpr_list[idx], tpr_list[idx], 'ko', markersize=6)
        plt.text(fpr_list[idx]+0.02, tpr_list[idx]-0.02, f"t={t}", fontsize=10)
    
    # Add statistics table
    stats_data = []
    for t in [0.3, 0.5, 0.7, 0.9]:
        correct_above = len(results_df[(results_df["is_correct"] == True) & 
                                     (results_df["top_confidence"] >= t)])
        incorrect_above = len(results_df[(results_df["is_correct"] == False) & 
                                       (results_df["top_confidence"] >= t)])
        correct_below = len(results_df[(results_df["is_correct"] == True) & 
                                     (results_df["top_confidence"] < t)])
        incorrect_below = len(results_df[(results_df["is_correct"] == False) & 
                                       (results_df["top_confidence"] < t)])
        
        total_above = correct_above + incorrect_above
        precision = correct_above / total_above if total_above > 0 else 0
        
        stats_data.append({
            "threshold": t,
            "correct_above": correct_above,
            "incorrect_above": incorrect_above,
            "precision": precision
        })
    
    stats_table = pd.DataFrame(stats_data)
    table_text = "Threshold statistics:\n"
    for _, row in stats_table.iterrows():
        table_text += f"t={row['threshold']:.1f}: {row['correct_above']} correct, {row['incorrect_above']} incorrect above threshold (precision: {row['precision']:.2f})\n"
    
    plt.text(0.05, 0.05, table_text, transform=plt.gca().transAxes, 
             verticalalignment='bottom', fontsize=10,
             bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Function to create an experiment-specific confidence visualization
def plot_experiment_confidence(results_df, output_file=None):
    """
    Create a detailed visualization of confidence scores by experiment.
    
    Args:
        results_df: DataFrame with analysis results
        output_file: Path to save the figure (if None, display it)
    """
    # Set up the figure
    fig, axes = plt.subplots(2, 3, figsize=(18, 12), sharex=True, sharey=True)
    axes = axes.flatten()
    
    # Set the color palette
    experiment_colors = {
        "Sim Data": "#10B981",
        "Sim Data + Wrong Guess": "#3B82F6",
        "Sim Data + Noise": "#6366F1",
        "Exp Data": "#F59E0B",
        "Exp Data + Wrong Guess": "#EC4899",
        "Exp Data d4": "#8B5CF6"
    }
    
    # Process each experiment
    experiments = results_df["experiment"].unique()
    
    for i, exp in enumerate(experiments):
        if i >= len(axes):
            break
            
        ax = axes[i]
        exp_data = results_df[results_df["experiment"] == exp]
        
        # Separate correct and incorrect predictions
        correct_preds = exp_data[exp_data["is_correct"] == True]["top_confidence"]
        incorrect_preds = exp_data[exp_data["is_correct"] == False]["top_confidence"]
        
        # Create the histogram
        bins = np.linspace(0, 1, 21)  # 20 bins from 0 to 1
        
        ax.hist(correct_preds, bins=bins, alpha=0.7, color=experiment_colors.get(exp, "#333333"), 
                label=f'Correct (n={len(correct_preds)})')
        ax.hist(incorrect_preds, bins=bins, alpha=0.5, hatch='///', color=experiment_colors.get(exp, "#333333"), 
                label=f'Incorrect (n={len(incorrect_preds)})')
        
        # Calculate statistics
        mean_correct = correct_preds.mean() if len(correct_preds) > 0 else 0
        mean_incorrect = incorrect_preds.mean() if len(incorrect_preds) > 0 else 0
        
        # Calculate suggested threshold for this experiment
        if len(correct_preds) > 0 and len(incorrect_preds) > 0:
            # Simple approach: average of the means
            suggested_threshold = (mean_correct + mean_incorrect) / 2
        else:
            suggested_threshold = 0.5
        
        # Add vertical line for suggested threshold
        ax.axvline(x=suggested_threshold, color='red', linestyle='--', linewidth=1.5)
        
        # Calculate accuracy at suggested threshold
        correct_above = len(exp_data[(exp_data["is_correct"] == True) & 
                                  (exp_data["top_confidence"] >= suggested_threshold)])
        correct_below = len(exp_data[(exp_data["is_correct"] == True) & 
                                  (exp_data["top_confidence"] < suggested_threshold)])
        incorrect_above = len(exp_data[(exp_data["is_correct"] == False) & 
                                    (exp_data["top_confidence"] >= suggested_threshold)])
        incorrect_below = len(exp_data[(exp_data["is_correct"] == False) & 
                                    (exp_data["top_confidence"] < suggested_threshold)])
        
        total = len(exp_data)
        accuracy = (correct_above + incorrect_below) / total if total > 0 else 0
        
        # Add statistics text
        stats_text = (
            f"Mean conf. (correct): {mean_correct:.2f}\n"
            f"Mean conf. (incorrect): {mean_incorrect:.2f}\n"
            f"Suggested threshold: {suggested_threshold:.2f}\n"
            f"Acc @ threshold: {accuracy:.2f}"
        )
        
        ax.text(0.05, 0.95, stats_text, transform=ax.transAxes, 
                verticalalignment='top', fontsize=10,
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
        
        # Customize the plot
        ax.set_title(exp, fontsize=14)
        ax.grid(linestyle='--', alpha=0.7)
        ax.legend(fontsize=10)
        
        # Calculate and display separation score (area between distributions)
        # This is a simplified measure of how well the confidence scores separate correct from incorrect
        if len(correct_preds) > 0 and len(incorrect_preds) > 0:
            # Use difference in means normalized by pooled standard deviation (Cohen's d)
            std_correct = correct_preds.std()
            std_incorrect = incorrect_preds.std()
            pooled_std = np.sqrt((std_correct**2 + std_incorrect**2) / 2)
            
            if pooled_std > 0:
                cohens_d = abs(mean_correct - mean_incorrect) / pooled_std
                separation_text = f"Separation score: {cohens_d:.2f}"
                ax.text(0.95, 0.05, separation_text, transform=ax.transAxes, 
                        horizontalalignment='right', verticalalignment='bottom', fontsize=10,
                        bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Set common labels
    fig.text(0.5, 0.02, 'Confidence Score', ha='center', fontsize=14)
    fig.text(0.02, 0.5, 'Frequency', va='center', rotation='vertical', fontsize=14)
    
    fig.suptitle("DeepSeek-R1 Confidence Score Distributions by Experimental Condition", fontsize=18)
    
    plt.tight_layout(rect=[0.03, 0.03, 1, 0.97])
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()

# Main function to run the analysis
def main(use_dummy_data=False):
    """
    Run the complete analysis.
    
    Args:
        use_dummy_data: If True, use generated dummy data instead of processing JSON files
    """
    # Define input and output paths
    json_dirs = [
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean"
    ]
    
    reference_csv = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
    
    # Process each directory and combine results
    all_results = []
    all_confidences = []
    
    for json_dir in json_dirs:
        print(f"Processing {json_dir}...")
        results_df, confidences_df = process_directory(json_dir, reference_csv)
        
        if not results_df.empty:
            all_results.append(results_df)
            all_confidences.append(confidences_df)
    
    # Combine all results
    combined_results = pd.concat(all_results) if all_results else pd.DataFrame()
    combined_confidences = pd.concat(all_confidences) if all_confidences else pd.DataFrame()
    
    if combined_results.empty:
        print("No valid results found.")
        return
    
    print(f"Analyzing {len(combined_results)} total items...")
    
    # Generate and save the plots
    print("Generating confidence distribution plots...")
    plot_confidence_distributions(combined_results, "deepseek_confidence_distributions.png")
    
    print("Generating confidence histogram...")
    plot_confidence_histogram(combined_results, "deepseek_confidence_histogram.png")
    
    print("Generating confidence by position plot...")
    plot_confidence_by_position(combined_confidences, "deepseek_confidence_by_position.png")
    
    print("Generating ROC curve...")
    plot_roc_curve(combined_results, "deepseek_confidence_roc.png")
    
    print("Generating experimental condition analysis...")
    plot_experiment_confidence(combined_results, "deepseek_experiment_confidence.png")
    
    print("Analysis complete. All plots have been saved.")
    
    # Print summary statistics
    if not combined_results.empty:
        total = len(combined_results)
        correct = combined_results["is_correct"].sum()
        incorrect = total - correct
        
        print(f"\nSUMMARY STATISTICS:")
        print(f"Total samples analyzed: {total}")
        print(f"Correct predictions: {correct} ({correct/total*100:.1f}%)")
        print(f"Incorrect predictions: {incorrect} ({incorrect/total*100:.1f}%)")
        
        # By experiment
        print("\nBreakdown by experiment:")
        exp_groups = combined_results.groupby("experiment")
        exp_counts = exp_groups.size()
        exp_correct = exp_groups["is_correct"].sum()
        exp_accuracy = exp_groups["is_correct"].mean() * 100
        
        for exp in exp_counts.index:
            print(f"  {exp}: {exp_counts[exp]} samples, {exp_correct[exp]} correct ({exp_accuracy[exp]:.1f}%)")
        
        # Confidence statistics
        print("\nConfidence score statistics:")
        mean_conf_correct = combined_results[combined_results["is_correct"]]["top_confidence"].mean()
        mean_conf_incorrect = combined_results[~combined_results["is_correct"]]["top_confidence"].mean()
        
        print(f"  Mean confidence when correct: {mean_conf_correct:.3f}")
        print(f"  Mean confidence when incorrect: {mean_conf_incorrect:.3f}")
        print(f"  Difference: {mean_conf_correct - mean_conf_incorrect:.3f}")
        
        # Suggest a threshold
        if not np.isnan(mean_conf_correct) and not np.isnan(mean_conf_incorrect):
            suggested_threshold = (mean_conf_correct + mean_conf_incorrect) / 2
            print(f"\nSuggested confidence threshold: {suggested_threshold:.2f}")
            
            # Calculate accuracy at this threshold
            correct_above = ((combined_results["is_correct"]) & 
                           (combined_results["top_confidence"] >= suggested_threshold)).sum()
            incorrect_below = ((~combined_results["is_correct"]) & 
                             (combined_results["top_confidence"] < suggested_threshold)).sum()
            
            threshold_accuracy = (correct_above + incorrect_below) / total
            print(f"Accuracy at suggested threshold: {threshold_accuracy:.2f}")
        
    return combined_results, combined_confidences

In [None]:

# Example usage
if __name__ == "__main__":
    # Use dummy data for testing
    # main(use_dummy_data=True)
    
    # Or use actual JSON files
    main(use_dummy_data=False)

### V3.1 Real plotting

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import glob
import os
from pathlib import Path

# Function to load reference data
def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

# Function to access nested dictionary keys safely
def get_nested(data, keys, default=None):
    """
    Safely navigate nested dictionary structure.
    """
    temp = data
    for key in keys:
        if isinstance(temp, dict) and key in temp:
            temp = temp[key]
        else:
            return default
    return temp

# Function to analyze a single JSON file
def analyze_json_file(file_path, reference_data):
    """
    Analyze confidence scores from a single JSON file.
    """
    try:
        # Load JSON data
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        # Extract sample ID and get true SMILES
        sample_id = get_nested(data, ["molecule_data", "sample_id"])
        if not sample_id:
            print(f"Warning: No sample_id found in {file_path}")
            return None
            
        # Get base sample ID (before underscore if present)
        base_sample_id = sample_id.split('_')[0] if '_' in sample_id else sample_id
        true_smiles = reference_data.get(base_sample_id)
        
        if true_smiles is None:
            print(f"Warning: No reference SMILES found for sample_id {base_sample_id}")
            return None
        
        # Navigate to DeepSeek candidates using the correct path
        candidates = get_nested(data, [
            "analysis_results", 
            "final_analysis", 
            "llm_responses", 
            "deepseek", 
            "parsed_results", 
            "candidates"
        ])
        
        if not candidates:
            print(f"Warning: No DeepSeek candidates found in {file_path}")
            return None
        
        # Sort candidates by confidence score (highest first)
        sorted_candidates = sorted(candidates, key=lambda x: x.get("confidence_score", 0), reverse=True)
        
        # Get the top candidate (highest confidence)
        top_candidate = sorted_candidates[0]
        top_candidate_smiles = top_candidate.get("smiles")
        top_confidence = top_candidate.get("confidence_score", 0)
        
        # Check if top candidate is correct
        is_correct = (top_candidate_smiles == true_smiles)
        
        # Find position of correct molecule in candidate list (if present)
        correct_position = None
        correct_confidence = None
        
        for i, candidate in enumerate(sorted_candidates, 1):
            if candidate.get("smiles") == true_smiles:
                correct_position = i
                correct_confidence = candidate.get("confidence_score", 0)
                break
        
        # Get experiment type from directory name
        dir_name = os.path.basename(os.path.dirname(file_path))
        if "sim+noise" in dir_name.lower():
            experiment = "Sim Data + Noise"
        elif "sim_aug" in dir_name.lower() or "sim_d1_aug" in dir_name.lower():
            experiment = "Sim Data + Wrong Guess"
        elif "sim" in dir_name.lower():
            experiment = "Sim Data"
        elif "exp_d1_aug" in dir_name.lower():
            experiment = "Exp Data + Wrong Guess"
        elif "exp_d4" in dir_name.lower():
            experiment = "Exp Data d4"
        elif "exp" in dir_name.lower():
            experiment = "Exp Data"
        else:
            experiment = "Unknown"
        
        # Create confidence scores for all candidates
        all_confidences = []
        for i, candidate in enumerate(sorted_candidates):
            is_true = (candidate.get("smiles") == true_smiles)
            all_confidences.append({
                "position": i + 1,
                "confidence": candidate.get("confidence_score", 0),
                "is_true": is_true,
                "smiles": candidate.get("smiles")
            })
            
        return {
            "sample_id": sample_id,
            "true_smiles": true_smiles,
            "experiment": experiment,
            "top_candidate_smiles": top_candidate_smiles,
            "top_confidence": top_confidence,
            "is_correct": is_correct,
            "correct_position": correct_position,
            "correct_confidence": correct_confidence,
            "all_confidences": all_confidences
        }
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return None

# Function to process a directory of JSON files
def process_directory(json_dir, reference_csv):
    """
    Process all JSON files in a directory.
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Find all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Process each file
    results = []
    all_confidences = []
    
    for file_path in json_files:
        result = analyze_json_file(file_path, reference_data)
        if result:
            results.append(result)
            
            # Extract confidence scores for all candidates
            for conf_data in result["all_confidences"]:
                all_confidences.append({
                    "sample_id": result["sample_id"],
                    "experiment": result["experiment"],
                    "position": conf_data["position"],
                    "confidence": conf_data["confidence"],
                    "is_true": conf_data["is_true"],
                    "smiles": conf_data["smiles"]
                })
    
    # Convert to DataFrames
    results_df = pd.DataFrame(results) if results else pd.DataFrame()
    confidences_df = pd.DataFrame(all_confidences) if all_confidences else pd.DataFrame()
    
    print(f"Successfully analyzed {len(results)} files with valid DeepSeek results")
    
    return results_df, confidences_df

def plot_confidence_by_position_with_stats(confidences_df, output_file=None):
    """
    Create a boxplot showing confidence scores by position and print detailed statistics.
    """
    # Limit to first 5 positions
    df = confidences_df[confidences_df["position"] <= 5].copy()
    
    # Convert boolean is_true to string to avoid palette issues
    df['is_true'] = df['is_true'].astype(str)
    
    # Calculate key statistics for the text
    position_stats = []
    
    for pos in range(1, 6):
        pos_data = df[df["position"] == pos]
        true_data = pos_data[pos_data["is_true"] == "True"]
        false_data = pos_data[pos_data["is_true"] == "False"]
        
        n_true = len(true_data)
        n_false = len(false_data)
        total = n_true + n_false
        
        mean_true = true_data["confidence"].mean() if n_true > 0 else 0
        mean_false = false_data["confidence"].mean() if n_false > 0 else 0
        median_true = true_data["confidence"].median() if n_true > 0 else 0
        median_false = false_data["confidence"].median() if n_false > 0 else 0
        median_all = pos_data["confidence"].median()
        
        correct_percentage = (n_true / total * 100) if total > 0 else 0
        
        position_stats.append({
            "position": pos,
            "total": total,
            "n_true": n_true,
            "n_false": n_false,
            "correct_percentage": correct_percentage,
            "mean_true": mean_true,
            "mean_false": mean_false,
            "median_true": median_true,
            "median_false": median_false,
            "median_all": median_all
        })
    
    # Print stats in a nicely formatted way
    print("\nDeepSeek-R1 Confidence Score Statistics by Position:")
    print("-" * 100)
    print(f"{'Position':<10} {'Total':<8} {'Correct':<10} {'Incorrect':<10} {'% Correct':<10} {'Mean (Correct)':<15} {'Mean (Incorrect)':<15} {'Median (All)':<15}")
    print("-" * 100)
    
    for stat in position_stats:
        print(f"{stat['position']:<10} {stat['total']:<8} {stat['n_true']:<10} {stat['n_false']:<10} {stat['correct_percentage']:.1f}%{' ':<5} {stat['mean_true']:.3f}{' ':<7} {stat['mean_false']:.3f}{' ':<7} {stat['median_all']:.2f}{' ':<7}")
    
    print("-" * 100)
    
    # Print specific values mentioned in the text
    print("\nKey Values for Text:")
    print(f"Position 1 correct percentage: {position_stats[0]['correct_percentage']:.1f}% ({position_stats[0]['n_true']}/{position_stats[0]['total']})")
    print(f"Position 1 mean confidence for correct structures: {position_stats[0]['mean_true']:.3f}")
    print(f"Position 1 mean confidence for incorrect structures: {position_stats[0]['mean_false']:.3f}")
    print(f"Position 1 median confidence score: ~{position_stats[0]['median_all']:.2f}")
    print(f"Position 5 median confidence score: ~{position_stats[4]['median_all']:.2f}")
    print(f"Position 5 correct percentage: {position_stats[4]['correct_percentage']:.1f}% ({position_stats[4]['n_true']}/{position_stats[4]['total']})")
    
    # Set up the figure
    plt.figure(figsize=(12, 7))
    
    # Create the boxplot
    ax = sns.boxplot(x="position", y="confidence", hue="is_true", 
                 data=df, palette={"True": "mediumseagreen", "False": "tomato"})
    
    # Add individual data points
    sns.stripplot(x="position", y="confidence", hue="is_true", 
               data=df, dodge=True, alpha=0.3, size=4, linewidth=1,
               palette={"True": "darkgreen", "False": "darkred"})
    
    # Customize the plot
    plt.title("DeepSeek-R1 Confidence Scores by Candidate Position", fontsize=16)
    plt.xlabel("Candidate Position (by confidence ranking)", fontsize=14)
    plt.ylabel("Confidence Score", fontsize=14)
    plt.ylim(0, 1.05)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create legend without duplicate items
    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:2], ["Correct Structure", "Incorrect Structure"], title="Structure", fontsize=12)
    
    # Add text with count statistics for each position
    for i in range(1, 6):
        pos_data = df[df["position"] == i]
        true_data = pos_data[pos_data["is_true"] == "True"]
        false_data = pos_data[pos_data["is_true"] == "False"]
        
        n_true = len(true_data)
        n_false = len(false_data)
        total = n_true + n_false
        
        text = f"n={total}\nCorrect: {n_true}\nIncorrect: {n_false}"
        plt.text(i-1, 0.05, text, ha='center', fontsize=9)
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()
    
    return position_stats

def main():
    """Main function to run the analysis."""
    # Define input and output paths
    json_dirs = [
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean"
    ]
    
    reference_csv = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
    
    # Process each directory and combine results
    all_results = []
    all_confidences = []
    
    for json_dir in json_dirs:
        print(f"Processing {json_dir}...")
        results_df, confidences_df = process_directory(json_dir, reference_csv)
        
        if not results_df.empty:
            all_results.append(results_df)
            all_confidences.append(confidences_df)
    
    # Combine all results
    combined_results = pd.concat(all_results) if all_results else pd.DataFrame()
    combined_confidences = pd.concat(all_confidences) if all_confidences else pd.DataFrame()
    
    if combined_results.empty:
        print("No valid results found.")
        return
    
    print(f"Analyzing {len(combined_results)} total items...")
    
    # Generate the plot and statistics
    stats = plot_confidence_by_position_with_stats(combined_confidences, "deepseek_confidence_by_position.png")
    
    return combined_results, combined_confidences, stats

if __name__ == "__main__":
    main()

### V4 HSQC - note done

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import glob
import os
from pathlib import Path

# Function to load reference data
def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

# Function to get base sample ID (before underscore)
def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

# Function to safely navigate nested dictionary keys
def get_nested(data, keys, default=None):
    """Safely navigate nested dictionary structure."""
    temp = data
    for key in keys:
        if isinstance(temp, dict) and key in temp:
            temp = temp[key]
        else:
            return default
    return temp

# Function to process a single JSON file and extract HSQC rankings
def process_json_file(file_path, reference_data):
    """
    Process a single JSON file, extract HSQC rankings, and determine correctness.
    
    Args:
        file_path: Path to the JSON file
        reference_data: Dictionary mapping sample IDs to true SMILES
        
    Returns:
        List of dictionaries with position, confidence, and correctness data
    """
    try:
        # Load JSON data
        with open(file_path, 'r') as f:
            data = json.load(f)
        
        # Extract sample ID
        sample_id = get_nested(data, ["molecule_data", "sample_id"])
        if not sample_id:
            return []
            
        # Get base sample ID for reference matching
        base_sample_id = get_base_sample_id(sample_id)
        
        # Get correct SMILES
        true_smiles = reference_data.get(base_sample_id)
        if true_smiles is None:
            return []
        
        # Extract candidate molecules from each analysis type
        all_molecules = []
        candidate_analysis = get_nested(data, ["molecule_data", "candidate_analysis"], {})
        
        for analysis_type in ['forward_synthesis', 'mol2mol', 'mmst']:
            if analysis_type in candidate_analysis:
                molecules = get_nested(candidate_analysis, [analysis_type, "molecules"], [])
                for mol in molecules:
                    try:
                        hsqc_score = get_nested(mol, ["nmr_analysis", "matching_scores", "by_spectrum", "HSQC"])
                        if hsqc_score is not None:
                            all_molecules.append({
                                'smiles': mol['smiles'],
                                'hsqc_score': hsqc_score,
                                'is_true': (mol['smiles'] == true_smiles)
                            })
                    except (KeyError, TypeError):
                        continue
        
        # Sort by HSQC score (lower is better)
        all_molecules.sort(key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf'))
        
        # Process top 5 candidates
        results = []
        for position, mol in enumerate(all_molecules[:5], 1):
            results.append({
                'sample_id': sample_id,
                'position': position,
                'hsqc_score': mol['hsqc_score'],
                'is_true': mol['is_true'],
                'smiles': mol['smiles']
            })
        
        # Get experiment type from directory name
        dir_name = os.path.basename(os.path.dirname(file_path))
        if "sim+noise" in dir_name.lower():
            experiment = "Sim Data + Noise"
        elif "sim_aug" in dir_name.lower() or "sim_d1_aug" in dir_name.lower():
            experiment = "Sim Data + Wrong Guess"
        elif "sim" in dir_name.lower():
            experiment = "Sim Data"
        elif "exp_d1_aug" in dir_name.lower():
            experiment = "Exp Data + Wrong Guess"
        elif "exp_d4" in dir_name.lower():
            experiment = "Exp Data d4"
        elif "exp" in dir_name.lower():
            experiment = "Exp Data"
        else:
            experiment = "Unknown"
            
        # Add experiment to each result
        for result in results:
            result['experiment'] = experiment
            
        return results
        
    except Exception as e:
        print(f"Error processing {file_path}: {str(e)}")
        return []

# Function to process a directory of JSON files
def process_directory(json_dir, reference_csv):
    """
    Process all JSON files in a directory.
    
    Args:
        json_dir: Directory containing JSON files
        reference_csv: Path to reference CSV file
        
    Returns:
        DataFrame with position, HSQC score, and correctness data
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Find all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Process each file
    all_results = []
    
    for file_path in json_files:
        results = process_json_file(file_path, reference_data)
        if results:
            all_results.extend(results)
    
    # Convert to DataFrame
    results_df = pd.DataFrame(all_results)
    
    return results_df

# Function to create a boxplot showing HSQC scores by position
def plot_hsqc_by_position(hsqc_df, output_file=None):
    """
    Create a boxplot showing HSQC scores by position.
    
    Args:
        hsqc_df: DataFrame with HSQC data
        output_file: Path to save the figure (if None, display it)
    """
    # Limit to first 5 positions
    df = hsqc_df[hsqc_df["position"] <= 5].copy()
    
    # Convert boolean is_true to string to avoid palette issues
    df['is_true'] = df['is_true'].astype(str)
    
    # Set up the figure
    plt.figure(figsize=(12, 7))
    
    # Normalize HSQC scores for visualization (lower is better, so invert)
    # HSQC errors are typically in the range of 0-5, with lower being better
    # We'll transform them to a 0-1 scale where 1 is best
    df['normalized_score'] = 1 - (df['hsqc_score'] / 5)
    df.loc[df['normalized_score'] < 0, 'normalized_score'] = 0  # Cap at 0 for any very large errors
    
    # Create the boxplot
    ax = sns.boxplot(x="position", y="normalized_score", hue="is_true", 
                 data=df, palette={"True": "mediumseagreen", "False": "tomato"})
    
    # Add individual data points
    sns.stripplot(x="position", y="normalized_score", hue="is_true", 
               data=df, dodge=True, alpha=0.3, size=4, linewidth=1,
               palette={"True": "darkgreen", "False": "darkred"})
    
    # Customize the plot
    plt.title("HSQC Score Performance by Candidate Position", fontsize=16)
    plt.xlabel("Candidate Position (by HSQC ranking)", fontsize=14)
    plt.ylabel("Normalized Score (higher is better)", fontsize=14)
    plt.ylim(0, 1.05)
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create legend without duplicate items
    handles, labels = ax.get_legend_handles_labels()
    plt.legend(handles[:2], ["Correct Structure", "Incorrect Structure"], title="Structure", fontsize=12, loc='upper right')
    
    # Add text with count statistics for each position
    for i in range(1, 6):
        pos_data = df[df["position"] == i]
        true_data = pos_data[pos_data["is_true"] == "True"]
        false_data = pos_data[pos_data["is_true"] == "False"]
        
        n_true = len(true_data)
        n_false = len(false_data)
        total = n_true + n_false
        
        text = f"n={total}\nCorrect: {n_true}\nIncorrect: {n_false}"
        plt.text(i-1, 0.05, text, ha='center', fontsize=9)
    
    plt.tight_layout()
    
    # Save or display the figure
    if output_file:
        plt.savefig(output_file, dpi=300, bbox_inches='tight')
    else:
        plt.show()
        
    return plt.gcf()

# Main function to run the analysis
def main():
    """Run the complete analysis."""
    # Define input and output paths
    json_dirs = [
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
        "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_6_exp_d4_finished_clean"
    ]
    
    reference_csv = "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
    
    # Process each directory and combine results
    all_results = []
    
    for json_dir in json_dirs:
        print(f"Processing {json_dir}...")
        results_df = process_directory(json_dir, reference_csv)
        
        if not results_df.empty:
            all_results.append(results_df)
    
    # Combine all results
    combined_results = pd.concat(all_results) if all_results else pd.DataFrame()
    
    if combined_results.empty:
        print("No valid results found.")
        return
    
    print(f"Analyzing {len(combined_results)} total items...")
    
    # Generate comparison plot for HSQC 
    fig = plot_hsqc_by_position(combined_results, "hsqc_score_by_position.png")
    
    # Print summary statistics
    print("\nSUMMARY STATISTICS:")
    
    # Calculate accuracy by position
    for position in range(1, 6):
        pos_data = combined_results[combined_results["position"] == position]
        if not pos_data.empty:
            correct = pos_data["is_true"].sum()
            total = len(pos_data)
            print(f"Position {position}: {correct}/{total} correct ({correct/total*100:.1f}%)")
    
    return combined_results

if __name__ == "__main__":
    results = main()

In [None]:
with open(json_path) as f:
    data = json.load(f)
data["analysis_results"]["final_analysis"]["llm_responses"]["gemini"]["parsed_results"] = raw_text__

### Basline vs LLM

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import to_rgba

# Set the style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("deep")
plt.rcParams['figure.figsize'] = [15, 10]
plt.rcParams['font.size'] = 11

# Define experimental conditions
experiment_conditions = [
    "Exp-1: Simulated / Broad MW / Correct",
    "Exp-2: Simulated / Narrow MW / Correct",
    "Exp-3: Simulated / Broad MW / ROM",
    "Exp-4: Simulated / Narrow MW / ROM",
    "Exp-5: Experimental / Broad MW / Correct",
    "Exp-6: Experimental / Narrow MW / Correct", 
    "Exp-7: Experimental / Broad MW / ROM",
    "Exp-8: Simulated+Noise / Broad MW / Correct"
]

# Create shorter labels for the plot
exp_labels = [
    "Sim/Broad/Correct",
    "Sim/Narrow/Correct",
    "Sim/Broad/ROM",
    "Sim/Narrow/ROM",
    "Exp/Broad/Correct",
    "Exp/Narrow/Correct", 
    "Exp/Broad/ROM",
    "Sim+Noise/Broad/Correct"
]

# Define models
models = [
    {"name": "HSQC Matching", "color": "#999999", "is_baseline": True},
    {"name": "Claude 3.5 Sonnet", "color": "#6366F1", "is_baseline": False},
    {"name": "Claude 3.7 Sonnet-Thinking", "color": "#3B82F6", "is_baseline": False},
    {"name": "DeepSeek-R1", "color": "#10B981", "is_baseline": False},
    {"name": "Gemini-Thinking", "color": "#F59E0B", "is_baseline": False},
    {"name": "o3-mini", "color": "#EC4899", "is_baseline": False},
    {"name": "Kimi 1.5", "color": "#8B5CF6", "is_baseline": False}
]

# Create dummy data based on the description
# Baseline accuracy for each experiment
baseline_accuracy = np.array([0.85, 0.88, 0.62, 0.70, 0.60, 0.68, 0.52, 0.50])

# Create random data with the described patterns for LLM models
np.random.seed(42)  # For reproducibility

# Function to generate accuracies for a specific model following the described patterns
def generate_model_accuracies(baseline, model_index):
    # Starting point for each model (slight variations)
    base_improvement = np.array([
        0.03,  # Sim/Broad/Correct - small improvement
        0.02,  # Sim/Narrow/Correct - small improvement
        0.10,  # Sim/Broad/ROM - moderate improvement
        0.08,  # Sim/Narrow/ROM - moderate improvement
        0.15,  # Exp/Broad/Correct - larger improvement
        0.12,  # Exp/Narrow/Correct - larger improvement
        0.20,  # Exp/Broad/ROM - largest improvement
        0.29   # Sim+Noise/Broad/Correct - dramatic improvement
    ])
    
    # Add some model-specific variation (+/- up to 5%)
    model_variation = (np.random.random(len(baseline)) - 0.5) * 0.10
    
    # More variation for reasoning models (improve more on challenging cases)
    if model_index >= 2:  # Reasoning models
        # Enhance improvement for challenging cases (noise and experimental)
        base_improvement[4:] += 0.03 * (model_index - 1)  # Progressive boost
    
    # Calculate accuracies, ensuring they don't exceed 1.0
    accuracies = np.minimum(baseline + base_improvement + model_variation, 0.95)
    
    return accuracies

# Generate data for all models
all_accuracies = []
for i, model in enumerate(models):
    if model["is_baseline"]:
        accuracies = baseline_accuracy
    else:
        accuracies = generate_model_accuracies(baseline_accuracy, i)
    
    # Create records for each experiment
    for j, exp in enumerate(experiment_conditions):
        all_accuracies.append({
            "experiment": exp,
            "experiment_short": exp_labels[j],
            "model": model["name"],
            "accuracy": accuracies[j],
            "color": model["color"],
            "is_baseline": model["is_baseline"]
        })

# Convert to DataFrame
df = pd.DataFrame(all_accuracies)

# Create a figure with a grid of subplots - one per experiment
fig, axes = plt.subplots(2, 4, figsize=(20, 10), sharex=True)
axes = axes.flatten()

# Plot each experiment in its own subplot
for i, exp in enumerate(experiment_conditions):
    ax = axes[i]
    exp_data = df[df['experiment'] == exp]
    
    # Sort by accuracy within each group to see ranking
    exp_data = exp_data.sort_values('accuracy', ascending=False)
    
    # Extract baseline data
    baseline_data = exp_data[exp_data['is_baseline']]
    baseline_value = baseline_data['accuracy'].values[0]
    
    # Extract non-baseline data
    model_data = exp_data[~exp_data['is_baseline']]
    
    # Plot baseline as a horizontal line
    ax.axhline(y=baseline_value, color='black', linestyle='--', 
              alpha=0.7, label='_nolegend_')
    
    # Annotate baseline
    ax.text(0.02, baseline_value + 0.01, f"Baseline: {baseline_value:.2f}", 
            transform=ax.get_yaxis_transform(), ha='left', va='bottom', 
            fontsize=9, fontstyle='italic')
    
    # Plot model bars
    bars = ax.bar(range(len(model_data)), model_data['accuracy'], 
                 color=model_data['color'].tolist())
    
    # Add improvement annotations
    for j, (_, row) in enumerate(model_data.iterrows()):
        improvement = row['accuracy'] - baseline_value
        if improvement > 0:
            ax.text(j, row['accuracy'] + 0.01, f"+{improvement:.2f}", 
                    ha='center', va='bottom', fontsize=9, fontweight='bold')
    
    # Customize subplot
    ax.set_title(exp_labels[i], fontsize=11, fontweight='bold')
    ax.set_ylim(0.45, 1.0)  # Set y-axis limit
    ax.set_xticks(range(len(model_data)))
    ax.set_xticklabels([name.split()[0] for name in model_data['model']], 
                      rotation=45, ha='right', fontsize=9)
    
    # Highlight the best performing model
    best_idx = model_data['accuracy'].idxmax()
    best_model = model_data.loc[best_idx]
    bars[model_data.index.get_loc(best_idx)].set_edgecolor('black')
    bars[model_data.index.get_loc(best_idx)].set_linewidth(2)
    
    # Add a grid for readability
    ax.grid(axis='y', linestyle='--', alpha=0.3)
    
    # Group experimental conditions visually
    if i < 4:  # Simulated data
        ax.patch.set_facecolor(to_rgba('#f8f9fa', 0.2))
    else:  # Experimental data or noise
        ax.patch.set_facecolor(to_rgba('#e9ecef', 0.2))

# Add a common y-label
fig.text(0.01, 0.5, 'Top-1 Accuracy', va='center', rotation='vertical', fontsize=14, fontweight='bold')

# Add a super title
plt.suptitle('Comparison of HSQC Matching vs. LLM-Enhanced Structure Elucidation', 
             fontsize=16, fontweight='bold', y=0.98)

# Create a common legend for all subplots
legend_elements = []
for model in models:
    if not model["is_baseline"]:  # Skip baseline in the legend (shown as line)
        legend_elements.append(plt.Rectangle((0,0), 1, 1, color=model["color"], 
                                            label=model["name"]))

fig.legend(handles=legend_elements, loc='upper center', 
          bbox_to_anchor=(0.5, 0.04), ncol=6, fontsize=11)

# Add explanatory text
fig.text(0.5, 0.01, 
         "Figure 5: Comparison of baseline HSQC matching vs. LLM-enhanced accuracy across experimental conditions. " + 
         "Each subplot represents a different experimental condition, with the baseline HSQC matching accuracy shown as a dashed line. " +
         "Bars indicate Top-1 accuracy for different LLM models, with improvements over baseline annotated above each bar. " +
         "Note the larger improvements in challenging conditions (experimental data, noise, and ROM starting structures).",
         ha='center', fontsize=10, style='italic', wrap=True)

# Adjust layout
plt.tight_layout(rect=[0.02, 0.08, 0.98, 0.95])

# Save figure
plt.savefig('performance_comparison.png', dpi=300, bbox_inches='tight')
plt.show()


# Now create a summary figure for baseline only vs. best LLM across all experiments
plt.figure(figsize=(12, 6))

# Get baseline data
baseline_df = df[df['is_baseline']].copy()

# For each experiment, get the best performing LLM model
best_llm_rows = []
for exp in experiment_conditions:
    exp_data = df[(df['experiment'] == exp) & (~df['is_baseline'])]
    best_row = exp_data.loc[exp_data['accuracy'].idxmax()]
    best_llm_rows.append(best_row)

# Create dataframe with best LLM results
best_llm_df = pd.DataFrame(best_llm_rows)

# Plot as grouped bar chart
x = np.arange(len(experiment_conditions))
width = 0.35

fig, ax = plt.subplots(figsize=(14, 7))
baseline_bars = ax.bar(x - width/2, baseline_df['accuracy'], width, 
                     color='#999999', label='HSQC Matching Baseline')
llm_bars = ax.bar(x + width/2, best_llm_df['accuracy'], width, 
                color='#3B82F6', label='Best LLM-Enhanced')

# Add improvement annotations
for i, (base, best) in enumerate(zip(baseline_df['accuracy'], best_llm_df['accuracy'])):
    improvement = best - base
    if improvement > 0.05:  # Only annotate significant improvements
        ax.annotate(f'+{improvement:.2f}',
                   xy=(i + width/2, best + 0.01),
                   ha='center', va='bottom',
                   fontweight='bold', color='#1e40af')

# Customize the chart
ax.set_ylabel('Top-1 Accuracy', fontsize=12, fontweight='bold')
ax.set_title('Baseline vs. Best LLM-Enhanced Performance Across Experimental Conditions', 
            fontsize=14, fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(exp_labels, rotation=45, ha='right')
ax.legend()

# Add gridlines for readability
ax.grid(axis='y', linestyle='--', alpha=0.3)
ax.set_ylim(0.4, 1.0)

# Add a divider to visually separate simulated from experimental conditions
ax.axvline(x=3.5, color='black', linestyle='--', alpha=0.5)
ax.text(1.5, 0.42, 'Simulated Data', ha='center', fontsize=10, fontweight='bold')
ax.text(5.5, 0.42, 'Experimental/Noise Data', ha='center', fontsize=10, fontweight='bold')

# Add figure caption
plt.figtext(0.5, 0.01, 
           "Figure 4: Comparison of baseline HSQC matching accuracy vs. best LLM-enhanced performance across all experimental conditions. " + 
           "Note the substantial improvements in challenging scenarios (right side) with experimental data, noise, or incorrect starting structures.",
           ha='center', fontsize=10, style='italic', wrap=True)

plt.tight_layout(rect=[0, 0.05, 1, 0.98])
plt.savefig('baseline_vs_best_llm.png', dpi=300, bbox_inches='tight')
plt.show()

## Plot Molecules

### Richard Molecules

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem import Descriptors
import math

def plot_molecules_from_csv(csv_path, molecules_per_row=5):
    """
    Plot all molecules from a CSV file with their sample ID and molecular weight.
    
    Parameters:
    -----------
    csv_path : str
        Path to the CSV file
    molecules_per_row : int, optional
        Number of molecules to plot in each row (default is 5)
    
    Returns:
    --------
    matplotlib figure with molecule grid
    """
    # Read the CSV file
    df = pd.read_csv(csv_path)
    
    # Extract SMILES and Sample ID columns (adjust column names as needed)
    smiles_column = [col for col in df.columns if 'SMILES' in col.upper()][0]
    sample_id_column = [col for col in df.columns if 'SAMPLE' in col.upper() and 'ID' in col.upper()][0]
    
    # Calculate number of rows needed
    total_molecules = len(df)
    num_rows = math.ceil(total_molecules / molecules_per_row)
    
    # Create a figure with appropriate size
    fig, axes = plt.subplots(num_rows, molecules_per_row, 
                              figsize=(4*molecules_per_row, 4*num_rows))
    
    # Flatten axes for easier indexing if multiple rows
    if num_rows > 1:
        axes = axes.flatten()
    
    # Plot each molecule
    for i, (_, row) in enumerate(df.iterrows()):
        # Get SMILES and Sample ID
        smiles = row[smiles_column]
        sample_id = row[sample_id_column]
        
        # Generate molecule
        mol = Chem.MolFromSmiles(smiles)
        
        # Calculate molecular weight
        mol_weight = round(Descriptors.ExactMolWt(mol), 2)
        
        # Generate molecule image
        img = Draw.MolToImage(mol, size=(300, 300))
        
        # Plot the molecule
        ax = axes[i] if len(axes) > 1 else axes
        ax.imshow(img)
        ax.axis('off')
        
        # Add sample ID and molecular weight as text with larger font
        ax.set_title(f'Sample ID: {sample_id}\nMW: {mol_weight} g/mol', 
                     fontsize=14, fontweight='bold', wrap=True)
    
    # Hide any unused subplots
    for j in range(i+1, len(axes)):
        axes[j].axis('off')
    
    # Adjust layout and return figure
    plt.tight_layout()
    return fig

# Example usage
fig = plot_molecules_from_csv('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv')
#plt.savefig('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures/molecule_visualization.png', dpi=300, bbox_inches='tight')
#plt.close()

print("Molecule visualization has been saved as 'molecule_visualization.png'")

In [None]:

# Example usage
fig = plot_molecules_from_csv('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/real_data/combined_real_nmr_data_no_stereo_aug.csv')
plt.savefig('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures/molecule_visualization_WG.png', dpi=300, bbox_inches='tight')
plt.close()

print("Molecule visualization has been saved as 'molecule_visualization.png'")

### Lukas Molecules

In [None]:

# Example usage
fig = plot_molecules_from_csv('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_aug_CLEAN.csv')
plt.savefig('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures/molecule_visualization_aug_Lukas_WG.png', dpi=300, bbox_inches='tight')
plt.close()

print("Molecule visualization has been saved as 'molecule_visualization_aug_Lukas_WG.png'")

In [None]:

# Example usage
fig = plot_molecules_from_csv('/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/53_Lukas_real_data/cleaned_data_CLEAN.csv')
plt.savefig('/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/Figures/molecule_visualization_Lukas_WG.png', dpi=300, bbox_inches='tight')
plt.close()

print("Molecule visualization has been saved as 'molecule_visualization.png'")

## Plot successes for each method

### V1 Used

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def analyze_approaches_by_json(json_data, true_smiles):
    """
    Analyze a single JSON file and return where the correct molecule was found.
    Returns a dictionary indicating if the true molecule was found in each approach.
    """
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return None
    
    results = {
        'forward_synthesis': False,
        'mol2mol': False,
        'mmst': False,
        'any': False  # Was the molecule found in any approach?
    }
    
    for approach in ['forward_synthesis', 'mol2mol', 'mmst']:
        if approach in candidate_analysis:
            molecules = candidate_analysis[approach].get('molecules', [])
            for mol in molecules:
                try:
                    if mol['smiles'] == true_smiles:
                        results[approach] = True
                        results['any'] = True
                        break
                except KeyError:
                    continue
    
    return results

def analyze_directory_by_approach(json_dir, reference_csv):
    """
    Analyze all JSON files and return counts of where correct molecules were found.
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Initialize counters
    approach_counts = {
        'forward_synthesis': 0,
        'mol2mol': 0,
        'mmst': 0,
        'any': 0,
        'total': 0
    }
    
    # For storing data about individual molecules
    molecule_data = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data_json = json_data.get('molecule_data', {})
            sample_id = molecule_data_json.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Analyze where the correct molecule was found
            results = analyze_approaches_by_json(json_data, true_smiles)
            if results is None:
                continue
            
            # Update counts
            approach_counts['total'] += 1
            for approach, found in results.items():
                if found:
                    approach_counts[approach] += 1
            
            # Save individual molecule data
            molecule_data.append({
                'sample_id': sample_id,
                'true_smiles': true_smiles,
                'found_in_forward_synthesis': results['forward_synthesis'],
                'found_in_mol2mol': results['mol2mol'],
                'found_in_mmst': results['mmst'],
                'found_anywhere': results['any']
            })
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return approach_counts, molecule_data

def create_enhanced_comparison_plot(experiments_data, output_dir):
    """
    Create an enhanced single plot comparing approach effectiveness across experiments
    with the specified color scheme.
    
    Args:
        experiments_data: List of dictionaries with experiment name and approach counts
        output_dir: Directory to save the output
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=(12, 8), facecolor='white')
    ax.set_facecolor('white')
    
    # Define approaches and their display names
    approaches = ['forward_synthesis', 'mol2mol', 'mmst']
    friendly_names = {
        'forward_synthesis': 'Retrosynthesis-Forward',
        'mol2mol': 'Mol2Mol Analogues',
        'mmst': 'MMST-Driven Generation'
    }
    
    # Set the width of a bar and gap between experiment groups
    bar_width = 0.275  # Slightly wider bars
    
    # Calculate positions for grouped bars
    positions = {}
    for i, approach in enumerate(approaches):
        positions[approach] = np.arange(len(experiments_data)) + i * bar_width
    
    # Define colors for each approach using the provided hex codes
    colors = {
        'forward_synthesis': '#529356',  # Green
        'mol2mol': '#FFB047',           # Orange
        'mmst': '#9F45B7'               # Purple
    }
    
    # Calculate the highest count for y-axis scaling
    max_count = 0
    for exp in experiments_data:
        for approach in approaches:
            if exp['counts'][approach] > max_count:
                max_count = exp['counts'][approach]
    
    # Create the bars
    bars = {}
    for approach in approaches:
        counts = [exp['counts'][approach] for exp in experiments_data]
        percentages = [exp['counts'][approach] / exp['counts']['total'] * 100 if exp['counts']['total'] > 0 else 0 for exp in experiments_data]
        
        bars[approach] = ax.bar(
            positions[approach],
            counts,
            width=bar_width,
            color=colors[approach],
            edgecolor='black',
            alpha=0.9,
            label=friendly_names[approach]
        )
        
        # Add counts and percentages inside or above bars
        for i, (bar, count, percentage) in enumerate(zip(bars[approach], counts, percentages)):
            height = bar.get_height()
            
            # Always show count and percentage, even for zero values
            if height <= 0:
                # For zero values, place text just above the x-axis
                ax.text(
                    bar.get_x() + bar.get_width()/2, 
                    0.5,  # Small offset from x-axis
                    f'0\n(0.0%)',
                    ha='center', 
                    va='bottom',
                    fontsize=14,
                    color='black'
                )
            else:
                # For shorter bars, place text above
                if height < max_count * 0.15:  # If bar is less than 15% of max height
                    y_pos = height + max_count * 0.02
                    color = 'black'
                    va = 'bottom'
                else:  # For taller bars, place text inside
                    y_pos = height/2
                    color = 'white'
                    va = 'center'
                
                ax.text(
                    bar.get_x() + bar.get_width()/2, 
                    y_pos,
                    f'{count}\n({percentage:.1f}%)',
                    ha='center', 
                    va=va,
                    fontsize=14,
                    color=color
                )
    
    # Add title and labels
    ax.set_title('Comparison of Structure Generation Approaches', fontsize=20, fontweight='bold')
    ax.set_ylabel('Number of Molecules', fontsize=16)
    
    # Increase the size of tick labels
    ax.tick_params(axis='both', labelsize=14)
    
    # Set x-axis ticks at the center of each experiment group
    middle_positions = [(positions['forward_synthesis'][i] + positions['mmst'][i]) / 2 for i in range(len(experiments_data))]
    ax.set_xticks(middle_positions)
    
    # Use single-line labels without rotation
    flat_labels = [exp['label'].replace('\n', ' ') for exp in experiments_data]
    ax.set_xticklabels(flat_labels, fontsize=14)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Add legend with enhanced styling
    legend = ax.legend(fontsize=14, loc='upper center', bbox_to_anchor=(0.5, -0.12),
                      frameon=True, fancybox=True, shadow=True, ncol=3)
    
    # Add text showing total molecules analyzed for each experiment
    for i, exp in enumerate(experiments_data):
        ax.text(
            middle_positions[i],
            -max_count * 0.05,  # Position below the x-axis
            f'Total: {exp["counts"]["total"]}',
            ha='center',
            va='top',
            fontsize=14,
            fontweight='bold'
        )
    
    # Set y-axis limit with a little headroom
    ax.set_ylim(0, max_count * 1.2)
    
    # Remove top and right spines for cleaner look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Adjust layout
    plt.tight_layout()
    
    # Adjust figure size to accommodate legend
    fig.subplots_adjust(bottom=0.18)
    
    # Save the plot with high resolution
    output_path = os.path.join(output_dir, "enhanced_approach_comparison.png")
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved enhanced comparison figure to: {output_path}")
    
    # Also save as SVG for vector graphics
    svg_output_path = os.path.join(output_dir, "enhanced_approach_comparison.svg")
    fig.savefig(svg_output_path, format='svg', bbox_inches='tight')
    print(f"Saved enhanced comparison figure as SVG to: {svg_output_path}")
    
    return fig

def plot_approach_comparison(approach_counts, experiment_label=""):
    """
    Create a bar chart showing the number of correct molecules found by each approach.
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=(10, 6), facecolor='white')
    ax.set_facecolor('white')
    
    # Data for plotting
    approaches = ['forward_synthesis', 'mol2mol', 'mmst']
    friendly_names = {
        'forward_synthesis': 'Retrosynthesis',
        'mol2mol': 'Mol2Mol',
        'mmst': 'MMST'
    }
    
    # Calculate percentages
    total = approach_counts['total']
    percentages = [approach_counts[a] / total * 100 if total > 0 else 0 for a in approaches]
    counts = [approach_counts[a] for a in approaches]
    
    # Define colors for each approach
    colors = {'forward_synthesis': '#6366F1', 'mol2mol': '#3B82F6', 'mmst': '#10B981'}
    
    # Create the bars
    bars = ax.bar(
        [friendly_names[a] for a in approaches],
        counts,
        color=[colors[a] for a in approaches],
        edgecolor='black',
        alpha=0.7
    )
    
    # Add title and labels
    title = f'Correct Molecules Found by Approach\n{experiment_label}'
    ax.set_title(title, fontsize=16)
    ax.set_xlabel('Approach', fontsize=14)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Customize y-axis to extend slightly above the maximum value
    max_count = max(counts)
    ax.set_ylim(0, max_count * 1.1)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Add counts and percentages above bars
    for bar, count, percentage in zip(bars, counts, percentages):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width()/2, 
            height + 0.1,
            f'{count} ({percentage:.1f}%)',
            ha='center', 
            va='bottom',
            fontsize=12
        )
    
    # Increase tick label size
    ax.tick_params(axis='both', labelsize=12)
    
    # Add text showing total molecules analyzed
    ax.text(0.02, 0.98, f'Total molecules: {total}',
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='left',
            fontsize=12,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    return fig, pd.DataFrame({
        'Approach': [friendly_names[a] for a in approaches],
        'Count': counts,
        'Percentage': percentages
    })

def main():
    # Define the experiments
    experiments = [
        {
            "label": "Simulated Data",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        {
            "label": "Simulated Data with Wrong Guess",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    ]
    
    # Define output directory for figures and data
    output_dir = "./output"
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Store results for each experiment
    experiments_data = []
    
    # Process each experiment
    for experiment in experiments:
        print(f"\nAnalyzing {experiment['label']}...")
        
        # Get approach counts and molecule data
        approach_counts, molecule_data = analyze_directory_by_approach(
            experiment["json_directory"], 
            experiment["reference_csv"]
        )
        
        # Create and show individual plot
        fig, summary_df = plot_approach_comparison(approach_counts, experiment['label'])
        
        # Save the plot
        output_path = os.path.join(output_dir, f"approach_comparison_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.png")
        fig.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")
        
        # Save the summary data
        summary_path = os.path.join(output_dir, f"approach_summary_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.csv")
        summary_df.to_csv(summary_path, index=False)
        print(f"Saved summary to: {summary_path}")
        
        # Save the molecule-level data
        if molecule_data:
            molecule_df = pd.DataFrame(molecule_data)
            molecule_path = os.path.join(output_dir, f"molecule_data_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.csv")
            molecule_df.to_csv(molecule_path, index=False)
            print(f"Saved molecule data to: {molecule_path}")
        
        # Store data for comparison plot
        experiments_data.append({
            'label': experiment['label'],
            'counts': approach_counts
        })
        
        # Print overall counts
        print("\nApproach counts:")
        for approach, count in approach_counts.items():
            if approach not in ['total', 'any']:
                percentage = count / approach_counts['total'] * 100 if approach_counts['total'] > 0 else 0
                print(f"  {approach}: {count} ({percentage:.1f}%)")
        print(f"  Total molecules: {approach_counts['total']}")
    
    # Create enhanced comparison plot across experiments with better styling and colors
    create_enhanced_comparison_plot(experiments_data, output_dir)
    
    # Create a markdown file with the figure description
    figure_description = """
# Figure X: Comparison of Structure Generation Approaches

Comparison of the three complementary structure generation approaches (Retrosynthesis-Forward in green, Mol2Mol Analogues in orange, and MMST-Driven Generation in purple) across two experimental conditions. Left: Performance with correct initial structure. Right: Performance with regioisomeric initial structure (wrong starting guess). Numbers indicate the count of correct molecules found by each approach, with percentages in parentheses. Total molecules in both datasets: 34.

## Key Findings

**With Correct Initial Structure:**
The Retrosynthesis-Forward approach demonstrates exceptional performance when provided with the correct starting structure. This is expected behavior as this method effectively explores the local chemical space around the target molecule through plausible synthetic transformations.

**With Incorrect Initial Structure (Regioisomeric Guess):**
The performance pattern changes dramatically when an incorrect regioisomeric structure serves as the initial guess. The Retrosynthesis-Forward approach drops significantly in effectiveness, as it cannot easily reconstruct the correct connectivity patterns from the wrong starting point.

In contrast, the Mol2Mol approach demonstrates remarkable resilience to incorrect starting structures, maintaining high performance even with regioisomeric initial guesses. This highlights Mol2Mol's strength in structural modification—it efficiently explores diverse molecular variations while preserving core scaffold features, enabling discovery of the correct structure despite the initial error.

The MMST-Driven Generation approach shows consistent performance across both scenarios. Its data-driven nature, enhanced by the improvement cycle with on-the-fly fine-tuning, allows it to generate candidates based primarily on spectral patterns rather than relying heavily on the initial structure.
"""
    
    # Save the figure description
    description_path = os.path.join(output_dir, "figure_description.md")
    with open(description_path, 'w') as f:
        f.write(figure_description)
    print(f"Saved figure description to: {description_path}")

if __name__ == "__main__":
    main()

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def analyze_approaches_by_json(json_data, true_smiles):
    """
    Analyze a single JSON file and return where the correct molecule was found.
    Returns a dictionary indicating if the true molecule was found in each approach.
    """
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return None
    
    results = {
        'forward_synthesis': False,
        'mol2mol': False,
        'mmst': False,
        'any': False  # Was the molecule found in any approach?
    }
    
    for approach in ['forward_synthesis', 'mol2mol', 'mmst']:
        if approach in candidate_analysis:
            molecules = candidate_analysis[approach].get('molecules', [])
            for mol in molecules:
                try:
                    if mol['smiles'] == true_smiles:
                        results[approach] = True
                        results['any'] = True
                        break
                except KeyError:
                    continue
    
    return results

def analyze_directory_by_approach(json_dir, reference_csv):
    """
    Analyze all JSON files and return counts of where correct molecules were found.
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Initialize counters
    approach_counts = {
        'forward_synthesis': 0,
        'mol2mol': 0,
        'mmst': 0,
        'any': 0,
        'total': 0
    }
    
    # For storing data about individual molecules
    molecule_data = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data_json = json_data.get('molecule_data', {})
            sample_id = molecule_data_json.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Analyze where the correct molecule was found
            results = analyze_approaches_by_json(json_data, true_smiles)
            if results is None:
                continue
            
            # Update counts
            approach_counts['total'] += 1
            for approach, found in results.items():
                if found:
                    approach_counts[approach] += 1
            
            # Save individual molecule data
            molecule_data.append({
                'sample_id': sample_id,
                'true_smiles': true_smiles,
                'found_in_forward_synthesis': results['forward_synthesis'],
                'found_in_mol2mol': results['mol2mol'],
                'found_in_mmst': results['mmst'],
                'found_anywhere': results['any']
            })
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    return approach_counts, molecule_data

def create_enhanced_comparison_plot(experiments_data, output_dir):
    """
    Create an enhanced single plot comparing approach effectiveness across experiments
    with the specified color scheme.
    
    Args:
        experiments_data: List of dictionaries with experiment name and approach counts
        output_dir: Directory to save the output
    """
    # Create figure with white background - narrower width
    fig, ax = plt.subplots(figsize=(8, 8), facecolor='white')
    ax.set_facecolor('white')
    
    # Define approaches and their display names
    approaches = ['forward_synthesis', 'mol2mol', 'mmst']
    friendly_names = {
        'forward_synthesis': 'Retrosynthesis-Forward',
        'mol2mol': 'Mol2Mol Analogues',
        'mmst': 'MMST-Driven Generation'
    }
    
    # Set the width of a bar 
    bar_width = 0.25  # Slightly narrower bars for the smaller plot
    
    # Calculate positions for grouped bars
    positions = {}
    for i, approach in enumerate(approaches):
        positions[approach] = np.arange(len(experiments_data)) + i * bar_width
    
    # Define colors for each approach using the provided hex codes
    colors = {
        'forward_synthesis': '#529356',  # Green
        'mol2mol': '#FFB047',           # Orange
        'mmst': '#9F45B7'               # Purple
    }
    
    # Calculate the highest count for y-axis scaling
    max_count = 0
    for exp in experiments_data:
        for approach in approaches:
            if exp['counts'][approach] > max_count:
                max_count = exp['counts'][approach]
    
    # Create the bars
    bars = {}
    for approach in approaches:
        counts = [exp['counts'][approach] for exp in experiments_data]
        percentages = [exp['counts'][approach] / exp['counts']['total'] * 100 if exp['counts']['total'] > 0 else 0 for exp in experiments_data]
        
        bars[approach] = ax.bar(
            positions[approach],
            counts,
            width=bar_width,
            color=colors[approach],
            edgecolor='black',
            alpha=0.9,
            label=friendly_names[approach]
        )
        
        # Add counts and percentages inside or above bars
        for i, (bar, count, percentage) in enumerate(zip(bars[approach], counts, percentages)):
            height = bar.get_height()
            
            # Always show count and percentage, even for zero values
            if height <= 0:
                # For zero values, place text just above the x-axis
                ax.text(
                    bar.get_x() + bar.get_width()/2, 
                    0.5,  # Small offset from x-axis
                    f'0\n(0.0%)',
                    ha='center', 
                    va='bottom',
                    fontsize=12,  # Reduced font size
                    color='black'
                )
            else:
                # For shorter bars, place text above
                if height < max_count * 0.15:  # If bar is less than 15% of max height
                    y_pos = height + max_count * 0.02
                    color = 'black'
                    va = 'bottom'
                else:  # For taller bars, place text inside
                    y_pos = height/2
                    color = 'white'
                    va = 'center'
                
                ax.text(
                    bar.get_x() + bar.get_width()/2, 
                    y_pos,
                    f'{count}\n({percentage:.1f}%)',
                    ha='center', 
                    va=va,
                    fontsize=12,  # Reduced font size
                    color=color
                )
    
    # Add title and labels - reduced font size
    ax.set_title('Comparison of Structure Generation Approaches', fontsize=16, fontweight='bold')
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Increase the size of tick labels
    ax.tick_params(axis='both', labelsize=12)
    
    # Set x-axis ticks at the center of each experiment group
    middle_positions = [(positions['forward_synthesis'][i] + positions['mmst'][i]) / 2 for i in range(len(experiments_data))]
    ax.set_xticks(middle_positions)
    
    # Use single-line labels without rotation
    flat_labels = [exp['label'].replace('\n', ' ') for exp in experiments_data]
    ax.set_xticklabels(flat_labels, fontsize=12)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Add legend with enhanced styling - reduced size and moved position
    legend = ax.legend(fontsize=12, loc='upper center', bbox_to_anchor=(0.5, -0.14),
                      frameon=True, fancybox=True, shadow=True, ncol=3)
    
    # Add text showing total molecules analyzed for each experiment
    for i, exp in enumerate(experiments_data):
        ax.text(
            middle_positions[i],
            -max_count * 0.05,  # Position below the x-axis
            f'Total: {exp["counts"]["total"]}',
            ha='center',
            va='top',
            fontsize=12,
            fontweight='bold'
        )
    
    # Set y-axis limit with a little headroom
    ax.set_ylim(0, max_count * 1.2)
    
    # Remove top and right spines for cleaner look
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    # Adjust layout
    plt.tight_layout()
    
    # Adjust figure size to accommodate legend
    fig.subplots_adjust(bottom=0.2)
    
    # Save the plot with high resolution - use a different filename to avoid overwriting
    output_path = os.path.join(output_dir, "narrow_approach_comparison.png")
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved narrow comparison figure to: {output_path}")
    
    # Also save as SVG for vector graphics
    svg_output_path = os.path.join(output_dir, "narrow_approach_comparison.svg")
    fig.savefig(svg_output_path, format='svg', bbox_inches='tight')
    print(f"Saved narrow comparison figure as SVG to: {svg_output_path}")
    
    return fig

def plot_approach_comparison(approach_counts, experiment_label=""):
    """
    Create a bar chart showing the number of correct molecules found by each approach.
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=(10, 6), facecolor='white')
    ax.set_facecolor('white')
    
    # Data for plotting
    approaches = ['forward_synthesis', 'mol2mol', 'mmst']
    friendly_names = {
        'forward_synthesis': 'Retrosynthesis',
        'mol2mol': 'Mol2Mol',
        'mmst': 'MMST'
    }
    
    # Calculate percentages
    total = approach_counts['total']
    percentages = [approach_counts[a] / total * 100 if total > 0 else 0 for a in approaches]
    counts = [approach_counts[a] for a in approaches]
    
    # Define colors for each approach
    colors = {'forward_synthesis': '#6366F1', 'mol2mol': '#3B82F6', 'mmst': '#10B981'}
    
    # Create the bars
    bars = ax.bar(
        [friendly_names[a] for a in approaches],
        counts,
        color=[colors[a] for a in approaches],
        edgecolor='black',
        alpha=0.7
    )
    
    # Add title and labels
    title = f'Correct Molecules Found by Approach\n{experiment_label}'
    ax.set_title(title, fontsize=16)
    ax.set_xlabel('Approach', fontsize=14)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Customize y-axis to extend slightly above the maximum value
    max_count = max(counts)
    ax.set_ylim(0, max_count * 1.1)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Add counts and percentages above bars
    for bar, count, percentage in zip(bars, counts, percentages):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width()/2, 
            height + 0.1,
            f'{count} ({percentage:.1f}%)',
            ha='center', 
            va='bottom',
            fontsize=12
        )
    
    # Increase tick label size
    ax.tick_params(axis='both', labelsize=12)
    
    # Add text showing total molecules analyzed
    ax.text(0.02, 0.98, f'Total molecules: {total}',
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='left',
            fontsize=12,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    return fig, pd.DataFrame({
        'Approach': [friendly_names[a] for a in approaches],
        'Count': counts,
        'Percentage': percentages
    })

def main():
    # Define the experiments
    experiments = [
        {
            "label": "Simulated Data",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        {
            "label": "Simulated Data with Wrong Guess",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    ]
    
    # Define output directory for figures and data
    output_dir = "./output"
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Store results for each experiment
    experiments_data = []
    
    # Process each experiment
    for experiment in experiments:
        print(f"\nAnalyzing {experiment['label']}...")
        
        # Get approach counts and molecule data
        approach_counts, molecule_data = analyze_directory_by_approach(
            experiment["json_directory"], 
            experiment["reference_csv"]
        )
        
        # Create and show individual plot
        fig, summary_df = plot_approach_comparison(approach_counts, experiment['label'])
        
        # Save the plot
        output_path = os.path.join(output_dir, f"approach_comparison_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.png")
        fig.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")
        
        # Save the summary data
        summary_path = os.path.join(output_dir, f"approach_summary_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.csv")
        summary_df.to_csv(summary_path, index=False)
        print(f"Saved summary to: {summary_path}")
        
        # Save the molecule-level data
        if molecule_data:
            molecule_df = pd.DataFrame(molecule_data)
            molecule_path = os.path.join(output_dir, f"molecule_data_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.csv")
            molecule_df.to_csv(molecule_path, index=False)
            print(f"Saved molecule data to: {molecule_path}")
        
        # Store data for comparison plot
        experiments_data.append({
            'label': experiment['label'],
            'counts': approach_counts
        })
        
        # Print overall counts
        print("\nApproach counts:")
        for approach, count in approach_counts.items():
            if approach not in ['total', 'any']:
                percentage = count / approach_counts['total'] * 100 if approach_counts['total'] > 0 else 0
                print(f"  {approach}: {count} ({percentage:.1f}%)")
        print(f"  Total molecules: {approach_counts['total']}")
    
    # Create enhanced comparison plot across experiments with better styling and colors
    create_enhanced_comparison_plot(experiments_data, output_dir)
    
    # Create a markdown file with the figure description
    figure_description = """
# Figure X: Comparison of Structure Generation Approaches

Comparison of the three complementary structure generation approaches (Retrosynthesis-Forward in green, Mol2Mol Analogues in orange, and MMST-Driven Generation in purple) across two experimental conditions. Left: Performance with correct initial structure. Right: Performance with regioisomeric initial structure (wrong starting guess). Numbers indicate the count of correct molecules found by each approach, with percentages in parentheses. Total molecules in both datasets: 34.

## Key Findings

**With Correct Initial Structure:**
The Retrosynthesis-Forward approach demonstrates exceptional performance when provided with the correct starting structure. This is expected behavior as this method effectively explores the local chemical space around the target molecule through plausible synthetic transformations.

**With Incorrect Initial Structure (Regioisomeric Guess):**
The performance pattern changes dramatically when an incorrect regioisomeric structure serves as the initial guess. The Retrosynthesis-Forward approach drops significantly in effectiveness, as it cannot easily reconstruct the correct connectivity patterns from the wrong starting point.

In contrast, the Mol2Mol approach demonstrates remarkable resilience to incorrect starting structures, maintaining high performance even with regioisomeric initial guesses. This highlights Mol2Mol's strength in structural modification—it efficiently explores diverse molecular variations while preserving core scaffold features, enabling discovery of the correct structure despite the initial error.

The MMST-Driven Generation approach shows consistent performance across both scenarios. Its data-driven nature, enhanced by the improvement cycle with on-the-fly fine-tuning, allows it to generate candidates based primarily on spectral patterns rather than relying heavily on the initial structure.
"""
    
    # Save the figure description
    description_path = os.path.join(output_dir, "figure_description.md")
    with open(description_path, 'w') as f:
        f.write(figure_description)
    print(f"Saved figure description to: {description_path}")

if __name__ == "__main__":
    main()

### V2 ALL (not used)

In [None]:
import json
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import glob
import os

def load_reference_data(csv_path):
    """Load reference SMILES from CSV file."""
    ref_df = pd.read_csv(csv_path)
    # Convert to dictionary for faster lookups
    return ref_df.set_index('sample-id')['SMILES'].to_dict()

def get_base_sample_id(sample_id):
    """Extract base sample ID (part before underscore)."""
    return sample_id.split('_')[0] if sample_id else ''

def analyze_approaches_by_json(json_data, true_smiles):
    """
    Analyze a single JSON file and return where the correct molecule was found.
    Returns a dictionary indicating if the true molecule was found in each approach.
    """
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
    except KeyError:
        return None
    
    results = {
        'forward_synthesis': False,
        'mol2mol': False,
        'mmst': False,
        'any': False  # Was the molecule found in any approach?
    }
    
    for approach in ['forward_synthesis', 'mol2mol', 'mmst']:
        if approach in candidate_analysis:
            molecules = candidate_analysis[approach].get('molecules', [])
            for mol in molecules:
                try:
                    if mol['smiles'] == true_smiles:
                        results[approach] = True
                        results['any'] = True
                        break
                except KeyError:
                    continue
    
    return results

def analyze_directory_by_approach(json_dir, reference_csv):
    """
    Analyze all JSON files and return counts of where correct molecules were found.
    """
    # Load reference data
    reference_data = load_reference_data(reference_csv)
    
    # Get all JSON files
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files to analyze")
    
    # Initialize counters
    approach_counts = {
        'forward_synthesis': 0,
        'mol2mol': 0,
        'mmst': 0,
        'any': 0,
        'total': 0
    }
    
    # For storing data about individual molecules
    molecule_data = []
    
    # For stacked bar chart analysis
    retro_molecules = []
    mol2mol_molecules = []
    mmst_molecules = []
    
    for file_path in json_files:
        try:
            with open(file_path, 'r') as f:
                json_data = json.load(f)
            
            molecule_data_json = json_data.get('molecule_data', {})
            sample_id = molecule_data_json.get('sample_id')
            
            if not sample_id:
                continue
                
            # Get base sample ID for reference matching
            base_sample_id = get_base_sample_id(sample_id)
            
            # Get correct SMILES
            true_smiles = reference_data.get(base_sample_id)
            if true_smiles is None:
                print(f"No reference SMILES found for {base_sample_id}")
                continue
            
            # Analyze where the correct molecule was found
            results = analyze_approaches_by_json(json_data, true_smiles)
            if results is None:
                continue
            
            # Update counts
            approach_counts['total'] += 1
            for approach, found in results.items():
                if found:
                    approach_counts[approach] += 1
                    
                    # Track individual molecules for stacked bar chart
                    if approach == 'forward_synthesis' and found:
                        retro_molecules.append(sample_id)
                    elif approach == 'mol2mol' and found:
                        mol2mol_molecules.append(sample_id)
                    elif approach == 'mmst' and found:
                        mmst_molecules.append(sample_id)
            
            # Save individual molecule data
            molecule_data.append({
                'sample_id': sample_id,
                'true_smiles': true_smiles,
                'found_in_forward_synthesis': results['forward_synthesis'],
                'found_in_mol2mol': results['mol2mol'],
                'found_in_mmst': results['mmst'],
                'found_anywhere': results['any']
            })
                
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            continue
    
    # Add molecule sets to the results
    return approach_counts, molecule_data, retro_molecules, mol2mol_molecules, mmst_molecules

def plot_approach_comparison(approach_counts, experiment_label=""):
    """
    Create a bar chart showing the number of correct molecules found by each approach.
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=(10, 6), facecolor='white')
    ax.set_facecolor('white')
    
    # Data for plotting
    approaches = ['forward_synthesis', 'mol2mol', 'mmst']
    friendly_names = {
        'forward_synthesis': 'Retrosynthesis',
        'mol2mol': 'Mol2Mol',
        'mmst': 'MMST'
    }
    
    # Calculate percentages
    total = approach_counts['total']
    percentages = [approach_counts[a] / total * 100 if total > 0 else 0 for a in approaches]
    counts = [approach_counts[a] for a in approaches]
    
    # Define colors for each approach
    colors = {'forward_synthesis': '#6366F1', 'mol2mol': '#3B82F6', 'mmst': '#10B981'}
    
    # Create the bars
    bars = ax.bar(
        [friendly_names[a] for a in approaches],
        counts,
        color=[colors[a] for a in approaches],
        edgecolor='black',
        alpha=0.7
    )
    
    # Add title and labels
    title = f'Correct Molecules Found by Approach\n{experiment_label}'
    ax.set_title(title, fontsize=16)
    ax.set_xlabel('Approach', fontsize=14)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Customize y-axis to extend slightly above the maximum value
    max_count = max(counts)
    ax.set_ylim(0, max_count * 1.1)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Add counts and percentages above bars
    for bar, count, percentage in zip(bars, counts, percentages):
        height = bar.get_height()
        ax.text(
            bar.get_x() + bar.get_width()/2, 
            height + 0.1,
            f'{count} ({percentage:.1f}%)',
            ha='center', 
            va='bottom',
            fontsize=12
        )
    
    # Increase tick label size
    ax.tick_params(axis='both', labelsize=12)
    
    # Add text showing total molecules analyzed
    ax.text(0.02, 0.98, f'Total molecules: {total}',
            transform=ax.transAxes,
            verticalalignment='top',
            horizontalalignment='left',
            fontsize=12,
            bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    # Adjust layout
    plt.tight_layout()
    
    return fig, pd.DataFrame({
        'Approach': [friendly_names[a] for a in approaches],
        'Count': counts,
        'Percentage': percentages
    })

def compare_approaches_plot(experiments_data, output_dir):
    """
    Create a single plot comparing approach effectiveness across experiments.
    
    Args:
        experiments_data: List of dictionaries with experiment name and approach counts
        output_dir: Directory to save the output
    """
    # Create figure with white background
    fig, ax = plt.subplots(figsize=(14, 8), facecolor='white')
    ax.set_facecolor('white')
    
    # Define approaches and their display names
    approaches = ['forward_synthesis', 'mol2mol', 'mmst']
    friendly_names = {
        'forward_synthesis': 'Retrosynthesis',
        'mol2mol': 'Mol2Mol',
        'mmst': 'MMST'
    }
    
    # Set the width of a bar and gap between experiment groups
    bar_width = 0.25
    experiment_gap = 0.1
    
    # Calculate positions for grouped bars
    positions = {}
    for i, approach in enumerate(approaches):
        positions[approach] = np.arange(len(experiments_data)) + i * bar_width
    
    # Define colors for each approach (consistent with previous)
    colors = {'forward_synthesis': '#6366F1', 'mol2mol': '#3B82F6', 'mmst': '#10B981'}
    
    # Calculate the highest count for y-axis scaling
    max_count = 0
    for exp in experiments_data:
        for approach in approaches:
            if exp['counts'][approach] > max_count:
                max_count = exp['counts'][approach]
    
    # Create the bars
    bars = {}
    for approach in approaches:
        counts = [exp['counts'][approach] for exp in experiments_data]
        percentages = [exp['counts'][approach] / exp['counts']['total'] * 100 if exp['counts']['total'] > 0 else 0 for exp in experiments_data]
        
        bars[approach] = ax.bar(
            positions[approach],
            counts,
            width=bar_width,
            color=colors[approach],
            edgecolor='black',
            alpha=0.8,
            label=friendly_names[approach]
        )
        
        # Add counts and percentages inside or above bars
        for i, (bar, count, percentage) in enumerate(zip(bars[approach], counts, percentages)):
            height = bar.get_height()
            if height > 0:
                # For shorter bars, place text above
                if height < max_count * 0.15:  # If bar is less than 15% of max height
                    y_pos = height + max_count * 0.02
                    color = 'black'
                    va = 'bottom'
                else:  # For taller bars, place text inside
                    y_pos = height/2
                    color = 'white'
                    va = 'center'
                
                ax.text(
                    bar.get_x() + bar.get_width()/2, 
                    y_pos,
                    f'{count}\n({percentage:.1f}%)',
                    ha='center', 
                    va=va,
                    fontsize=10,
                    color=color,
                    fontweight='bold'
                )
    
    # Add title and labels
    ax.set_title('Comparison of Approach Effectiveness Across Experiments', fontsize=16)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    
    # Set x-axis ticks at the center of each experiment group
    middle_positions = [(positions['forward_synthesis'][i] + positions['mmst'][i]) / 2 for i in range(len(experiments_data))]
    ax.set_xticks(middle_positions)
    ax.set_xticklabels([exp['label'] for exp in experiments_data], fontsize=12)
    
    # Add grid with light gray color
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Add legend
    ax.legend(fontsize=12)
    
    # Add text showing total molecules analyzed for each experiment
    for i, exp in enumerate(experiments_data):
        ax.text(
            middle_positions[i],
            -max_count * 0.05,  # Position below the x-axis
            f'Total: {exp["counts"]["total"]}',
            ha='center',
            va='top',
            fontsize=10
        )
    
    # Set y-axis limit with a little headroom
    ax.set_ylim(0, max_count * 1.15)
    
    # Adjust layout
    plt.tight_layout()
    
    # Save the plot
    output_path = os.path.join(output_dir, "approach_comparison_across_experiments.png")
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved comparison figure to: {output_path}")
    
    return fig

def create_stacked_comparison(experiments_data, output_dir):
    """Create a stacked bar chart showing the distribution of molecules across approaches."""
    
    # Create figure with white background
    fig, ax = plt.subplots(figsize=(14, 8), facecolor='white')
    ax.set_facecolor('white')
    
    # Data for the x-axis (experiment labels)
    labels = [exp['label'] for exp in experiments_data]
    
    # Set up colors
    colors = {
        'All Three': '#4B0082',  # Deep purple for molecules found by all approaches
        'Retro+Mol2Mol': '#6A5ACD',  # Slate blue for retro + mol2mol only
        'Retro+MMST': '#9370DB',  # Medium purple for retro + mmst only
        'Mol2Mol+MMST': '#8A2BE2',  # Blue violet for mol2mol + mmst only
        'Retrosynthesis Only': '#6366F1',  # Same as in previous plots
        'Mol2Mol Only': '#3B82F6',  # Same as in previous plots
        'MMST Only': '#10B981',  # Same as in previous plots
        'Not Found': '#D3D3D3'  # Light gray for molecules not found by any method
    }
    
    # Calculate counts for each experiment
    data = []
    
    for exp in experiments_data:
        counts = exp['counts']
        total = counts['total']
        
        # Skip if no data
        if total == 0:
            data.append({k: 0 for k in colors.keys()})
            continue
        
        # Convert molecules data to a list of sets for each approach
        retro_set = set(exp.get('retro_molecules', []))
        mol2mol_set = set(exp.get('mol2mol_molecules', []))
        mmst_set = set(exp.get('mmst_molecules', []))
        
        # Calculate molecules in each segment
        all_three = len(retro_set & mol2mol_set & mmst_set)
        retro_mol2mol = len((retro_set & mol2mol_set) - mmst_set)
        retro_mmst = len((retro_set & mmst_set) - mol2mol_set)
        mol2mol_mmst = len((mol2mol_set & mmst_set) - retro_set)
        retro_only = len(retro_set - mol2mol_set - mmst_set)
        mol2mol_only = len(mol2mol_set - retro_set - mmst_set)
        mmst_only = len(mmst_set - retro_set - mol2mol_set)
        
        # Calculate not found
        found = len(retro_set | mol2mol_set | mmst_set)
        not_found = total - found
        
        # Store data
        data.append({
            'All Three': all_three,
            'Retro+Mol2Mol': retro_mol2mol,
            'Retro+MMST': retro_mmst,
            'Mol2Mol+MMST': mol2mol_mmst,
            'Retrosynthesis Only': retro_only,
            'Mol2Mol Only': mol2mol_only,
            'MMST Only': mmst_only,
            'Not Found': not_found
        })
    
    # Get the order of segments (bottom to top)
    segments = ['Not Found', 'MMST Only', 'Mol2Mol Only', 'Retrosynthesis Only', 
                'Mol2Mol+MMST', 'Retro+MMST', 'Retro+Mol2Mol', 'All Three']
    
    # Calculate positions for the bars
    x = np.arange(len(labels))
    width = 0.6
    
    # Create the stacked bars
    bottom = np.zeros(len(labels))
    
    for segment in segments:
        values = [d[segment] for d in data]
        percentages = [values[i] / experiments_data[i]['counts']['total'] * 100 
                      if experiments_data[i]['counts']['total'] > 0 else 0 
                      for i in range(len(values))]
        
        rects = ax.bar(x, values, width, label=segment, bottom=bottom, 
                      color=colors[segment], edgecolor='black')
        
        # Add labels inside the bars for non-zero segments
        for i, rect in enumerate(rects):
            height = rect.get_height()
            if height > 0:
                # Calculate the position for the text
                value = values[i]
                percentage = percentages[i]
                y_pos = bottom[i] + height/2
                
                # Use black text for light colors, white for dark colors
                if segment == 'Not Found':
                    text_color = 'black'
                else:
                    text_color = 'white'
                
                ax.text(rect.get_x() + rect.get_width()/2, y_pos,
                       f'{value}\n({percentage:.1f}%)',
                       ha='center', va='center', color=text_color, fontweight='bold')
        
        bottom += values
    
    # Configure the plot
    ax.set_title('Distribution of Molecules Across Approaches', fontsize=16)
    ax.set_ylabel('Number of Molecules', fontsize=14)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, fontsize=12)
    
    # Add a legend at the bottom of the plot
    ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.15), 
             fancybox=True, shadow=True, ncol=4, fontsize=12)
    
    # Add grid lines
    ax.grid(True, alpha=0.3, color='gray', linestyle='--', axis='y')
    
    # Adjust layout to make room for the legend
    fig.tight_layout(rect=[0, 0.1, 1, 0.95])
    
    # Save the plot
    output_path = os.path.join(output_dir, "approach_stacked_comparison.png")
    fig.savefig(output_path, dpi=300, bbox_inches='tight')
    print(f"Saved stacked comparison figure to: {output_path}")
    
    return fig

def generate_summary_tables(experiments_data, output_dir):
    """Generate summary tables comparing approaches across experiments."""
    
    # Prepare data for the comparison table
    comparison_data = []
    
    for exp in experiments_data:
        total = exp['counts']['total']
        if total == 0:
            continue
            
        for approach in ['forward_synthesis', 'mol2mol', 'mmst']:
            count = exp['counts'][approach]
            percentage = count / total * 100 if total > 0 else 0
            
            comparison_data.append({
                'Experiment': exp['label'].replace('\n', ' '),
                'Approach': approach,
                'Count': count,
                'Percentage': percentage,
                'Total': total
            })
    
    # Create the comparison DataFrame
    comparison_df = pd.DataFrame(comparison_data)
    
    # Save the comparison table
    comparison_path = os.path.join(output_dir, "approach_comparison_summary.csv")
    comparison_df.to_csv(comparison_path, index=False)
    print(f"Saved comparison summary to: {comparison_path}")
    
    # Create a pivot table for easier comparison
    pivot_df = comparison_df.pivot_table(
        index='Experiment', 
        columns='Approach', 
        values=['Count', 'Percentage'], 
        aggfunc='sum'
    )
    
    # Add a total column to the pivot table
    pivot_df[('Total', '')] = comparison_df.groupby('Experiment')['Total'].first()
    
    # Save the pivot table
    pivot_path = os.path.join(output_dir, "approach_comparison_pivot.csv")
    pivot_df.to_csv(pivot_path)
    print(f"Saved pivot table summary to: {pivot_path}")
    
    # Print a summary table
    print("\n===== APPROACH EFFECTIVENESS SUMMARY =====")
    print("\nExperiment | Approach | Count | Percentage | Total")
    print("-" * 60)
    
    for exp in experiments_data:
        total = exp['counts']['total']
        if total == 0:
            continue
            
        print(f"\n{exp['label'].replace(chr(10), ' ')}:")
        
        for approach in ['forward_synthesis', 'mol2mol', 'mmst']:
            count = exp['counts'][approach]
            percentage = count / total * 100 if total > 0 else 0
            print(f"  {approach}: {count} ({percentage:.1f}%) of {total}")
    
    return comparison_df, pivot_df

def main():
    # Define the experiments
    experiments = [
        {
            "label": "Simulated Data",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        {
            "label": "Simulated Data\nwith Wrong Guess",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_2_sim_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        {
            "label": "Simulated Data\nwith Noise",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_3_sim+noise_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        {
            "label": "Experimental Data",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_4_exp_d1_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        },
        {
            "label": "Experimental Data\nwith Wrong Guess",
            "json_directory": "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_5_exp_d1_aug_finished_clean",
            "reference_csv": "/projects/cc/se_users/knlr326/1_NMR_project/1_NMR_data_AZ/52_project_3.3_data/sim_data/combined_sim_nmr_data_no_stereo.csv"
        }
    ]
    
    # Define output directory for figures and data
    output_dir = "./output"
    
    # Create output directory if it doesn't exist
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Store results for each experiment
    experiments_data = []
    
    # Process each experiment
    for experiment in experiments:
        print(f"\nAnalyzing {experiment['label']}...")
        
        # Get approach counts and molecule data
        approach_counts, molecule_data, retro_molecules, mol2mol_molecules, mmst_molecules = analyze_directory_by_approach(
            experiment["json_directory"], 
            experiment["reference_csv"]
        )
        
        # Create and show individual plot
        fig, summary_df = plot_approach_comparison(approach_counts, experiment['label'])
        
        # Save the plot
        output_path = os.path.join(output_dir, f"approach_comparison_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.png")
        fig.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved figure to: {output_path}")
        
        # Save the summary data
        summary_path = os.path.join(output_dir, f"approach_summary_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.csv")
        summary_df.to_csv(summary_path, index=False)
        print(f"Saved summary to: {summary_path}")
        
        # Save the molecule-level data
        if molecule_data:
            molecule_df = pd.DataFrame(molecule_data)
            molecule_path = os.path.join(output_dir, f"molecule_data_{experiment['label'].replace(' ', '_').replace('\\n', '_')}.csv")
            molecule_df.to_csv(molecule_path, index=False)
            print(f"Saved molecule data to: {molecule_path}")
        
        # Store data for comparison plot
        experiments_data.append({
            'label': experiment['label'],
            'counts': approach_counts,
            'retro_molecules': retro_molecules,
            'mol2mol_molecules': mol2mol_molecules,
            'mmst_molecules': mmst_molecules
        })
        
        # Print overall counts
        print("\nApproach counts:")
        for approach, count in approach_counts.items():
            if approach not in ['total', 'any']:
                percentage = count / approach_counts['total'] * 100 if approach_counts['total'] > 0 else 0
                print(f"  {approach}: {count} ({percentage:.1f}%)")
        print(f"  Total molecules: {approach_counts['total']}")
    
    # Create comparison plot across experiments
    comparison_fig = compare_approaches_plot(experiments_data, output_dir)
    
    # Generate summary tables
    comparison_df, pivot_df = generate_summary_tables(experiments_data, output_dir)
    
    # Create a stacked bar chart to visualize overlaps
    create_stacked_comparison(experiments_data, output_dir)

if __name__ == "__main__":
    main()

## Plot molecules for each approach

### AZ10282497

In [None]:
import json
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
import numpy as np
import os

# Path to the specific JSON file for AZ10282497
json_filepath = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished/AZ10282497_intermediate.json"

# Function to safely get molecules and deduplicate them
def get_unique_molecules(smiles_list):
    unique_smiles = []
    unique_mols = []
    
    for smi in smiles_list:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            # Canonicalize the SMILES
            canon_smi = Chem.MolToSmiles(mol, canonical=True)
            if canon_smi not in unique_smiles:
                unique_smiles.append(canon_smi)
                unique_mols.append(mol)
        except Exception as e:
            print(f"Error processing SMILES {smi}: {e}")
    
    return unique_mols, unique_smiles

# Load the JSON data for the specific sample
try:
    with open(json_filepath, "r") as f:
        sample = json.load(f)
    print(f"Successfully loaded data for AZ10282497")
except Exception as e:
    print(f"Error loading JSON file: {e}")
    # Check if file exists
    if not os.path.exists(json_filepath):
        print(f"File does not exist: {json_filepath}")
        exit(1)

# This file contains data for a single sample
target_sample_id = "AZ10282497"
print(f"Processing sample: {target_sample_id}")

# Extract the target SMILES
target_smiles = sample["molecule_data"]["smiles"]
target_mol = Chem.MolFromSmiles(target_smiles)
print(f"Target SMILES: {target_smiles}")

# Check for candidate_analysis and forward_synthesis
forward_synthesis_molecules = []
mol2mol_molecules = []
mmst_molecules = []

if "candidate_analysis" in sample["molecule_data"]:
    candidate_analysis = sample["molecule_data"]["candidate_analysis"]
    print("Available analysis methods:", list(candidate_analysis.keys()))
    
    # Extract forward synthesis molecules
    if "forward_synthesis" in candidate_analysis:
        if "molecules" in candidate_analysis["forward_synthesis"]:
            forward_synthesis_molecules = [
                mol.get("smiles", "") 
                for mol in candidate_analysis["forward_synthesis"]["molecules"] 
                if "smiles" in mol
            ]
            print(f"Found {len(forward_synthesis_molecules)} forward synthesis molecules")
    
    # Extract mol2mol molecules
    if "mol2mol" in candidate_analysis:
        if "molecules" in candidate_analysis["mol2mol"]:
            mol2mol_molecules = [
                mol.get("smiles", "") 
                for mol in candidate_analysis["mol2mol"]["molecules"] 
                if "smiles" in mol
            ]
            print(f"Found {len(mol2mol_molecules)} mol2mol molecules")
    
    # Extract MMST molecules
    if "mmst" in candidate_analysis:
        if "molecules" in candidate_analysis["mmst"]:
            mmst_molecules = [
                mol.get("smiles", "") 
                for mol in candidate_analysis["mmst"]["molecules"] 
                if "smiles" in mol
            ]
            print(f"Found {len(mmst_molecules)} MMST molecules")
else:
    print("No candidate_analysis found in the data")

# Get unique molecules for each method
unique_forward_mols, unique_forward_smiles = get_unique_molecules(forward_synthesis_molecules)
unique_mol2mol_mols, unique_mol2mol_smiles = get_unique_molecules(mol2mol_molecules)
unique_mmst_mols, unique_mmst_smiles = get_unique_molecules(mmst_molecules)

print(f"Unique forward synthesis molecules: {len(unique_forward_mols)}")
print(f"Unique mol2mol molecules: {len(unique_mol2mol_mols)}")
print(f"Unique MMST molecules: {len(unique_mmst_mols)}")

# Define a function to create and display the visualization for each method
def visualize_molecules(method_name, target_mol, target_smiles, unique_mols, unique_smiles):
    if not unique_mols:
        print(f"No unique molecules found for {method_name}")
        return
    
    # Calculate grid dimensions
    n_molecules = len(unique_mols)
    n_cols = min(4, n_molecules)  # Maximum 4 columns
    n_rows = (n_molecules + n_cols - 1) // n_cols + 1  # +1 for target molecule row
    
    # Create figure
    fig = plt.figure(figsize=(5 * n_cols, 4 * n_rows))
    
    # Add target molecule at the top spanning all columns
    gs = gridspec.GridSpec(n_rows, n_cols, height_ratios=[1] + [1] * (n_rows - 1))
    
    # Plot target molecule
    ax_target = plt.subplot(gs[0, :])
    if not target_mol.GetNumConformers():
        AllChem.Compute2DCoords(target_mol)
    target_img = Draw.MolToImage(target_mol, size=(400, 300))
    ax_target.imshow(target_img)
    ax_target.set_title(f"Target: {target_smiles}", fontsize=12)
    ax_target.axis("off")
    
    # Plot unique molecules
    for i, (mol, smi) in enumerate(zip(unique_mols, unique_smiles)):
        row = 1 + i // n_cols
        col = i % n_cols
        
        ax = plt.subplot(gs[row, col])
        # Generate 2D coordinates if they don't exist
        if not mol.GetNumConformers():
            AllChem.Compute2DCoords(mol)
        mol_img = Draw.MolToImage(mol, size=(400, 300))
        
        # Truncate long SMILES for display
        display_smi = smi if len(smi) < 40 else smi[:37] + "..."
        ax.set_title(f"Analogue {i+1}:\n{display_smi}", fontsize=10)
        ax.imshow(mol_img)
        ax.axis("off")
    
    # Add title
    plt.suptitle(f"{method_name} Analogues for Sample {target_sample_id}", fontsize=16)
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    
    # Save the figure
    output_filename = f"{target_sample_id}_{method_name.replace(' ', '_')}_analogues.png"
    plt.savefig(output_filename, dpi=300, bbox_inches='tight')
    print(f"Saved figure to {output_filename}")
    
    # Show the plot
    plt.show()

# Visualize molecules for each method
print("\nGenerating visualizations...")
visualize_molecules("Forward Synthesis", target_mol, target_smiles, unique_forward_mols, unique_forward_smiles)
visualize_molecules("Mol2Mol", target_mol, target_smiles, unique_mol2mol_mols, unique_mol2mol_smiles)
visualize_molecules("MMST", target_mol, target_smiles, unique_mmst_mols, unique_mmst_smiles)

print("Visualization complete!")

### AZ11034953

In [None]:
import json
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
import os
import glob

# Path to the directory containing the simulated data
base_dir = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/_run_1_sim_finished/"

# Find the target file for AZ11034953
target_file = os.path.join(base_dir, "AZ11034953_intermediate.json")
if not os.path.exists(target_file):
    # Try to find it by pattern if exact name doesn't exist
    file_pattern = os.path.join(base_dir, "AZ11034953*.json")
    matching_files = glob.glob(file_pattern)
    if matching_files:
        target_file = matching_files[0]
        print(f"Found file: {target_file}")
    else:
        print(f"No file found for AZ11034953. Checking for any available files...")
        # List a few files in the directory to help troubleshoot
        all_files = glob.glob(os.path.join(base_dir, "*.json"))
        if all_files:
            print("Available files (first 5):")
            for i, f in enumerate(all_files[:5]):
                print(f"  {i+1}. {os.path.basename(f)}")
        else:
            print(f"No JSON files found in {base_dir}")
        exit(1)

# Function to get unique molecules from SMILES list
def get_unique_molecules(smiles_list):
    unique_smiles = []
    unique_mols = []
    
    for smi in smiles_list:
        try:
            mol = Chem.MolFromSmiles(smi)
            if mol is None:
                continue
            # Canonicalize the SMILES
            canon_smi = Chem.MolToSmiles(mol, canonical=True)
            if canon_smi not in unique_smiles:
                unique_smiles.append(canon_smi)
                unique_mols.append(mol)
        except Exception as e:
            print(f"Error processing SMILES {smi}: {e}")
    
    return unique_mols, unique_smiles

# Load the JSON data
try:
    with open(target_file, 'r') as f:
        sample = json.load(f)
    print(f"Successfully loaded data for AZ11034953")
except Exception as e:
    print(f"Error loading JSON file: {e}")
    exit(1)

# Extract the target SMILES
target_smiles = sample["molecule_data"]["smiles"]
target_mol = Chem.MolFromSmiles(target_smiles)
print(f"Target SMILES: {target_smiles}")

# Extract molecules for each method
forward_molecules = []
mol2mol_molecules = []
mmst_molecules = []

if "candidate_analysis" in sample["molecule_data"]:
    candidate_analysis = sample["molecule_data"]["candidate_analysis"]
    print("Available analysis methods:", list(candidate_analysis.keys()))
    
    # Extract forward synthesis molecules
    if "forward_synthesis" in candidate_analysis:
        if "molecules" in candidate_analysis["forward_synthesis"]:
            forward_molecules = [
                mol.get("smiles", "") 
                for mol in candidate_analysis["forward_synthesis"]["molecules"] 
                if "smiles" in mol
            ]
            print(f"Found {len(forward_molecules)} forward synthesis molecules")
    
    # Extract mol2mol molecules
    if "mol2mol" in candidate_analysis:
        if "molecules" in candidate_analysis["mol2mol"]:
            mol2mol_molecules = [
                mol.get("smiles", "") 
                for mol in candidate_analysis["mol2mol"]["molecules"] 
                if "smiles" in mol
            ]
            print(f"Found {len(mol2mol_molecules)} mol2mol molecules")
    
    # Extract MMST molecules
    if "mmst" in candidate_analysis:
        if "molecules" in candidate_analysis["mmst"]:
            mmst_molecules = [
                mol.get("smiles", "") 
                for mol in candidate_analysis["mmst"]["molecules"] 
                if "smiles" in mol
            ]
            print(f"Found {len(mmst_molecules)} MMST molecules")
else:
    print("No candidate_analysis found")

# Get unique molecules
unique_forward_mols, unique_forward_smiles = get_unique_molecules(forward_molecules)
unique_mol2mol_mols, unique_mol2mol_smiles = get_unique_molecules(mol2mol_molecules)
unique_mmst_mols, unique_mmst_smiles = get_unique_molecules(mmst_molecules)

print(f"Unique forward synthesis molecules: {len(unique_forward_mols)}")
print(f"Unique mol2mol molecules: {len(unique_mol2mol_mols)}")
print(f"Unique MMST molecules: {len(unique_mmst_mols)}")

# Visualize molecules for each method
def visualize_molecules(method_name, target_mol, target_smiles, unique_mols, unique_smiles):
    if not unique_mols:
        print(f"No unique molecules found for {method_name}")
        return
    
    # For larger sets, split into multiple figures with 30 molecules each
    molecules_per_figure = 30
    num_figures = (len(unique_mols) + molecules_per_figure - 1) // molecules_per_figure
    
    for fig_idx in range(num_figures):
        start_idx = fig_idx * molecules_per_figure
        end_idx = min(start_idx + molecules_per_figure, len(unique_mols))
        
        display_mols = unique_mols[start_idx:end_idx]
        display_smiles = unique_smiles[start_idx:end_idx]
        
        # Calculate grid dimensions
        n_molecules = len(display_mols)
        n_cols = min(6, n_molecules)  # Maximum 6 columns
        n_rows = (n_molecules + n_cols - 1) // n_cols + 1  # +1 for target molecule row
        
        # Create figure
        fig = plt.figure(figsize=(4 * n_cols, 3 * n_rows))
        
        # Add target molecule at the top spanning all columns
        gs = gridspec.GridSpec(n_rows, n_cols, height_ratios=[1] + [1] * (n_rows - 1))
        
        # Plot target molecule
        ax_target = plt.subplot(gs[0, :])
        if not target_mol.GetNumConformers():
            AllChem.Compute2DCoords(target_mol)
        target_img = Draw.MolToImage(target_mol, size=(300, 250))
        ax_target.imshow(target_img)
        ax_target.set_title(f"Target: {target_smiles}", fontsize=10)
        ax_target.axis("off")
        
        # Plot molecules
        for i, (mol, smi) in enumerate(zip(display_mols, display_smiles)):
            row = 1 + i // n_cols
            col = i % n_cols
            
            ax = plt.subplot(gs[row, col])
            # Generate 2D coordinates if they don't exist
            if not mol.GetNumConformers():
                AllChem.Compute2DCoords(mol)
            mol_img = Draw.MolToImage(mol, size=(300, 250))
            
            # Check if this molecule matches the target
            is_target = (Chem.MolToSmiles(mol, canonical=True) == Chem.MolToSmiles(target_mol, canonical=True))
            
            # Create title without SMILES
            title = f"Analogue {start_idx + i + 1}"
            if is_target:
                title += " (TARGET MATCH)"
            ax.set_title(title, fontsize=10)
            ax.imshow(mol_img)
            ax.axis("off")
        
        # Add title
        plt.suptitle(f"{method_name} Analogues for AZ11034953 (Page {fig_idx+1}/{num_figures})", fontsize=16)
        plt.tight_layout(rect=[0, 0, 1, 0.96])
        
        # Save the figure
        output_filename = f"AZ11034953_{method_name.replace(' ', '_')}_analogues_page{fig_idx+1}.png"
        plt.savefig(output_filename, dpi=300, bbox_inches='tight')
        print(f"Saved figure to {output_filename}")
        
        # Show the plot
        plt.show()

# Also create a file with SMILES and indices for reference
def save_smiles_list(method_name, smiles_list):
    filename = f"AZ11034953_{method_name.replace(' ', '_')}_smiles.txt"
    with open(filename, 'w') as f:
        f.write(f"Target SMILES: {target_smiles}\n\n")
        for i, smi in enumerate(smiles_list):
            is_target = (smi == Chem.MolToSmiles(target_mol, canonical=True))
            f.write(f"Analogue {i+1}{' (TARGET MATCH)' if is_target else ''}: {smi}\n")
    print(f"Saved SMILES list to {filename}")

# Visualize molecules for each method
print("\nGenerating visualizations...")
visualize_molecules("Forward Synthesis", target_mol, target_smiles, unique_forward_mols, unique_forward_smiles)
visualize_molecules("Mol2Mol", target_mol, target_smiles, unique_mol2mol_mols, unique_mol2mol_smiles)
visualize_molecules("MMST", target_mol, target_smiles, unique_mmst_mols, unique_mmst_smiles)

# Save SMILES lists for reference
save_smiles_list("Forward Synthesis", unique_forward_smiles)
save_smiles_list("Mol2Mol", unique_mol2mol_smiles)
save_smiles_list("MMST", unique_mmst_smiles)

print("Visualization complete!")

# Test ACD data

In [None]:
# %% Import necessary libraries
import json
import pandas as pd
import numpy as np
from pathlib import Path
import glob
import os
from typing import Dict, List, Any, Optional, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import to_rgba

# --- Configuration ---
# Set the directory containing the ZINC experiment JSON files
ZINC_DATA_DIR = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/ZINC_large"
# Define where to save the analysis results
OUTPUT_DIR = Path(ZINC_DATA_DIR) / "analysis_results_zinc"
OUTPUT_DIR.mkdir(exist_ok=True) # Create output directory if it doesn't exist

# LLM models present in the JSON files (adjust if different for ZINC)
LLM_MODELS = ["claude", "claude3-7", "o3", "kimi", "gemini", "deepseek"]
# Colors for plotting LLM models
LLM_COLORS = {
    "claude": "#4169E1",      # Royal Blue
    "claude3-7": "#1E90FF",   # Dodger Blue
    "o3": "#2E8B57",          # Sea Green
    "kimi": "#8B4513",        # Saddle Brown
    "gemini": "#4B0082",      # Indigo
    "deepseek": "#CD853F",     # Peru
    "HSQC Baseline": "#999999" # Gray for baseline
}

# --- Helper Functions ---

def process_single_json_candidates(json_data: Dict[str, Any]) -> List[Dict[str, Any]]:
    """
    Process a single JSON file's candidate analysis and combine all molecules.
    Extracts SMILES and HSQC scores.

    Args:
        json_data: Dictionary containing the JSON data with molecule_data.

    Returns:
        List of dictionaries containing molecule smiles and hsqc_score.
    """
    all_molecules = []
    try:
        candidate_analysis = json_data["molecule_data"]['candidate_analysis']
        analysis_types = ['forward_synthesis', 'mol2mol', 'mmst']

        for analysis_type in analysis_types:
            if analysis_type in candidate_analysis:
                molecules = candidate_analysis[analysis_type].get('molecules', [])
                for mol in molecules:
                    try:
                        hsqc_score = mol.get('nmr_analysis', {}).get('matching_scores', {}).get('by_spectrum', {}).get('HSQC', None)
                        # Only include molecules with a valid HSQC score
                        if hsqc_score is not None:
                            processed_mol = {
                                'smiles': mol['smiles'],
                                'hsqc_score': hsqc_score
                            }
                            all_molecules.append(processed_mol)
                    except KeyError as e:
                        # Log or print a warning if essential keys like 'smiles' are missing
                        # print(f"Warning: Skipping molecule in {analysis_type} due to missing key: {e}")
                        continue
    except KeyError as e:
        # Log or print if 'candidate_analysis' is missing
        # print(f"Warning: 'candidate_analysis' not found in JSON data: {e}")
        return [] # Return empty list if the structure is unexpected

    # Remove duplicates based on SMILES, keeping the one with the best (lowest) HSQC score
    unique_molecules = {}
    for mol in all_molecules:
        smiles = mol['smiles']
        if smiles not in unique_molecules or mol['hsqc_score'] < unique_molecules[smiles]['hsqc_score']:
             unique_molecules[smiles] = mol

    # Sort unique molecules by HSQC score (ascending, None treated as infinity)
    sorted_molecules = sorted(
        list(unique_molecules.values()),
        key=lambda x: x['hsqc_score'] if x['hsqc_score'] is not None else float('inf')
    )

    return sorted_molecules

def find_molecule_rank(sorted_molecules: List[Dict[str, Any]], true_smiles: str) -> Optional[int]:
    """
    Find the rank (1-based index) of the true SMILES in a list sorted by HSQC score.

    Args:
        sorted_molecules: List of molecule dictionaries sorted by HSQC score.
        true_smiles: The ground truth SMILES string.

    Returns:
        The rank (int) if found, otherwise None.
    """
    for idx, mol in enumerate(sorted_molecules, 1):
        if mol['smiles'] == true_smiles:
            return idx
    return None

def analyze_llm_predictions(json_data: Dict, true_smiles: str, llm_name: str) -> Optional[Dict[str, Any]]:
    """
    Analyze predictions from a specific LLM model based on confidence score.

    Args:
        json_data: Loaded JSON data for a single sample.
        true_smiles: True SMILES string to compare against.
        llm_name: Name of the LLM model to analyze (e.g., 'claude', 'o3').

    Returns:
        Dictionary with analysis results for the LLM, or None if analysis fails.
    """
    try:
        # Navigate to the parsed results for the specific LLM
        llm_response_data = json_data.get("analysis_results", {}).get("final_analysis", {}).get("llm_responses", {}).get(llm_name, {})
        parsed_results = llm_response_data.get("parsed_results", {})

        # Ensure parsed_results is a dictionary and contains 'candidates'
        if not isinstance(parsed_results, dict) or "candidates" not in parsed_results:
             # Try to parse if it's a string that looks like JSON (common issue)
            if isinstance(parsed_results, str):
                try:
                    parsed_results = json.loads(parsed_results)
                    if not isinstance(parsed_results, dict) or "candidates" not in parsed_results:
                         # print(f"Warning: Parsed string for {llm_name} did not yield expected structure.")
                         return None
                except json.JSONDecodeError:
                     # print(f"Warning: Could not parse 'parsed_results' string for {llm_name}.")
                     return None
            else:
                 # print(f"Warning: 'parsed_results' for {llm_name} is not a dict or does not contain 'candidates'.")
                 return None

        candidates = parsed_results["candidates"]
        if not isinstance(candidates, list):
            # print(f"Warning: 'candidates' for {llm_name} is not a list.")
            return None

        # Filter out candidates without a valid confidence score and sort
        valid_candidates = []
        for cand in candidates:
            if isinstance(cand, dict) and "confidence_score" in cand and isinstance(cand["confidence_score"], (int, float)):
                valid_candidates.append(cand)
            # else:
            #     print(f"Warning: Skipping invalid candidate for {llm_name}: {cand}")


        if not valid_candidates:
            # print(f"Warning: No valid candidates with confidence scores found for {llm_name}.")
            return None

        sorted_candidates = sorted(valid_candidates,
                                 key=lambda x: x["confidence_score"],
                                 reverse=True)

        # Find position of correct molecule
        correct_position_llm = None
        for i, cand in enumerate(sorted_candidates, 1):
            if isinstance(cand, dict) and cand.get("smiles") == true_smiles:
                correct_position_llm = i
                break

        return {
            "llm_model": llm_name,
            "correct_position_llm": correct_position_llm,
            "total_candidates_llm": len(sorted_candidates),
            "is_top_1_llm": correct_position_llm == 1,
            "is_top_3_llm": correct_position_llm is not None and correct_position_llm <= 3,
            "is_top_5_llm": correct_position_llm is not None and correct_position_llm <= 5,
        }

    except KeyError as e:
        # print(f"Info: LLM '{llm_name}' results not found or key missing: {e}")
        return None
    except Exception as e:
        # Catch other potential errors during processing
        # print(f"Error analyzing LLM '{llm_name}' predictions: {type(e).__name__} - {e}")
        return None

def analyze_single_zinc_json(json_file: str) -> Optional[Dict[str, Any]]:
    """
    Analyze a single ZINC JSON file for HSQC baseline ranking and all LLM rankings.

    Args:
        json_file: Path to the JSON file.

    Returns:
        Dictionary containing combined analysis results for the sample, or None if basic info missing.
    """
    try:
        with open(json_file, 'r') as f:
            data = json.load(f)

        # --- Extract Basic Info ---
        molecule_data = data.get("molecule_data", {})
        sample_id = molecule_data.get("sample_id")
        true_smiles = molecule_data.get("smiles") # ZINC data should have true SMILES here

        if not sample_id or not true_smiles:
            print(f"Warning: Skipping {json_file.name}. Missing 'sample_id' or 'smiles' in 'molecule_data'.")
            return None

        # --- HSQC Baseline Analysis ---
        hsqc_candidates = process_single_json_candidates(data)
        hsqc_rank = find_molecule_rank(hsqc_candidates, true_smiles)

        baseline_results = {
            "sample_id": sample_id,
            "true_smiles": true_smiles,
            "hsqc_rank": hsqc_rank,
            "total_candidates_hsqc": len(hsqc_candidates),
            "is_top_1_hsqc": hsqc_rank == 1,
            "is_top_3_hsqc": hsqc_rank is not None and hsqc_rank <= 3,
            "is_top_5_hsqc": hsqc_rank is not None and hsqc_rank <= 5,
        }

        # --- LLM Analysis ---
        llm_analysis_results = {}
        for llm_name in LLM_MODELS:
            llm_result = analyze_llm_predictions(data, true_smiles, llm_name)
            if llm_result:
                 # Prefix keys with llm_name to avoid clashes in the final dict
                 # llm_analysis_results.update({f"{llm_name}_{k}": v for k, v in llm_result.items() if k != 'llm_model'})
                 llm_analysis_results[llm_name] = llm_result # Store results nested under LLM name


        # --- Combine Results ---
        combined_results = baseline_results
        combined_results["llm_results"] = llm_analysis_results # Store LLM results nested

        return combined_results

    except json.JSONDecodeError:
        print(f"Error: Could not parse JSON file: {Path(json_file).name}")
        return None
    except Exception as e:
        print(f"Error processing file {Path(json_file).name}: {type(e).__name__} - {e}")
        # import traceback
        # print(traceback.format_exc()) # Optional: print full traceback for debugging
        return None

def analyze_zinc_directory(json_dir: str) -> pd.DataFrame:
    """
    Analyze all JSON files in the ZINC directory.

    Args:
        json_dir: Directory containing the ZINC JSON files.

    Returns:
        Pandas DataFrame containing aggregated analysis results.
    """
    json_files = glob.glob(os.path.join(json_dir, "*.json"))
    print(f"Found {len(json_files)} JSON files in {json_dir}")

    all_results = []
    for file_path in json_files:
        result = analyze_single_zinc_json(file_path)
        if result:
            all_results.append(result)

    if not all_results:
        print("Warning: No files were successfully processed.")
        return pd.DataFrame()

    # Flatten the results for DataFrame creation
    flat_results = []
    for sample_result in all_results:
        row = {k: v for k, v in sample_result.items() if k != 'llm_results'}
        for llm_name, llm_data in sample_result.get('llm_results', {}).items():
             # Add LLM specific results with prefixed keys
             row.update({f"{llm_name}_{k}": v for k, v in llm_data.items() if k != 'llm_model'})
        flat_results.append(row)


    return pd.DataFrame(flat_results)

# --- Plotting Functions ---

def plot_ranking_histogram(rankings: List[Optional[int]], title: str, color: str, max_rank: int = 5, ax: Optional[plt.Axes] = None) -> plt.Axes:
    """
    Plot a histogram of rankings for a specific method (HSQC or LLM).

    Args:
        rankings: List of ranks (int or None).
        title: Title for the subplot.
        color: Color for the histogram bars.
        max_rank: Maximum rank to show on the x-axis.
        ax: Matplotlib Axes object to plot on. If None, creates a new one.

    Returns:
        Matplotlib Axes object with the plot.
    """
    if ax is None:
        fig, ax = plt.subplots(figsize=(6, 5), facecolor='white')
        ax.set_facecolor('white')

    valid_rankings = [r for r in rankings if r is not None]
    total_analyzed = len(rankings)
    total_found = len(valid_rankings)

    bins = np.arange(1, max_rank + 2) - 0.5

    n, _, _ = ax.hist(
        [r for r in valid_rankings if r <= max_rank],
        bins=bins,
        edgecolor='black',
        alpha=0.75,
        color=color,
    )

    ax.set_title(title, fontsize=14, pad=15)
    ax.set_xlabel('Rank', fontsize=12)
    ax.set_ylabel('Number of Molecules', fontsize=12)
    ax.set_xticks(range(1, max_rank + 1))
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.grid(True, alpha=0.3, color='gray', linestyle='--')

    # Add counts above bars
    for i, count in enumerate(n):
        if count > 0:
            ax.text(i + 1, count + 0.01 * total_analyzed, f'{int(count)}',
                   ha='center', va='bottom', fontsize=10)

    # Calculate statistics for annotation
    if total_analyzed > 0:
        in_top_1 = sum(1 for r in valid_rankings if r == 1)
        in_top_3 = sum(1 for r in valid_rankings if r <= 3)
        in_top_5 = sum(1 for r in valid_rankings if r <= 5)
        top_1_percent = (in_top_1 / total_analyzed) * 100
        top_3_percent = (in_top_3 / total_analyzed) * 100
        top_5_percent = (in_top_5 / total_analyzed) * 100

        stats_text = (
            #f'Total: {total_analyzed}\n'
            f'Found: {total_found} ({total_found/total_analyzed*100:.1f}%)\n'
            f'Top 1: {in_top_1} ({top_1_percent:.1f}%)\n'
            f'Top 3: {in_top_3} ({top_3_percent:.1f}%)\n'
            f'Top 5: {in_top_5} ({top_5_percent:.1f}%)'
        )

        ax.text(0.97, 0.97, stats_text,
                transform=ax.transAxes,
                verticalalignment='top',
                horizontalalignment='right',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8),
                fontsize=9)
    else:
        ax.text(0.5, 0.5, "No data", ha='center', va='center', fontsize=12)

    ax.set_ylim(bottom=0) # Ensure y-axis starts at 0

    return ax


def create_comparison_plots(results_df: pd.DataFrame, output_dir: Path):
    """
    Creates and saves comparison histogram plots for HSQC baseline and each LLM.

    Args:
        results_df: DataFrame containing the analysis results.
        output_dir: Path object for the directory to save plots.
    """
    num_llms = len(LLM_MODELS)
    num_plots = num_llms + 1 # +1 for HSQC baseline
    n_cols = 3
    n_rows = (num_plots + n_cols - 1) // n_cols

    fig, axes = plt.subplots(n_rows, n_cols, figsize=(6 * n_cols, 5 * n_rows), facecolor='white', squeeze=False)
    axes = axes.flatten()

    # 1. Plot HSQC Baseline
    hsqc_ranks = results_df['hsqc_rank'].tolist()
    plot_ranking_histogram(hsqc_ranks, "HSQC Score Ranking (Baseline)", LLM_COLORS["HSQC Baseline"], max_rank=5, ax=axes[0])

    # 2. Plot Each LLM
    plot_idx = 1
    for llm_name in LLM_MODELS:
        rank_col = f"{llm_name}_correct_position_llm"
        if rank_col in results_df.columns:
            llm_ranks = results_df[rank_col].tolist()
            plot_ranking_histogram(llm_ranks, f"{llm_name.upper()} Confidence Ranking", LLM_COLORS[llm_name], max_rank=5, ax=axes[plot_idx])
            plot_idx += 1
        else:
            print(f"Info: No ranking data found for LLM: {llm_name}")
             # Optionally add a placeholder plot or text
            axes[plot_idx].text(0.5, 0.5, f"No data for\n{llm_name.upper()}", ha='center', va='center', fontsize=12)
            axes[plot_idx].set_title(f"{llm_name.upper()} Confidence Ranking", fontsize=14, pad=15)
            axes[plot_idx].set_xticks([])
            axes[plot_idx].set_yticks([])
            plot_idx += 1


    # Hide unused subplots
    for i in range(plot_idx, len(axes)):
        fig.delaxes(axes[i])

    fig.suptitle(f"Ranking Performance Comparison (ZINC Dataset)", fontsize=18, y=1.02)
    plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout
    plot_filename = output_dir / "ranking_comparison_histograms_zinc.png"
    plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
    print(f"Saved comparison histogram plot to: {plot_filename}")
    plt.close(fig) # Close the figure to prevent displaying it inline if not desired


# --- Main Execution ---
if __name__ == "__main__":
    print("--- Starting ZINC Data Analysis ---")
    print(f"Input data directory: {ZINC_DATA_DIR}")
    print(f"Output directory: {OUTPUT_DIR}")

    # Analyze the directory
    analysis_df = analyze_zinc_directory(ZINC_DATA_DIR)

    if not analysis_df.empty:
        # Save the detailed results DataFrame
        detailed_results_file = OUTPUT_DIR / "detailed_analysis_results_zinc.csv"
        analysis_df.to_csv(detailed_results_file, index=False)
        print(f"\nSaved detailed analysis results to: {detailed_results_file}")

        # --- Generate Summary Statistics ---
        summary = {}
        total_samples = len(analysis_df)
        summary["Total Samples Processed"] = total_samples

        # HSQC Baseline Summary
        hsqc_top1 = analysis_df['is_top_1_hsqc'].sum()
        hsqc_top3 = analysis_df['is_top_3_hsqc'].sum()
        hsqc_top5 = analysis_df['is_top_5_hsqc'].sum()
        hsqc_found = analysis_df['hsqc_rank'].notna().sum()
        summary["HSQC Baseline"] = {
            "Top-1": f"{hsqc_top1} ({hsqc_top1/total_samples*100:.1f}%)",
            "Top-3": f"{hsqc_top3} ({hsqc_top3/total_samples*100:.1f}%)",
            "Top-5": f"{hsqc_top5} ({hsqc_top5/total_samples*100:.1f}%)",
            "Found": f"{hsqc_found} ({hsqc_found/total_samples*100:.1f}%)"
        }

        # LLM Summaries
        for llm_name in LLM_MODELS:
            top1_col = f"{llm_name}_is_top_1_llm"
            top3_col = f"{llm_name}_is_top_3_llm"
            top5_col = f"{llm_name}_is_top_5_llm"
            rank_col = f"{llm_name}_correct_position_llm"

            if top1_col in analysis_df.columns:
                llm_top1 = analysis_df[top1_col].sum()
                llm_top3 = analysis_df[top3_col].sum()
                llm_top5 = analysis_df[top5_col].sum()
                llm_found = analysis_df[rank_col].notna().sum()
                summary[f"{llm_name.upper()} LLM"] = {
                    "Top-1": f"{llm_top1} ({llm_top1/total_samples*100:.1f}%)",
                    "Top-3": f"{llm_top3} ({llm_top3/total_samples*100:.1f}%)",
                    "Top-5": f"{llm_top5} ({llm_top5/total_samples*100:.1f}%)",
                    "Found": f"{llm_found} ({llm_found/total_samples*100:.1f}%)"
                }
            else:
                 summary[f"{llm_name.upper()} LLM"] = {"Status": "No data found"}


        # Save summary statistics
        summary_file = OUTPUT_DIR / "summary_statistics_zinc.json"
        with open(summary_file, 'w') as f:
            json.dump(summary, f, indent=2)
        print(f"Saved summary statistics to: {summary_file}")
        print("\nSummary Statistics:")
        print(json.dumps(summary, indent=2))


        # --- Create Plots ---
        print("\nGenerating comparison plots...")
        create_comparison_plots(analysis_df, OUTPUT_DIR)

    else:
        print("\nAnalysis finished, but no data was generated. Check input files and logs.")

    print("\n--- ZINC Data Analysis Complete ---")

In [None]:
# %% Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json
from pathlib import Path

# --- Configuration ---
# Directory where the analysis results (including summary) are stored
ANALYSIS_DIR = Path("/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/ZINC_large/analysis_results_zinc")
SUMMARY_FILE = ANALYSIS_DIR / "summary_statistics_zinc.json"
OUTPUT_PLOT_FILE = ANALYSIS_DIR / "top1_accuracy_improvement_zinc.png"

# --- Load Data ---
try:
    with open(SUMMARY_FILE, 'r') as f:
        summary_data = json.load(f)
except FileNotFoundError:
    print(f"Error: Summary file not found at {SUMMARY_FILE}")
    exit()
except json.JSONDecodeError:
    print(f"Error: Could not decode JSON from {SUMMARY_FILE}")
    exit()

# --- Extract Required Accuracy Data ---
# Helper function to parse percentage string like "31 (12.3%)" -> 12.3
def parse_accuracy_percent(stat_entry: str) -> float:
    """Extracts the percentage value from a string like 'N (P%)'."""
    try:
        # Find percentage value within parentheses
        percent_str = stat_entry.split('(')[1].split('%')[0]
        return float(percent_str)
    except (IndexError, ValueError, TypeError):
        print(f"Warning: Could not parse accuracy from entry: {stat_entry}. Returning 0.0")
        return 0.0 # Return 0.0 if parsing fails

# Extract Top-1 accuracies, providing default value if keys are missing
try:
    baseline_stats = summary_data.get("HSQC Baseline", {})
    deepseek_stats = summary_data.get("DEEPSEEK LLM", {})

    baseline_acc_str = baseline_stats.get("Top-1", "0 (0.0%)")
    deepseek_acc_str = deepseek_stats.get("Top-1", "0 (0.0%)") # Default if DeepSeek data missing

    if not deepseek_stats:
         print("Warning: DEEPSEEK LLM data not found in summary. Using 0% accuracy.")

    baseline_acc = parse_accuracy_percent(baseline_acc_str)
    deepseek_acc = parse_accuracy_percent(deepseek_acc_str)

except KeyError as e:
    print(f"Error: Missing expected key in summary data: {e}")
    exit()
except Exception as e:
    print(f"An unexpected error occurred while processing summary data: {e}")
    exit()


# --- Plotting ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(8, 7)) # Adjusted size for better annotation spacing

# Data for plotting
methods = ['HSQC Baseline', 'DeepSeek LLM']
accuracies = [baseline_acc, deepseek_acc]
colors = ['#999999', '#CD853F'] # Consistent colors: Gray for baseline, Peru for DeepSeek

# Create bars
bars = ax.bar(methods, accuracies, color=colors, alpha=0.85, width=0.6)

# Add labels and title
ax.set_ylabel('Top-1 Accuracy (%)', fontsize=13, fontweight='bold')
ax.set_title('Top-1 Accuracy Improvement: HSQC Baseline vs. DeepSeek LLM\n(ZINC Dataset)', fontsize=15, fontweight='bold', pad=20)
ax.set_ylim(0, max(accuracies) * 1.20 if accuracies else 10) # Dynamic Y limit with more buffer for annotation
ax.tick_params(axis='x', labelsize=12, rotation=0) # Keep x-labels horizontal
ax.tick_params(axis='y', labelsize=11)

# Add percentage values on top of bars
for bar in bars:
    yval = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2.0, yval + 0.8, f'{yval:.1f}%', # Increased vertical offset
           va='bottom', ha='center', fontsize=11, fontweight='medium')

# Add improvement annotation
improvement = deepseek_acc - baseline_acc
if deepseek_acc > 0: # Only show annotation if there's LLM data
    # Position annotation between the bars, slightly above the highest bar
    mid_x = (bars[0].get_x() + bars[0].get_width() / 2 + bars[1].get_x() + bars[1].get_width() / 2) / 2
    annotation_y_pos = max(accuracies) * 0.65 # Adjust vertical position

    # Arrow pointing from baseline area towards LLM bar
    ax.annotate(f'Improvement\n+{improvement:.1f}% pts',
                xy=(bars[1].get_x() + bars[1].get_width() * 0.5, deepseek_acc * 0.95), # Arrow points slightly below top of LLM bar
                xytext=(mid_x, annotation_y_pos),
                arrowprops=dict(arrowstyle="->", color='black',
                                connectionstyle="arc3,rad=0.2", # Curved arrow
                                lw=1.5),
                ha='center', va='center', fontsize=12, fontweight='bold',
                bbox=dict(boxstyle="round,pad=0.4", fc="#FFFACD", ec="grey", lw=1, alpha=0.9)) # Lemon Chiffon background

# Add grid
ax.yaxis.grid(True, linestyle='--', alpha=0.6)
ax.set_axisbelow(True) # Grid behind bars

plt.tight_layout()

# Save the figure
try:
    OUTPUT_PLOT_FILE.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists
    plt.savefig(OUTPUT_PLOT_FILE, dpi=300, bbox_inches='tight')
    print(f"\nSaved plot to: {OUTPUT_PLOT_FILE}")
except Exception as e:
    print(f"Error saving plot: {e}")

plt.show()

In [None]:
# %% Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import json
from pathlib import Path
import seaborn as sns # Using seaborn styles

# --- Configuration ---
# Directory where the analysis results (including summary) are stored
ANALYSIS_DIR = Path("/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/ZINC_large/analysis_results_zinc")
SUMMARY_FILE = ANALYSIS_DIR / "summary_statistics_zinc.json"
OUTPUT_PLOT_FILE_COMPARE = ANALYSIS_DIR / "top_n_accuracy_comparison_zinc.png"

# Define colors
BASELINE_COLOR = '#999999' # Gray
DEEPSEEK_COLOR = '#CD853F' # Peru/Orange

# --- Load Data ---
try:
    with open(SUMMARY_FILE, 'r') as f:
        summary_data = json.load(f)
except FileNotFoundError:
    print(f"Error: Summary file not found at {SUMMARY_FILE}")
    exit()
except json.JSONDecodeError:
    print(f"Error: Could not decode JSON from {SUMMARY_FILE}")
    exit()

# --- Extract Required Accuracy Data ---
# Helper function to parse percentage string like "31 (12.3%)" -> 12.3
def parse_accuracy_percent(stat_entry: str) -> float:
    """Extracts the percentage value from a string like 'N (P%)'."""
    try:
        percent_str = stat_entry.split('(')[1].split('%')[0]
        return float(percent_str)
    except (IndexError, ValueError, TypeError):
        print(f"Warning: Could not parse accuracy from entry: {stat_entry}. Returning 0.0")
        return 0.0

# Extract accuracies for Baseline and DeepSeek
baseline_accuracies = {}
deepseek_accuracies = {}
accuracy_levels = ["Top-1", "Top-3", "Top-5"]

try:
    baseline_stats = summary_data.get("HSQC Baseline", {})
    deepseek_stats = summary_data.get("DEEPSEEK LLM", {})

    if not deepseek_stats:
         print("Warning: DEEPSEEK LLM data not found in summary. Using 0% accuracy.")

    for level in accuracy_levels:
        baseline_acc_str = baseline_stats.get(level, "0 (0.0%)")
        deepseek_acc_str = deepseek_stats.get(level, "0 (0.0%)") # Default if missing

        baseline_accuracies[level] = parse_accuracy_percent(baseline_acc_str)
        deepseek_accuracies[level] = parse_accuracy_percent(deepseek_acc_str)

except KeyError as e:
    print(f"Error: Missing expected key in summary data: {e}")
    exit()
except Exception as e:
    print(f"An unexpected error occurred while processing summary data: {e}")
    exit()

# --- Plotting ---
plt.style.use('seaborn-v0_8-whitegrid') # Use a clean seaborn style
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True) # 1 row, 2 columns, shared Y-axis

# --- Plot 1: HSQC Baseline ---
ax1 = axes[0]
categories = list(baseline_accuracies.keys())
values_baseline = list(baseline_accuracies.values())

bars1 = ax1.bar(categories, values_baseline, color=BASELINE_COLOR, alpha=0.85, width=0.6)

ax1.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold')
ax1.set_title('HSQC Score Ranking Accuracy (Baseline)', fontsize=14, fontweight='bold', pad=15)
ax1.set_ylim(0, 100) # Y-axis from 0 to 100%
ax1.tick_params(axis='x', labelsize=12)
ax1.tick_params(axis='y', labelsize=11)
ax1.yaxis.grid(True, linestyle='--', alpha=0.6)
ax1.set_axisbelow(True)

# Add percentage values on top of baseline bars
for bar in bars1:
    yval = bar.get_height()
    ax1.text(bar.get_x() + bar.get_width()/2.0, yval + 1.5, f'{yval:.1f}%',
           va='bottom', ha='center', fontsize=11, fontweight='medium')

# --- Plot 2: DeepSeek LLM ---
ax2 = axes[1]
values_deepseek = list(deepseek_accuracies.values())

bars2 = ax2.bar(categories, values_deepseek, color=DEEPSEEK_COLOR, alpha=0.85, width=0.6)

ax2.set_title('DeepSeek LLM Confidence Ranking Accuracy', fontsize=14, fontweight='bold', pad=15)
# ax2.set_ylabel('Accuracy (%)', fontsize=13, fontweight='bold') # Y label is shared
ax2.set_ylim(0, 100) # Ensure consistent Y-axis limit
ax2.tick_params(axis='x', labelsize=12)
ax2.tick_params(axis='y', labelsize=11)
ax2.yaxis.grid(True, linestyle='--', alpha=0.6)
ax2.set_axisbelow(True)


# Add percentage values on top of DeepSeek bars
for bar in bars2:
    yval = bar.get_height()
    ax2.text(bar.get_x() + bar.get_width()/2.0, yval + 1.5, f'{yval:.1f}%',
           va='bottom', ha='center', fontsize=11, fontweight='medium')


# Add overall title
fig.suptitle('Top-N Accuracy Comparison: Baseline vs. DeepSeek LLM (ZINC Dataset)', fontsize=16, fontweight='bold', y=1.03)

plt.tight_layout(rect=[0, 0, 1, 0.97]) # Adjust layout slightly for suptitle

# Save the figure
try:
    OUTPUT_PLOT_FILE_COMPARE.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists
    plt.savefig(OUTPUT_PLOT_FILE_COMPARE, dpi=300, bbox_inches='tight')
    print(f"\nSaved comparison plot to: {OUTPUT_PLOT_FILE_COMPARE}")
except Exception as e:
    print(f"Error saving plot: {e}")

plt.show()

In [None]:
# %% Import necessary libraries
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import seaborn as sns # Using seaborn styles

# --- Configuration ---
# Directory where the analysis results are stored
ANALYSIS_DIR = Path("/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/ZINC_large/analysis_results_zinc")
DETAILED_RESULTS_FILE = ANALYSIS_DIR / "detailed_analysis_results_zinc.csv"
OUTPUT_PLOT_FILE_HSQC_COUNTS_CENTERED = ANALYSIS_DIR / "hsqc_top10_counts_centered_accuracy_zinc.png"

# Define color for baseline
BASELINE_COLOR = '#999999' # Gray
TEXT_COLOR_INSIDE = 'white' # Color for text inside bars

# --- Load Data ---
try:
    results_df = pd.read_csv(DETAILED_RESULTS_FILE)
except FileNotFoundError:
    print(f"Error: Detailed results file not found at {DETAILED_RESULTS_FILE}")
    exit()
except Exception as e:
    print(f"Error loading detailed results: {e}")
    exit()

# --- Validate Data ---
if 'hsqc_rank' not in results_df.columns:
    print(f"Error: Column 'hsqc_rank' not found in {DETAILED_RESULTS_FILE}")
    exit()

if results_df.empty:
    print("Error: The detailed results file is empty.")
    exit()

# --- Calculate Cumulative Accuracies and Counts for HSQC Baseline ---
total_samples = len(results_df)
print(f"Total samples analyzed: {total_samples}")

top_n_accuracies = {}
top_n_counts = {} # Dictionary to store absolute counts
max_n = 10 # Calculate up to Top-10

# Ensure hsqc_rank is numeric, coerce errors, drop rows where rank is unknown for calculation
results_df['hsqc_rank'] = pd.to_numeric(results_df['hsqc_rank'], errors='coerce')
valid_ranks_df = results_df.dropna(subset=['hsqc_rank'])
count_found_any_rank = len(valid_ranks_df)
print(f"Samples with valid HSQC rank (molecule found): {count_found_any_rank}")


# Calculate Top-N cumulative accuracies and counts
for n in range(1, max_n + 1):
    category_name = f"Top-{n}"
    count_at_n_or_better = valid_ranks_df[valid_ranks_df['hsqc_rank'] <= n].shape[0]
    accuracy_percent = (count_at_n_or_better / total_samples) * 100 if total_samples > 0 else 0
    top_n_accuracies[category_name] = accuracy_percent
    top_n_counts[category_name] = count_at_n_or_better # Store the count

# Calculate overall 'Found' percentage and count
category_name_found = "Found (Any Rank)"
accuracy_found_any_rank = (count_found_any_rank / total_samples) * 100 if total_samples > 0 else 0
top_n_accuracies[category_name_found] = accuracy_found_any_rank
top_n_counts[category_name_found] = count_found_any_rank # Store the count

print("\nCalculated Cumulative Accuracies and Counts (HSQC Baseline):")
for level, acc in top_n_accuracies.items():
    count = top_n_counts[level]
    print(f"  {level}: {acc:.1f}% ({count}/{total_samples})")


# --- Plotting ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, ax = plt.subplots(figsize=(14, 8))

categories = list(top_n_accuracies.keys())
values = list(top_n_accuracies.values())

plot_colors = [BASELINE_COLOR] * len(categories)
bars = ax.bar(categories, values, color=plot_colors, alpha=0.85, width=0.7)

# Customize plot
ax.set_ylabel('Cumulative Accuracy (%)', fontsize=14, fontweight='bold')
ax.set_xlabel('Rank Cutoff', fontsize=14, fontweight='bold')
ax.set_title('HSQC Baseline Cumulative Top-N Accuracy', fontsize=16, fontweight='bold', pad=20)
ax.set_ylim(0, 105)
ax.tick_params(axis='x', labelsize=12)
ax.tick_params(axis='y', labelsize=12)
ax.yaxis.grid(True, linestyle='--', alpha=0.6)
ax.set_axisbelow(True)

# Add percentage values ABOVE bars and counts INSIDE bars (vertically centered)
MIN_HEIGHT_FOR_INSIDE_LABEL = 10 # Minimum bar height (%) to add inside label (increased slightly)

for i, bar in enumerate(bars):
    yval = bar.get_height()
    category_name = categories[i]
    count = top_n_counts[category_name]

    # Add percentage label ABOVE the bar
    ax.text(bar.get_x() + bar.get_width()/2.0, yval + 1.5, f'{yval:.1f}%',
           va='bottom', ha='center', fontsize=10, fontweight='medium')

    # Add count label INSIDE the bar (if bar is tall enough)
    if yval > MIN_HEIGHT_FOR_INSIDE_LABEL:
        # Position the count label vertically centered within the bar
        inside_y_pos = yval / 2
        ax.text(bar.get_x() + bar.get_width() / 2.0, inside_y_pos, str(count),
                va='center', ha='center', # Use 'center' for vertical alignment
                color=TEXT_COLOR_INSIDE, fontsize=10, fontweight='bold')


# Add total samples info text box
ax.text(0.02, 0.95, f'Total Samples Analyzed: {total_samples}',
        transform=ax.transAxes, ha='left', va='top',
        fontsize=11, style='italic', bbox=dict(boxstyle='round,pad=0.3', fc='white', ec='grey', lw=0.5))

plt.tight_layout()

# Save the figure
try:
    OUTPUT_PLOT_FILE_HSQC_COUNTS_CENTERED.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(OUTPUT_PLOT_FILE_HSQC_COUNTS_CENTERED, dpi=300, bbox_inches='tight')
    print(f"\nSaved HSQC plot with centered counts inside bars to: {OUTPUT_PLOT_FILE_HSQC_COUNTS_CENTERED}")
except Exception as e:
    print(f"Error saving plot: {e}")

plt.show()

# Combine json Files for rerunning parts of the code

In [None]:
from pathlib import Path
import json

def combine_json_files(input_folder, output_file):
    """
    Combines multiple JSON files from a folder into a single JSON file.
    Uses sample_id as the parent key, with all data (including sample_id and SMILES)
    at the same hierarchical level within each entry.
    
    Parameters:
    -----------
    input_folder : str
        Path to the folder containing JSON files
    output_file : str
        Path where the combined JSON file should be saved
    """

    # Convert input paths to Path objects
    input_path = Path(input_folder)
    output_path = Path(output_file)
    
    # Initialize the combined data dictionary
    combined_data = {}
    
    # Count processed files for reporting
    processed_files = 0
    skipped_files = 0
    files =  input_path.glob('*.json')
    # Iterate through all JSON files in the folder
    for json_file in  input_path.glob('*.json'):
        try:
            # Read the JSON file
            with open(json_file, 'r') as f:
                data = json.load(f)
            
            # Check for required fields
            if 'sample_id' not in data["molecule_data"]:
                print(f"Warning: {json_file.name} does not contain 'sample_id'. Skipping...")
                skipped_files += 1
                continue
                
            sample_id = data["molecule_data"]['sample_id']
            
            # Initialize or update the entry
            if sample_id not in combined_data:
                combined_data[sample_id] = {
                    'sample_id': sample_id,
                    'SMILES': None  # Initialize SMILES as None
                }
            
            # Update SMILES if present in the data
            if 'smiles' in data["molecule_data"]:
                combined_data[sample_id]['smiles'] = data["molecule_data"]['smiles']
            
            # Update all other data
            for key, value in data.items():
                #if key != 'sample_id':  # Don't overwrite sample_id
                    combined_data[sample_id][key] = value
            
            processed_files += 1
        except json.JSONDecodeError:
            print(f"Error: Could not parse JSON file: {json_file.name}")
            skipped_files += 1
        except Exception as e:
            print(f"Error processing file {json_file.name}: {str(e)}")
            skipped_files += 1
    
    # Create output directory if it doesn't exist
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    # Write the combined data to the output file
    try:
        with open(output_path, 'w') as f:
            json.dump(combined_data, f, indent=2)
        print(f"Successfully combined {processed_files} files into {len(combined_data)} entries in {output_path}")
        if skipped_files > 0:
            print(f"Skipped {skipped_files} files due to errors or missing required fields")
    except Exception as e:
        print(f"Error writing output file: {str(e)}")


    return combined_data

# Example usage:
if __name__ == "__main__":
    input_folder = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/do_json"
    output_file = "/projects/cc/se_users/knlr326/1_NMR_project/2_Notebooks/MMT_explainability/LLM_Structure_Elucidator/_temp_folder/intermediate_results/molecular_data.json"
    
    # Combine the files
    combined_data = combine_json_files(input_folder, output_file)
    
    # Print first entry structure (for verification)
    if combined_data:
        first_key = next(iter(combined_data))
        print("\nStructure of first entry:")
        print(json.dumps({first_key: combined_data[first_key]}, indent=2))

In [None]:
!pip list


# Test MCP

In [None]:
!pip install mcp[cli]

In [None]:
!mcp dev simple_mcp_server.py

In [None]:
import asyncio
import os
from anthropic import Anthropic
from mcp import ClientSession
from mcp.client.http import http_client


In [None]:
pwd

# OLD

In [None]:
data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"]["parsed_results"]

In [None]:
data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"].keys()

In [None]:
data["analysis_results"]["final_analysis"]["llm_responses"]["kimi"]["parsed_results"]

In [None]:
          "thinking": "Alright, let's dive into this intriguing puzzle of determining the most likely molecular structure among the given candidates based on the spectral and analytical evidence. This is going to be a meticulous process, dissecting each piece of information to ensure we arrive at the correct conclusion.\n\n**Understanding the Task:**\n\nWe have a target molecule with a specific molecular weight (259.168462292) but an unknown formula. Five candidate structures are provided, each with their SMILES notation, molecular weight, and formula. Importantly, four of these candidates share the same molecular weight and formula (C15H21N3O), while the fifth has a slightly different molecular weight and formula (C14H20N4O). Our goal is to analyze all available spectral data, particularly the HSQC data, to determine which candidate best matches the experimental evidence.\n\n**Approach:**\n\n1. **Initial Screening:** First, I'll perform a quick comparison of the molecular weights and formulas. This will immediately help in shortlisting the candidates that match the target's molecular weight.\n\n2. **Detailed Spectral Analysis:** For each candidate, I'll analyze the available HSQC data, comparing the predicted and experimental chemical shifts. I'll pay close attention to the nature and magnitude of any deviations.\n\n3. **Structural Feature Evaluation:** I'll examine the structural features of each candidate, considering how well they align with the observed spectral data. This includes the types of functional groups present, their expected chemical shifts, and any characteristic spectral patterns they might produce.\n\n4. **Data Quality Assessment:** It's crucial to evaluate the quality of the spectral data itself. Any inconsistencies or potential errors in the data could impact our conclusions.\n\n5. **Comparative Analysis and Final Recommendation:** After analyzing each candidate, I'll compare the findings to identify the candidate with the strongest overall support from the data. I'll then provide a final recommendation, including a confidence score and a detailed explanation of the reasoning.\n\n**Step 1: Initial Screening Based on Molecular Weight and Formula**\n\nLet's first compare the molecular weights and formulas of the candidates with the target:\n\n- **Target:** Molecular Weight = 259.168462292, Formula = Unknown\n- **Candidate 1:** Molecular Weight = 259.168462292, Formula = C15H21N3O\n- **Candidate 2:** Molecular Weight = 259.168462292, Formula = C15H21N3O\n- **Candidate 3:** Molecular Weight = 259.168462292, Formula = C15H21N3O\n- **Candidate 4:** Molecular Weight = 260.16371126, Formula = C14H20N4O\n- **Candidate 5:** Molecular Weight = 259.168462292, Formula = C15H21N3O\n\n**Observation:** Candidates 1, 2, 3, and 5 perfectly match the target's molecular weight and likely its formula as well. Candidate 4 has a slightly different molecular weight and formula, making it an outlier at this stage. Unless there's a compelling reason from the spectral data to consider Candidate 4, our primary focus will be on the other four. However, I won't completely disregard Candidate 4 yet, as there might be a justification for the discrepancy, such as a different isotopic composition or a measurement error.\n\n**Step 2: Detailed Spectral Analysis (HSQC) for Each Candidate**\n\nNow, let's delve into the HSQC data for each candidate. I'll analyze the predicted vs. experimental chemical shifts, focusing on the nature and magnitude of any deviations.\n\n**Candidate 1: COc1ccnc2cccc(NC(C)CCCN)c12**\n\n- **HSQC Error:** 2.1222373347931835\n- **Detailed Analysis:**\n  - **Aliphatic Region (Atoms 0-6):** Errors are minimal (0.011-0.079), indicating a strong match. This suggests that the pentane-1,4-diamine chain is correctly positioned and substituted.\n  - **Methoxy Group (Atom 5):** Error of 0.029, which is acceptable. This confirms the presence and likely position of the methoxy group.\n  - **Aromatic Protons (Atoms 7-11):** Most show good correlation with errors below 0.1, except for atoms 8 and 10, which have significant deviations (0.900 and 0.701, respectively). This could indicate differences in the electronic environment of these carbons, possibly due to tautomerism or solvent effects.\n  - **Amine-Adjacent Methylene (Atom 6):** Error of 0.114, which is moderate.\n\n**Conclusion for Candidate 1:** The overall structure is largely supported, especially the aliphatic chain and methoxy substitution. However, the significant deviations in the aromatic region require further investigation. It's possible that these deviations are due to electronic effects rather than a fundamental structural error.\n\n**Candidate 2: COc1ccc(NC(C)CCCN)c2cccnc12**\n\n- **HSQC Error:** 2.3838264192736895\n- **Detailed Analysis:**\n  - **Aliphatic Region:** Excellent matches with errors ranging from 0.001 to 0.043. This strongly supports the structure of the aminopentyl chain.\n  - **Methoxy Group:** Good agreement, although the specific error isn't mentioned here.\n  - **Aromatic Region:** Significant deviations are observed, particularly at atom 8 (error of 1.510). This suggests a mismatch in the electronic environment of this carbon, which could be due to the position of the methoxy group or other factors like tautomerism.\n  - **Other Aromatic Atoms:** Moderate deviations (e.g., atom 10 with an error of 0.267) are also noted.\n\n**Conclusion for Candidate 2:** Similar to Candidate 1, the aliphatic portions show excellent agreement, providing confidence in that part of the structure. The aromatic region, however, has substantial deviations that need to be resolved. The position of the methoxy group might be a factor here.\n\n**Candidate 3: COc1ccc(NC(C)CCCN)c2ncccc12**\n\n- **HSQC Error:** 2.614620683844665\n- **Detailed Analysis:**\n  - **Aliphatic Region:** Very good matches with errors between 0.028 and 0.071. This confirms the structure of the aliphatic chain.\n  - **Aromatic Protons:** Atom 9 shows a good match (error 0.099), but atom 8 has a large deviation (1.718). This indicates a significant difference in the chemical environment of this carbon.\n  - **N-CH Proton (Atom 6):** Excellent match (error 0.028), supporting the connectivity at this point.\n  - **Quinoline CH (Atom 10):** Moderate deviation (0.292).\n\n**Conclusion for Candidate 3:** The aliphatic chain is well-supported by the data. The large deviation in the aromatic region (atom 8) is concerning and suggests a possible misassignment or electronic effect that isn't accounted for in the prediction.\n\n**Candidate 4: COc1cc(NC(C)CCCN)c2ncccc2n1**\n\n- **HSQC Error:** 2.64268761259165\n- **Detailed Analysis:** No detailed HSQC analysis is provided for Candidate 4. However, its molecular weight (260.16371126) and formula (C14H20N4O) differ from the target. Unless there's a specific reason to consider it (e.g., a possible isotopic variant or a measurement error), it's less likely to be the correct structure based on the molecular weight alone.\n\n**Candidate 5: COc1ccnc2c(NC(C)CCCN)cccc12**\n\n- **HSQC Error:** 2.7837488088212536\n- **Detailed Analysis:** No detailed HSQC analysis is provided for Candidate 5. Therefore, it's challenging to assess its suitability based solely on the given information.\n\n**Step 3: Evaluating Structural Features and Compatibility with Data**\n\nLet's now consider the structural features of each candidate and how well they align with the available data:\n\n- **Candidates 1, 2, and 3:** All feature a quinoline core with a methoxy substituent and an aminopentyl chain. The primary differences lie in the position of the methoxy group and the specific connectivity of the aminopentyl chain.\n  - **Methoxy Group:** The presence of a methoxy group is confirmed by the HSQC data in all three candidates, with reasonable errors.\n  - **Aminopentyl Chain:** The aliphatic regions show excellent agreement, indicating that the chain's structure is likely correct in all candidates.\n  - **Aromatic System:** The discrepancies in the aromatic region could be due to the position of the methoxy group affecting the electron density, leading to shifts in the chemical shifts of the adjacent carbons.\n\n- **Candidate 4:** Features a 1,5-naphthyridine core instead of a quinoline. The different core structure could explain the discrepancy in molecular weight and formula. However, without detailed spectral analysis, it's hard to justify its selection over the other candidates that perfectly match the molecular weight.\n\n- **Candidate 5:** Similar to Candidate 1, but with a different connectivity of the aminopentyl chain. The lack of detailed spectral analysis prevents a thorough evaluation.\n\n**Step 4: Assessing Data Quality and Potential Issues**\n\nFrom the provided information, the primary data quality issue is the lack of detailed HSQC analysis for Candidates 4 and 5. This makes it challenging to fully evaluate their suitability. For Candidates 1-3, the HSQC data is available, but there are deviations in the aromatic region that need to be addressed. These deviations could be due to:\n\n- **Electronic Effects:** The position of the methoxy group and the nitrogen atoms in the quinoline ring can significantly influence the electron density, leading to shifts in chemical shifts.\n- **Tautomerism:** The presence of multiple nitrogen atoms could lead to tautomeric forms, resulting in different chemical environments for some atoms.\n- **Solvent Effects:** The solvent used in the NMR experiment can impact chemical shifts, particularly in aromatic systems.\n- **Prediction Model Limitations:** The errors could also result from limitations in the prediction model used to generate the theoretical chemical shifts.\n\n**Step 5: Comparative Analysis and Final Recommendation**\n\nBased on the analysis so far, let's compare the candidates:\n\n- **Molecular Weight Match:** Candidates 1, 2, 3, and 5 perfectly match the target's molecular weight. Candidate 4 does not, making it less likely unless there's a specific justification.\n- **Spectral Data Availability:** Candidates 1-3 have detailed HSQC analyses, while 4 and 5 do not. This makes it easier to evaluate Candidates 1-3.\n- **Aliphatic Region Agreement:** All three candidates show excellent agreement in the aliphatic region, confirming the structure of the aminopentyl chain.\n- **Aromatic Region Discrepancies:** All three candidates exhibit deviations in the aromatic region. However, these deviations are not deal-breakers and can be attributed to electronic effects or other factors mentioned earlier.\n\n**Decisive Factors:**\n\n- **Molecular Weight:** Perfect match with Candidates 1, 2, 3, and 5.\n- **Aliphatic Region:** Excellent agreement in all evaluated candidates.\n- **Aromatic Region:** Deviations are present but are consistent with electronic effects due to substitution.\n\n**Addressing Uncertainties:**\n\nThe main uncertainty lies in the aromatic region's chemical shifts. To resolve this, additional 2D NMR experiments (such as COSY, HSQC, and HMBC) would be beneficial. These experiments can provide more detailed information about the connectivity and electronic environment of the atoms in the aromatic system, helping to confirm the exact position of the methoxy group and the overall structure.\n\n**Final Recommendation:**\n\nBased on the available data, **Candidate 1** appears to be the most likely correct structure. It has the highest overall agreement in the aliphatic region and a reasonable explanation for the deviations observed in the aromatic region. However, it's crucial to conduct further NMR experiments to resolve the uncertainties in the aromatic system.\n\n**JSON Result:**\n\n```json\nJSON_RESULT = {\n    \"candidates\": [\n        {\n            \"smiles\": \"COc1ccnc2cccc(NC(C)CCCN)c12\",\n            \"confidence_score\": 0.8,\n            \"molecular_weight\": 259.168462292,\n            \"reasoning\": \"The aliphatic region shows excellent agreement with minimal errors (0.011-0.079), confirming the aminopentyl chain structure. The methoxy group also matches well with an error of 0.029. Deviations in the aromatic region (errors up to 0.900) are noted but can be attributed to electronic effects due to substitution. The overall structure is largely supported, with high confidence in the aliphatic portions.\",\n            \"data_quality_issues\": {\n                \"title\": \"Aromatic Region Deviations\",\n                \"description\": \"Significant deviations in the chemical shifts of some aromatic carbons (atoms 8 and 10) are observed. These deviations likely result from electronic effects due to the position of the methoxy group and the nitrogen atoms in the quinoline ring.\",\n                \"impact\": \"medium\",\n                \"atom_index\": 8\n            }\n        },\n        {\n            \"smiles\": \"COc1ccc(NC(C)CCCN)c2cccnc12\",\n            \"confidence_score\": 0.75,\n            \"molecular_weight\": 259.168462292,\n            \"reasoning\": \"The aliphatic chain shows excellent correlation with errors ranging from 0.001 to 0.043, strongly supporting that part of the structure. The methoxy group is present and matches reasonably. However, significant deviations in the aromatic region, particularly at atom 8 (error 1.510), indicate potential issues with the electronic environment that need further investigation.\",\n            \"data_quality_issues\": {\n                \"title\": \"Aromatic Region Mismatch\",\n                \"description\": \"Large deviation in the chemical shift of atom 8 suggests a mismatch in the electronic environment, possibly due to the position of the methoxy group or tautomerism.\",\n                \"impact\": \"medium\",\n                \"atom_index\": 8\n            }\n        },\n        {\n            \"smiles\": \"COc1ccc(NC(C)CCCN)c2ncccc12\",\n            \"confidence_score\": 0.7,\n            \"molecular_weight\": 259.168462292,\n            \"reasoning\": \"The aliphatic chain protons and carbons show excellent correlation, confirming that part of the structure. The N-CH proton matches perfectly. However, significant deviations in the aromatic region, particularly at atom 8 (error 1.718), suggest potential issues with the electronic environment or connectivity.\",\n            \"data_quality_issues\": {\n                \"title\": \"Large Aromatic Deviation\",\n                \"description\": \"Atom 8 shows a large deviation in both 1H and 13C chemical shifts, indicating a significant mismatch in the electronic environment that needs to be resolved.\",\n                \"impact\": \"high\",\n                \"atom_index\": 8\n            }\n        },\n        {\n            \"smiles\": \"COc1cc(NC(C)CCCN)c2ncccc2n1\",\n            \"confidence_score\": 0.5,\n            \"molecular_weight\": 260.16371126,\n            \"reasoning\": \"The molecular weight and formula differ from the target, making it less likely unless there's a specific justification (e.g., isotopic variant). No detailed HSQC analysis is available to support or refute the structure.\",\n            \"data_quality_issues\": {\n                \"title\": \"Molecular Weight Discrepancy\",\n                \"description\": \"The molecular weight and formula do not match the target, which is a significant concern unless there's a valid explanation.\",\n                \"impact\": \"high\",\n                \"atom_index\": 0\n            }\n        },\n        {\n            \"smiles\": \"COc1ccnc2c(NC(C)CCCN)cccc12\",\n            \"confidence_score\": 0.6,\n            \"molecular_weight\": 259.168462292,\n            \"reasoning\": \"No detailed HSQC analysis is provided, making it difficult to fully evaluate the structure. The molecular weight matches, but without spectral data, the confidence is lower compared to other candidates.\",\n            \"data_quality_issues\": {\n                \"title\": \"Lack of Spectral Data\",\n                \"description\": \"No detailed HSQC analysis is available, preventing a thorough evaluation of the structure's compatibility with the experimental data.\",\n                \"impact\": \"medium\",\n                \"atom_index\": 0\n            }\n        }\n    ],\n    \"final_recommendation\": {\n        \"best_smiles\": \"COc1ccnc2cccc(NC(C)CCCN)c12\",\n        \"overall_confidence\": 0.8,\n        \"molecular_weight_match\": true,\n        \"explanation\": \"Candidate 1 is recommended as the most likely correct structure based on the strong agreement in the aliphatic region and a reasonable explanation for the deviations observed in the aromatic region. The molecular weight perfectly matches the target. While Candidates 2 and 3 also show promise, Candidate 1 has the highest overall agreement. Further 2D NMR experiments are recommended to resolve the uncertainties in the aromatic system and confirm the exact position of the methoxy group.\"\n    }\n}\n```",
