In [1]:
from copy import deepcopy
from math import isnan, nan
from matplotlib import pyplot as plt
from pathlib import Path
from plotly.subplots import make_subplots
from pymatgen.core.structure import Structure
from pymatgen.ext.matproj import MPRester
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
from scipy.stats import gaussian_kde, spearmanr
from sklearn.metrics import mean_absolute_error, root_mean_squared_error, mean_absolute_percentage_error

import copy
import json
import numpy as np
import os
import pandas as pd
import plotly.graph_objs as go

In [2]:
# Import sys and set the path of the customized modules
import sys
sys.path.append('/home/vtrinquet/Documents/Doctorat/JNB_Scripts_Clusters/Custom_Modules/')
sys.path.append('/home/vtrinquet/Documents/Doctorat/JNB_Scripts_Clusters/NLO/HT/shg/acquisition')

# Import importlib.reload to be able to reload dynamically modified modules
from importlib import reload

# Import custom modules (reloadable when modified)
import SHG_Tensor_Func as shg
shg = reload(shg)

In [3]:
cur_v = int(Path(os.getcwd()).parent.name.split("_")[0])
print(cur_v)

19


# Process successfully computed

In [4]:
#   Paths to the candidates datasets
candidate_df_path    = Path(f"../../{cur_v-1}_al_it/data/df_selected_v{cur_v-1}.pkl.gz")
    
#   Loading the candidates datasets
df_prev_selected = pd.read_pickle(candidate_df_path)
print(df_prev_selected.shape)
display(df_prev_selected.head())

(519, 11)


Unnamed: 0,dKP,dKP_unc,dKP_unc_cal_ma,dKP_unc_cal_nll,structure,nsites,ehull,bandgap,spg,reduced_formula,added_from
agm2000129361,0.031509,0.220027,0.318939,4.337121,"{'@module': 'pymatgen.core.structure', '@class...",6,0.008973,7.3598,189,PF5,Pareto
agm003738833,0.096076,0.130822,0.189632,2.578731,"{'@module': 'pymatgen.core.structure', '@class...",12,0.027481,7.3123,30,LiAlF4,Pareto
agm002138386,0.112327,0.171703,0.248891,3.384572,"{'@module': 'pymatgen.core.structure', '@class...",8,0.0,7.2732,5,KBe2F5,Pareto
agm005052162,0.142544,0.138844,0.201261,2.736868,"{'@module': 'pymatgen.core.structure', '@class...",7,0.014591,6.9649,82,NaLiBeF4,Pareto
agm1000019892,0.202417,0.222,0.321798,4.376006,"{'@module': 'pymatgen.core.structure', '@class...",18,0.013705,6.8742,1,CO2,Pareto


In [5]:
path_df_outputs = Path(f'../data/df_outputs.pkl.gz')
str_origin = f"hg v{cur_v-1}"
if path_df_outputs.exists():
    df_outputs = pd.read_pickle(path_df_outputs)
    df_tmp = df_outputs[df_outputs['origin']!=str_origin]
    print(f"{df_tmp.shape = }")
    df_tmp = df_outputs[df_outputs['origin']==str_origin]
    print(f"{df_tmp.shape = }")
    print(f"{df_outputs.shape = }")
