In [None]:
# ------ CRYSTAL STRUCTURE ANALYSIS ------

# Crystal Structure Analysis
#   1. import crystal structure
#   2. relax with BFGS algorithm 
#   3. calculate relative energy (with a reference structure)
#   4  calculate lattice params and avg bond lengths
#   5. create JSON raw data files#

from pathlib import Path
from ase.io import read, write
from ase.optimize import BFGS
from ase.filters import UnitCellFilter
from graph_pes.models import load_model
import json

# Imports structures from in_dir and creates dict structures = {file_path: structure}
def import_crystal_structures(in_dir):
    
    in_dir = Path(in_dir)
    structures = {}

    imported_structures_counter = 0

    for file in in_dir.rglob('*'):

        structure = read(file)
        structures[file] = structure
        imported_structures_counter += 1
    
    if imported_structures_counter:
        print(f"Imported {imported_structures_counter} structures from {in_dir}\n")

    return structures

# Calculate reference energy with a given model (including relaxation)
def reference_energy(path_to_reference_structure, path_to_model, fmax, steps, OVERWRITE):
    
    # Create directories
    reference_structure_dir = Path(path_to_reference_structure).parent
    relaxed_dir = reference_structure_dir / "Relaxed_Reference_Structures"
    relaxed_dir.mkdir(parents=True, exist_ok=True)

    traj_dir = relaxed_dir / "Trajectories"
    traj_dir.mkdir(parents=True, exist_ok=True)

    final_traj_dir = relaxed_dir/ "Final Trajectory Frame"
    final_traj_dir.mkdir(parents=True, exist_ok=True)

    traj_path = traj_dir / f"{Path(path_to_model).name}_{Path(path_to_reference_structure).stem}.traj"
    final_traj_out_path = final_traj_dir / f"{Path(path_to_model).stem}_relaxed_{Path(path_to_reference_structure).stem}.cif"
        
    # Load model and calculator
    model = load_model(Path(path_to_model))
    reference_calc = model.ase_calculator()

    # Load reference structure
    ref_struct = read(Path(path_to_reference_structure))
    ref_struct.calc = reference_calc
    
    # Relax structure
    if not traj_path.exists() or OVERWRITE:

        print(f"Relaxing reference structure: {traj_path.name}")
        
        # Relax reference structure with BFGS (allowing for cell params to change)
        ucf_ref = UnitCellFilter(ref_struct)
        opt_ref = BFGS(ucf_ref,
                logfile=None,             
                trajectory=traj_path)
        opt_ref.run(fmax=fmax, steps=steps)
    
    print(f"Reading relaxed reference structure: {traj_path.name}")

    # Calculate energy of relaxed reference structure
    relaxed_structure = read(traj_path, index=-1)
    write(final_traj_out_path, relaxed_structure)

    reference_calc.calculate(relaxed_structure, properties = ["energy"])

    ref_energy = reference_calc.results.get("energy", None)
    ref_energy_per_atom = ref_energy / len(ref_struct)

    return ref_energy_per_atom

# Import structures and return lattice params, forces, relative energy
def energetics_calculator(in_dir, relaxed_dir, fmax, steps,
                          path_to_models, energetics_data_dir,
                          path_to_reference_structure, OVERWRITE):

    # Import structures
    structure_dict = import_crystal_structures(in_dir)

    existing_traj_files_counter = 0

    # Import model
    for path_to_model in path_to_models:

        # Calculate reference energy
        ref_energy_per_atom = reference_energy(path_to_reference_structure, path_to_model,
                                                fmax, steps, OVERWRITE)

        model = load_model(Path(path_to_model))
        calculator = model.ase_calculator()

        # Relax imported structures
        # Write trajectories and final frame 
        counter = 0

        # Loop over all structures
        for file_path, structure in structure_dict.items():

            traj_dir = Path(relaxed_dir) / "Trajectories"
            traj_dir.mkdir(parents=True, exist_ok=True)

            final_traj_dir = Path(relaxed_dir)/ "Final Trajectory Frame"
            final_traj_dir.mkdir(parents=True, exist_ok=True)

            traj_out_path = traj_dir / f"{Path(path_to_model).stem}_relaxed_{file_path.stem}.traj"
            final_traj_out_path = final_traj_dir / f"{Path(path_to_model).stem}_relaxed_{file_path.stem}.cif"

            structure.calc = calculator

            if traj_out_path.exists() and not OVERWRITE:
                existing_traj_files_counter +=1
                continue
            
            # Relax structure with BFGS (allowing for cell params to change)
            ucf = UnitCellFilter(structure)
            opt = BFGS(ucf,
                    logfile=None,             
                    trajectory=traj_out_path)
            opt.run(fmax=0.02, steps=200)

            # Write final relaxed structure
            final_structure = read(traj_out_path, index=-1)
            write(final_traj_out_path, final_structure)

            # Calculate energy and forces of final structure
            calculator.calculate(final_structure, properties=["energy", "forces"])
            
            raw_energy = calculator.results.get("energy", None)
            raw_energy_per_atom = raw_energy / len(final_structure)

            forces = calculator.results.get("forces", None)

            # Relative energy per atom
            relative_energy_per_atom = raw_energy_per_atom - ref_energy_per_atom
            
            # Lattice params and bond lengths
            a, b, c, alpha, beta, gamma = final_structure.cell.cellpar()

            # Save path and file name
            data_out_dir = Path(energetics_data_dir) / f"{Path(path_to_model).name}"
            data_out_dir.mkdir(parents = True, exist_ok=True)
            data_out_path = data_out_dir/ f"{file_path.stem}"

            # Print results
            data = {
                "lattice_parameters": {
                    "a": a,
                    "b": b,
                    "c": c,
                    "alpha": alpha,
                    "beta": beta,
                    "gamma": gamma
                },
                "reference_energy/atom": ref_energy_per_atom ,
                "raw_energy/atom": raw_energy_per_atom,
                "relative_energy/atom": relative_energy_per_atom,
                "forces": forces.tolist() if hasattr(forces, "tolist") else forces
            }

            with open(data_out_path, "w", encoding="utf-8") as f:
                f.write(json.dumps(data, indent=2) + "\n")
            
            counter += 1

        if counter:
            print(f"\nRelaxed and analyzed {counter} crystal structures with {Path(path_to_model).name}\n")

    if existing_traj_files_counter:
        print(f"\nSkipped energetics analysis for {existing_traj_files_counter} analyzed files")

models_to_analyse = ["MACE_Models/medium-0b3.pt", 
                     "MACE_Models/medium-mpa-0.pt",
                     "MACE_Models/medium-omat-0.pt"]    

set_OVERWRITE = False

crystalline_analysis_dir = Path("Analysis") / "Crystalline Analysis"
crystalline_analysis_dir.mkdir(parents=True, exist_ok=True)

set_energetics_data_dir = crystalline_analysis_dir / "Raw Data"
set_energetics_data_dir.mkdir(exist_ok=True)

formation_energy = energetics_calculator(
                    in_dir="Carbon_Structures/Crystalline/Downloaded",
                    relaxed_dir="Carbon_Structures/Crystalline/Relaxed",
                    fmax = 0.02,
                    steps = 200,
                    energetics_data_dir=set_energetics_data_dir,
                    path_to_models= models_to_analyse,
                    path_to_reference_structure="Carbon_Structures/Graphite_mp169.cif",
                    OVERWRITE=set_OVERWRITE
                    )


In [None]:
# ------  GRAPHICAL ANALYSIS OF RAW DATA RESULTS ------
