In [None]:
import os
import json
import umap # 0.5.3
import pandas as pd # 1.5.0
import numpy as np # 1.23.5

import matplotlib.pyplot as plt # 3.6.2
import seaborn as sns # 0.12.0

import plotly # 5.10.0
import plotly.express as px
import plotly.graph_objects as go

from rdkit import Chem # 2023.03.3
from scipy import stats # 1.9.1
from sklearn.metrics import mean_absolute_error # 1.2.2
from prettytable import *

# 1) Load data

In [None]:
# Dictionary to format and order central atom classes.

formatted_ca = {
    "B_3": "B(III)",
    "Al_3": "Al(III)",
    "Ga_3": "Ga(III)",
    "In_3": "In(III)",
    "Si_2": "Si(II)",
    "Ge_2": "Ge(II)",
    "Sn_2": "Sn(II)",
    "Pb_2": "Pb(II)",
    "Si_4": "Si(IV)",
    "Ge_4": "Ge(IV)",
    "Sn_4": "Sn(IV)",
    "Pb_4": "Pb(IV)",
    "P_3": "P(III)",
    "As_3": "As(III)",
    "Sb_3": "Sb(III)",
    "Bi_3": "Bi(III)",
    "P_5": "P(V)",
    "As_5": "As(V)",
    "Bi_5": "Bi(V)",
    "Sb_5": "Sb(V)",
    "Te_4": "Te(IV)"
}

In [None]:
# Load predicted FIA values.

fia_predictions = pd.read_csv(os.path.join((os.path.split(os.getcwd())[0]), "data", "FIA49k_predictions.csv.gz"))
fia_predictions.shape

In [None]:
# Add absolute errors to the data frame.

fia_predictions["abs_error_fia_gas-DSDBLYP"] = abs(
    fia_predictions["fia_gas-DSDBLYP"] - fia_predictions["pred_fia_gas-DSDBLYP"]
)

fia_predictions["abs_error_fia_solv-DSDBLYP"] = abs(
    fia_predictions["fia_solv-DSDBLYP"] - fia_predictions["pred_fia_solv-DSDBLYP"]
)

fia_predictions.shape

In [None]:
# Load general data file.

general_data = pd.read_csv(os.path.join((os.path.split(os.getcwd())[0]), "data", "FIA49k.csv.gz"))
general_data = general_data[general_data["Compound"].isin(fia_predictions["Compound"])]
general_data.shape

In [None]:
# Formatting.

fia_predictions = fia_predictions.merge(general_data[[
    "Compound", 
    "ca_class", 
    "denticity_class",
    "la_smiles"
]], on="Compound")

fia_predictions.shape

# 2) Analyze FIA predictions

In [None]:
# These are some functions to analyze the data.

def get_r2(x, y):
    """
    Get the squared Pearson correlation coefficient.
    """
    if len(x) < 2:
        return 0
    else:
        return round((stats.pearsonr(x, y)[0]**2), 4)


def get_mae(x, y):
    """
    Get the mean absolute error.
    """
    if len(x) < 2:
        return 0
    else:
        return round((mean_absolute_error(x, y)), 4)


def get_spearman_rank(x, y):
    """
    Get the squared Spearman rank coefficient.
    """
    if len(x) < 2:
        return 0
    else:
        return round((stats.spearmanr(x, y)[0]**2), 4)

def make_plot(df, fia_type, set_assignment, hover, hue=None, add_line=False):
    """
    This function makes a parity plot between true and predicted FIA values
    ...
    Arguments
    ---------
    df: pd.DataFrame
        Data to be plotted.
    fia_type: str
        FIA type to be plotted, must be either "fia_gas-DSDBLYP" or "fia_solv-DSDBLYP".
    set_assignment: str
        This allows to limit the data set to a certain subgroup.
    hue: str
        This allows to color-code the markers in the scatter plot.
    add_line: bool
        Whether or not to add a linear regression line to each hue class.
        
    """
    df = df.loc[df["set_assignment"] == set_assignment]
    
    hover_data = ["Compound"]
    for i in hover:
        hover_data.append(i)
    
    if add_line:
        trendline = "ols"
        trendline_color_override = "red"
    else:
        trendline = None
        trendline_color_override = None
        
    fig = px.scatter(
        df,
        x = fia_type,
        y = f"pred_{fia_type}",
        color = hue,
        hover_data = hover_data,
        trendline = trendline,
        trendline_color_override = trendline_color_override,
        color_discrete_sequence = px.colors.qualitative.Alphabet
    )
    
    if add_line and hue != None:
        try:
            print()
            print(f"hue was '{hue}'.")
            print()
            print(f"class\t\tr_2")
            print("-----------------------------")
            results = px.get_trendline_results(fig)
            df_dict = results.to_dict(orient="index")
            for i in df_dict:
                print(f"{df_dict[i][hue]}\t\t{round((df_dict[i]['px_fit_results'].rsquared), 4)}")
        except:
            print("provided hue is continuous")
    
    fig.add_trace(
        go.Scatter(
            x=[
                min(df[fia_type]) - 20, 
                max(df[fia_type]) + 20
            ], 
            y=[
                min(df[fia_type]) - 20, 
                max(df[fia_type]) + 20
            ],
            name="perfection", 
            line_shape='linear')
    )
    
    fig.show()


