# MP Structure Filter
# Goals:
- Query all MP structures which have the relevant target elements from the Materials Project Database, using the Materials Project API.

## Inputs:
- Joey's Bader MP Charge data; Use the MP Bader Charges ipython notebook in the data folder to set up a local MongoDB which can query it
## Outputs:
MP_{}\_{}\_API\_collations.json: File containing the structures + associated Bader charges based on MP ID and Structure matching in XAS_collation form, serialized as dictionaries using the as_dict() method built into MSONables. This is the preferred format for later steps in the pre-processing.
##### Optional:
MP_{}\_{}\_API\_structures.json: File containing the structures in pymatgen Structure, 
serialized as dictionaries using the as_dict() method built in to Pymatgen structures.

In [None]:
import os
data_prefix = '/Users/steventorrisi/Documents/TRIXS/data/'
storage_directory = os.path.join(data_prefix,'Pre-Processing')


In [None]:
from pymatgen.ext.matproj import MPRester
from tqdm import tqdm, tqdm_notebook
import json

from trixs.spectra.core import XAS_Collation
from pymatgen.core.structure import Structure
from pymatgen.analysis.structure_matcher import StructureMatcher, ElementComparator

# My MP id... please use wisely :)
mpr = MPRester('80n2gkFfpXbPxZJTxD')
# NB if you add more, ENSURE THE FIRST ATOM IS THE METAL
target_elements_groups=[('Co','O'),('Fe','O'),('V','O'),('Cu','O'),
                        ('Ni','O'),('Cr','O'),('Mn','O'),('Ti','O')]
target_metals = set(['Co','Fe','V','Cu','Ni','Cr','Mn','Ti'])
O = set(['O'])


## Connect to local MongoDB with Bader Chg information

Mount Mongo DB in order to associate MP with Charges, and instantiate structure matcher

In [None]:
from pymongo import MongoClient
conn = MongoClient()
bader = conn.vasp_jhm_test.bader_refined
all_charges = {struc['original_task_id']:struc for struc in bader.find()}
matcher = StructureMatcher(comparator = ElementComparator(),attempt_supercell= True)



In [None]:
def bader_doc_to_colls(doc):
    """
    From a bader document from the database, return a list of collations
    using the multiple structures associated with each Bader document (typically
    different stages of relaxation).
    """
    colls = []
    mp_id = doc['original_task_id']
    for struc in doc['structures']:
        cur_struc = Structure.from_dict(struc)
        cur_baders = [cur_struc.sites[i].as_dict()['species'][0]['oxidation_state'] 
                               for i in range(len(cur_struc.sites))]
        cur_species = [cur_struc.sites[i].as_dict()['species'][0]['element']
               for i in range(len(cur_struc.sites))]
        cur_oxy = [(spec,oxy) for spec, oxy in zip(cur_species,cur_baders)]
        new_col = XAS_Collation(structure=cur_struc, mp_bader = cur_oxy, mp_id = mp_id )
        new_col.elements = set(cur_species)
        colls.append(new_col)
    return colls

def match_bader_by_id(all_charges, mp_id, structure):
        # Check to see if the bader charges exist for the current collations's MP id
        bader_mp_doc = all_charges.get(mp_id,None)
        if not bader_mp_doc:
            return []
        # Look at each bader charge analyzed MP structure individually,
        # since each Bader object has multiple structures associated with it
        for bader_struc_dict in bader_mp_doc['structures']:
            bader_struc = Structure.from_dict(bader_struc_dict)
            # If structure match works, assign charge then break to next MP structure
            if matcher.fit(structure,bader_struc):
                species = [bader_struc.sites[i].as_dict()['species'][0]['element'] 
                               for i in range(len(bader_struc.sites))]
                baders = [bader_struc.sites[i].as_dict()['species'][0]['oxidation_state'] 
                               for i in range(len(bader_struc.sites))]
                return [(spec,oxy) for spec,oxy in zip(species,baders)]
            
def match_bader_by_structure(bader_candidate_cols,mp_id,structure):
    # Helper function to match via structure similarity
    present_species = set([str(site.specie) for site in structure.sites])
    mp_oxy_candidates = [x for x in bader_candidate_cols if
                         present_species==x.elements
                         and x.mp_id != mp_id]
    for candidate in mp_oxy_candidates:
        cand_struc = candidate.structure
        if matcher.fit(structure,cand_struc):
            return candidate.mp_bader, candidate.mp_id
    return [], None

## Query MP Server

In [None]:
for pair in target_elements_groups:
    data = mpr.query(criteria = {'elements':{"$all":pair}},
                     properties=['final_structure','formation_energy_per_atom',
                                 'energy','task_id','icsd_ids','icsd'])
    
    file_name = '{}_{}_MP_API_collations.json'.format(pair[0],pair[1])
    write_target = os.path.join(storage_directory,file_name)
    
    # Prepare candidate bader candidates for all pair sets
    pairset = set(pair)
    bader_candidate_docs = [doc for doc in list(all_charges.values()) if pairset.issubset(set(doc['elements']))]
    bader_candidate_cols = []
    for doc in bader_candidate_docs:
        bader_candidate_cols += bader_doc_to_colls(doc)  
    matched_by_structure = 0
    matched_by_id = 0 
    with open(write_target,'w') as f:
        for dat in tqdm_notebook(data,desc='Matching Bader, writing for {}'.format(pair)):
            cur_struc = dat['final_structure']
            cur_mpid = dat['task_id']
            cur_icsds = dat['icsd_ids']
            cur_bader = match_bader_by_id(all_charges,cur_mpid,cur_struc)
            if not cur_bader:
                cur_bader,assoc_id = match_bader_by_structure(bader_candidate_cols,
                                                    cur_mpid, cur_struc)
                if cur_bader:
                    matched_by_structure += 1
            else:
                matched_by_id +=1 

            cur_col = XAS_Collation(structure = cur_struc,
                                   mp_id = cur_mpid,
                                   icsd_ids = cur_icsds,
                                   mp_bader = cur_bader)
            
            f.write(json.dumps(cur_col.as_dict()) + '\n')
    file_name = '{}_{}_MP_API_structures.json'.format(pair[0],pair[1])
    write_target = os.path.join(storage_directory,file_name)          
    print('Bader charges matched by structure:',matched_by_structure)
    print('Bader charges matched by ID:',matched_by_id)

    with open(write_target,'w') as f:
        for dat in tqdm_notebook(data,desc='Writing {}'.format(pair)):
            cur_struc = dat['final_structure'].as_dict()
            cur_mpid = dat['task_id']            
            f.write(json.dumps({'structure':cur_struc,'id':cur_mpid}) + '\n')
