# Import

In [None]:
import os
import numpy as np
import pandas as pd
from operator import itemgetter
import copy
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns

from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn import metrics

from ase.units import Hartree, mol, kcal, kJ

from IPython.core.display import display
from IPython.display import SVG
from rdkit.Chem.Draw import IPythonConsole
IPythonConsole.ipython_useSVG = True  # Change output to SVG
IPythonConsole.drawOptions.addAtomIndices = False

In [None]:
# config
plt.rcParams["figure.figsize"] = (8, 8)
plt.rcParams.update({'font.size': 22})

# Change Matplotlib font to Helvetica
import matplotlib as mpl
import matplotlib.font_manager as fm
from matplotlib.legend_handler import HandlerTuple

mpl.rcParams['font.family'] = 'Helvetica'
fm.findfont("Helvetica", fontext="ttx", rebuild_if_missing=False)

In [None]:
import py3Dmol
def draw3d(
    mols,
    width=400,
    height=400,
    Hs=True,
    confId=-1,
    multipleConfs=False,
    atomlabel=False,
):
    try:
        p = py3Dmol.view(width=width, height=height)
        if type(mols) is not list:
            mols = [mols]
        for mol in mols:
            if multipleConfs:
                for conf in mol.GetConformers():
                    mb = Chem.MolToMolBlock(mol, confId=conf.GetId())
                    p.addModel(mb, "sdf")
            else:
                if type(mol) is str:
                    if os.path.splitext(mol)[-1] == ".xyz":
                        xyz_f = open(mol)
                        line = xyz_f.read()
                        xyz_f.close()
                        p.addModel(line, "xyz")
                else:
                    mb = Chem.MolToMolBlock(mol, confId=confId)
                    p.addModel(mb, "sdf")
        p.setStyle({"sphere": {"radius": 0.4}, "stick": {}})
        if atomlabel:
            p.addPropertyLabels("index")  # ,{'elem':'H'}
        p.zoomTo()
        p.update()
        # p.show()
    except:
        print("py3Dmol, RDKit, and IPython are required for this feature.")

In [None]:
def cm_analysis(y_true, y_pred, labels, ymap=None, figsize=(6,6), filename=None):
    """
    Generate matrix plot of confusion matrix with pretty annotations.
    The plot image is saved to disk.
    args: 
      y_true:    true label of the data, with shape (nsamples,)
      y_pred:    prediction of the data, with shape (nsamples,)
      filename:  filename of figure file to save
      labels:    string array, name the order of class labels in the confusion matrix.
                 use `clf.classes_` if using scikit-learn models.
                 with shape (nclass,).
      ymap:      dict: any -> string, length == nclass.
                 if not None, map the labels & ys to more understandable strings.
                 Caution: original y_true, y_pred and labels must align.
      figsize:   the size of the figure plotted.
    """
    if ymap is not None:
        y_pred = [ymap[yi] for yi in y_pred]
        y_true = [ymap[yi] for yi in y_true]
        labels = [ymap[yi] for yi in labels]
    cm = metrics.confusion_matrix(y_true, y_pred, labels=labels)
    cm_sum = np.sum(cm, axis=1, keepdims=True)
    cm_perc = cm / cm_sum.astype(float) * 100
    annot = np.empty_like(cm).astype(str)
    nrows, ncols = cm.shape
    for i in range(nrows):
        for j in range(ncols):
            c = cm[i, j]
            p = cm_perc[i, j]
            if i == j:
                s = cm_sum[i]
                annot[i, j] = '%.1f%%\n%d/%d' % (p, c, s)
            elif c == 0:
                annot[i, j] = ''
            else:
                annot[i, j] = '%.1f%%\n%d' % (p, c)
    cm = pd.DataFrame(cm, index=labels, columns=labels)
    cm.index.name = 'Actual'
    cm.columns.name = 'Predicted'
    fig, ax = plt.subplots(figsize=figsize)
    sns.heatmap(cm, annot=annot, fmt='', linewidth=1, ax=ax, cmap=plt.cm.Blues, cbar=False, xticklabels=[r'$\alpha$', r'$\beta$', 'Mix'], yticklabels=[r'$\alpha$', r'$\beta$', 'Mix'])
    plt.yticks(rotation=0)
    plt.tight_layout()
    if filename:
        plt.savefig(filename)
    plt.show()