def get_final_results(df, set_assignment=None, hue=None, hue_sort_by=None):
    """
    This function plots final results (MAE, r2, spearman rank**2).
    ...
    Arguments
    ---------
    df: pd.DataFrame
        Data to be analyzed.
    set_assignment: str
        This allows to limit the data set to a certain subgroup.
    hue: str
        This allows to do the analysis for certain subgroups of the data set separately.
    hue_sort_by: str
        This allows to sort the data calculated for the hue classes by a certain value.
    """
    print("#################")
    print("# Final results #")
    print("#################")
    print()
    
    if set_assignment is not None:
        df = df.loc[df["set_assignment"] == set_assignment]
    
    print("-------")
    print("Overall")
    print("-------")
    mae = get_mae(df['fia_gas-DSDBLYP'], df[f'pred_fia_gas-DSDBLYP'])
    r2 = get_r2(df['fia_gas-DSDBLYP'], df[f'pred_fia_gas-DSDBLYP'])
    sm_rank = get_spearman_rank(df['fia_gas-DSDBLYP'], df[f'pred_fia_gas-DSDBLYP'])
    print(f"> FIA_gas\tMAE: {mae}\tr2: {r2}\tSpearman coeff **2: {sm_rank}")
    mae = get_mae(df['fia_solv-DSDBLYP'], df[f'pred_fia_solv-DSDBLYP'])
    r2 = get_r2(df['fia_solv-DSDBLYP'], df[f'pred_fia_solv-DSDBLYP'])
    sm_rank = get_spearman_rank(df['fia_solv-DSDBLYP'], df[f'pred_fia_solv-DSDBLYP'])
    print(f"> FIA_solv\tMAE: {mae}\tr2: {r2}\tSpearman coeff **2: {sm_rank}")
    print()
    
    print("-----")
    print("Hue's")
    print("-----")
    if hue is not None:
        if hue_sort_by is None:
            hue_sort_by = "Subclass"
            
        hue_classes = list(set(df[hue]))
        hue_members = []
        hue_maes_gas = []
        hue_maes_solv = []
        hue_r2s_gas = []
        hue_r2s_solv = []
        hue_sm_rank_gas = []
        hue_sm_rank_solv = []
        for hue_class in hue_classes:
            hue_df = df.loc[df[hue] == hue_class]
            hue_members.append(len(hue_df))
            
            hue_maes_gas.append(get_mae(hue_df["fia_gas-DSDBLYP"], hue_df["pred_fia_gas-DSDBLYP"]))
            hue_maes_solv.append(get_mae(hue_df["fia_solv-DSDBLYP"], hue_df["pred_fia_solv-DSDBLYP"]))
            
            hue_r2s_gas.append(get_r2(hue_df["fia_gas-DSDBLYP"], hue_df["pred_fia_gas-DSDBLYP"]))
            hue_r2s_solv.append(get_r2(hue_df["fia_solv-DSDBLYP"], hue_df["pred_fia_solv-DSDBLYP"]))
            
            hue_sm_rank_gas.append(get_spearman_rank(hue_df["fia_gas-DSDBLYP"], hue_df["pred_fia_gas-DSDBLYP"]))
            hue_sm_rank_solv.append(get_spearman_rank(hue_df["fia_solv-DSDBLYP"], hue_df["pred_fia_solv-DSDBLYP"]))
        
        T = PrettyTable()
        T.add_column("Subclass", hue_classes)
        T.add_column("Member count", hue_members)
        T.add_column("MAE_gas", hue_maes_gas)
        T.add_column("r2_gas", hue_r2s_gas)
        T.add_column("Sm_rank_gas", hue_sm_rank_gas)
        T.add_column("MAE_solv", hue_maes_solv)
        T.add_column("r2_solv", hue_r2s_solv)
        T.add_column("Sm_rank_solv", hue_sm_rank_solv)
        T.sortby = hue_sort_by
        T.align["Subclass"] = "l"
        print(T)
    else:
        print("None")

