In [None]:
%pylab inline

import pickle
import pandas as pd
import seaborn as sns
import pymatgen.core as mg
import plotly.express as px
from matplotlib.colors import LogNorm, Normalize, SymLogNorm
import warnings
import re
import itertools

import dft_analysis_tools as dat

def atom_num_sort(x):
    els=[mg.Element(e).number for e in re.split("\d+",x) if e]
    els.sort()
    return els

In [None]:
import re

def Better_Elem_List(formula,site_specs,subs,quiet=True):
    '''
    site_specs: list of sites specifying shape of list. Ex: ['A','X','X']
    subs: Dictionary of substituted elements at each site. Ex: {'A':['Br','Cr',...],'X':[...]}
    formula: chemical formula
    '''
    e_list=list(site_specs)
    ind_list=[None]*len(site_specs)
    sites=subs.keys()
    sub_sites = {}
    
    for k,v in subs.items():
        for x in v:
            sub_sites.setdefault(x,[]).append(k)
    elem_list = re.findall('[A-Za-z]+', formula)
    elem_occ = [int(i) for i in re.findall('\d+',formula)]
    for i in range(len(elem_list)):
        elem=elem_list[i]
        site_type=sub_sites[elem][0]
        occ=elem_occ[i]
        for n in range(occ):
            try:
                ind=e_list.index(site_type)
                e_list[ind]=elem
                ind_list[ind]=i
            except:
                if not quiet:
                    print(f'excess atom {elem} found in {formula}, repeat maybe?')
    return e_list,ind_list

site_specs=['A','A','B','B','X']

subs={'A':['Ti','V','Cr','Mn','Fe','Co','Ni','Cu','Y','Nb','Ru','Ta','Mo','Tc','Sc','Zr','Hf',
          'Ag','Au','Ir','Os','Pd','Pt','Rh','W','Re','Zn','Hg','Cd'],
      'B':['Ge','Si','P','In','Pb','As','Sn','Sb','Al'],
      'X':['Te','Se','S',"Po"]}

In [None]:
with open('post_data-06-06-2023.pkl', 'rb') as file:
    jobs = pickle.load(file)

for i in range(len(jobs)):
    #print(i,jobs[i]['workflow'],jobs[i]['name'])
    if jobs[i]['state']=="JOB_FINISHED":
        print(f"{jobs[i]['workflow']}/{jobs[i]['name']}")

only_ground=[j['name'].split("_")[0] for j in jobs if "mae_2020" in j["workflow"]]
only_ground

In [None]:
def rotate_vec(saxis,vec):
    #rotates a vector in the saxis basis to cartesian coordinates
    sx,sy,sz=saxis
    mx,my,mz=vec #vec components in saxis basis
    a=arctan2(sy,sx) #angle between saxis and x axis
    b=arctan2(sqrt(sx**2+sy**2),sz) #angle between saxis and z axis
    return [cos(b)*cos(a)*mx+cos(b)*sin(a)*my-sin(b)*mz,
           -sin(a)*mz+cos(a)*my,
           sin(b)*cos(a)*mx+sin(b)*sin(a)*my+cos(b)*mz]#not sure if this is correct, might be off by 180 degrees

def unpack_balsam(pickle_file,
                  spin_states=['fm_op', 'afm_op', 'fm_ip', 'afm_ip'],
                  materials=None,
                  spin_orbit_states=['fm_op', 'afm_op', 'fm_ip', 'afm_ip'],
                 saxis={'fm_op': array([0,0,1]), 'afm_op': array([0,0,1]), 'fm_ip': array([1,0,0]), 'afm_ip': array([1,0,0])},
                  read_state="True"):

    with open(pickle_file, 'rb') as file:
        jobs = pickle.load(file)

    if materials==None:
        materials=list(set([j['name'].split("_")[0] for j in jobs]))
    job_data={m: {s: {'job_status': None, 'energy': None, 'total_mag': [None,None,None], 'mag_table': None,'NIONS': None,
                      'lattice_vecs': None, 'pos': None, 'bandgap': None, 'bg_direct': None, 'bg_transition': None} 
                  for s in spin_states} 
              for m in materials}
    for ind,j in enumerate(jobs):
        print(ind)
        mat=j['name'].split("_",maxsplit=1)[0] #material
        s_s=j['name'].split("_",maxsplit=1)[1] #spin state
        if mat not in materials or s_s not in spin_states:
            print(f"I don't recognize {mat} {s_s}")
            continue
        if read_state:
            job_status=j['state']
        else:
            job_status=""
        if job_status.endswith('FINISHED')|(not read_state):