In [None]:
def remove_identical_atoms(rdkit_mol, patt):
    atom_list = np.concatenate(rdkit_mol.GetSubstructMatches(patt))

    idx_list = []
    rank_kept = []
    atom_rank = list(Chem.CanonicalRankAtoms(rdkit_mol, breakTies=False))
    for idx, atom in enumerate(atom_list):
        if atom_rank[atom] not in rank_kept:
            rank_kept.append(atom_rank[atom])
            idx_list.append(idx)
    
    atom_list = np.array(atom_list)[idx_list].tolist()
    
    return atom_list

# Load data

In [None]:
df = pd.read_pickle('results.pkl')

In [None]:
mono_list = ['tetrakis(triphenylphosphine) palladium(0)', 'bis-triphenylphosphine-palladium(II) bromide', 'bis-triphenylphosphine-palladium(II) chloride', 'palladium diacetate and triphenylphosphine', 'palladium dibromide and triphenylphosphine', 'palladium dichloride and triphenylphosphine', 'bis-(dibenzylideneacetone)-palladium(0) and triphenylphosphine', 'tris-(dibenzylideneacetone)dipalladium(0) and triphenylphosphine', 'tris-(dibenzylideneacetone)dipalladium(0) chloroform complex and triphenylphosphine', 'dichloro bis(acetonitrile) palladium(II) and triphenylphosphine']
bi_list = ['bis[1,2-bis(diphenylphosphino)ethane]palladium(0)', '(1,2-bis(diphenylphosphino)ethane)palladium(II) chloride', 'palladium diacetate and dppe', 'palladium dichloride and dppe', 'bis-(dibenzylideneacetone)-palladium(0) and dppe', 'tris-(dibenzylideneacetone)dipalladium(0) and dppe', 'tris-(dibenzylideneacetone)dipalladium(0) chloroform complex and dppe', 'dichloro bis(acetonitrile) palladium(II) and dppe', 'palladium diacetate and dppp', 'palladium dichloride and dppp', 'bis-(dibenzylideneacetone)-palladium(0) and dppp', 'tris-(dibenzylideneacetone)dipalladium(0) and dppp', 'tris-(dibenzylideneacetone)dipalladium(0) chloroform complex and dppp', 'dichloro bis(acetonitrile) palladium(II) and dppp']

df_mono = df[df['Catalyst'].isin(mono_list)]
df_bi = df[df['Catalyst'].isin(bi_list)]

print('Mono RXNS', df_mono.shape[0])
print('Bi RXNS', df_bi.shape[0])

In [None]:
triflate_rxns = []
for idx, row in df.iterrows():

    rxn_id = row['Reaction ID']

    reactants = row['Reactant(s)']
    alkene_smi = reactants[0]
    alkene_mol = Chem.AddHs(Chem.MolFromSmiles(alkene_smi))
    if len(reactants) > 1:
        halogen_smi = reactants[1]
        halogen_mol = Chem.AddHs(Chem.MolFromSmiles(halogen_smi))
    else:
        halogen_mol = None
        continue
    
    if halogen_mol.HasSubstructMatch(Chem.MolFromSmarts("[$(OS(=O)(=O)C(F)(F)F)]")) and not halogen_mol.HasSubstructMatch(Chem.MolFromSmarts("[#53][#6;$([#6]=[*]),$(c:*)]")):
        triflate_rxns.append(rxn_id)

df_cationic = df_bi[df_bi['Reaction ID'].isin(triflate_rxns) | df_bi['Reagent'].str.contains('silver') | df_bi['Reagent'].str.contains('thallium')]

print('Cationic conditions RXNS', df_cationic.shape[0])

# Analyse data

In [None]:
"""Heck reactions prefer I >> OTf > Br >> Cl"""
# Iodine
pattern_iodine = Chem.MolFromSmarts("[#53][#6;$([#6]=[*]),$(c:*)]")

# Triflate
pattern_triflate = Chem.MolFromSmarts("[$(OS(=O)(=O)C(F)(F)F)][#6;$([#6]=[*]),$(c:*)]")

# Bromine
pattern_bromine = Chem.MolFromSmarts("[#35][#6;$([#6]=[*]),$(c:*)]")

# Chlorine
pattern_chlorine = Chem.MolFromSmarts("[#17][#6;$([#6]=[*]),$(c:*)]")

# Search patterns for alkene
# alkene_patt = Chem.MolFromSmarts('[*:3][#6:2]([*:5])=&!@[#6:1]([#1,#6:6])[*:4]')
alkene_patt = Chem.MolFromSmarts('[*;!#1:3][#6:1]([#1:4])=&!@[#6:2]([#1:5])[#1:6]')


rxn_type_all = np.zeros(4)
rxn_reacting_halide = [[] for x in range(4)]