### 2a) Overall

In [None]:
get_final_results(
    df=fia_predictions,  
    set_assignment=None, 
    hue="set_assignment", 
)

### 2b) Test set of the FIA44k data set (set_assignment == "test")

In [None]:
# Get the data of the subset.

SET = "test"
df_test = fia_predictions.loc[fia_predictions["set_assignment"] == SET]
df_test.shape

In [None]:
get_final_results(
    df=df_test,  
    set_assignment=SET, 
    hue="ca_class",
    hue_sort_by="MAE_gas"
)

In [None]:
make_plot(
    df=df_test, 
    fia_type="fia_gas-DSDBLYP", 
    set_assignment=SET, 
    hover=["ca_class"], 
    hue="denticity_class", 
    add_line=True
)

##### Plots for publication

In [None]:
ca_maes = dict(df_test.groupby(["ca_class"])["abs_error_fia_gas-DSDBLYP"].mean())
ca_maes_df = pd.DataFrame([ca_maes]).T
ca_maes_df = ca_maes_df.reset_index(drop=False)
ca_maes_df.columns = ["ca_class", "MAE"]
ca_maes_df.shape

In [None]:
ca_r2s = {}
for ca_class, data in df_test.groupby(["ca_class"]):
    ca_r2s[ca_class] = get_r2(data["fia_gas-DSDBLYP"], data["pred_fia_gas-DSDBLYP"])

ca_r2s = {ca_class: ca_r2s[ca_class] for ca_class in [x for x in list(formatted_ca.values())]}

ca_r2s = ["%.3f" % r2 for r2 in ca_r2s.values()]

In [None]:
plt.style.use('default')
plt.figure(figsize=(2.5, 8))
plt.rcParams['figure.dpi'] = 300

ax = sns.barplot(
    data=ca_maes_df,
    x="MAE",
    y="ca_class",
    color=sns.color_palette("inferno")[1],
    edgecolor="black",
    order=[x for x in list(formatted_ca.values())],
)

for idx, p in enumerate(ax.patches):
    ax.annotate(
        ca_r2s[idx], 
        (p.get_width()+1, p.get_y() + p.get_height() / 2.),
        ha='center', 
        va='center', 
        fontsize=12, 
        xytext=(18, 0),
        color=(0, 128/255, 128/255),
        textcoords='offset points', 
               )

plt.xlabel("MAE$_{gas}$ / kJ mol$^{-1}$", fontsize=14)
plt.ylabel("Central atom class", fontsize=14)

plt.xticks(fontsize=11)
plt.yticks(fontsize=11)

ax.axvline(x=12.1119, linestyle="--", color=sns.color_palette("inferno")[4], linewidth=3)

plt.xlim(0, 28)

ax

***

In [None]:
dent_maes = dict(df_test.groupby(["denticity_class"])["abs_error_fia_gas-DSDBLYP"].mean())
dent_maes_df = pd.DataFrame([dent_maes]).T
dent_maes_df = dent_maes_df.reset_index(drop=False)
dent_maes_df.columns = ["denticity_class", "MAE"]
dent_maes_df.shape

In [None]:
dent_r2s = {}
for denticity_class, data in df_test.groupby(["denticity_class"]):
    dent_r2s[denticity_class] = get_r2(data["fia_gas-DSDBLYP"], data["pred_fia_gas-DSDBLYP"])

dent_r2s = {denticity_class: dent_r2s[denticity_class] for denticity_class in ["mono", "bi", "tri"]}

dent_r2s = ["%.3f" % r2 for r2 in dent_r2s.values()]

In [None]:
plt.style.use('default')
plt.figure(figsize=(2.1, 1.6))
plt.rcParams['figure.dpi'] = 300

ax = sns.barplot(
    data=dent_maes_df,
    x="MAE",
    y="denticity_class",
    color=sns.color_palette("inferno")[1],
    edgecolor="black",
    order=["mono", "bi", "tri"]
)

for idx, p in enumerate(ax.patches):
    ax.annotate(dent_r2s[idx], (p.get_width()+1, p.get_y() + p.get_height() / 2.),
                ha='center', va='center', fontsize=12, xytext=(18, 0),
                color=(0, 128/255, 128/255),
                textcoords='offset points')

