In [None]:
# Figure 5: Accuracy of functional group prediction
import json
import matplotlib.pyplot as plt
from rdkit import Chem

def plot_functional_group_accuracies(model_path):
    # Define functional groups as SMARTS
    functional_groups = {
        "Alcohol": "[OX2H][CX4;!$(C([OX2H])[O,S,#7,#15])]",
        "Carboxylic Acid": "[CX3](=O)[OX2H1]",
        "Ether": "[OD2]([#6])[#6]",
        "Alkene": "[CX3]=[CX3]",
        "Benzene": "c1ccccc1",
        "Primary Amine": "[NX3;H2;!$(NC=[!#6]);!$(NC#[!#6])][#6]",
        "Amide": "[NX3][CX3](=[OX1])[#6]",
        "Sulfide": "[#16X2H0]"
    }

    # Initialize counters
    counts_original = {fg: 0 for fg in functional_groups.keys()}
    counts_correct = {fg: 0 for fg in functional_groups.keys()}

    def smiles_classifier(smiles):
        for fg, smarts in functional_groups.items():
            mol = Chem.MolFromSmiles(smiles)
            if mol and mol.HasSubstructMatch(Chem.MolFromSmarts(smarts)):
                return fg
        return None

    with open(model_path, "r") as file:
        smiles_data = json.load(file)

        for smile in smiles_data:
            original = smile["original"]
            predicted = smile["predicted"][0]
            fg_original = smiles_classifier(original)
            if fg_original:
                counts_original[fg_original] += 1

                try:
                    mol_predicted = Chem.MolFromSmiles(predicted)
                    if mol_predicted and mol_predicted.HasSubstructMatch(Chem.MolFromSmarts(functional_groups[fg_original])):
                        counts_correct[fg_original] += 1
                except:
                    continue

    accuracies = [counts_correct[fg] / counts_original[fg] * 100 if counts_original[fg] else 0 for fg in functional_groups.keys()]
    zipped = list(zip(functional_groups.keys(), accuracies))
    labels, accuracies = zip(*zipped)

    spacing = 0.4
    x = [i * spacing for i in range(len(labels))]

    # Colors palette
    colors = {
        "Alcohol": "red",
        "Carboxylic Acid": "orange",
        "Ether": "green",
        "Alkene": "blue",
        "Benzene": "purple",
        "Primary Amine": "brown",
        "Amide": "pink",
        "Sulfide": "cyan",
    }

    fig, ax = plt.subplots(figsize=(10, 6))
    color_list = [colors[label] for label in labels]
  
    # Gray dotted lines
    for x_val, accuracy in zip(x, accuracies):
        ax.plot([x_val, x_val], [0, accuracy], 'gray', linestyle=':', linewidth=1.5) 
    
    # Circles
    ax.scatter(x, accuracies, s=300, facecolors='white', edgecolors=color_list, linewidth=2.5)

    ax.set_ylabel("Accuracy (%)", fontsize=24)
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha="right", fontsize=18)
    ax.tick_params(axis='y', labelsize=18)
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.set_ylim(0, 100)

    for i, yval in enumerate(accuracies):
        ax.text(x[i], yval + 5, f"{yval:.1f}%", ha="center", va="bottom", fontsize=16)

    fig.tight_layout()
    plt.show()

plot_functional_group_accuracies(model_path = "../models/contrastive0.5/outputs/test_outputs_1_attempts.json")