rxn_disobeying_rule = [[] for x in range(3)]
rxn_site_selec_halogen = [[] for x in range(3)]
rxn_site_selec_alkene = [[] for x in range(3)]
rxn_not_identified = [[] for x in range(3)]
rxn_site_multiple_halogen = [[] for x in range(3)]

rxn_number = -1
for k, reaction in df.groupby(['RXN type']):
    rxn_number += 1
    for idx, row in reaction.iterrows():

        rxn_id = row['Reaction ID']
        prod_mols = [Chem.MolFromSmiles(prod_smi) for prod_smi in row['Product']]
        prod_mol = prod_mols[0]

        reactants = row['Reactant(s)']
        alkene_smi = reactants[0]
        alkene_mol = Chem.AddHs(Chem.MolFromSmiles(alkene_smi))
        if len(reactants) > 1:
            halogen_smi = reactants[1]
            halogen_mol = Chem.AddHs(Chem.MolFromSmiles(halogen_smi))
        else:
            halogen_mol = None
            continue

        # Find reactions with more than one of the same halogen
        reac_alkene_matches = np.zeros(4)
        reac_alkene_matches[0] += len(alkene_mol.GetSubstructMatches(pattern_iodine))
        reac_alkene_matches[1] += len(alkene_mol.GetSubstructMatches(pattern_triflate))
        reac_alkene_matches[2] += len(alkene_mol.GetSubstructMatches(pattern_bromine))
        reac_alkene_matches[3] += len(alkene_mol.GetSubstructMatches(pattern_chlorine))
        reac_halogen_matches = copy.deepcopy(reac_alkene_matches)

        reac_halogen_matches[0] += len(halogen_mol.GetSubstructMatches(pattern_iodine))
        reac_halogen_matches[1] += len(halogen_mol.GetSubstructMatches(pattern_triflate))
        reac_halogen_matches[2] += len(halogen_mol.GetSubstructMatches(pattern_bromine))
        reac_halogen_matches[3] += len(halogen_mol.GetSubstructMatches(pattern_chlorine))
        
        prod_halogen_matches = np.zeros(4)
        prod_halogen_matches[0] += len(prod_mol.GetSubstructMatches(pattern_iodine))
        prod_halogen_matches[1] += len(prod_mol.GetSubstructMatches(pattern_triflate))
        prod_halogen_matches[2] += len(prod_mol.GetSubstructMatches(pattern_bromine))
        prod_halogen_matches[3] += len(prod_mol.GetSubstructMatches(pattern_chlorine))

        rxn_type = reac_halogen_matches - prod_halogen_matches
        if sum(rxn_type) > 1:
            rxn_not_identified[rxn_number].append(rxn_id)
            print(rxn_id, rxn_type)
            # rxn_smiles = df[df['Reaction ID'] == rxn_id]['Reaction'].values[0]
            # display(Chem.rdChemReactions.ReactionFromSmarts(rxn_smiles, useSmiles=True))
            # display(Chem.Draw.MolsToGridImage([alkene_mol, halogen_mol, prod_mol]))
        else:
            rxn_type_all += rxn_type
        
        reac_halogen_matches = reac_halogen_matches - reac_alkene_matches
        
        if min(np.where(reac_halogen_matches >= 1)[0]) != min(np.where(rxn_type >= 1)[0]):
            # print('Disobeying Heck reaction rule (I >> OTf > Br >> Cl)')
            rxn_disobeying_rule[rxn_number].append(rxn_id)

        if reac_halogen_matches[min(np.where(rxn_type >= 1)[0])] >= 2:
            # print('More than one of the same reacting halogen atom')
            patts = [pattern_iodine, pattern_triflate, pattern_bromine, pattern_chlorine]
            patt = patts[min(np.where(rxn_type >= 1)[0])]
            if len(remove_identical_atoms(halogen_mol, patt)) >= 3:
                rxn_site_selec_halogen[rxn_number].append(rxn_id)
        
        if sum(reac_halogen_matches) >= 2:
            # print('More than one halogen atom')
            patts = [pattern_iodine, pattern_triflate, pattern_bromine, pattern_chlorine]
            patts = [patts[i] for i in range(len(patts)) if i in np.where(reac_halogen_matches >= 1)[0]]
            
            number_of_unique_halogens = 0
            for i, patt in enumerate(patts):
                number_of_unique_halogens += len(remove_identical_atoms(halogen_mol, patt))
            
            if number_of_unique_halogens >= 3:
                rxn_site_multiple_halogen[rxn_number].append(rxn_id)
                rxn_reacting_halide[min(np.where(rxn_type >= 1)[0])].append(rxn_id)


        # Find reactions with more than one reacting double bond
        if len(alkene_mol.GetSubstructMatches(alkene_patt)) >= 2:
            # print('More than one reacting double bond')
            if len(remove_identical_atoms(alkene_mol, alkene_patt)) > 6:
                rxn_site_selec_alkene[rxn_number].append(rxn_id)