plt.xlabel("MAE$_{gas}$ / kJ mol$^{-1}$", fontsize=14)
plt.ylabel("Denticity\nclass", fontsize=14)

plt.xticks(fontsize=11)
plt.yticks(fontsize=11)

ax.axvline(x=12.1119, linestyle="--", color=sns.color_palette("inferno")[4], linewidth=3)

plt.xlim(0, 25)

ax

### 2b) Test set of the FIA2k-CSD data set (set_assignment == "test_2")

In [None]:
# Get the data of the subset.

SET = "test_2"
df_test_2 = fia_predictions.loc[fia_predictions["set_assignment"] == SET]
df_test_2.shape

In [None]:
get_final_results(
    df=df_test_2,  
    set_assignment=SET, 
    hue="ca_class", 
    hue_sort_by="MAE_gas"
)

In [None]:
make_plot(
    df=df_test_2, 
    fia_type="fia_solv-DSDBLYP", 
    set_assignment=SET, 
    hover=["ca_class"], 
    hue="denticity_class", 
    add_line=True
)

##### Plot for publication

In [None]:
plt.style.use('default')
plt.figure(figsize=(4, 4))
plt.rcParams['figure.dpi'] = 450
ax = sns.scatterplot(
    data=df_test_2,
    x="fia_gas-DSDBLYP",
    y="pred_fia_gas-DSDBLYP",
    s=100,
    alpha=0.7,
    color=sns.color_palette("inferno")[1]
)

ax = sns.lineplot(
    x=[0,575],
    y=[0,575],
    alpha=0.9,
    color=sns.color_palette("inferno")[4],
    linestyle='--',
    linewidth = 3
)

ax.annotate("FIA2k-CSD (test)", (0, 550), fontsize=14, weight="bold", style="italic", color=sns.color_palette("inferno")[1])
ax.annotate("1,200 data points", (0, 500), fontsize=12, style="italic")
ax.annotate("MAE: 14.4 kJ mol$^{-1}$", (210, 50), fontsize=14)
ax.annotate("r$^2$: 0.905", (210, 0), fontsize=14)

plt.xlabel("FIA$_{gas}$ (DFT) / kJ mol$^{-1}$", fontsize=14)
plt.ylabel("FIA$_{gas}$ (FIA-GNN) / kJ mol$^{-1}$", fontsize=14)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

ax

##### Specific examples

In [None]:
picks = [
    "SAXRUI",
    "YEHFUQ",
    "XEKPEM",
    "WEDXAI", 
    "SENGIF",
    "POSZUS",
    "DIVXOX"
]

mols = [Chem.MolFromSmiles(smiles) 
        for name, smiles 
        in zip(df_test_2["Compound"], df_test_2["la_smiles"]) 
        if name.split("-")[-1] in picks]

labels = [f"{name}\n\n\n\nML: {round(pred, 2)}\tDFT: {round(true, 2)}" 
          for name, true, pred 
          in zip(
              df_test_2["Compound"], 
              df_test_2["fia_gas-DSDBLYP"], 
              df_test_2["pred_fia_gas-DSDBLYP"])
          if name.split("-")[-1] in picks]

In [None]:
Chem.Draw.MolsToGridImage(mols, legends=labels, molsPerRow=3, subImgSize=(400, 400))

##### Analyze outliers

In [None]:
outlier_df = df_test_2.sort_values(by="abs_error_fia_gas-DSDBLYP", ascending=False)[[
    "Compound",
    
    "fia_gas-DSDBLYP",
    "pred_fia_gas-DSDBLYP",
    "abs_error_fia_gas-DSDBLYP"
]].head(20)

In [None]:
outlier_df

### 2c) Test set FIA763-bimacro (set_assignment == "test_3")

In [None]:
# Get the data of the subset.

SET = "test_3"
df_test_3 = fia_predictions.loc[fia_predictions["set_assignment"] == SET]
df_test_3.shape

In [None]:
get_final_results(
    df=df_test_3,  
    set_assignment=SET, 
    hue="ca_class", 
    hue_sort_by="MAE_gas"
)

In [None]:
make_plot(
    df=df_test_3, 
    fia_type="fia_gas-DSDBLYP", 
    set_assignment=SET, 
    hover=["ca_class"], 
    hue="ca_class", 
    add_line=False
)

##### Plot for publication