else:
    # Load previous json outputs
    path_json_output = (Path('..') / ".." / f"{cur_v-1}_al_it" / "data" / f'outputs_v{cur_v-1}.json')

    # decompress json file
    os.system(f"gunzip {str(path_json_output)}.gz")
    
    # read file
    with open(path_json_output, 'r') as myfile:
        data=myfile.read()
    
    # parse file
    list_dict_outputs = json.loads(data)

    # recompress json file
    os.system(f"gzip {str(path_json_output)}")

    # Get basic outputs
    list_matid = []
    list_structures = []
    list_dijk = []
    list_epsij = []
    for output in list_dict_outputs:
        list_matid.append(output['metadata']['material_id'])
        list_structures.append(Structure.from_dict(output['output']["output"]['structure']).as_dict())
        list_dijk.append(output['output']['output']['dijk'])
        list_epsij.append(output['output']['output']['epsinf'])
    
    df_prev_outputs = pd.DataFrame()
    df_prev_outputs.index = list_matid
    df_prev_outputs['structure'] = list_structures
    df_prev_outputs['dijk'] = list_dijk
    df_prev_outputs['epsij'] = list_epsij
    
    # print(df_prev_outputs.shape)
    # display(df_prev_outputs.head())

    # Every structure-related fields
    
    list_formulae = []
    list_crystal_system = []
    list_elements = []
    list_nelements = []
    list_nsites = []
    list_pg_symbol = []
    list_spg_number = []
    list_spg_symbol = []
    for ir, r in df_prev_outputs.iterrows():
        struc = Structure.from_dict(r['structure'])
    
        list_formulae.append(struc.composition.reduced_formula)
        list_crystal_system.append(SpacegroupAnalyzer(structure=struc, symprec=1e-3).get_lattice_type())
        list_elements.append([str(elem) for elem in struc.composition.elements])
        list_nelements.append(len(struc.composition.elements))
        list_nsites.append(len(struc))
        list_pg_symbol.append(SpacegroupAnalyzer(structure=struc, symprec=1e-3).get_point_group_symbol())
        list_spg_number.append(struc.get_space_group_info()[1])
        list_spg_symbol.append(struc.get_space_group_info()[0])
    
    df_prev_outputs['formula_reduced'] = list_formulae
    df_prev_outputs['crystal_system'] = list_crystal_system
    df_prev_outputs['elements'] = list_elements
    df_prev_outputs['nelements'] = list_nelements
    df_prev_outputs['nsites'] = list_nsites
    df_prev_outputs['pg_symbol'] = list_pg_symbol
    df_prev_outputs['spg_number'] = list_spg_number
    df_prev_outputs['spg_symbol'] = list_spg_symbol
    
    # Every dijk-, epsij-related fields
    
    list_dRMS = []
    list_dKP = []
    list_inv2 = []
    list_inv3 = []
    list_refractive_index = []
    for ir, r in df_prev_outputs.iterrows():
        list_dRMS.append(shg.get_dRMS(r['dijk']))
        list_dKP.append(shg.get_dKP_weird(r['dijk']))
        _, inv2, inv3 = shg.get_invariants(np.array(r['dijk']))
        list_inv2.append(inv2**(1/6))
        list_inv3.append(inv3**(1/6))
    
        eps_ij = np.array(r['epsij'])
        try:
            list_refractive_index.append(np.trace(eps_ij**0.5)/3)
        except TypeError:
            list_refractive_index.append(np.nan)
    
    df_prev_outputs['dRMS'] = list_dRMS
    df_prev_outputs['dKP'] = list_dKP
    df_prev_outputs['dinv2'] = list_inv2
    df_prev_outputs['dinv3'] = list_inv3
    df_prev_outputs['n'] = list_refractive_index

    # Every src-related fields
    
    with MPRester() as mpr:
        mp_bandgap = []
        mp_ehull = []
        mp_is_gap_direct = []
        mp_is_magnetic = []
        mp_n = []
        mp_database_IDs = []
        mp_theoretical = []
        mp_epsij = []
        list_src = []
        list_origin = []
        for ir, r in df_prev_outputs.iterrows():
            if 'mp' in ir:
                docs = mpr.materials.summary.search(material_ids=[ir], 
                                                    fields=[
                                                            'band_gap',
                                                            # 'energy_above_hull',
                                                            'is_gap_direct',
                                                            'is_magnetic',
                                                            'n',
                                                            'database_IDs',
                                                            'theoretical',
                                                            ])[0]
                docs_dielec = mpr.materials.dielectric.search(material_ids=[ir],
                                                              fields=['electronic',])
                try:
                    gap_el_bs = mpr.get_bandstructure_by_material_id(ir).get_band_gap()['energy']
                except Exception:
                    gap_el_bs = float('nan')
                try:
                    gap_dos = mpr.get_dos_by_material_id(ir).get_gap()
                except Exception:
                    gap_dos = float('nan')
            
                list_src.append('Materials Project')
                # mp_ehull.append(docs.energy_above_hull)
                mp_is_gap_direct.append(docs.is_gap_direct)
                mp_is_magnetic.append(docs.is_magnetic)
                mp_n.append(docs.n)
                mp_database_IDs.append(docs.database_IDs)
                mp_theoretical.append(docs.theoretical)
                try:
                    mp_epsij.append(np.array(docs_dielec[0].electronic))
                except IndexError:
                    mp_epsij.append(float('nan'))
                gaps_to_consider = []
                if not isnan(gap_dos) and gap_dos!=0:
                    gaps_to_consider.append(gap_dos)
                if not isnan(gap_el_bs) and gap_el_bs!=0:
                    gaps_to_consider.append(gap_el_bs)
                if docs.band_gap!=0:
                    gaps_to_consider.append(docs.band_gap)
                if len(gaps_to_consider)>0:
                    mp_bandgap.append(min(gaps_to_consider))
                else:
                    mp_bandgap.append(0)


            elif 'agm' in ir:
                list_src.append('Alexandria')
                # mp_ehull.append(df_agm.loc[ir]["ehull"])
                # mp_is_gap_direct.append(df_agm.loc[ir]["bandgap_direct"]==df_agm.loc[ir]['bandgap'])
                mp_is_gap_direct.append(nan)
                mp_is_magnetic.append(nan)
                mp_n.append(nan)
                mp_database_IDs.append(nan)
                mp_theoretical.append(nan)
                mp_epsij.append(nan)
                mp_bandgap.append(df_prev_selected.loc[ir]["bandgap"])


            mp_ehull.append(df_prev_selected.loc[ir]["ehull"])
    
    
    df_prev_outputs['src_bandgap'] = mp_bandgap
    df_prev_outputs['src_ehull'] = mp_ehull
    df_prev_outputs['src_is_gap_direct'] = mp_is_gap_direct
    df_prev_outputs['src_is_magnetic'] = mp_is_magnetic
    df_prev_outputs['src_n'] = mp_n
    df_prev_outputs['src_DB_IDs'] = mp_database_IDs
    df_prev_outputs['src_theoretical'] = mp_theoretical
    df_prev_outputs['src_epsij'] = mp_epsij
    
    df_prev_outputs['src'] = list_src
    df_prev_outputs['origin'] = [f'hg v{cur_v-1}']*len(df_prev_outputs)
    
    print(df_prev_outputs.shape)
    display(df_prev_outputs.head())

    # Let's load the outputs of the previous selected compounds and recover all the relevant data into one dataframe
    path_df_prev_outputs = (Path(f'../../{cur_v-1}_al_it') / "data" / f'df_outputs.pkl.gz')
    df_outputs = pd.read_pickle(path_df_prev_outputs)

    print(df_outputs.shape)
    display(df_outputs.head())

    # Combine
    print(f"{df_prev_outputs.shape = }")
    print(f"{df_outputs.shape = }")
    df_outputs = pd.concat([df_outputs, df_prev_outputs])
    print(f"{df_outputs.shape = }")


    df_outputs.to_pickle(path_df_outputs)

df_tmp.shape = (2456, 26)
df_tmp.shape = (320, 26)
df_outputs.shape = (2776, 26)


# Visualize successfully computed

In [6]:
# dKP-Eg-(ns)-KDE with log-lin inset


# ==============================================================================================================================
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.02, column_widths=[0.8,0.2])