print('Reaction types [I, OTf, Br, Cl]:', rxn_type_all)
print('Total number of identified reactions:', int(sum(rxn_type_all)), f'(should be {df.shape[0]})')
print('Total number of NOT identified reactions:', df.shape[0] - int(sum(rxn_type_all)) )
print('Number of reactions disobeying Heck reaction rule:', len(np.concatenate(rxn_disobeying_rule)))
print('Number of reactions with more than one of the same halogen:', len(np.concatenate(rxn_site_selec_halogen)))
print('Number of reactions with more than one reacting double bond:', len(np.concatenate(rxn_site_selec_alkene)))
print('Number of reactions with more than one halogen:', len(np.concatenate(rxn_site_multiple_halogen)))

rxns_to_leave_out = np.concatenate([np.concatenate(rxn_site_selec_alkene), np.concatenate(rxn_site_selec_halogen)])

In [None]:
def determine_pred(energies, names, cutoff):

    e_vals = np.array(energies)-min(energies)
    pred_prod_idx = np.where(np.array(e_vals) <= cutoff)[0]
    pred_prod_names = [names[i] for i in range(len(names)) if i in pred_prod_idx]
    pred_prod_names = [n.split('_')[2] if n.split('_')[2] in ['31', '29', '37', '35', '17', '09'] else n.split('_')[1] for n in pred_prod_names]
    
    alpha_names = ['4', '2', '31', '37', '5', '17'] # RXN names for alpha
    beta_names = ['3', '1', '29', '35', '6', '09'] # RXN names for beta
    if len(set(pred_prod_names) & set(alpha_names)) and len(set(pred_prod_names) & set(beta_names)):
        pred_prod = 'mixture'
    elif len(set(pred_prod_names) & set(alpha_names)):
        pred_prod = 'alpha'
    elif len(set(pred_prod_names) & set(beta_names)):
        pred_prod = 'beta'
    else:
        print(e_vals, names)

    return pred_prod_names, pred_prod


### SELECT CUTOFF (0 kcal/mol: TOP-1) ###
cutoff = 12.6 #1 kcal/mol = 4.2 kJ/mol, 3 kcal/mol = 12.6 kJ/mol

# RXNs with/without errors
rxns_with_errors = []
rxns_with_multiple_selectivity = []
rxns_without_errors = []

# Actual product
actual_product = []

# Neutral lists
pred_xtb_neu = []
pred_all_neu = []

correct_neu_xtb = []
correct_neu_all = []
failed_neu_xtb = []
failed_neu_all = []

# Cationic lists
pred_xtb_cat = []
pred_all_cat = []

correct_cat_xtb = []
correct_cat_all = []
failed_cat_xtb = []
failed_cat_all = []