In [None]:
plt.style.use('default')
plt.figure(figsize=(4, 4.5))
plt.rcParams['figure.dpi'] = 300
ax = sns.scatterplot(
    data=df_test_3,
    x="fia_gas-DSDBLYP",
    y="pred_fia_gas-DSDBLYP",
    s=100,
    alpha=0.7,
    color=sns.color_palette("inferno")[1]
)

ax = sns.lineplot(
    x=[40,550],
    y=[40,550],
    alpha=0.9,
    color=sns.color_palette("inferno")[4],
    linestyle='--',
    linewidth = 3
)

ax.annotate("FIA763-bimacro", (40, 525), fontsize=14, weight="bold", style="italic", color=sns.color_palette("inferno")[1])
ax.annotate("MAE: 24.1 kJ mol$^{-1}$", (40, 465), fontsize=14)
ax.annotate("r$^2$: 0.847", (40, 415), fontsize=14)

plt.xlabel("FIA$_{gas}$ (DFT) / kJ mol$^{-1}$", fontsize=14)
plt.ylabel("FIA$_{gas}$ (FIA-GNN) / kJ mol$^{-1}$", fontsize=14)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

ax

### 2d) Test set FIA911-ring4 (set_assignment == "test_4")

In [None]:
# Get the data of the subset.

SET = "test_4"
df_test_4 = fia_predictions.loc[fia_predictions["set_assignment"] == SET]
df_test_4.shape

In [None]:
get_final_results(
    df=df_test_4,  
    set_assignment=SET, 
    hue="ca_class", 
    hue_sort_by="MAE_gas"
)

In [None]:
make_plot(
    df=df_test_4, 
    fia_type="fia_gas-DSDBLYP", 
    set_assignment=SET, 
    hover=["ca_class"], 
    hue="ca_class", 
    add_line=False
)

##### Plot for publication

In [None]:
plt.style.use('default')
plt.figure(figsize=(4, 4.5))
plt.rcParams['figure.dpi'] = 300
ax = sns.scatterplot(
    data=df_test_4,
    x="fia_gas-DSDBLYP",
    y="pred_fia_gas-DSDBLYP",
    s=100,
    alpha=0.7,
    color=sns.color_palette("inferno")[1]
)

ax = sns.lineplot(
    x=[40,530],
    y=[40,530],
    alpha=0.9,
    color=sns.color_palette("inferno")[4],
    linestyle='--',
    linewidth = 3
)

ax.annotate("FIA911-ring4", (40, 508), fontsize=14, weight="bold", style="italic", color=sns.color_palette("inferno")[1])
ax.annotate("MAE: 17.9 kJ mol$^{-1}$", (40, 448), fontsize=14)
ax.annotate("r$^2$: 0.897", (40, 398), fontsize=14)

plt.xlabel("FIA$_{gas}$ (DFT) / kJ mol$^{-1}$", fontsize=14)
plt.ylabel("FIA$_{gas}$ (FIA-GNN) / kJ mol$^{-1}$", fontsize=14)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

ax

### 2e) Test set FIA15-PTcat (set_assignment == "test_5")

In [None]:
# Get the data of the subset.

SET = "test_5"
df_test_5 = fia_predictions.loc[fia_predictions["set_assignment"] == SET]
df_test_5.shape

In [None]:
get_final_results(
    df=df_test_5,  
    set_assignment=SET
)

In [None]:
make_plot(
    df=df_test_5, 
    fia_type="fia_solv-DSDBLYP", 
    set_assignment=SET, 
    hover=["ca_class"], 
    hue="denticity_class", 
    add_line=False
)

##### Plot for publication

In [None]:
plt.style.use('default')
plt.figure(figsize=(5, 4.5))
plt.rcParams['figure.dpi'] = 300

ax = sns.scatterplot(
    data=df_test_5,
    x="fia_solv-DSDBLYP",
    y="pred_fia_solv-DSDBLYP",
    s=100,
    alpha=0.7,
    color=sns.color_palette("inferno")[1],
    zorder=1
    
)

ax = sns.lineplot(
    x=[45,260],
    y=[45,260],
    alpha=0.7,
    color=sns.color_palette("inferno")[4],
    linestyle='--',
    linewidth = 3,
    zorder=2
)

plt.axvspan(95, 120, color=(209/255, 226/255, 238/255), zorder=0)
plt.axhspan(95, 120, color=(209/255, 226/255, 238/255), zorder=0)

