# Associate scraped spectra with collations and remove those attached to atoms which are symmetrically equivalent

 # Goals:
- Associate the spectra with collations drawn from the MP API.
- Remove spectra which are attached to symmetrically equivalent atoms.


## Inputs:
- MP_{}\_{}\_API_collations.json: File containing the structures + associated Bader charges based on MP ID and Structure matching in XAS_collation form, serialized as dictionaries using the as_dict() method built into MSONables. Produced by PP-A.
- MP_{}\_{}\_{}scraped_spectra.json: File containing the spectra scraped from the API based on MP ID serialized as dictionaries using the as_dict() method built into MSONables. Produced by PP-B.


## Outputs:
- MP_{}\_{}\_API+Spec+Symm\_collations.json: File containing the collations which have structures, bader charges, and spectra associated with them (spectra purged by symemtrically equivalent atoms).

In [None]:
import os
from pymatgen.ext.matproj import MPRester
data_prefix = '/Users/steventorrisi/Documents/TRIXS/data/'
storage_directory = os.path.join(data_prefix,'MP_OQMD_combined')

target_elements_groups=[('Co','O'),('Fe','O'),('V','O'),('Cu','O'),
                        ('Ni','O'),('Cr','O'),('Mn','O'),('Ti','O')]
target_metals = set(['Co','Ni','Fe','Cr','V','Mn','Cu','Ti'])

## Import statements

In [None]:
from pymatgen.ext.matproj import MPRester
from pymatgen.core import Structure
from pymatgen.analysis.structure_matcher import StructureMatcher, ElementComparator
from tqdm import tqdm, tqdm_notebook
from pprint import pprint
import json
import os
import numpy as np
from monty.json import MSONable
import gc
import matplotlib.pyplot as plt
from trixs.spectra.core import XAS_Spectrum, XAS_Collation
from pymatgen.analysis.local_env import CrystalNN

matcher = StructureMatcher(comparator = ElementComparator())

#  Load Pre-computed MP Collations

Loop through the MP structures obtained from the written output of an earlier scrape from the materials project API.
For each of those structures, index the structures as XANES collation objects.

In [None]:
mp_cols = {pair:[] for pair in target_elements_groups}
mp_id_to_col = {pair:{} for pair in target_elements_groups}
for pair in tqdm_notebook(target_elements_groups):
    file_name = '{}_{}_MP_API_collations.json'.format(pair[0],pair[1])
    read_target = os.path.join(storage_directory,file_name)
    with open(read_target,'r') as f:
        for line in f.readlines():
            cur_col = XAS_Collation.from_dict(json.loads(line))
            mp_cols[pair].append(cur_col)
            mp_id_to_col[pair][cur_col.mp_id]=cur_col
    print("Loaded in {} for {}".format(len(mp_cols[pair]),pair))

# Load Pre-computed MP Spectra

In [None]:
for pair in tqdm_notebook(target_elements_groups):
    file_name = '{}_{}_scraped_spectra+coord.json'.format(pair[0],pair[1])
    read_target = os.path.join(storage_directory,file_name)

    with open(read_target,'r') as f:
        for line in f.readlines():
            cur_dict = json.loads(line)
            try:
                cur_spec = XAS_Spectrum.from_dict(cur_dict)
            except:
                cur_spec = XAS_Spectrum.from_atomate_document(json.loads(line))
            
            cur_col = mp_id_to_col[pair].get(cur_spec.metadata['id'],None)
            if cur_col is not None:
                cur_col.mp_spectra.append(cur_spec)
            else:
                new_col = XAS_Collation(cur_spec.structure,mp_id = cur_spec.metadata['id'],
                                       mp_spectra= [cur_spec])
                mp_cols[pair].append(new_col)
                mp_id_to_col[pair][new_col.mp_id] = new_col

## Status report pre-structural purge

In [None]:
#master_path = '/Users/steventorrisi/Documents/TRIXS/data/MP_Xas/master_trans_metal_oxides.json'
print("Total spectra found before pruning symmetrically equivalent sites")
for pair in target_elements_groups:
    n_uniq_mp_struc = len([col for col in mp_cols[pair] if col.has_mp_spectra()])
    n_mp_spec = sum([len(col.mp_spectra) for col in mp_cols[pair]])
    print("Found {} unique structures with spectra with {} total spectra for {}".format(n_uniq_mp_struc,n_mp_spec,pair))
    

## Structure pruning 

In [None]:
def determine_uniqueness(strucs1,struc2):
    
    species_filtered = [struc for struc in strucs1 if struc.present_species==struc2.present_species]
    
    if len(species_filtered)==0 or len(strucs1)==0:
        return True

    
    structure_matcher = StructureMatcher(attempt_supercell=True,
                        comparator = ElementComparator())
    for struc in species_filtered:
        if structure_matcher.fit(struc,struc2):
            return False
    return True

## Determine unique structures, first doing ones which have spectra (so as to not accidentally throw away those).