# ROW 1 COL 1 =================================================================================================================
fig.add_trace(go.Scatter(x=df_outputs[df_outputs['origin']!=str_origin]['src_bandgap'],
                         y=df_outputs[df_outputs['origin']!=str_origin]['dKP'],
                         mode='markers',
                         marker = dict(color = df_outputs[df_outputs['origin']!=str_origin]['n'],
                                       showscale=True,
                                       colorbar = dict(title = '<i>n<sub>s</sub></i>', 
                                                       title_font_size = 36,
                                                       orientation = 'v',
                                                    #    tickvals=[1,2,3,4,5,6],
                                                       x = 0.9,
                                                       y=0.6, 
                                                       len = 0.9),
                                       colorscale = 'plasma', 
                                       ),
                         showlegend=True,
                         name=f'Old',
                         text=[mpid for mpid in df_outputs[df_outputs['origin']!=str_origin].index.values]
                         ))

fig.add_trace(go.Scatter(x=df_outputs[df_outputs['origin']==str_origin]['src_bandgap'],
                         y=df_outputs[df_outputs['origin']==str_origin]['dKP'],
                         mode='markers',
                         marker = dict(symbol='diamond',
                                       size=8,
                                       color = 'lightgreen',
                                       showscale=False,
                                    #    colorbar = dict(title = '<i>n<sub>s</sub></i>', 
                                    #                    title_font_size = 24,
                                    #                    title_side = 'top',
                                    #                    orientation = 'h',
                                    #                    tickvals=[1,2,3,4,5,6],
                                    #                    x = 0.75,
                                    #                    y=0.75, 
                                    #                    len = 0.4),
                                    #    colorscale = ['lightgray','orange','aqua','lightgreen','white','pink'],
                                       line=dict(width=2, color="black")
                                       ),
                         showlegend=True,
                         name=f'v{cur_v-1} outputs',
                         text=[mpid for mpid in df_outputs[df_outputs['origin']==str_origin].index.values]
                         ))


# AXES
fig.update_xaxes(title = '<i>E<sub>g</sub></i> (eV)',
                 title_font_size=36,
                 range = [-0.1, 8.4],
                 row=1, col=1)
fig.update_yaxes(title = '<i>d</i><sub>KP</sub> (pm/V)',
                 title_font_size=36)


fig.update_layout(font={'family':'Arial', 'size': 20},
                 )

# ROW 1 COL 2 =================================================================================================================
density = gaussian_kde(np.reshape(df_outputs['dKP'].values, (1,len(df_outputs))))
density.covariance_factor = lambda : .02 #Smoothing parameter
density._compute_covariance()

x_vals = np.linspace(min(df_outputs['dKP'].values),
                     max(df_outputs['dKP'].values),
                     200) # Specifying the limits of our data
kde_dist = density(x_vals)

fig.add_trace(go.Scatter(x=kde_dist, 
                         y=x_vals, 
                         mode='lines', 
                         marker_color='indianred',
                         fill='tozerox',
                         showlegend=False),
              row=1, col=2)

fig.update_xaxes(title = 'Distribution',
                 title_font_size=36,
                 row=1, col=2)
fig.update_yaxes(title = '', row=1, col=2)

# INSET =======================================================================================================================
inset = copy.deepcopy(fig.data[0])
inset.xaxis = 'x3'
inset.yaxis = 'y3'

inset_candidates_v1 = copy.deepcopy(fig.data[1])
inset_candidates_v1.xaxis = 'x3'
inset_candidates_v1.yaxis = 'y3'
inset


fig.update_layout(
    xaxis3=dict(domain      = [0.30, 0.75],
                anchor      = 'y3',
                range       = [-0.1, 8.4],
                linecolor   = 'black'
               ),
    yaxis3=dict(domain      = [0.60, 0.98],
                anchor      = 'x3',
                range       = [-3.5, 2.6],
                type        = 'log',
                tickvals    = [0.001, 0.01, 1, 100],
                linecolor   = 'black'
    ))

fig.add_trace(inset)
fig.add_trace(inset_candidates_v1)

fig.update_xaxes(showspikes=True, spikecolor="gray", spikethickness=2, spikesnap="cursor", spikemode="across")
fig.update_yaxes(showspikes=True, spikecolor="gray", spikethickness=2, spikesnap="cursor", spikemode='across')
fig.update_layout(hoverdistance=5)

# THEME OF GRAPH
fig.update_layout(template='simple_white')

fig.update_layout(
    width=1000,
    height=500
)

fig.show()

name_fig = "dKP-Eg_KDE_cur_res"
if not Path(f'../figures/pdf/{name_fig}.pdf').exists():
    fig.write_image(f'../figures/pdf/{name_fig}.pdf')
    fig.write_image(f'../figures/svg/{name_fig}.svg')
    fig.write_image(f'../figures/png/{name_fig}.png', scale=10)

## Remove outliers and visualize

Let's remove the instances with NaN and potential outliers

In [7]:
lim_up   = 170
lim_down = 0.00001
lim_gap_down = 0.00001 # eV
# lim_ehull_up = 0.025 # eV/atom
lim_n_up = 20

# Outliers
outliers = df_outputs[(df_outputs['dKP']>lim_up) |                 \
                      (df_outputs['dKP']<=lim_down) |              \
                      (df_outputs['src_bandgap']<lim_gap_down) |   \
                      (df_outputs['n']>lim_n_up) #|  \
                    #   (df_outputs['src_ehull']>lim_ehull_up)
                      ]
print(f'{outliers.shape = }')
display(outliers.head())

path_df_outputs_filtout = Path(f'../data/df_outputs_filtout.pkl.gz')
if path_df_outputs_filtout.exists():
    df_outputs_filtout = pd.read_pickle(path_df_outputs_filtout)
