# This script contains the code to calculate formation energies and save the necessary data for binary oxides.

In [None]:
import pickle

import numpy as np
import matplotlib.pyplot as plt

from collections import defaultdict

from pymatgen.ext.matproj import MPRester
from pymatgen.core.composition import Composition
from pymatgen.core.periodic_table import Element
from pymatgen.analysis.pourbaix_diagram import PourbaixDiagram, PourbaixPlotter

from matminer.featurizers.site import CrystalNNFingerprint
from matminer.featurizers.structure import SiteStatsFingerprint

# Initialize the MP Rester
mpr = MPRester('pn8XdbGhMrv90STu')


In [None]:
ele2gs = pickle.load(open("ele2gs.p", "rb"))
ele2gs["O"] = -4.95

In [None]:
def pourbaix_data_multielement(pbx, element1, element2, ph, voltage, entries_of_interest):
    """
    Arguments:
    pbx:           the PourbaixDiagram inputs, passed to the function as 
                   entries = mpr.get_pourbaix_entries([metal])
                   pbx_dat = PourbaixDiagram(entries, filter_solids=True)
                   pbx_dat would then be passed in as the argument.
    entries:       A list of Pourbaix Entries, a pymatgen data type of each materials project datapoint considered in 
                   the Pourbaix diagram
    element1:        One of the metals to consider, passed as a string ie. "Mn"
    element2:        Other metal to consider, passed as a string ie. "Mn". By default, we set metal2=None, so that we can
                   decide when we are using the function whether we want to use one or two metals.

    ph:            A single value or range of values
    voltage:       A single value or range of values (must match ph in shape)
    comp_entries:  Since the PourbaixDiagram object has to be redrawn if there is a change in compostion, 
                   we pre-sort the entries by composition so that we don't have to do them for each entry but for 
                   each given composition.

    
    Returns:
    material_names:     List of names of each material which are considered stable by the function, as a string.
    materials_proj_ids: List of MP ids of each material which are considered stable by the function, as a string.
    energies:           List of ΔG_pbx of each material thats considered stable by the function, as floating point.
    """

    materials_proj_ids = []
    material_names = []
    energies = []
    
    for entry in entries_of_interest:
        energy = pbx.get_decomposition_energy(entry, pH=ph, V=voltage)
        #print(energy)
        materials_proj_ids.append(entry.entry_id)
        material_names.append(entry.name)
        energies.append(energy)
        
    return material_names, materials_proj_ids, energies

def calc_formation_ene(entry):
    """
    Need to calculate formation energies ourselves because materials project formation energies are inconsistent.
    See https://matsci.org/t/formation-energy-calculation/41574 for further information.
    
    """    
    composition = {str(key):value for (key,value) in entry.composition.items()}
    total_atoms = sum(composition.values())
    
    total_energy = entry.energy_per_atom
    formation_energy = total_energy
    for element in composition:
        formation_energy-=ele2gs[element]*composition[element]/total_atoms
    print(formation_energy)
    print(entry.entry_id)
    return formation_energy


In [None]:
#binary_oxide_data = defaultdict()
# Want to fill in this with metals of interest
#binary_oxide_data = pickle.load(open("binary_oxide_data.p", "rb"))
binary_oxide_data = pickle.load(open("binary_oxide_data_form_ene_self_calc_fingerprints.p", "rb"))
# actives = [ "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Nb", "Mo"]
# pairers = ["Li", "Be", "Na", "Mg", "K", "Ca", "Rb", "Sr", "Cs", "Ba", 
#            "Sc",
#            "Zn", "Ga", "Ge", "As", "Se", "Br", "Y", "Se", "Tc", "Cd",
#            "Ru", "Rh", "Pd", "Ag",
#            "In", "Sn", "Sb", "Te", "I", "Hg", "Tl", "Pb", "Bi", 
#           "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au"]
elements = ['Cr', 'Mn', 'Fe', 'Co']
pairers = ["Li", "Be", "Na", "Mg", "K", "Ca", "Rb", "Sr", "Cs", "Ba", # Alkalis
            "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
            "Ga", "Ge", "As", "Se", "Br",
            "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd",
            "In", "Sn", "Sb", "Te", "I",
            "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", 
            "Tl", "Pb", "Bi",
           "La", "Ce", "Nd", "Pr", "Sm", "Eu", "Gd", "Tb", "Dy", "Ho", "Er", "Tm", "Yb", "Lu"]