#        if True:
            try:
                NIONS=j['data']['NIONS']
                if s_s in spin_orbit_states:
                    mag_table=j['data']['magnetization'][-3:] #get table of ion magnetizations
                    ion_mags_x=[ion_mag[-1] for ion_mag in mag_table[0][2]]
                    ion_mags_y=[ion_mag[-1] for ion_mag in mag_table[1][2]]
                    ion_mags_z=[ion_mag[-1] for ion_mag in mag_table[2][2]]
                    x_tot=sum(ion_mags_x)
                    y_tot=sum(ion_mags_y)
                    z_tot=sum(ion_mags_z)
                    total_mag=rotate_vec(saxis[s_s],[x_tot,y_tot,z_tot])
                else:
                    mag_table=j['data']['magnetization'][-1][2] #get table of ion magnetizations
                    tots=[ion_mag[-1] for ion_mag in mag_table] #total magnetization of each ion
                    total_mag=list(sum(tots)*saxis[s_s]) #total magnetization vector of material
                es=j['data']['energy']
                if type(es)==list:
                    energy=es[-1]
                    pos=array(j['data']['position_force'][-1])[:,0:3]
                    lattice_vecs=array(j['data']['lattice_vecs'][-1])[:,0:3]
                else:
                    energy=es
                    pos=array(j['data']['position_force'])[:,0:3]
                    lattice_vecs=array(j['data']['lattice_vecs'][-1])[:,0:3]
                if type(j['data']['bandgap'])!=dict:
                    raise KeyError("bandgap")
                bandgap=j['data']['bandgap']['energy']
                bg_direct=j['data']['bandgap']['direct']
                bg_transition=j['data']['bandgap']['transition']

            except KeyError as e:
                print(f"KeyError on {mat}, {s_s}, missing value {e}")
                if e.args[0]!="lattice_vecs":
                    total_mag=[None,None,None]
                    energy=None
                    NIONS=None
                    job_status="FINISHED_INCORRECTLY"
                    pos=None
                    mag_table=None
                    lattice_vecs=None
                    bandgap=None
                    bg_direct=None
                    bg_transition=None
                else:
                    lattice_vecs=None
        else: #if job did not finish fill in with none
            total_mag=[None,None,None]
            energy=None
            NIONS=None
            pos=None
            lattice_vecs=None
            mag_table=None
            bandgap=None
            bg_direct=None
            bg_transition=None
        job_data[mat][s_s]['job_status']=job_status
        job_data[mat][s_s]['energy']=energy
        job_data[mat][s_s]['total_mag']=total_mag
        job_data[mat][s_s]['NIONS']=NIONS
        job_data[mat][s_s]['pos']=pos
        job_data[mat][s_s]['lattice_vecs']=lattice_vecs
        job_data[mat][s_s]['mag_table']=mag_table
        job_data[mat][s_s]['bandgap']=bandgap
        job_data[mat][s_s]['bg_direct']=bg_direct
        job_data[mat][s_s]['bg_transition']=bg_transition
    #write data in format:
    # material NIONS i_stat i_energy i_mag i_pos s_stat s_energy s_mag s_pos s_so_stat etc.
    
    n_dat=11 #number of data per spin state [status,energy,mag_x,mag_y,mag_z,mag_table,positions,lattice_vecs]
    t_len=2+n_dat*len(spin_states)
    job_table=[[None]*t_len for m in materials]
    for i in range(len(materials)):
        mat=materials[i]
        job_table[i][0]=mat
        for j in range(len(spin_states)):
            s_s=spin_states[j]
            job_table[i][(2+j*n_dat):(2+n_dat+j*n_dat)]=[job_data[mat][s_s]['job_status'],job_data[mat][s_s]['energy']]+job_data[mat][s_s]['total_mag']+[job_data[mat][s_s]['mag_table'],job_data[mat][s_s]['pos'],job_data[mat][s_s]['lattice_vecs'],job_data[mat][s_s]['bandgap'],job_data[mat][s_s]['bg_direct'],job_data[mat][s_s]['bg_transition']]
            if job_table[i][1]==None and job_data[mat][s_s]['NIONS']!=None:
                job_table[i][1]=job_data[mat][s_s]['NIONS']
    
    return job_table,job_data

