In [1]:
import numpy as np
import pandas as pd
import json
import pickle
from sklearn.metrics import r2_score, mean_absolute_error
from scipy.stats import pearsonr
import copy

##### load for plotting (with plotly)

In [2]:
from _plotly_future_ import v4_subplots
import plotly.graph_objs as go
import plotly
import plotly.express as px
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
import plotly.figure_factory as ff
plotly.io.orca.config.executable = '/Users/chenruduan/opt/anaconda3/envs/mols_newplotly/bin/orca'
init_notebook_mode(connected=True)
glob_layout = go.Layout(
    font=dict(family='Helvetica', size=24, color='black'),
    margin=dict(l=100, r=10, t=10, b=100),
    xaxis=dict(showgrid=False,  zeroline=False, ticks="inside", showline=True,
               tickwidth=1.5, linewidth=1.5, ticklen=10, linecolor='black',
               mirror="allticks", color="black"),
    yaxis=dict(showgrid=False,  zeroline=False, ticks="inside", showline=True,
               tickwidth=1.5, linewidth=1.5, ticklen=10, linecolor='black',
               mirror="allticks", color="black"),
    legend_orientation="v",
    paper_bgcolor='rgba(255,255,255,100)',
    plot_bgcolor='white',
)
blue = "rgba(0, 0, 255, 1)"
red = "rgba(255, 0, 0, 1)"
green = "rgba(0, 196, 64, 1)"
gray = "rgba(140, 140, 140, 1)"

##### load for ML parts 

In [3]:
from dfa_recommender.net import GatedNetwork, MySoftplus, TiledMultiLayerNN, MLP, finalMLP, ElementalGate
from dfa_recommender.dataset import SubsetDataset
from dfa_recommender.sampler import InfiniteSampler
from dfa_recommender.ml_utils import numpy_to_dataset
import torch
from torch.utils.data import DataLoader

##### DFAs that we considered

In [26]:
base_keys = ["name", "path", "metal",]
functionals = [
    "bp86", "blyp", "pbe",
    "tpss", "scan", "m06-l", "mn15-l",
    "b3p86", "b3pw91", "b3lyp",
    "tpssh", "scan0", "m06", "m06-2x",
    "wb97x", "LRC-wPBEh",
    "b2gpplyp", "pbe0-dh", "dsd-blyp-d3bj", "dsd-pbeb95-d3bj", "dsd-pbep86-d3bj",
]
functionals += ["blyp_hfx_10", "blyp_hfx_20", "blyp_hfx_30", "blyp_hfx_40", "blyp_hfx_50",
                "pbe_hfx_10", "pbe_hfx_20", "pbe_hfx_40", "pbe_hfx_50",   
                "scan_hfx_10", "scan_hfx_20", "scan_hfx_30", "scan_hfx_40", "scan_hfx_50", 
                "m06-l_hfx_10", "m06-l_hfx_30", "m06-l_hfx_40", "m06-l_hfx_50", 
                "mn15-l_hfx_10", "mn15-l_hfx_20", "mn15-l_hfx_30", "mn15-l_hfx_50"]
len(functionals)
all_functionals = copy.deepcopy(functionals)

### Basic ground truth stats

In [6]:
from pkg_resources import resource_filename, Requirement
basepath = resource_filename(Requirement.parse("dfa_recommender"), "/dfa_recommender/data/")

In [7]:
df = pd.read_csv(basepath + "/fe2_h2o_cs_res.csv") ## csv file that stores the compuated vert SSE values at different methods (CSD data)

nCS = []
for name in df["name"].values:
    nCS.append(name.count("CS"))
df["nCS"] = nCS
nCS

[6, 5, 4, 4, 3, 3, 2, 2, 1, 0]

##### sanity check on ligand fields