In [None]:
has_spectra = {pair: [col for col in mp_cols[pair] if col.has_spectra()] for pair in target_elements_groups}
lacks_spectra = {pair: [col for col in mp_cols[pair] if not col.has_spectra()] for pair in target_elements_groups}

unique_structures = {pair: [] for pair in target_elements_groups}
unique_collations = {pair: [] for pair in target_elements_groups}
for pair in target_elements_groups:
    print("{} collations with spectra before purge:{}".format(pair,len(has_spectra[pair])))
    print("{} collations without spectra before purge:{}".format(pair,len(lacks_spectra[pair])))

    for col in tqdm_notebook(has_spectra[pair],
                             desc='Looping through {} structures'.format(pair)):
        # Load in the current spectrum as a dictionary
        cur_specs = col.mp_spectra
        if len(cur_specs):
            cur_strucs = [XAS_Spectrum.load_from_object(spec).structure for spec in cur_specs]
            cur_struc = cur_strucs[0]
        else:
            cur_struc = col.structure
        cur_struc.present_species = set([str(x) for x in cur_struc.species])
        
        if determine_uniqueness(unique_structures[pair],cur_struc):
            unique_structures[pair].append(cur_struc)
            unique_collations[pair].append(col)

    unique_collations[pair] += lacks_spectra[pair]
    del unique_structures[pair]
    N_unique_has_spec =len([col for col in unique_collations[pair] if col.has_spectra()])
    N_unique= len(unique_collations[pair])

    print("After purging, unique collations with spectra / total for pair: {}, {}/{}".format(pair,N_unique_has_spec,N_unique))


## Prune spectra which are redundant by symmetry

In [None]:
def are_all_same_structures(struc_list):
    if len(struc_list)==1:
        return True
    main = struc_list[0]
    for secondary in struc_list[1:]:
        if not matcher.fit(main,secondary):
            print("We got a serious problem")
            return False
    return True

The great symmetry purge

In [None]:
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
verbose = False

for pair in target_elements_groups:
    print("Spectra before purge for :",pair,sum([len(col.mp_spectra) for col in mp_cols[pair] if col.has_mp_spectra()]))
    print("MP collations before purge:",len([col for col in mp_cols[pair] if col.has_mp_spectra()]))

    for col in tqdm_notebook([col for col in mp_cols[pair] if col.has_spectra()],
                             desc='Looping through {} collations with unique structures'.format(pair)):
        # If only one spectrum is present, no risk of redundacy by symmetry
        if len(col.mp_spectra)==1:
            continue
        # Load in the current spectrum as a dictionary
        cur_specs = [XAS_Spectrum.load_from_object(spec) for spec in  col.mp_spectra]
        cur_strucs = [spec.structure for spec in cur_specs]
        absorbing_indices = [spec.absorbing_site for spec in cur_specs]        

        assert are_all_same_structures(cur_strucs)
        cur_struc = cur_strucs[0]
        
        sg = SpacegroupAnalyzer(cur_struc).get_space_group_operations()
        
        unique_sites = []
        for idx1 in absorbing_indices:
            unique = True
            site1 = cur_struc.sites[idx1]
            for idx2 in unique_sites:
                site2 = cur_struc.sites[idx2]
                if sg.are_symmetrically_equivalent([site1],[site2],symm_prec = 0.015):
                    unqiue = False
                    break
            if unique:
                unique_sites.append(idx1)
        specs_to_keep = [absorbing_indices.index(site) for site in unique_sites]
        col.mp_spectra = [cur_specs[i] for i in specs_to_keep]

    print("Spectra after purge for pair:",pair,sum([len(col.mp_spectra) for col in mp_cols[pair] if col.has_mp_spectra()]))
    print("Collations after purge for pair:",pair,len([col for col in mp_cols[pair] if col.has_mp_spectra()]))
    print('=====================')

In [None]:
for pair in target_elements_groups:
    target_file = "{}_{}_MP_API+Spec+Symm_collations.json".format(pair[0],pair[1])
    write_path = os.path.join(storage_directory,target_file)
    with open(write_path,'w') as f:
        for col in mp_cols[pair]:
            f.write(json.dumps(col.as_dict())+'\n')

Sanity check : Eyeball spectral ranges

In [None]:
for pair in target_elements_groups:
    mins = []
    abs_elts = []
    for col in [col for col in mp_cols[pair] if col.has_mp_spectra()]:
        for spec in col.mp_spectra:
            spec = XAS_Spectrum.load_from_object(spec)
            struc = spec.structure
            abs_elts.append(spec.absorbing_element)
            #print(set(abs_elts))
            if spec.absorbing_element!=pair[0]:
                print(spec.metadata)
            mins.append(np.min(spec.x))
    plt.hist(mins)
    
    plt.title(pair)
    plt.show()#print(min(np.array(spec['x'])))
    print(set(abs_elts))