for idx, row in df.iterrows():

    rxn_id = row['Reaction ID']
    if rxn_id in rxns_to_leave_out:
        rxns_with_errors.append(rxn_id)
        continue

    rxn_type = row['RXN type'].split('_')[-1]
    prod_smis = [Chem.MolToSmiles(Chem.MolFromSmiles(prod_smi), isomericSmiles=False) for prod_smi in row['Product']]
    reactants = row['Reactant(s)']
    alkene_smi = reactants[0]
    alkene_mol = Chem.AddHs(Chem.MolFromSmiles(alkene_smi))
    if len(reactants) > 1:
        halogen_smi = reactants[1]
        halogen_mol = Chem.AddHs(Chem.MolFromSmiles(halogen_smi))
    else:
        halogen_mol = None
        continue

    names_neu = row['xtb_one_all_names_neu']
    products_neu = [Chem.MolToSmiles(m, isomericSmiles=False) for m in row['xtb_one_all_products_neu']]
    xtb_energies_neu = row['xtb_energies_neu']
    all_energies_neu = row['all_energies_neu']
    
    names_cat = row['xtb_one_all_names_cat']
    products_cat = [Chem.MolToSmiles(m, isomericSmiles=False) for m in row['xtb_one_all_products_cat']]
    xtb_energies_cat = row['xtb_energies_cat']
    all_energies_cat = row['all_energies_cat']
    
    if len(names_cat) > 2 or len(names_neu) > 4:
        rxns_with_multiple_selectivity.append(rxn_id)
        
        patts = np.array([pattern_iodine, pattern_triflate, pattern_bromine, pattern_chlorine])
        patt = patts[[i for i in range(4) if np.any(np.array(rxn_reacting_halide[i]) == rxn_id)]][0]
        
        names_neu = []
        products_neu = []
        xtb_energies_neu = []
        all_energies_neu = []
        for i, pred_prod_mol in enumerate(row['xtb_one_all_products_neu']):
            if len(halogen_mol.GetSubstructMatches(patt)) + len(alkene_mol.GetSubstructMatches(patt)) - len(pred_prod_mol.GetSubstructMatches(patt)) > 0:
                names_neu.append(row['xtb_one_all_names_neu'][i])
                products_neu.append(Chem.MolToSmiles(pred_prod_mol, isomericSmiles=False))
                xtb_energies_neu.append(row['xtb_energies_neu'][i])
                all_energies_neu.append(row['all_energies_neu'][i])

        names_cat = []
        products_cat = []
        xtb_energies_cat = []
        all_energies_cat = []
        for i, pred_prod_mol in enumerate(row['xtb_one_all_products_cat']):
            if len(halogen_mol.GetSubstructMatches(patt)) + len(alkene_mol.GetSubstructMatches(patt)) - len(pred_prod_mol.GetSubstructMatches(patt)) > 0:
                names_cat.append(row['xtb_one_all_names_cat'][i])
                products_cat.append(Chem.MolToSmiles(pred_prod_mol, isomericSmiles=False))
                xtb_energies_cat.append(row['xtb_energies_cat'][i])
                all_energies_cat.append(row['all_energies_cat'][i])


    if (all_energies_neu == all_energies_neu) and (xtb_energies_neu == xtb_energies_neu) and \
        (60000.0 not in all_energies_neu) and (60000.0 not in xtb_energies_neu) and \
        (not np.any(np.isnan(all_energies_neu))) and (not np.any(np.isnan(xtb_energies_neu))) and \
        (all_energies_cat == all_energies_cat) and (xtb_energies_cat == xtb_energies_cat) and \
        (60000.0 not in all_energies_cat) and (60000.0 not in xtb_energies_cat) and \
        (not np.any(np.isnan(all_energies_cat))) and (not np.any(np.isnan(xtb_energies_cat))):

        rxns_without_errors.append(rxn_id)
        actual_product.append(rxn_type)

        pred_prod_names_xtb_neu, pred_prod_xtb_neu = determine_pred(xtb_energies_neu, names_neu, cutoff)
        pred_xtb_neu.append(pred_prod_xtb_neu)

        pred_prod_names_all_neu, pred_prod_all_neu = determine_pred(all_energies_neu, names_neu, cutoff)
        pred_all_neu.append(pred_prod_all_neu)

        pred_prod_names_xtb_cat, pred_prod_xtb_cat = determine_pred(xtb_energies_cat, names_cat, cutoff)
        pred_xtb_cat.append(pred_prod_xtb_cat)

        pred_prod_names_all_cat, pred_prod_all_cat = determine_pred(all_energies_cat, names_cat, cutoff)
        pred_all_cat.append(pred_prod_all_cat)

        
        if len(set(prod_smis) - set([products_neu[i] for i in range(len(products_neu)) if xtb_energies_neu[i]-min(xtb_energies_neu) <= cutoff])) == 0:
            correct_neu_xtb.append(rxn_id)
        else:
            failed_neu_xtb.append(rxn_id)
        
        if len(set(prod_smis) - set([products_neu[i] for i in range(len(products_neu)) if all_energies_neu[i]-min(all_energies_neu) <= cutoff])) == 0:
            correct_neu_all.append(rxn_id)
        else:
            failed_neu_all.append(rxn_id)
        

        if len(set(prod_smis) - set([products_cat[i] for i in range(len(products_cat)) if xtb_energies_cat[i]-min(xtb_energies_cat) <= cutoff])) == 0:
            correct_cat_xtb.append(rxn_id)
        else:
            failed_cat_xtb.append(rxn_id)

        if len(set(prod_smis) - set([products_cat[i] for i in range(len(products_cat)) if all_energies_cat[i]-min(all_energies_cat) <= cutoff])) == 0:
            correct_cat_all.append(rxn_id)
        else:
            failed_cat_all.append(rxn_id)

    else:
        rxns_with_errors.append(rxn_id)