ax.annotate("FIA15-PTCat", (42, 252), fontsize=14, weight="bold", style="italic", color=sns.color_palette("inferno")[1])
ax.annotate("MAE: 5.6 kJ mol$^{-1}$", (160,65), fontsize=14)
ax.annotate("r$^2$: 0.993", (160,45), fontsize=14)
ax.annotate("Most active phase\ntransfer catalysts", (110,190), fontsize=14, 
            ha="center", style="italic", color=(0/255, 93/255, 126/255))

plt.xlabel("FIA$_{solv}$ (DFT) / kJ mol$^{-1}$", fontsize=14)
plt.ylabel("FIA$_{solv}$ (FIA-GNN) / kJ mol$^{-1}$", fontsize=14)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)

ax

### 2f) Test set FIA31-cat (set_assignment == "test_6")

In [None]:
# Get the data of the subset.

SET = "test_6"
df_test_6 = fia_predictions.loc[fia_predictions["set_assignment"] == SET]
df_test_6 = df_test_6.reset_index(drop=True)
df_test_6.shape

In [None]:
# Classify the subset into cases with perfluoro- and perchlorocatecholato ligand

ligand_classifications = []
for idx, data in df_test_6.iterrows():
    if "(F)" in data["la_smiles"]:
        ligand_classifications.append("perfluorocatechol")
    elif "(Cl)" in data["la_smiles"]:
        ligand_classifications.append("perchlorocatechol")
    else:
        print("ERROR !!!")

df_test_6["ligand_classification"] = ligand_classifications
df_test_6.shape

In [None]:
get_final_results(
    df=df_test_6,  
    set_assignment=SET, 
    hue="ligand_classification", 
    hue_sort_by="MAE_gas"
)

In [None]:
make_plot(
    df=df_test_6, 
    fia_type="fia_gas-DSDBLYP", 
    set_assignment=SET, 
    hover=["ca_class"], 
    hue="ligand_classification", 
    add_line=True
)

In [None]:
fig = px.scatter(
    df_test_6,
    x = "ca_class",
    y = "fia_gas-DSDBLYP",
    color = "ligand_classification",
    hover_data = ["Compound", "ca_class", "ligand_classification"],
    category_orders = {"ca_class": [x for x in list(formatted_ca.values()) if x in set(df_test_6.ca_class)]}
)
fig

In [None]:
# Differences with respect to the ligands

differences = []
for i in df_test_6.groupby(["ca_class"])["fia_gas-DSDBLYP"]:
    print(f"Central atom class:   {i[0]}")
    values = list(i[1])
    print(f"FIAs:                 {values}")
    if len(values) == 2:
        dif = values[0]-values[1]
        print(f"Difference:           {dif}")
        differences.append(abs(dif))
    print()

np.mean(differences), np.std(differences)

##### Plot for publication

In [None]:
# Combine the true and predicted data in one data frame.

aux_df_dft = df_test_6.copy()
aux_df_ml = df_test_6.copy()

aux_df_dft["origin"] = ["DFT" for _ in range(len(aux_df_dft))]
aux_df_ml["origin"] = ["FIA-GNN" for _ in range(len(aux_df_ml))]

In [None]:
aux_df_dft["Compound"] = [f"DFT__{name}" for name in aux_df_dft["Compound"]]
aux_df_ml["Compound"] = [f"ML__{name}" for name in aux_df_ml["Compound"]]

aux_df_dft["FIA"] = aux_df_dft["fia_gas-DSDBLYP"]
aux_df_ml["FIA"] = aux_df_dft["pred_fia_gas-DSDBLYP"]

In [None]:
aux_df = pd.concat([aux_df_dft, aux_df_ml])
aux_df.shape

In [None]:
aux_df = aux_df[aux_df["ligand_classification"].isin(["perfluorocatechol"])]

In [None]:
plt.style.use('default')
plt.figure(figsize=(9, 4.5))
plt.rcParams['figure.dpi'] = 300

ax = sns.pointplot(
    data=aux_df,
    x = "ca_class",
    y = "FIA",
    hue="origin",
    order=[x for x in list(formatted_ca.values()) if x in set(aux_df.ca_class)],
    linestyles="",
    markers=['o', 'D'],
    palette=[sns.color_palette("inferno")[1], sns.color_palette("inferno")[4]]
)

plt.setp(
    ax.collections, 
    alpha=0.5,
    edgecolor="none",
    sizes=[200]
)

plt.xlabel("Central atom class", fontsize=14)
plt.ylabel("FIA$_{gas}$ / kJ mol$^{-1}$", fontsize=14)

plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)