In [27]:
fs = ["dlpno-CCSD_T", "blyp", "b3lyp", "blyp_hfx_50"]
data = []
for f in fs: 
    trace0 = go.Scatter(
        x=df["nCS"].values,
        y=df[f + ".vertsse"].values,
        mode='markers',
        opacity=1,
        marker=dict(
            symbol='circle',
            size=8,
        ),
        text=df["name"].values,
        showlegend=True,
        name = f,
    )
    data += [trace0]
layout = go.Layout()
layout.update(glob_layout)
layout.legend.update(x=1, y=1, bgcolor="rgba(0,0,0,0)")
layout['xaxis'].update({"title": "number of CS"})
layout['yaxis'].update({"title": "vertSSE (kcal/mol)",})
layout.update(height=500, width=700, showlegend=True)
fig = go.Figure(data=data, layout=layout)
iplot(fig)

##### MAEs of different DFAs for LS $Fe_2(CS)_n(H_2O)_{6-n}$

In [28]:
x, y = [], []
mae_dict = {}
for f in functionals:
    mae_dict[f] = np.nanmean(np.abs(df["delta.%s.vertsse"%f].values*1))
mae_dict = dict(sorted(mae_dict.items(), key=lambda item: item[1]))
data = [go.Bar(x=list(mae_dict.keys())[:20],
               y=list(mae_dict.values())[:20], name='all'),]
layout = go.Layout()
layout.update(glob_layout)
layout["xaxis"].update({'title': "DFA"})
layout["yaxis"].update({'title': "MAE (kcal/mol)", })
layout.update(width=1000, height=500, boxmode='group')
fig = dict(data=data, layout=layout)
iplot(fig)

##### Top DFAs

In [29]:
fs = ["m06-l_hfx_30", "blyp_hfx_50", "scan_hfx_40", "pbe0-dh"]
data = []
for f in fs: 
    trace0 = go.Scatter(
        x=df["nCS"].values,
        y=df[f + ".vertsse"].values -  df["dlpno-CCSD_T" + ".vertsse"].values,
        mode='markers',
        opacity=0.7,
        marker=dict(
            symbol='circle',
            size=8,
        ),
        text=df["name"].values,
        showlegend=True,
        name = f,
    )
    data += [trace0]
trace0 = go.Scatter(
        x=df["nCS"].values,
        y=df["dlpno-CCSD_T.vertsse"].values -  df["dlpno-CCSD_T" + ".vertsse"].values,
        mode='markers+lines',
        opacity=1,
        marker=dict(
            symbol='circle',
            size=8,
            color="black"
        ),
        text=df["name"].values,
        showlegend=False,
        name = f,
    )
data += [trace0]
layout = go.Layout()
layout.update(glob_layout)
layout.legend.update(x=1, y=1, bgcolor="rgba(0,0,0,0)")
layout['xaxis'].update({"title": "number of CS"})
layout['yaxis'].update({"title": "vertSSE (kcal/mol)",})
layout.update(height=500, width=700, showlegend=True)
fig = go.Figure(data=data, layout=layout)
iplot(fig)

### Predict the veritcal spin-splitting energy on the controled fe-h2o-cs set

In [11]:
torch.set_num_threads(4)
torch.manual_seed(0)
np.random.seed(0)
device = torch.device('cpu')
num_workers = 0

##### set path for relavant data files: previous model directly tested on the out-of-distribution CSD complexes that have more diverse ligands and connectivities

In [13]:
X = pickle.load(open(basepath +  "/X_fe2_h2o_cs.pkl", "rb")) ## features 
df_org = pd.read_csv(basepath + "/labeled_res.csv") ## csv file that stores the compuated vert SSE values at different methods (self-assembled complexes, used in training)
y_scalers = pickle.load(open(basepath +  "/abs-reg-y_scalers.pkl", "rb")) ## sklearn.preprocessing.StandardScaler object created on the stats of training data


Trying to unpickle estimator StandardScaler from version 0.24.2 when using version 1.0.2. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to:
https://scikit-learn.org/stable/modules/model_persistence.html#security-maintainability-limitations



##### predict |DFA - DLPNO-CCSD(T)| vertical spin splitting