In [None]:
print(len(actual_product))
print(len(rxns_with_errors))

In [None]:
cm_analysis(actual_product, pred_all_neu, ['alpha','beta','mixture'], ymap=None, figsize=(6,6), filename=None)

In [None]:
cm_analysis(actual_product, pred_all_cat, ['alpha','beta','mixture'], ymap=None, figsize=(6,6), filename=None)

In [None]:
df[df['Reaction ID'].isin(rxns_without_errors)]

# More analysis

In [None]:
print('Actual Product:', list(zip(*np.unique(actual_product, return_counts=True))), np.unique(actual_product, return_counts=True)[-1]/len(actual_product)*100)
print('xtb -  neutral:', list(zip(*np.unique(pred_xtb_neu, return_counts=True))), np.unique(pred_xtb_neu, return_counts=True)[-1]/len(pred_xtb_neu)*100)
print('dft -  neutral:', list(zip(*np.unique(pred_all_neu, return_counts=True))), np.unique(pred_all_neu, return_counts=True)[-1]/len(pred_all_neu)*100)
print('xtb - cationic:', list(zip(*np.unique(pred_xtb_cat, return_counts=True))), np.unique(pred_xtb_cat, return_counts=True)[-1]/len(pred_xtb_cat)*100)
print('dft - cationic:', list(zip(*np.unique(pred_all_cat, return_counts=True))), np.unique(pred_all_cat, return_counts=True)[-1]/len(pred_all_cat)*100)

In [None]:
pred_xtb_neu_clear_winner = len(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) + len(np.where(np.array(pred_xtb_neu) == 'beta')[0])
pred_xtb_neu_clear_winner_correct = len(set(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0])) + len(set(np.where(np.array(pred_xtb_neu) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0]))
print('xtb - neutral pathway correct when it predict a clear winner:', f'{pred_xtb_neu_clear_winner_correct/pred_xtb_neu_clear_winner*100} ({pred_xtb_neu_clear_winner_correct} / {pred_xtb_neu_clear_winner})')

pred_all_neu_clear_winner = len(np.where(np.array(pred_all_neu) == 'alpha')[0]) + len(np.where(np.array(pred_all_neu) == 'beta')[0])
pred_all_neu_clear_winner_correct = len(set(np.where(np.array(pred_all_neu) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0])) + len(set(np.where(np.array(pred_all_neu) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0]))
print('dft - neutral pathway correct when it predict a clear winner:', f'{pred_all_neu_clear_winner_correct/pred_all_neu_clear_winner*100} ({pred_all_neu_clear_winner_correct} / {pred_all_neu_clear_winner})')

pred_xtb_cat_clear_winner = len(np.where(np.array(pred_xtb_cat) == 'alpha')[0]) + len(np.where(np.array(pred_xtb_cat) == 'beta')[0])
pred_xtb_cat_clear_winner_correct = len(set(np.where(np.array(pred_xtb_cat) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0])) + len(set(np.where(np.array(pred_xtb_cat) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0]))
print('xtb - cationic pathway correct when it predict a clear winner:', f'{pred_xtb_cat_clear_winner_correct/pred_xtb_cat_clear_winner*100} ({pred_xtb_cat_clear_winner_correct} / {pred_xtb_cat_clear_winner})')

pred_all_cat_clear_winner = len(np.where(np.array(pred_all_cat) == 'alpha')[0]) + len(np.where(np.array(pred_all_cat) == 'beta')[0])
pred_all_cat_clear_winner_correct = len(set(np.where(np.array(pred_all_cat) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0])) + len(set(np.where(np.array(pred_all_cat) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0]))
print('dft - cationic pathway correct when it predict a clear winner:', f'{pred_all_cat_clear_winner_correct/pred_all_cat_clear_winner*100} ({pred_all_cat_clear_winner_correct} / {pred_all_cat_clear_winner})')

In [None]:
pred_xtb_neu_mixture = len(np.where(np.array(pred_xtb_neu) == 'mixture')[0])
pred_xtb_neu_mixture_correct = len(set(np.where(np.array(pred_xtb_neu) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0]))
print('xtb - neutral pathway correct when it predict a mixture:', f'{pred_xtb_neu_mixture_correct/pred_xtb_neu_mixture*100} ({pred_xtb_neu_mixture_correct} / {pred_xtb_neu_mixture})')

pred_all_neu_mixture = len(np.where(np.array(pred_all_neu) == 'mixture')[0])
pred_all_neu_mixture_correct = len(set(np.where(np.array(pred_all_neu) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0]))
print('dft - neutral pathway correct when it predict a mixture:', f'{pred_all_neu_mixture_correct/pred_all_neu_mixture*100} ({pred_all_neu_mixture_correct} / {pred_all_neu_mixture})')