sns.move_legend(ax, "upper right", ncol=2, title=None)
plt.setp(ax.get_legend().get_texts(), fontsize='14')
ax.legend_.set_title("$^{F}cat$ ligand")
plt.setp(ax.get_legend().get_title(), fontsize='16')

for lh in ax.legend_.legendHandles: 
    lh.set_alpha(1)
    lh.set_edgecolors(None)
    lh.set_sizes([200])

ax.annotate("MAE: 8.5 kJ mol$^{-1}$", (-0.2,315), fontsize=14)
ax.annotate("r$^2$: 0.963", (-0.2,295), fontsize=14)

ax

# 3) Dimensionality reduction of the learned molecular representations

In [None]:
# Load general data file.

general_data = pd.read_csv(r"F:\FIA_GENERATION\for_publication\FINAL\FIA49k.csv.gz")
general_data = general_data[general_data["set_assignment"].isin(["train", "test"])]
general_data.shape

In [None]:
# Read in FIA-GNN embeddings

learned_mol_reps = pd.read_csv(r"F:\FIA_GENERATION\for_publication\FINAL\FIA49k_fia_gnn_embeddings.csv.gz")
learned_mol_reps = learned_mol_reps[learned_mol_reps["Compound"].isin(general_data["Compound"])]
learned_mol_reps.shape

In [None]:
# Formatting.

learned_mol_reps = learned_mol_reps.merge(general_data[["Compound", "set_assignment"]], on="Compound")
learned_mol_reps.shape

In [None]:
# Do train/all splitting.

# Train sets
X_train_32_gas = learned_mol_reps.loc[learned_mol_reps["set_assignment"] == "train"]
X_train_32_gas = X_train_32_gas.drop([col for col in X_train_32_gas.columns if not col.startswith("vec32_gas_")], axis=1)

X_train_32_solv = learned_mol_reps.loc[learned_mol_reps["set_assignment"] == "train"]
X_train_32_solv = X_train_32_solv.drop([col for col in X_train_32_solv.columns if not col.startswith("vec32_solv_")], axis=1)

X_train_128_gas = learned_mol_reps.loc[learned_mol_reps["set_assignment"] == "train"]
X_train_128_gas = X_train_128_gas.drop([col for col in X_train_128_gas.columns if not col.startswith("vec128_gas_")], axis=1)

X_train_128_solv = learned_mol_reps.loc[learned_mol_reps["set_assignment"] == "train"]
X_train_128_solv = X_train_128_solv.drop([col for col in X_train_128_solv.columns if not col.startswith("vec128_solv_")], axis=1)

# All (train + test)
X_all_32_gas = learned_mol_reps.drop([col for col in learned_mol_reps.columns if not col.startswith("vec32_gas_")], axis=1)

X_all_32_solv = learned_mol_reps.drop([col for col in learned_mol_reps.columns if not col.startswith("vec32_solv_")], axis=1)

X_all_128_gas = learned_mol_reps.drop([col for col in learned_mol_reps.columns if not col.startswith("vec128_gas_")], axis=1)

X_all_128_solv = learned_mol_reps.drop([col for col in learned_mol_reps.columns if not col.startswith("vec128_solv_")], axis=1)

In [None]:
X_train_32_gas.shape, X_train_32_solv.shape, X_train_128_gas.shape, X_train_128_solv.shape

In [None]:
X_all_32_gas.shape, X_all_32_solv.shape, X_all_128_gas.shape, X_all_128_solv.shape

In [None]:
# Train reducers

# 32_gas
umap_reducer_32_gas = umap.UMAP(random_state=42)
umap_reducer_32_gas.fit(X_train_32_gas)
print("32_gas done.")

# 32_solv
umap_reducer_32_solv = umap.UMAP(random_state=42)
umap_reducer_32_solv.fit(X_train_32_solv)
print("32_solv done.")

# 128_gas
umap_reducer_128_gas = umap.UMAP(random_state=42)
umap_reducer_128_gas.fit(X_train_128_gas)
print("128_gas done.")

# 128_solv
umap_reducer_128_solv = umap.UMAP(random_state=42)
umap_reducer_128_solv.fit(X_train_128_solv)
print("128_solv done.")

In [None]:
# Get UMAP embeddings

umap_embedding_32_gas = umap_reducer_32_gas.transform(X_all_32_gas)
umap_embedding_32_solv = umap_reducer_32_solv.transform(X_all_32_solv)

umap_embedding_128_gas = umap_reducer_128_gas.transform(X_all_128_gas)
umap_embedding_128_solv = umap_reducer_128_solv.transform(X_all_128_solv)