else:
    df_outputs_filtout = df_outputs.drop(outliers.index.values)
    df_outputs_filtout.to_pickle(path_df_outputs_filtout)

print(f'{df_outputs.shape = }')
print(f'{df_outputs_filtout.shape = }')

outliers.shape = (258, 26)


Unnamed: 0,formula_reduced,crystal_system,dRMS,dijk,elements,epsij,src_bandgap,src_DB_IDs,src_ehull,src_epsij,...,nsites,pg_symbol,spg_number,spg_symbol,structure,dKP,src,origin,dinv2,dinv3
mp-7457,SnF2,tetragonal,1.539305e-11,"[[[0.0, 0.0, 0.0], [0.0, 0.0, 3.26535874190781...","[F, Sn]","[[3.4703555436558315, 0.0, 0.0], [0.0, 3.47035...",3.2743,{'icsd': ['icsd-14195']},0.005852,"[[3.311556995, 5.421010862427522e-20, -2.71050...",...,12,422,92,P4_12_12,"{'@module': 'pymatgen.core.structure', '@class...",2.759732e-11,Materials Project,Naccarato,1.545791e-07,5.545824e-11
mp-3427,LiAlO2,tetragonal,3.290391e-13,"[[[0.0, 0.0, 0.0], [0.0, 0.0, 6.97997335399112...","[Al, Li, O]","[[2.774615752063739, 0.0, 0.0], [0.0, 2.774615...",4.5928,"{'icsd': ['icsd-430358', 'icsd-430359', 'icsd-...",0.01406,"[[2.75284424, 0.0, 0.0], [0.0, 2.75284424, 0.0...",...,16,422,92,P4_12_12,"{'@module': 'pymatgen.core.structure', '@class...",5.899154e-13,Materials Project,Naccarato,1.190601e-08,1.185466e-12
mp-757031,Mg2TiO4,tetragonal,4.228159e-11,"[[[0.0, 0.0, 0.0], [0.0, 0.0, -8.9692788852692...","[Mg, O, Ti]","[[4.100792324670548, 0.0, 0.0], [0.0, 4.100792...",3.3894,{},0.0,,...,28,422,91,P4_122,"{'@module': 'pymatgen.core.structure', '@class...",7.580424e-11,Materials Project,Naccarato,3.031809e-07,1.523326e-10
mp-3079,CaTa4O11,hexagonal,4.559622e-10,"[[[0.0, 0.0, -8.651273937676019e-10], [0.0, 0....","[Ca, O, Ta]","[[5.101235262917198, -1.2008901979278439e-15, ...",3.5694,"{'icsd': ['icsd-18306', 'icsd-108808', 'icsd-1...",0.0,,...,32,622,182,P6_322,"{'@module': 'pymatgen.core.structure', '@class...",8.453326e-10,Materials Project,Naccarato,1.570148e-06,1.88463e-09
mp-541462,SrTa4O11,hexagonal,4.984819e-10,"[[[0.0, 0.0, 9.458028576342811e-10], [0.0, 0.0...","[O, Sr, Ta]","[[5.1063906085247615, -2.0059476025066026e-11,...",3.5575,{'icsd': ['icsd-79704']},0.0,"[[5.152425241993238, 3.7866870166718205e-18, 0...",...,32,622,182,P6_322,"{'@module': 'pymatgen.core.structure', '@class...",9.241621e-10,Materials Project,Naccarato,1.666305e-06,2.060377e-09


df_outputs.shape = (2776, 26)
df_outputs_filtout.shape = (2518, 26)


In [8]:
df_outputs_filtout[df_outputs_filtout['origin']==f'hg v{cur_v-1}'].shape

(276, 26)

In [9]:
# dKP-Eg-(ns)-KDE with log-lin inset

# ==============================================================================================================================
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.02, column_widths=[0.8,0.2])

# ROW 1 COL 1 =================================================================================================================
fig.add_trace(go.Scatter(x=df_outputs_filtout[df_outputs_filtout['origin']!=str_origin]['src_bandgap'],
                         y=df_outputs_filtout[df_outputs_filtout['origin']!=str_origin]['dKP'],
                         mode='markers',
                         marker = dict(color = df_outputs_filtout[df_outputs_filtout['origin']!=str_origin]['n'],
                                       showscale=True,
                                       colorbar = dict(title = '<i>n<sub>s</sub></i>', 
                                                       title_font_size = 36,
                                                       orientation = 'v',
                                                    #    tickvals=[1,2,3,4,5,6],
                                                       x = 0.9,
                                                       y=0.6, 
                                                       len = 0.9),
                                       colorscale = 'plasma', 
                                       ),
                         showlegend=True,
                         name=f'Old',
                         text=[mpid for mpid in df_outputs_filtout[df_outputs_filtout['origin']!=str_origin].index.values]
                         ))

fig.add_trace(go.Scatter(x=df_outputs_filtout[df_outputs_filtout['origin']==str_origin]['src_bandgap'],
                         y=df_outputs_filtout[df_outputs_filtout['origin']==str_origin]['dKP'],
                         mode='markers',
                         marker = dict(symbol='diamond',
                                       size=8,
                                       color = 'lightgreen',
                                       showscale=False,
                                    #    colorbar = dict(title = '<i>n<sub>s</sub></i>', 
                                    #                    title_font_size = 24,
                                    #                    title_side = 'top',
                                    #                    orientation = 'h',
                                    #                    tickvals=[1,2,3,4,5,6],
                                    #                    x = 0.75,
                                    #                    y=0.75, 
                                    #                    len = 0.4),
                                    #    colorscale = ['lightgray','orange','aqua','lightgreen','white','pink'],
                                       line=dict(width=2, color="black")
                                       ),
                         showlegend=True,
                         name=f'v{cur_v-1} outputs',
                         text=[mpid for mpid in df_outputs_filtout[df_outputs_filtout['origin']==str_origin].index.values]
                         ))