pred_xtb_cat_mixture = len(np.where(np.array(pred_xtb_cat) == 'mixture')[0])
pred_xtb_cat_mixture_correct = len(set(np.where(np.array(pred_xtb_cat) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0]))
print('xtb - cationic pathway correct when it predict a mixture:', f'{pred_xtb_cat_mixture_correct/pred_xtb_cat_mixture*100} ({pred_xtb_cat_mixture_correct} / {pred_xtb_cat_mixture})')

pred_all_cat_mixture = len(np.where(np.array(pred_all_cat) == 'mixture')[0])
pred_all_cat_mixture_correct = len(set(np.where(np.array(pred_all_cat) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0]))
print('dft - cationic pathway correct when it predict a mixture:', f'{pred_all_cat_mixture_correct/pred_all_cat_mixture*100} ({pred_all_cat_mixture_correct} / {pred_all_cat_mixture})')

In [None]:
print('xtb - both pathways agrees on alpha:', len(set(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) & set(np.where(np.array(pred_xtb_cat) == 'alpha')[0])) / len(actual_product) * 100, f"({len(set(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) & set(np.where(np.array(pred_xtb_cat) == 'alpha')[0]))} / {len(actual_product)})")
print('xtb - success rate when both pathways agrees on alpha:', (len(set(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) & set(np.where(np.array(pred_xtb_cat) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0])) / len(set(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) & set(np.where(np.array(pred_xtb_cat) == 'alpha')[0]))) * 100, f"({len(set(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) & set(np.where(np.array(pred_xtb_cat) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0]))} / {len(set(np.where(np.array(pred_xtb_neu) == 'alpha')[0]) & set(np.where(np.array(pred_xtb_cat) == 'alpha')[0]))})")
print('dft - both pathways agrees on alpha:', len(set(np.where(np.array(pred_all_neu) == 'alpha')[0]) & set(np.where(np.array(pred_all_cat) == 'alpha')[0])) / len(actual_product) * 100, f"({len(set(np.where(np.array(pred_all_neu) == 'alpha')[0]) & set(np.where(np.array(pred_all_cat) == 'alpha')[0]))} / {len(actual_product)})")
print('dft - success rate when both pathways agrees on alpha:', (len(set(np.where(np.array(pred_all_neu) == 'alpha')[0]) & set(np.where(np.array(pred_all_cat) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0])) / len(set(np.where(np.array(pred_all_neu) == 'alpha')[0]) & set(np.where(np.array(pred_all_cat) == 'alpha')[0]))) * 100, f"({len(set(np.where(np.array(pred_all_neu) == 'alpha')[0]) & set(np.where(np.array(pred_all_cat) == 'alpha')[0]) & set(np.where(np.array(actual_product) == 'alpha')[0]))} / {len(set(np.where(np.array(pred_all_neu) == 'alpha')[0]) & set(np.where(np.array(pred_all_cat) == 'alpha')[0]))})")

In [None]:
print('xtb - both pathways agrees on beta:', len(set(np.where(np.array(pred_xtb_neu) == 'beta')[0]) & set(np.where(np.array(pred_xtb_cat) == 'beta')[0])) / len(actual_product) * 100, f"({len(set(np.where(np.array(pred_xtb_neu) == 'beta')[0]) & set(np.where(np.array(pred_xtb_cat) == 'beta')[0]))} / {len(actual_product)})")
print('xtb - success rate when both pathways agrees on beta:', (len(set(np.where(np.array(pred_xtb_neu) == 'beta')[0]) & set(np.where(np.array(pred_xtb_cat) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0])) / len(set(np.where(np.array(pred_xtb_neu) == 'beta')[0]) & set(np.where(np.array(pred_xtb_cat) == 'beta')[0]))) * 100, f"({len(set(np.where(np.array(pred_xtb_neu) == 'beta')[0]) & set(np.where(np.array(pred_xtb_cat) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0]))} / {len(set(np.where(np.array(pred_xtb_neu) == 'beta')[0]) & set(np.where(np.array(pred_xtb_cat) == 'beta')[0]))})")
print('dft - both pathways agrees on beta:', len(set(np.where(np.array(pred_all_neu) == 'beta')[0]) & set(np.where(np.array(pred_all_cat) == 'beta')[0])) / len(actual_product) * 100, f"({len(set(np.where(np.array(pred_all_neu) == 'beta')[0]) & set(np.where(np.array(pred_all_cat) == 'beta')[0]))} / {len(actual_product)})")
print('dft - success rate when both pathways agrees on beta:', (len(set(np.where(np.array(pred_all_neu) == 'beta')[0]) & set(np.where(np.array(pred_all_cat) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0])) / len(set(np.where(np.array(pred_all_neu) == 'beta')[0]) & set(np.where(np.array(pred_all_cat) == 'beta')[0]))) * 100, f"({len(set(np.where(np.array(pred_all_neu) == 'beta')[0]) & set(np.where(np.array(pred_all_cat) == 'beta')[0]) & set(np.where(np.array(actual_product) == 'beta')[0]))} / {len(set(np.where(np.array(pred_all_neu) == 'beta')[0]) & set(np.where(np.array(pred_all_cat) == 'beta')[0]))})")