In [14]:
res_all = {}
for f in all_functionals:
    y_scaler = y_scalers[f]
    y_t = np.abs(df["delta.%s.vertsse"%f].values)
    y_t = y_scaler.transform(y_t.reshape(-1, 1)).reshape(-1, )
    y_t = np.stack((y_t, y_t), axis=-1)
    data_te = numpy_to_dataset(X, y_t, regression=True) ## org

    te_l = SubsetDataset(data_te, list(range(len(data_te))))
    te_loader = DataLoader(te_l, len(te_l), num_workers=0)

    best_model = pickle.load(open(basepath + "/models-trends/mergedG10-abs-reg-%s.pkl"%f, "rb"))
    best_model.eval()
    preds_r, labels_r = [], []
    with torch.no_grad():
        for x, y in te_loader:
            y_r, y_c = y[:, 0], y[:, 1]
            _pred_r = best_model(x.to(device))
            preds_r.append(_pred_r.cpu().numpy())
            labels_r.append(y_r.cpu().numpy())
    y_t = y_scaler.inverse_transform(labels_r[0].reshape(-1, 1)).reshape(-1, )
    y_hat = y_scaler.inverse_transform(preds_r[0].reshape(-1, 1)).reshape(-1, )
    mae = mean_absolute_error(y_hat, y_t)
    scaled_mae = mae/(np.max(y_t) - np.min(y_t))
    R2 = r2_score(y_t, y_hat)
    rval = pearsonr(y_t, y_hat)[0]
    # print(f, "mae: ", round(mae, 5), "scaled mae: ", round(scaled_mae, 5), "R2: ", round(R2, 4), "r val: ", round(rval, 4))
    res_all[f + ".y_t"] = np.abs(y_t)
    res_all[f + ".y_hat"] = np.abs(y_hat)
res_all["name"] = df["name"]

### Analyze

##### sort based on ML predicted |DFA - DLPNO-CCSD(T)| vertical spin splitting to select DFAs

In [15]:
df_res = pd.DataFrame.from_dict(res_all)

In [30]:
removed = []
functionals = list(set(all_functionals).difference(set(removed)))

thresh = 0.0
errs_t, errs_hat, best_fs, lmstds, lower_bound, true_best_fs, ranks = [], [], [], [], [], [], []
err = {}
for f in functionals:
    err[f] = []
for _, row in df_res.iterrows():
    res_true = [row[f + ".y_t"] if not np.isnan(row[f + ".y_t"]) else 1000 for f in functionals]
    fs_true = [x for _, x in sorted(zip(res_true, functionals))]
    res_true = sorted(res_true)
    res_hat = [row[f + ".y_hat"] if ((not np.isnan(row[f + ".y_t"]))) else 1000 for f in functionals]
    fs = [x for _, x in sorted(zip(res_hat, functionals))]
    res_hat = sorted(res_hat)
    ranks += [res_true.index(row[fs[0] + ".y_t"])]
    
    errs_t += [row[fs[0] + ".y_t"]]
    errs_hat += [row[fs[0] + ".y_hat"]]
    best_fs += [fs[0]]
    lower_bound += [row[fs_true[0] + ".y_t"]]
    true_best_fs += [fs_true[0]]
    
df_sel = pd.DataFrame.from_dict({"errs_t": errs_t, "errs_hat": errs_hat,
                                 "best_fs": best_fs, "lower_bound": lower_bound, "true_best_fs": true_best_fs,
                                 "ranks": ranks, "name": df_res["name"].values,
                                })

In [31]:
df_sel