In [None]:
table,data=unpack_balsam('post_data-06-06-2023.pkl')

In [None]:
spin_states=['fm_op', 'afm_op', 'fm_ip', 'afm_ip']

In [None]:
header=['Material','NATOMS']+[s_s+" "+dat for s_s in spin_states 
                                         for dat in ["Status","E","m_x","m_y","m_z","mag_table","pos","lattice","bandgap","bg_direct","bg_transition"]]


dt=pd.DataFrame(data=table,columns=header)

elems_elem_ind=dt['Material'].apply(lambda x: pd.Series(Better_Elem_List(x,site_specs=site_specs,subs=subs),index=['elems','elem_ind']))
dt['elems']=elems_elem_ind['elems']
dt["elems_full"]=dt["elems"].apply(lambda x: x+5*[x[-1]])
dt['elem_ind']=elems_elem_ind['elem_ind']

dt=dt.set_index("Material")
materials=dt.index.unique(level="Material")


def theta(v):
    #find angle between z axis and vector [vx,vy,vz]
    if isnan(sum(v)):
        return nan
    if v==[0,0,0]:
        return 0
    else:
        return arccos(v[2]/sqrt(v[0]**2+v[1]**2+v[2]**2))
    
def phi(v):
    #find azimuthal angle of vector [vx,vy,vz]
    if isnan(sum(v)):
        return nan
    if v==[0,0,0]:
        return 0
    else:
        return arctan2(v[1],v[0])
    
def mag(v):
    if isnan(sum(v)):
        return nan
    else:
        return sqrt(v[0]**2+v[1]**2+v[2]**2)

site_names=["A1","A2","B1","B2","X1","X2","X3","X4","X5","X6"]
for s_s in spin_states:
    print(s_s)
    dt[s_s+" E"]=dt[s_s+" E"].fillna(value=nan)
    for i in ['x','y','z']:
        dt[s_s+" m_"+i]=dt[s_s+" m_"+i].fillna(value=nan)
    dt[s_s+" mag"]=sqrt(dt[s_s+" m_x"]**2+dt[s_s+" m_y"]**2+dt[s_s+" m_z"]**2)
    
    for atom,atom_ind in zip(site_names,range(len(site_names))):
        if "ip" in s_s:
            dt[s_s+f" mu_{atom}"]=dt.apply(lambda x: (nan if ((x[s_s+" mag_table"] is None)) 
                                                      else rotate_vec([1,0,0],[x[s_s+" mag_table"][xyz][2][atom_ind][-1] for xyz in [0,1,2]])),
                                           axis=1)
        else:
            dt[s_s+f" mu_{atom}"]=dt.apply(lambda x: (nan if ((x[s_s+" mag_table"] is None))
                                                      else [x[s_s+" mag_table"][xyz][2][atom_ind][-1] for xyz in [0,1,2]]),
                                           axis=1)
        dt[s_s+f" mag_{atom}"]=dt.apply(lambda x: mag(x[s_s+f" mu_{atom}"]),axis=1)
        dt[s_s+f" theta_{atom}"]=dt.apply(lambda x: theta(x[s_s+f" mu_{atom}"]),axis=1)
        dt[s_s+f" phi_{atom}"]=dt.apply(lambda x: phi(x[s_s+f" mu_{atom}"]),axis=1)
    dt[f"{s_s} mag_avg"]=dt.filter(like=f"{s_s} mag_A").mean(axis=1)