# AXES
fig.update_xaxes(title = '<i>E<sub>g</sub></i> (eV)',
                 title_font_size=36,
                 range = [-0.1, 8.4],
                 row=1, col=1)
fig.update_yaxes(title = '<i>d</i><sub>KP</sub> (pm/V)',
                 title_font_size=36)


fig.update_layout(font={'family':'Arial', 'size': 20},
                 )

# ROW 1 COL 2 =================================================================================================================
density = gaussian_kde(np.reshape(df_outputs_filtout['dKP'].values, (1,len(df_outputs_filtout))))
density.covariance_factor = lambda : .02 #Smoothing parameter
density._compute_covariance()

x_vals = np.linspace(min(df_outputs_filtout['dKP'].values),
                     max(df_outputs_filtout['dKP'].values),
                     200) # Specifying the limits of our data
kde_dist = density(x_vals)

fig.add_trace(go.Scatter(x=kde_dist, 
                         y=x_vals, 
                         mode='lines', 
                         marker_color='indianred',
                         fill='tozerox',
                         showlegend=False),
              row=1, col=2)

fig.update_xaxes(title = 'Distribution',
                 title_font_size=36,
                 row=1, col=2)
fig.update_yaxes(title = '', row=1, col=2)

# INSET =======================================================================================================================
inset = copy.deepcopy(fig.data[0])
inset.xaxis = 'x3'
inset.yaxis = 'y3'

inset_candidates_v1 = copy.deepcopy(fig.data[1])
inset_candidates_v1.xaxis = 'x3'
inset_candidates_v1.yaxis = 'y3'
inset


fig.update_layout(
    xaxis3=dict(domain      = [0.30, 0.75],
                anchor      = 'y3',
                range       = [-0.1, 8.4],
                linecolor   = 'black'
               ),
    yaxis3=dict(domain      = [0.60, 0.98],
                anchor      = 'x3',
                range       = [-3.5, 2.6],
                type        = 'log',
                tickvals    = [0.001, 0.01, 1, 100],
                linecolor   = 'black'
    ))

fig.add_trace(inset)
fig.add_trace(inset_candidates_v1)

fig.update_xaxes(showspikes=True, spikecolor="gray", spikethickness=2, spikesnap="cursor", spikemode="across")
fig.update_yaxes(showspikes=True, spikecolor="gray", spikethickness=2, spikesnap="cursor", spikemode='across')
fig.update_layout(hoverdistance=5)

# THEME OF GRAPH
fig.update_layout(template='simple_white')

fig.update_layout(
    width=1000,
    height=500
)

fig.show()

name_fig = "dKP-Eg_KDE_cur_res_filtout"
if not Path(f'../figures/pdf/{name_fig}.pdf').exists():
    fig.write_image(f'../figures/pdf/{name_fig}.pdf')
    fig.write_image(f'../figures/svg/{name_fig}.svg')
    fig.write_image(f'../figures/png/{name_fig}.png', scale=10)
    fig.write_html(f'../figures/html/{name_fig}.html')

In [10]:
str_origin_nac = "Naccarato"
# dKP-Eg-(ns)-KDE with log-lin inset

# ==============================================================================================================================
fig = make_subplots(rows=1, cols=2, shared_yaxes=True, horizontal_spacing=0.02, column_widths=[0.8,0.2])

# ROW 1 COL 1 =================================================================================================================
fig.add_trace(go.Scatter(x=df_outputs_filtout[df_outputs_filtout['origin']==str_origin_nac]['src_bandgap'],
                         y=df_outputs_filtout[df_outputs_filtout['origin']==str_origin_nac]['dKP'],
                         mode='markers',
                         marker = dict(color = df_outputs_filtout[df_outputs_filtout['origin']==str_origin_nac]['n'],
                                       showscale=True,
                                       colorbar = dict(title = '<i>n<sub>s</sub></i>', 
                                                       title_font_size = 36,
                                                       orientation = 'v',
                                                    #    tickvals=[1,2,3,4,5,6],
                                                       x = 0.9,
                                                       y=0.6, 
                                                       len = 0.9),
                                       colorscale = 'plasma', 
                                       ),
                         showlegend=True,
                         name=f'Old',
                         text=[mpid for mpid in df_outputs_filtout[df_outputs_filtout['origin']==str_origin_nac].index.values]
                         ))

fig.add_trace(go.Scatter(x=df_outputs_filtout[df_outputs_filtout['origin']!=str_origin_nac]['src_bandgap'],
                         y=df_outputs_filtout[df_outputs_filtout['origin']!=str_origin_nac]['dKP'],
                         mode='markers',
                         marker = dict(symbol='circle',
                                       size=5,
                                       color = 'lightgreen',
                                       showscale=False,
                                    #    colorbar = dict(title = '<i>n<sub>s</sub></i>', 
                                    #                    title_font_size = 24,
                                    #                    title_side = 'top',
                                    #                    orientation = 'h',
                                    #                    tickvals=[1,2,3,4,5,6],
                                    #                    x = 0.75,
                                    #                    y=0.75, 
                                    #                    len = 0.4),
                                    #    colorscale = ['lightgray','orange','aqua','lightgreen','white','pink'],
                                       line=dict(width=1, color="black")
                                       ),
                         showlegend=True,
                         name=f'New',
                         text=[mpid for mpid in df_outputs_filtout[df_outputs_filtout['origin']!=str_origin_nac].index.values]
                         ))


# AXES
fig.update_xaxes(title = '<i>E<sub>g</sub></i> (eV)',
                 title_font_size=36,
                 range = [-0.1, 8.4],
                 row=1, col=1)
