In [None]:
import os
import glob
import numpy as np
from ase.io import read
import matplotlib.pyplot as plt

def get_energy_from_outcar(file_path):
    with open(file_path) as OUTCAR:
        outcar = OUTCAR.readlines()
        for line in range(len(outcar)):
            if 'sigma' in outcar[-line]:
                return float(outcar[-line].split()[-1])
    return None

def collect_energies(directory='.', pattern='OUTCAR*', exclude_terms=('idm', 're', 'zpe', 'sp')):
    file_paths = []
    for root, dirs, files in os.walk(directory):
        for file in glob.glob(os.path.join(root, pattern)):
            if all(term not in file for term in exclude_terms):
                file_paths.append(file)
    
    file_paths.sort(key=lambda path: int(path.split('/')[1]))
    
    energies = []
    for file in file_paths:
        energy = get_energy_from_outcar(file)
        if energy is not None:
            energies.append(energy)
            print(f" ### {file}: {energy}")
    
    return energies

def calculate_mwcc(energies, init_index=1):
    mwcc_list = [0]
    
    for i in range(len(energies)-1):
        try:
            file1_path = f'{i+1+init_index:02}/CONTCAR'
            atoms1 = read(file1_path)
        except:
            file1_path = f'{i+1+init_index:02}/POSCAR'
            atoms1 = read(file1_path)
        
        try:
            file2_path = f'{i+init_index:02}/CONTCAR'
            atoms2 = read(file2_path)
        except:
            file2_path = f'{i+init_index:02}/POSCAR'
            atoms2 = read(file2_path)
        
        pos1 = atoms1.get_positions().copy()
        pos2 = atoms2.get_positions().copy()
        cell = atoms1.get_cell().copy()
        
        mwcc = 0
        for j in range(len(pos1)):
            dx2 = (pos1[j][0] - pos2[j][0])**2
            dy2 = (pos1[j][1] - pos2[j][1])**2
            dz2 = (pos1[j][2] - pos2[j][2])**2
            
            for k in range(3):
                for l in [-1, 1]:
                    adj_pos2 = pos2[j] + (l*cell[k])
                    dx2_a = (pos1[j][0] - adj_pos2[0])**2
                    dy2_a = (pos1[j][1] - adj_pos2[1])**2
                    dz2_a = (pos1[j][2] - adj_pos2[2])**2
                    if dx2_a < dx2 or dy2_a < dy2 or dz2_a < dz2:
                        dx2 = dx2_a
                        dy2 = dy2_a
                        dz2 = dz2_a
            
            mwcc += np.sqrt(atoms1[j].mass * (dx2 + dy2 + dz2))
        
        mwcc_list.append(mwcc)
    
    return np.cumsum(mwcc_list)

def plot_reaction_pathway(coordinates_list, energies_list, labels, markers, colors, 
                         image_paths=None, image_positions=None, figsize=(14, 7)):
    fig, ax = plt.subplots(figsize=figsize, dpi=300)
    
    for i, (coords, energies, label, marker, color) in enumerate(zip(coordinates_list, energies_list, labels, markers, colors)):
        rel_energies = [e - energies[0] for e in energies]
        ax.plot(coords, rel_energies, marker=marker, markersize=10, 
                color=color[0], mec=color[1], mfc=color[2], 
                lw=3, label=label)
    
   
    ax.set_xlabel('Mass-weighted coordinate (\u212B√amu)', fontsize=20)
    ax.set_ylabel('Relative energy (eV)', fontsize=20)
    ax.tick_params(axis='x', labelsize=20)
    ax.tick_params(axis='y', labelsize=20)
    ax.legend(fontsize=20)
    
    return fig, ax

def main():
    # work_dir = '../work/dir'
    # os.chdir(work_dir)
    
    use_predefined_data = True
    
    if use_predefined_data:
        dft_energies = [-101.898, -101.739, -101.733, -101.549, -103.627, 
                        -103.646, -103.628, -103.686, -103.765, -103.798, -103.780]
        
        sgpr_energies = [-101.90625964924551, -101.73851605316437, -101.73265493472852, 
                         -101.54821557363309, -103.62763907572068, -103.64616501143016, 
                         -103.62837792744673, -103.68629474004544, -103.76470008304156, 
                         -103.79818572660722, -103.77987138738669]
        
        bcm_energies = dft_energies.copy()  
    else:
        print("Energy in OUTCAR reading...")
        energies = collect_energies()
        
        print(f'\n - Include first/final Energy difference?')
        whether = input(' -- yes/no? :  ')
        
        try:
            if whether.lower() in ['yes', 'y'] or isinstance(float(whether), float):
                initial = float(input(' --- Initial state energy (eV) :  '))
                final = float(input(' --- Final state energy (eV) :  '))
                energies.insert(0, initial)
                energies.append(final)
                init_index = 0
            else:
                init_index = 1
        except ValueError:
            init_index = 1
        
        max_energy = max(energies)
        max_index = energies.index(max_energy)
        
        try:
            min_energy = min(energies[0:max_index]) 
        except:
            min_energy = energies[0]
            
        min_index = energies.index(min_energy)
        E_a = max_energy - min_energy
        print(f"\n Activation Energy: {E_a} eV")
        
        dft_energies = energies
        sgpr_energies = energies
        bcm_energies = energies
    
    print("\n Calculating mass weighted coordinates...")
    dft_coords = calculate_mwcc(dft_energies)
    sgpr_coords = calculate_mwcc(sgpr_energies)
    bcm_coords = calculate_mwcc(bcm_energies)
    
    labels = ['DFT', 'SGPR', 'BCM']
    markers = ['D', 'o', '^']
    colors = [
        ['#f28e2b', 'black', '#f28e2b'],  # DFT: color, mec, mfc
        ['#4e79a7', 'white', '#4e79a7'],  # SGPR
        ['#76b7b2', 'white', '#76b7b2']   # BCM
    ]
    
    image_paths = [
    ]
    
    image_positions = [
        [-0.05, 0.3, 0.4, 0.2, 300, 300],
        [0.33, 0.55, 0.3, 0.3, 150, 150],
        [0.7, 0.15, 0.3, 0.3, 140, 140]
    ]
    
    fig, ax = plot_reaction_pathway(
        [dft_coords, sgpr_coords, bcm_coords],
        [dft_energies, sgpr_energies, bcm_energies],
        labels, markers, colors,
        image_paths, image_positions
    )
    
    plt.tight_layout()
    plt.show()
    
if __name__ == "__main__":
    main()