In [None]:
# Save UMAP embeddings

general_data["UMAP_1__32_gas"] = list(umap_embedding_32_gas[:,0])
general_data["UMAP_2__32_gas"] = list(umap_embedding_32_gas[:,1])

general_data["UMAP_1__32_solv"] = list(umap_embedding_32_solv[:,0])
general_data["UMAP_2__32_solv"] = list(umap_embedding_32_solv[:,1])

general_data["UMAP_1__128_gas"] = list(umap_embedding_128_gas[:,0])
general_data["UMAP_2__128_gas"] = list(umap_embedding_128_gas[:,1])

general_data["UMAP_1__128_solv"] = list(umap_embedding_128_solv[:,0])
general_data["UMAP_2__128_solv"] = list(umap_embedding_128_solv[:,1])

In [None]:
fig = px.scatter(
    general_data,
    x = "UMAP_1__128_gas",
    y = "UMAP_2__128_gas",
    color = "ca_class",
    hover_data = ["Compound"],
    color_discrete_sequence = px.colors.qualitative.Alphabet
)
fig

##### Plots for publication

In [None]:
# 32, FIA_gas

plt.rcParams['figure.dpi'] = 300

ax = sns.scatterplot(
    data=general_data,
    x="UMAP_1__32_gas",
    y="UMAP_2__32_gas",
    hue="fia_gas-DSDBLYP",
    palette="inferno",
    edgecolor="black",
)

ax.set_xlabel("UMAP 1", size=14)
ax.set_ylabel("UMAP 2", size=14)

sns.move_legend(ax, "lower center", ncol=5, title="FIA$_{gas}$ / kJ mol$^{-1}$")

ax.set(xticklabels=[])
ax.set(yticklabels=[])
ax.tick_params(left=False, bottom=False)

ax.collections[0].set_sizes([25])

plt.ylim(-20, 24)

ax

***

In [None]:
# 32, FIA_solv

plt.rcParams['figure.dpi'] = 300

ax = sns.scatterplot(
    data=general_data,
    x="UMAP_1__32_solv",
    y="UMAP_2__32_solv",
    hue="fia_solv-DSDBLYP",
    palette="inferno",
    edgecolor="black",
)

ax.set_xlabel("UMAP 1", size=14)
ax.set_ylabel("UMAP 2", size=14)

sns.move_legend(ax, "lower center", ncol=5, title="FIA$_{solv}$ / kJ mol$^{-1}$")

ax.set(xticklabels=[])
ax.set(yticklabels=[])
ax.tick_params(left=False, bottom=False)

ax.collections[0].set_sizes([25])

plt.ylim(-20, 25)

ax

***

In [None]:
# 128, FIA_gas

plt.rcParams['figure.dpi'] = 300

ax = sns.scatterplot(
    data=general_data,
    x="UMAP_1__128_gas",
    y="UMAP_2__128_gas",
    hue="ca_class",
    palette=plotly.colors.qualitative.Light24_r,
    hue_order=formatted_ca.values(),
    edgecolor="black",
    style="denticity_class"
)

ax.set_xlabel("UMAP 1", size=14)
ax.set_ylabel("UMAP 2", size=14)

sns.move_legend(ax, "lower center", ncol=7, title="", fontsize="small", handletextpad=0, columnspacing=1)
# ax.legend_.remove()

ax.set(xticklabels=[])
ax.set(yticklabels=[])
ax.tick_params(left=False, bottom=False)

ax.collections[0].set_sizes([60])

plt.ylim(-45, 27)

ax

***

In [None]:
# 128, FIA_solv

plt.rcParams['figure.dpi'] = 300

ax = sns.scatterplot(
    data=general_data,
    x="UMAP_1__128_solv",
    y="UMAP_2__128_solv",
    hue="ca_class",
    palette=plotly.colors.qualitative.Light24_r,
    hue_order=formatted_ca.values(),
    edgecolor="black",
    style="denticity_class"
)

ax.set_xlabel("UMAP 1", size=14)
ax.set_ylabel("UMAP 2", size=14)

sns.move_legend(ax, "lower center", ncol=7, title="", fontsize="small", handletextpad=0, columnspacing=1)
# ax.legend_.remove()

ax.set(xticklabels=[])
ax.set(yticklabels=[])
ax.tick_params(left=False, bottom=False)

ax.collections[0].set_sizes([60])

plt.ylim(-50, 27)

ax