fig.update_yaxes(title = '<i>d</i><sub>KP</sub> (pm/V)',
                 title_font_size=36)


fig.update_layout(font={'family':'Arial', 'size': 20},
                 )

# ROW 1 COL 2 =================================================================================================================
density = gaussian_kde(np.reshape(df_outputs_filtout['dKP'].values, (1,len(df_outputs_filtout))))
density.covariance_factor = lambda : .02 #Smoothing parameter
density._compute_covariance()

x_vals = np.linspace(min(df_outputs_filtout['dKP'].values),
                     max(df_outputs_filtout['dKP'].values),
                     200) # Specifying the limits of our data
kde_dist = density(x_vals)

fig.add_trace(go.Scatter(x=kde_dist, 
                         y=x_vals, 
                         mode='lines', 
                         marker_color='indianred',
                         fill='tozerox',
                         showlegend=False),
              row=1, col=2)

fig.update_xaxes(title = 'Distribution',
                 title_font_size=36,
                 row=1, col=2)
fig.update_yaxes(title = '', row=1, col=2)

# INSET =======================================================================================================================
inset = copy.deepcopy(fig.data[0])
inset.xaxis = 'x3'
inset.yaxis = 'y3'

inset_candidates_v1 = copy.deepcopy(fig.data[1])
inset_candidates_v1.xaxis = 'x3'
inset_candidates_v1.yaxis = 'y3'
inset


fig.update_layout(
    xaxis3=dict(domain      = [0.30, 0.75],
                anchor      = 'y3',
                range       = [-0.1, 8.4],
                linecolor   = 'black'
               ),
    yaxis3=dict(domain      = [0.60, 0.98],
                anchor      = 'x3',
                range       = [-3.5, 2.6],
                type        = 'log',
                tickvals    = [0.001, 0.01, 1, 100],
                linecolor   = 'black'
    ))

fig.add_trace(inset)
fig.add_trace(inset_candidates_v1)

fig.update_xaxes(showspikes=True, spikecolor="gray", spikethickness=2, spikesnap="cursor", spikemode="across")
fig.update_yaxes(showspikes=True, spikecolor="gray", spikethickness=2, spikesnap="cursor", spikemode='across')
fig.update_layout(hoverdistance=5)

# THEME OF GRAPH
fig.update_layout(template='simple_white')

fig.update_layout(
    width=1000,
    height=500
)

fig.show()

name_fig = "dKP-Eg_KDE_all_res_filtout"
if not Path(f'../figures/pdf/{name_fig}.pdf').exists():
    fig.write_image(f'../figures/pdf/{name_fig}.pdf')
    fig.write_image(f'../figures/svg/{name_fig}.svg')
    fig.write_image(f'../figures/png/{name_fig}.png', scale=10)
    fig.write_html(f'../figures/html/{name_fig}.html')

In [11]:
df_outputs_filtout[df_outputs_filtout['origin']!="Naccarato"].shape

(1960, 26)

# Accuracy previous model

In [12]:
#   Paths to the candidates datasets
candidate_df_path    = Path(f"../../{cur_v-1}_al_it/data/df_selected_v{cur_v-1}.pkl.gz")
    
#   Loading the candidates datasets
df_prev_selected = pd.read_pickle(candidate_df_path)
print(df_prev_selected.shape)
display(df_prev_selected.head())

(519, 11)


Unnamed: 0,dKP,dKP_unc,dKP_unc_cal_ma,dKP_unc_cal_nll,structure,nsites,ehull,bandgap,spg,reduced_formula,added_from
agm2000129361,0.031509,0.220027,0.318939,4.337121,"{'@module': 'pymatgen.core.structure', '@class...",6,0.008973,7.3598,189,PF5,Pareto
agm003738833,0.096076,0.130822,0.189632,2.578731,"{'@module': 'pymatgen.core.structure', '@class...",12,0.027481,7.3123,30,LiAlF4,Pareto
agm002138386,0.112327,0.171703,0.248891,3.384572,"{'@module': 'pymatgen.core.structure', '@class...",8,0.0,7.2732,5,KBe2F5,Pareto
agm005052162,0.142544,0.138844,0.201261,2.736868,"{'@module': 'pymatgen.core.structure', '@class...",7,0.014591,6.9649,82,NaLiBeF4,Pareto
agm1000019892,0.202417,0.222,0.321798,4.376006,"{'@module': 'pymatgen.core.structure', '@class...",18,0.013705,6.8742,1,CO2,Pareto


In [13]:
df_prev_selected_filtout = df_prev_selected.filter(items=df_outputs_filtout.index, axis=0)

In [14]:
# Data
data_prev_outputs = {'x': df_outputs_filtout[df_outputs_filtout['origin']==str_origin]['dKP'], 
                     'y': df_prev_selected_filtout['dKP'], 
                     'error_y': df_prev_selected_filtout['dKP_unc'],
                    #  'color': df_outputs_filtout['bandgap'],
                     'color': df_prev_selected_filtout['dKP_unc']/df_prev_selected_filtout['dKP']}


mae = mean_absolute_error(data_prev_outputs['x'], data_prev_outputs['y'])
mape = mean_absolute_percentage_error(data_prev_outputs['x'], data_prev_outputs['y'])
rmse = root_mean_squared_error(data_prev_outputs['x'], data_prev_outputs['y'])
spearmanrho = spearmanr(data_prev_outputs['x'], data_prev_outputs['y'])
print(f"MAE = {mae}")
print(f"MAPE = {mape}")
print(f"RMSE = {rmse}")
print(f"Rho_sp = {spearmanrho.statistic}")