def get_min_prop(x,p):
    min_ss=x["Ground Spin State"]
    if not isinstance(min_ss,str):
        return nan
    else:
        return x[f"{min_ss} {p}"]

considered_ss=["fm_op","afm_op"]


dt["E_min"]=dt[[f"{ss} E" for ss in considered_ss]].min(axis=1,skipna=False)

dt["Ground Spin State"]=dt[[f"{ss} E" for ss in considered_ss]].idxmin(axis=1,skipna=False).apply(lambda x: x.split()[0] if isinstance(x,str) else x)

for mat in only_ground:
    dt.at[mat,"E_min"]=dt.loc[mat][[f"{ss} E" for ss in considered_ss]].min(skipna=True)
    dt.at[mat,"Ground Spin State"]="fm_op" if isnan(dt.loc[mat]["afm_op E"]) else "afm_op"
    
mag_props=["m_x","m_y","m_z","mag_table","pos","lattice","mag","bandgap"]+[
    f"{x}_{atom}"for x,atom in itertools.product(["mu","theta","mag","phi"],site_names)]

for p in mag_props:
    dt[f"Ground {p}"]=dt.apply(lambda x: get_min_prop(x,p),axis=1)

dt["atom_mus"]=dt[['Ground mu_A1', 'Ground mu_A2', 'Ground mu_B1', 'Ground mu_B2',
       'Ground mu_X1', 'Ground mu_X2', 'Ground mu_X3', 'Ground mu_X4',
       'Ground mu_X5', 'Ground mu_X6']].values.tolist()
dt["atom_mags"]=dt[['Ground mag_A1', 'Ground mag_A2', 'Ground mag_B1',
       'Ground mag_B2', 'Ground mag_X1', 'Ground mag_X2', 'Ground mag_X3',
       'Ground mag_X4', 'Ground mag_X5', 'Ground mag_X6']].values.tolist()

dt["E_form"]=dt.apply(lambda x: dat.calc_formation_energy(x.elems_full,x.E_min),axis=1)

In [None]:
dt_nona=dt[~dt["E_min"].isna()]

In [None]:
def poscar_to_atoms(poscar_filename):
    jarvis_atoms=js.atoms.Atoms.from_poscar(poscar_filename)
    return dict(jarvis_atoms.to_dict())

In [None]:
dt_nona["poscar_name"]=[f"contcars/{mat}" for mat in dt_nona.index]
dt_nona["atoms"]=dt_nona.apply(lambda x: poscar_to_atoms(x.poscar_name))

In [None]:
# output as .json format (needed for atomwise data, i.e. train_folder_ff.py):
out_data=dt_nona[["atoms","E_min","Gound mag","atom_mus","atom_mags"]]
out_data=out_data.reset_index()

#output columns must include an unquie identifier as 'jid' and the jarvis atoms as 'atoms', all other columns
#are for various targets
out_data.columns=["jid","atoms","total_energy","total_mag","mus","mags"]

with open("id_prop-CGT-CONTCAR.json",'w') as f:
    f.write(out_data.to_json(orient='records',lines=False))

In [None]:
# output as .csv format (default for graphwise data, i.e. train_folder.py)
out_folder="out_folder" #name of output folder, should be the same folder as the POSCARs
out_data=dt_nona[["poscar_name","E_min","Ground mag"]]
out_data.to_csv(f"{out_folder}/id_prop.csv",index=False,header=False)