In [1]:
import re
import seaborn as sns
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.metrics import mean_absolute_error
from scipy.stats import pearsonr

from training.sigopt_utils import build_sigopt_name

plt.rcParams["figure.figsize"] = (13, 8)
plt.rcParams['axes.linewidth'] = 1.0
plt.rcParams["xtick.major.size"] = 4
plt.rcParams["ytick.major.size"] = 4
plt.rcParams["ytick.major.width"] = 2
plt.rcParams["xtick.major.width"] = 2
plt.rcParams['text.usetex'] = False
plt.rc('lines', linewidth=3, color='g')
plt.rcParams.update({'font.size': 16})
plt.rcParams['font.family'] = "sans-serif"
plt.rcParams['font.sans-serif'] = "Arial"
plt.rcParams['mathtext.fontset'] = 'dejavusans'

In [2]:
target_prop = "dft_e_hull"
test_set_types = ["test_set", "holdout_set_B_sites", "holdout_set_series"]

experimental_settings = [
    {"model_type": "CGCNN", "interpolation": False, "relaxed": False, "exp_id": 596732},
    {"model_type": "CGCNN", "interpolation": False, "relaxed": True, "exp_id": 596731},
    {"model_type": "CGCNN", "interpolation": True, "relaxed": False, "exp_id": 596670},
    {"model_type": "CGCNN", "interpolation": True, "relaxed": True, "exp_id": 596671},
    {"model_type": "Painn", "interpolation": True, "relaxed": False, "exp_id": 596674},
    {"model_type": "Painn", "interpolation": True, "relaxed": True, "exp_id": 596675},
    {"model_type": "e3nn", "interpolation": True, "relaxed": False, "exp_id": 596676},
    {"model_type": "e3nn", "interpolation": True, "relaxed": True, "exp_id": 596677},
    {"model_type": "e3nn_contrastive", "interpolation": True, "relaxed": False, "exp_id": 596672},
    {"model_type": "e3nn_contrastive", "interpolation": True, "relaxed": False, "exp_id": 597227},
    {"model_type": "e3nn_contrastive", "interpolation": True, "relaxed": True, "exp_id": 596673},
    {"model_type": "e3nn_contrastive", "interpolation": True, "relaxed": True, "exp_id": 597288},
]

series = [
    [["La", "Pr"], ["Y", "Ni"]],
    [["K", "Ba"], ["Ti", "Al"]],
    [["Y", "La"], ["In", "Mg"]],
    [["Mg", "Pr"], ["Ni", "V"]],
]

In [3]:
def plot_hex(true_values, pred_values, test_set_type, sigopt_name, exp_id, pure_interp=False):
    if test_set_type == "test_set":
        hex_xylim = [-0.1, 1.1]
    if test_set_type == "holdout_set_B_sites":
        hex_xylim = [-0.1, 0.6]
    else:
        hex_xylim = [0, 0.7]

    hex_figsize = (4, 3.2)
    
    mae = mean_absolute_error(true_values, pred_values)
    r, _ = pearsonr(true_values, pred_values)
    fig, ax = plt.subplots(figsize=hex_figsize)
    ax.set_xlabel("DFT $E_{\mathrm{hull}}$ (eV/atom)")
    ax.set_ylabel("ML $E_{\mathrm{hull}}$ (eV/atom)")
    ax.axline((hex_xylim[0], hex_xylim[0]), (hex_xylim[1], hex_xylim[1]), color='black', linestyle='--', linewidth=1)

    hb = ax.hexbin(
        true_values, pred_values,
        cmap='viridis', gridsize=50, bins=None, mincnt=1, edgecolors='none',
        extent=[hex_xylim[0], hex_xylim[1], hex_xylim[0], hex_xylim[1]]
        )

    ax.annotate("r = %.3f\nMAE = %.3f" % (r, mae), xy=(0.05, 0.95), xycoords='axes fraction', ha='left', va='top')
    cb = fig.colorbar(hb)
    cb.set_label('Count')
    plt.tight_layout()
    
    if pure_interp:
        plt.savefig("figures/" + test_set_type + "_pure_interpolation_hexbin.pdf")
        print("Completed pure interpolation")
    else:
        plt.savefig("figures/" + test_set_type + "_" + sigopt_name + "_" + exp_id +  "_hexbin.pdf")
        print("Completed " + sigopt_name + " " + str(exp_id))

    plt.close()