# Scatter plot for previous outputs.
scatter_outputs_prev = go.Scatter(
    x=data_prev_outputs['x'],
    y=data_prev_outputs['y'],
    mode='markers',
    error_y=dict(
        array=data_prev_outputs['error_y'],
        color='gray',
        thickness=1,
        visible=True,
        ),
    marker=dict(
        color=data_prev_outputs['color'],
        colorbar=dict(
            title="<i>E<sub>g</sub></i> (eV)"
        ),
        colorscale="Plasma"
    ),
    line=dict(color='darkorange'),
    name=f'v{cur_v-1} outputs',
    showlegend=False,
    text=[mpid for mpid in df_outputs_filtout.index.values]
)

ideal = go.Scatter(
    x=[-100,200],
    y=[-100,200],
    mode="lines",
    line=dict(color='gray', dash='dot'),
    showlegend=False
)

# Layout
layout = go.Layout(
    # title=dict(text='Scatter Plot'),
    xaxis=dict(title='<i>d</i><sub>KP</sub> (pm/V)',  range=[-5,180]),
    yaxis=dict(title='<i>d&#770;</i><sub>KP</sub> (pm/V)', range=[-5,180]),
    # legend=dict(font=dict(size=12)),
)

# Create figure
fig = go.Figure(data=[scatter_outputs_prev,ideal], layout=layout)

fig.update_layout(
    autosize=False,
    font_size=20,
    width=600,
    height=600,
    # plot_bgcolor="white",
    template='simple_white',
)
fig.update_layout(
    xaxis = dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
    yaxis = dict(
        # tickmode = 'array',
        # tickvals = [1, 2,3,4,5,6,7,8],
        # ticktext = ['One', 'Three', 'Five', 'Seven', 'Nine', 'Eleven']
        showgrid=False,
    ),
)
fig.update_traces(marker=dict(
    colorbar=dict(
        title="<i><d&#770;</i><sub>KP</sub>>/<i>d&#770;</i><sub>KP</sub>",
        # tickmode = 'array',
        # tickvals = [0,0.1,0.2,0.3,0.4]
        ),
    cmin=0,
    # cmax=0.6,
))

# Show figure
fig.show()

name_fig = "dKP_comparison_prev_pred_ML_vs_FP_filtout"
if not Path(f'../figures/pdf/{name_fig}.pdf').exists():
    fig.write_image(f'../figures/pdf/{name_fig}.pdf')
    fig.write_image(f'../figures/svg/{name_fig}.svg')
    fig.write_image(f'../figures/png/{name_fig}.png', scale=10)
    fig.write_html(f'../figures/html/{name_fig}.html')
    with open(f'../figures/{name_fig}.json', 'w') as f:
        json.dump(obj={'mape': mape, 'mae': mae, 'rmse': rmse, 'rho_sp': spearmanrho.statistic},
                  fp=f)

MAE = 4.807854065868976
MAPE = 915.9656182557774
RMSE = 10.742142062618846
Rho_sp = 0.885168876014897


# Interesting materials

In [15]:
def get_pareto_front(Xs=None, Ys=None, maxX=True, maxY=True):
    '''Pareto frontier selection process'''
    sorted_list = sorted([[Xs[i], Ys[i]] for i in range(len(Xs))], reverse=maxY) # sorts the pairs wrt X descending (highest first) and then only Y descending
    pareto_front = [sorted_list[0]] # saves the highest X with highest Y ("last point")
    for pair in sorted_list[1:]:
        if maxY:
            if pair[1] >= pareto_front[-1][1]: # if Y higher than the last saved point (X is lower anyway), saves this point
                pareto_front.append(pair)
        else:
            if pair[1] <= pareto_front[-1][1]:
                pareto_front.append(pair)
    
    # '''Plotting process'''
    # plt.scatter(Xs,Ys,marker='.')
    # pf_X = [pair[0] for pair in pareto_front]
    # pf_Y = [pair[1] for pair in pareto_front]
    # plt.scatter(pf_X, pf_Y,color='red',marker='.')
    # plt.xlabel("Objective 1")
    # plt.ylabel("Objective 2")
    # plt.show()

    return pareto_front

def peel_pareto_fronts(nb_fronts, df_entries, X_name, Y_name, Y_unc_name=None, coef_unc=1):
    df_pareto = pd.DataFrame(columns=df_entries.columns)
    df_wo_pareto = deepcopy(df_entries)

    for i in range(nb_fronts):
        Xs = df_wo_pareto[X_name].tolist()
        if Y_unc_name:
            Ys = df_wo_pareto[Y_name]+df_wo_pareto[Y_unc_name]*coef_unc
            Ys = Ys.tolist()
        else:
            Ys = df_wo_pareto[Y_name].tolist()

        pareto_front = get_pareto_front(Xs=Xs, Ys=Ys)
        df_pareto_new = pd.DataFrame(columns=df_entries.columns)

        for xs, ys in pareto_front:
            if Y_unc_name:
                df_match = df_entries[(df_entries[X_name]==xs) & (df_entries[Y_name]+df_entries[Y_unc_name]*coef_unc==ys)]
            else:
                df_match = df_entries[(df_entries[X_name]==xs) & (df_entries[Y_name]==ys)]
            if len(df_match) > 1:
                print("There are multiple occurences of the same bandgap")
            df_pareto_new = pd.concat([df_pareto_new.astype(df_match.dtypes), df_match])
        df_pareto = pd.concat([df_pareto.astype(df_match.dtypes), df_pareto_new])
        # df_pareto = df_pareto.reset_index(drop=True)
        # display(df_pareto)
    
        df_wo_pareto = df_wo_pareto.drop(df_pareto_new.index, axis=0)

    return df_pareto, df_wo_pareto


    # df_wo_pareto = pd.concat([df,df_pareto], ignore_index=True).drop_duplicates(subset=['mp_id'], keep=False, ignore_index=True)

    # print(f"Number of selected instances = {len(pareto_front)}")