Unnamed: 0,errs_t,errs_hat,best_fs,lower_bound,true_best_fs,ranks,name
0,0.173648,0.544491,pbe0-dh,0.173648,pbe0-dh,0,fe_oct_2_CS_1_CS_1_CS_1_CS_1_CS_1_CS_1_s_1_conf_1
1,2.693905,1.127551,scan_hfx_40,0.649002,pbe0-dh,5,fe_oct_2_CS_1_CS_1_CS_1_CS_1_CS_1_water_1_s_1_...
2,1.353778,1.357697,pbe0-dh,0.374734,blyp_hfx_50,2,fe_oct_2_CS_1_CS_1_CS_1_CS_1_water_1_water_1_s...
3,1.01938,1.266669,pbe0-dh,0.890124,blyp_hfx_50,1,fe_oct_2_CS_1_CS_1_CS_1_water_1_CS_1_water_1_s...
4,1.54918,1.15949,m06-l_hfx_30,0.251791,blyp_hfx_50,4,fe_oct_2_water_1_water_1_CS_1_CS_1_water_1_CS_...
5,0.167158,1.302298,blyp_hfx_50,0.167158,blyp_hfx_50,0,fe_oct_2_water_1_water_1_water_1_CS_1_CS_1_CS_...
6,0.982086,1.247723,scan_hfx_40,0.695455,mn15-l_hfx_50,2,fe_oct_2_water_1_water_1_water_1_CS_1_water_1_...
7,0.65416,1.131048,scan_hfx_40,0.65416,scan_hfx_40,0,fe_oct_2_water_1_water_1_water_1_water_1_CS_1_...
8,0.505398,1.034265,scan_hfx_40,0.359566,m06,1,fe_oct_2_water_1_water_1_water_1_water_1_water...
9,0.303311,1.291487,scan_hfx_40,0.303311,scan_hfx_40,0,fe_oct_2_water_1_water_1_water_1_water_1_water...


##### absolute error distribution

In [32]:
hist_data = [df_sel['errs_t'].values]
group_labels = [""]
colors = ['black', blue, green, red]
fig = ff.create_distplot(hist_data, group_labels, show_hist=True, colors=colors, bin_size=0.5)
layout = go.Layout()
layout.legend.update(x=.5, y=1, bgcolor="rgba(0,0,0,0)")
layout.update(showlegend=False)
layout.update(width=550, height=500)
layout.update(glob_layout)
layout["xaxis"].update({'title': "abs. err. (kcal/mol)"})
layout["yaxis"].update({'title': "frequency"})
fig.layout.update(layout)
fig.show()

np.mean(df_sel['errs_t'])

0.9402004420757294

##### DFA ranks

In [33]:
y = []
for ii in range(48):
    y += [len(df_sel[df_sel["ranks"] == ii])*100./len(df_sel)]
data = [go.Bar(x=list(range(48)),
               y=y, name='all', marker_color='rgba(0, 0, 0, 0.5)', showlegend=False),]
xs=list(range(48))
ys=[np.sum(y[:ii])*100./np.sum(y) for ii in xs]
for ii in range(47):
    data += [go.Scatter(x=[xs[ii], xs[ii]], y=[ys[ii], ys[ii+1]], mode='lines', yaxis="y2", line=dict(color='blue', width=2, dash='solid'), showlegend=False)]
    data += [go.Scatter(x=[xs[ii], xs[ii+1]], y=[ys[ii+1], ys[ii+1]], mode='lines', yaxis="y2", line=dict(color='blue', width=2, dash='solid'), showlegend=False)]
layout = go.Layout()
layout.update(glob_layout)
layout.legend.update(x=1, y=1, bgcolor="rgba(0,0,0,0)")
layout["xaxis"].update({'title': "DFA rank"})
layout["yaxis"].update({'title': "percentage", "mirror": False})
layout.update({"yaxis2": dict(
    title="cumulative percentage",
    titlefont=dict(color="black"),
    tickfont=dict(color="black"),
    anchor="free",
    overlaying="y",
    side="right",
    position=1,
    range=[0, 100],
    showgrid=True,  
    zeroline=True, 
    ticks="inside", 
    showline=True,
    tickwidth=3, 
    linewidth=3, 
    ticklen=10,
#     mirror=True,
)},)
layout.update(width=600, height=500, boxmode='group')
fig = dict(data=data, layout=layout)
iplot(fig)