def plot_hex_all(target_prop, test_set_types, experimental_settings, num_best_models=3):
    for experimental_setting in experimental_settings:
        sigopt_name = build_sigopt_name(target_prop, experimental_setting["relaxed"], experimental_setting["interpolation"], experimental_setting["model_type"])
        directory = "./best_models/" + experimental_setting["model_type"] + "/" + sigopt_name + "/" +str(experimental_setting["exp_id"])
        test_set_dfs = {}

        for test_set_type in test_set_types:
            test_set_dfs[test_set_type] = []
            for i in range(num_best_models):
                with open(directory + "/" + "best_" + str(i) + "/" + test_set_type + "_predictions.json") as f:
                    test_set_dfs[test_set_type].append(pd.read_json(f))

            true_values = test_set_dfs[test_set_type][0][target_prop].to_numpy()

            temp_tuple = ()
            for i in range(num_best_models):
                temp_tuple += (test_set_dfs[test_set_type][i]['predicted_' + target_prop].to_numpy(),)
            pred_values_mean = np.mean(np.vstack(temp_tuple), axis=0)

            plot_hex(true_values, pred_values_mean, test_set_type, sigopt_name, str(experimental_setting["exp_id"]))

            if experimental_setting == experimental_settings[-1]:
                plot_hex(
                    true_values, test_set_dfs[test_set_type][0][target_prop + '_interp'].to_numpy(),
                    test_set_type, None, None, pure_interp=True
                    )
                
def plot_violin(target_prop, experimental_settings, series, num_best_models=3):
    test_set_type = "holdout_set_series"
    ylim = [-0.2, 1]
    sns.set(rc={'figure.figsize':(4, 2)})

    for experimental_setting in experimental_settings:
        sigopt_name = build_sigopt_name(target_prop, experimental_setting["relaxed"], experimental_setting["interpolation"], experimental_setting["model_type"])
        directory = "./best_models/" + experimental_setting["model_type"] + "/" + sigopt_name + "/" +str(experimental_setting["exp_id"])
        test_set_dfs = []

        for i in range(num_best_models):
            with open(directory + "/" + "best_" + str(i) + "/" + test_set_type + "_predictions.json") as f:
                test_set_dfs.append(pd.read_json(f))
    
        for j in range(len(series)):
            to_plots = []
            column_conc = series[j][0][0] + " on A site"
            column_entry = "$E_{\mathrm{hull}}$ (eV/atom)"
            comp_string = series[j][0][0] + "$_x$" + series[j][0][1] + "$_{1-x}$" + series[j][1][0] + "$_{0.5}$" + series[j][1][1] + "$_{0.5}$O$_3$"

            for i in range(num_best_models):
                temp_df = test_set_dfs[i][
                        test_set_dfs[i].formula.str.contains(series[j][0][0]) &
                        test_set_dfs[i].formula.str.contains(series[j][0][1]) &
                        test_set_dfs[i].formula.str.contains(series[j][1][0]) &
                        test_set_dfs[i].formula.str.contains(series[j][1][1])
                    ]
                to_plot = pd.DataFrame(columns=[column_conc, column_entry])
                k = 0
                
                for framework, subdf in temp_df.groupby('framework'):
                    conc = float(re.findall(r'%s(0\.\d+)' % series[j][0][0], framework)[0])
                    for entry in subdf['predicted_' + target_prop]:
                        to_plot.loc[k] = [conc, entry]
                        k += 1            
            
                to_plots.append(to_plot)
            
            to_plot_final = pd.DataFrame()
            to_plot_final[column_conc] = to_plots[0][column_conc]

            temp_tuple = ()
            for i in range(num_best_models):
                temp_tuple += (to_plots[i][[column_entry]],)
            to_plot_final[column_entry] = pd.concat(temp_tuple, axis=1).mean(axis=1)        

            ax = sns.violinplot(x=column_conc, y=column_entry, data=to_plot_final, inner="points")
            ax.set_ylim(ylim)
            ax.set_title(comp_string)
            plt.tight_layout()
            plt.savefig("figures/" + test_set_type + "_" + sigopt_name + "_" + str(experimental_setting["exp_id"]) +  "_series_" + comp_string + ".pdf")
            print("Completed " + sigopt_name + " " + str(experimental_setting["exp_id"]))
            plt.close()

            if experimental_setting == experimental_settings[-1]:
                true_temp_df = test_set_dfs[0][
                        test_set_dfs[0].formula.str.contains(series[j][0][0]) &
                        test_set_dfs[0].formula.str.contains(series[j][0][1]) &
                        test_set_dfs[0].formula.str.contains(series[j][1][0]) &
                        test_set_dfs[0].formula.str.contains(series[j][1][1])
                    ]
                true_to_plot = pd.DataFrame(columns=[column_conc, column_entry])
                k = 0
                
                for framework, subdf in true_temp_df.groupby('framework'):
                    conc = float(re.findall(r'%s(0\.\d+)' % series[j][0][0], framework)[0])
                    for entry in subdf[target_prop]:
                        true_to_plot.loc[k] = [conc, entry]
                        k += 1                   
                
                ax = sns.violinplot(x=column_conc, y=column_entry, data=true_to_plot, inner="points")
                ax.set_ylim(ylim)
                ax.set_title(comp_string)
                plt.tight_layout()
                plt.savefig("figures/" + test_set_type + "_true_values_series_" + comp_string + ".pdf")
                print("Completed true values")
                plt.close()

In [4]:
# plot_hex_all(target_prop, test_set_types, experimental_settings)

In [5]:
# to_plots = plot_violin(target_prop, experimental_settings, series)