## Kiril

In [34]:
elem_to_excl = ["H", "B", "C", "N", "O", "S", "Se", "Te", "F", "Cl", "Br", "I"]
elem_to_focus = ["P", "As"]
# elem_to_excl = ["H", "B", "C", "N", "O", "Te", "F", "Cl", "Br", "I"]
# elem_to_focus = ["S", "Se"]

In [35]:
import plotly.express as px
import plotly.graph_objects as go

nb_pareto_peeled = 10

print("Select the Pareto front")
print(f"{df_outputs_filtout.shape = }")
df_pareto, df_wo_pareto = peel_pareto_fronts(nb_pareto_peeled, df_outputs_filtout, X_name='src_bandgap', Y_name='dKP')
print(f"Peel {nb_pareto_peeled} Pareto fronts")
print(f"{df_pareto.shape = }")

# Plotly scatter plot
fig = go.Figure()

# Add non-Pareto points
list_hovertext = []
for ir, r in df_wo_pareto.iterrows():
        list_hovertext.append(ir + " " + r['formula_reduced'])
fig.add_trace(go.Scatter(
    x=df_wo_pareto['src_bandgap'],
    y=df_wo_pareto['dKP'],
    mode='markers',
    name='Non-Pareto',
    marker=dict(color='lightgray'),
    hovertext=list_hovertext,
    hoverinfo='text'
))

# Add Pareto points
list_hovertext = []
for ir, r in df_pareto.iterrows():
        list_hovertext.append(ir + " " + r['formula_reduced'])
fig.add_trace(go.Scatter(
    x=df_pareto['src_bandgap'],
    y=df_pareto['dKP'],
    mode='markers',
    name='Pareto',
    marker=dict(color='gray'),
    hovertext=list_hovertext,
    hoverinfo='text'
))


# Convert excluded elements to drop indices
idx_to_drop = []
for ir, r in df_pareto.iterrows():
    for e in r['elements']:
        if e in elem_to_excl:
            idx_to_drop.append(ir)
df_pareto = df_pareto.drop(idx_to_drop, axis=0)
print("Exclude elements")
print(f"{df_pareto.shape = }")

# Add Pareto points
list_hovertext = []
for ir, r in df_pareto.iterrows():
        list_hovertext.append(ir + " " + r['formula_reduced'])
fig.add_trace(go.Scatter(
    x=df_pareto['src_bandgap'],
    y=df_pareto['dKP'],
    mode='markers',
    name='Exclude elements',
    marker=dict(color='blue'),
    hovertext=list_hovertext,
    hoverinfo='text'
))

# Focus elements for keeping
idx_to_keep = []
for ir, r in df_pareto.iterrows():
    for e in r['elements']:
        if e in elem_to_focus:
            idx_to_keep.append(ir)
df_pareto = df_pareto.loc[idx_to_keep]
print("Focus on elements")
print(f"{df_pareto.shape = }")

# Add Pareto points
list_hovertext = []
for ir, r in df_pareto.iterrows():
        list_hovertext.append(ir + " " + r['formula_reduced'])
fig.add_trace(go.Scatter(
    x=df_pareto['src_bandgap'],
    y=df_pareto['dKP'],
    mode='markers',
    name='Focus on elements',
    marker=dict(color='red'),
    hovertext=list_hovertext,
    hoverinfo='text'
))

# Add axis labels and title
# Update layout with larger font sizes
fig.update_layout(
    title="Materials selection Visualization",
    title_font=dict(size=16),  # Title font size
    xaxis=dict(
        title="Bandgap (eV)",
        title_font=dict(size=16),  # X-axis title font size
        tickfont=dict(size=14)  # X-axis tick font size
    ),
    yaxis=dict(
        title="d<sub>KP</sub> (pm/V)",
        title_font=dict(size=16),  # Y-axis title font size
        tickfont=dict(size=14)  # Y-axis tick font size
    ),
    legend=dict(
        font=dict(size=14)  # Legend font size
    ),
    template="plotly_white"
)

fig.show()

display(df_pareto[['formula_reduced', 'crystal_system', 'dKP', 'n', 'src_bandgap', 'nsites', 'spg_number', 'origin']])


Select the Pareto front
df_outputs_filtout.shape = (2518, 26)
Peel 10 Pareto fronts
df_pareto.shape = (630, 26)
Exclude elements
df_pareto.shape = (57, 26)
Focus on elements
df_pareto.shape = (36, 26)


Unnamed: 0,formula_reduced,crystal_system,dKP,n,src_bandgap,nsites,spg_number,origin
mp-34903,MgGeP2,triclinic,80.938905,3.153317,1.4962,8,122,hg v8
agm003161522,BaTlAs,hexagonal,141.523926,4.011406,0.9755,3,156,hg v6
mp-4666,CdSiP2,tetragonal,89.663214,3.360277,1.4257,8,122,hg v17
mp-1215555,Zn2Si2As3P,monoclinic,108.202015,3.6087,1.1823,8,5,hg v14
mp-1215555,Zn2Si2As3P,monoclinic,108.202015,3.6087,1.1823,8,5,hg v14
mp-1222618,Li2ZnCdP2,rhombohedral,116.894795,3.419502,1.0652,6,160,hg v18
mp-1105127,MnP4,monoclinic,165.114231,4.515302,0.5336,20,9,hg v14
mp-1222661,Li2MgCdP2,Tetragonal,41.539623,2.936768,2.004,6,115,Pareto
agm006085187,AlGa3P4,tetragonal,47.445923,3.233662,1.697,8,121,hg v18
mp-2490,GaP,cubic,48.751564,3.35453,1.6843,2,216,Naccarato