In [None]:
print('xtb - both pathways agrees on mixture:', len(set(np.where(np.array(pred_xtb_neu) == 'mixture')[0]) & set(np.where(np.array(pred_xtb_cat) == 'mixture')[0])) / len(actual_product) * 100, f"({len(set(np.where(np.array(pred_xtb_neu) == 'mixture')[0]) & set(np.where(np.array(pred_xtb_cat) == 'mixture')[0]))} / {len(actual_product)})")
print('xtb - success rate when both pathways agrees on mixture:', (len(set(np.where(np.array(pred_xtb_neu) == 'mixture')[0]) & set(np.where(np.array(pred_xtb_cat) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0])) / len(set(np.where(np.array(pred_xtb_neu) == 'mixture')[0]) & set(np.where(np.array(pred_xtb_cat) == 'mixture')[0]))) * 100, f"({len(set(np.where(np.array(pred_xtb_neu) == 'mixture')[0]) & set(np.where(np.array(pred_xtb_cat) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0]))} / {len(set(np.where(np.array(pred_xtb_neu) == 'mixture')[0]) & set(np.where(np.array(pred_xtb_cat) == 'mixture')[0]))})")
print('dft - both pathways agrees on mixture:', len(set(np.where(np.array(pred_all_neu) == 'mixture')[0]) & set(np.where(np.array(pred_all_cat) == 'mixture')[0])) / len(actual_product) * 100, f"({len(set(np.where(np.array(pred_all_neu) == 'mixture')[0]) & set(np.where(np.array(pred_all_cat) == 'mixture')[0]))} / {len(actual_product)})")
print('dft - success rate when both pathways agrees on mixture:', (len(set(np.where(np.array(pred_all_neu) == 'mixture')[0]) & set(np.where(np.array(pred_all_cat) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0])) / len(set(np.where(np.array(pred_all_neu) == 'mixture')[0]) & set(np.where(np.array(pred_all_cat) == 'mixture')[0]))) * 100, f"({len(set(np.where(np.array(pred_all_neu) == 'mixture')[0]) & set(np.where(np.array(pred_all_cat) == 'mixture')[0]) & set(np.where(np.array(actual_product) == 'mixture')[0]))} / {len(set(np.where(np.array(pred_all_neu) == 'mixture')[0]) & set(np.where(np.array(pred_all_cat) == 'mixture')[0]))})")

In [None]:
pathways_agree = []
pathways_agree_correct = []
pathways_disagree = []
pathways_disagree_neu_correct = []
pathways_disagree_cat_correct = []
for i in range(len(actual_product)):
    if pred_all_neu[i] == pred_all_cat[i]: #the two pathways agree
        pathways_agree.append(i)

        if pred_all_neu[i] == actual_product[i]: #prediction is correct
            pathways_agree_correct.append(i)

    elif pred_all_neu[i] != pred_all_cat[i]: #the two pathways disagree
        pathways_disagree.append(i)

        if pred_all_neu[i] == actual_product[i]: #neutral pathway correct
            pathways_disagree_neu_correct.append(i)
        elif pred_all_cat[i] == actual_product[i]: #cationic pathway correct
            pathways_disagree_cat_correct.append(i)

print('dft - neutral and cationic pathways agree:', len(pathways_agree))
print('dft - neutral and cationic pathways agree and the prediction is correct:', len(pathways_agree_correct))
print('')
print('dft - neutral and cationic pathways disagree:', len(pathways_disagree))
print('dft - neutral and cationic pathways disagree and neutral pathway is correct:', len(pathways_disagree_neu_correct))
print('dft - neutral and cationic pathways disagree and cationic pathway is correct:', len(pathways_disagree_cat_correct))