#all_ = elements
# Could vary this range, ie. perhaps we want to look at pH up to 2/3
min_ph = 0
max_ph = 0
min_v = 1
max_v = 2

ssf = SiteStatsFingerprint(
    CrystalNNFingerprint.from_preset('ops', distance_cutoffs=None, x_diff_weight=0),
    stats=('mean', 'std_dev', 'minimum', 'maximum'))


for idx, ele in enumerate(['Fe', 'Al', 'Ca']):
    for ele2 in ['Si']:
        if ele+"_"+ele2 in binary_oxide_data or ele==ele2:
            continue        
        print(ele, ele2)
        #continue
        # pass the pH and V as an numpy 'meshgrid'.

        pH, V = np.mgrid[
        min_ph : max_ph : 1 * 1j,
        min_v : max_v : 11 * 1j,
        ]
        
        # here, we get all entries in the materials project database with metal/O/H
        entries = mpr.get_pourbaix_entries([ele, ele2])
        # now need to sort entries by composition so that they are read in cleverly
        composition2entries = defaultdict(list)
        for entry in entries:
            if ele in entry.name and ele2 in entry.name and "O" in entry.name and "H" not in entry.name and "ion" not in entry.entry_id:
                composition = entry.composition.as_dict()
                amount_ele1 = composition[ele]
                amount_ele2 = composition[ele2]
                amount_o = composition["O"]
                if amount_o/(amount_ele1+amount_ele2)!=2:
                    #print('svnsdfjklnvdfkjnvf')
                    continue
                composition1 = amount_ele1/(amount_ele1+amount_ele2)
                composition2entries[str(composition1)].append(entry)
        
        structures = []
        fingerprints = []
        material_names_ = []
        material_ids_ = []
        energies_ = []
        form_enes_ = []

        for comp in composition2entries:
            pbx_dat = PourbaixDiagram(entries, filter_solids=True, comp_dict={ele: float(comp),
                                                                              ele2: 1-float(comp)})
            
            
            material_names, material_ids, energies = pourbaix_data_multielement(pbx_dat, ele, ele2, 
                                                                                pH, V, composition2entries[comp])
            assert len(material_names)==len(material_ids)==len(energies)
            for idx, mp_id in enumerate(material_ids):
                try:
                    structure = mpr.get_structure_by_material_id(mp_id)
                    # exception can happen below
                    fingerprint = np.array(ssf.featurize(structure))
                    # exception happens above
                    structures.append(structure)
                    fingerprints.append(fingerprint)
                    material_names_.append(material_names[idx])
                    material_ids_.append(material_ids[idx])
                    energies_.append(energies[idx])
                    ref_query = mpr.get_entry_by_material_id(material_ids[idx])#, property_data=['formation_energy_per_atom'])
                    ref_form_ene = calc_formation_ene(ref_query)
                    form_enes_.append(ref_form_ene)

                except ValueError as e:
                    print(e)
                    

        #assert len(fingerprints)==len(structures)==len(material_ids_)==len(material_names_)==len(energies_)
        assert len(structures)==len(material_ids_)==len(material_names_)==len(energies_)

        binary_oxide_data[ele+"_"+ele2] = {
            "names": material_names_,
            "mp_ids": material_ids_,
            "energies": energies_,
            "structures": structures,
            "fingerprints": fingerprints,
            "formation_energies": form_enes_
        }
        pickle.dump(binary_oxide_data, open("binary_oxide_data_form_ene_self_calc_fingerprints.p", "wb"))
    