In [1]:
from pathlib import Path
from atomate2.jdftx.io.JDFTXOutfile import JDFTXOutfile
from atomate2.jdftx.io.JDFTXInfile import JDFTXInfile
from pytest import approx
import pytest
from pymatgen.util.typing import PathLike
from pymatgen.core.units import Ha_to_eV
import os



A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.1.0 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/Users/richb/anaconda3/envs/dev3_atomate2/lib/python3.12/site-packages/ipykernel_launcher.py", line 18, in <module>
    app.launch_new_instance()
  File "/Users/richb/anaconda3/envs/dev3_atomate2/lib/python3.12/site-packages/traitlets/config/application.py", line 1075, in launch_instance
    app.start()
  File "/Users/richb/anaconda3/envs/dev3_atomate2/lib/python3.12/site-packages/ipykernel/kernelapp.py", line 7

In [63]:
import os
from functools import wraps
import math
from ase import Atom, Atoms
from atomate2.jdftx.io.JMinSettings import JMinSettings, JMinSettingsElectronic, JMinSettingsFluid, JMinSettingsIonic, JMinSettingsLattice
import numpy as np
from dataclasses import dataclass, field
import scipy.constants as const
from atomate2.jdftx.io.data import atom_valence_electrons
from atomate2.jdftx.io.JOutStructures import JOutStructures
from pymatgen.core import Structure
from pymatgen.core.trajectory import Trajectory
from typing import List, Optional
from pymatgen.core.units import Ha_to_eV, ang_to_bohr, bohr_to_ang


class ClassPrintFormatter():
    def __str__(self) -> str:
        '''generic means of printing class to command line in readable format'''
        return str(self.__class__) + '\n' + '\n'.join((str(item) + ' = ' + str(self.__dict__[item]) for item in sorted(self.__dict__)))


def get_start_lines(text: list[str], start_key: Optional[str]="*************** JDFTx", add_end: Optional[bool]=False) -> list[int]:
    '''
    Get the line numbers corresponding to the beginning of seperate JDFTx calculations
    (in case of multiple calculations appending the same out file)

    Args:
        text: output of read_file for out file
    '''
    start_lines = []
    for i, line in enumerate(text):
        if start_key in line:
            start_lines.append(i)
    if add_end:
        start_lines.append(i)
    return start_lines

def find_key_first(key_input, tempfile):
    '''
    Finds first instance of key in output file. 

    Parameters
    ----------
    key_input: str
        key string to match
    tempfile: List[str]
        output from readlines() function in read_file method
    '''
    key_input = str(key_input)
    line = None
    for i in range(0,len(tempfile)):
        if key_input in tempfile[i]:
            line = i
            break
    return line


def find_key(key_input, tempfile):
    '''
    Finds last instance of key in output file. 

    Parameters
    ----------
    key_input: str
        key string to match
    tempfile: List[str]
        output from readlines() function in read_file method
    '''
    key_input = str(key_input)
    line = None
    lines = find_all_key(key_input, tempfile)
    if len(lines):
        line = lines[-1]
    # line = None
    # for i in range(0,len(tempfile)):
    #     if key_input in tempfile[i]:
    #         line = i
    return line


def find_first_range_key(key_input: str, tempfile: list[str], startline: int=0, endline: int=-1, skip_pound:bool = False) -> list[int]:
    '''
    Find all lines that exactly begin with key_input in a range of lines

    Parameters
    ----------
    key_input: str
        key string to match
    tempfile: List[str]
        output from readlines() function in read_file method
    startline: int
        line to start searching from
    endline: int
        line to stop searching at
    skip_pound: bool
        whether to skip lines that begin with a pound sign

    Returns
    -------
    L: list[int]
        list of line numbers where key_input occurs
    
    '''
    key_input = str(key_input)
    startlen = len(key_input)
    L = []

    if endline == -1:
        endline = len(tempfile)
    for i in range(startline,endline):
        line = tempfile[i]
        if skip_pound == True:
            for j in range(10):  #repeat to make sure no really weird formatting
                line = line.lstrip()
                line = line.lstrip('#')
        line = line[0:startlen]
        if line == key_input:
            L.append(i)
    if not L:
        L = [len(tempfile)]
    return L

def key_exists(key_input, tempfile):
    line = find_key(key_input, tempfile)
    if line == None:
        return False
    else:
        return True

def find_all_key(key_input, tempfile, startline = 0):
    # Ben: I don't think this is deprecated by find_first_range_key, since this function
    # doesn't require the key to be at the beginning of the line
    #DEPRECATED: NEED TO REMOVE INSTANCES OF THIS FUNCTION AND SWITCH WITH find_first_range_key
    #finds all lines where key occurs in in lines
    L = []     #default
    key_input = str(key_input)
    for i in range(startline,len(tempfile)):
        if key_input in tempfile[i]:
            L.append(i)
    return L

@dataclass
class JDFTXOutfileSlice(ClassPrintFormatter):
    '''
    A class to read and process a JDFTx out file

    Attributes:
        see JDFTx documentation for tag info and typing
    '''

    prefix: str = None

    jstrucs: JOutStructures = None
    jsettings_fluid: JMinSettingsFluid = None
    jsettings_electronic: JMinSettingsElectronic = None
    jsettings_lattice: JMinSettingsLattice = None
    jsettings_ionic: JMinSettingsIonic = None

    xc_func: str = None

    lattice_initial: list[list[float]] = None
    lattice_final: list[list[float]] = None
    lattice: list[list[float]] = None
    a: float = None
    b: float = None
    c: float = None

    fftgrid: list[int] = None
    geom_opt: bool = None
    geom_opt_type: str = None

    # grouping fields related to electronic parameters.
    # Used by the get_electronic_output() method
    _electronic_output = [ 
    "EFermi", "Egap", "Emin", "Emax", "HOMO",
    "LUMO", "HOMO_filling", "LUMO_filling", "is_metal"
    ]
    EFermi: float = None
    Egap: float = None
    Emin: float = None
    Emax: float = None
    HOMO: float = None
    LUMO: float = None
    HOMO_filling: float = None
    LUMO_filling: float = None
    is_metal: bool = None

    broadening_type: str = None
    broadening: float = None
    kgrid: list = None
    truncation_type: str = None
    truncation_radius: float = None
    pwcut: float = None
    rhocut: float = None

    pp_type: str = None
    total_electrons: float = None
    semicore_electrons: int = None
    valence_electrons: float = None
    total_electrons_uncharged: int = None
    semicore_electrons_uncharged: int = None
    valence_electrons_uncharged: int = None
    Nbands: int = None

    atom_elements: list = None
    atom_elements_int: list = None
    atom_types: list = None
    spintype: str = None
    Nspin: int = None
    Nat: int = None
    atom_coords_initial: list[list[float]] = None
    atom_coords_final: list[list[float]] = None
    atom_coords: list[list[float]] = None

    has_solvation: bool = False
    fluid: str = None

    # #@ Cooper added @#
    # Ecomponents: dict = field(default_factory=dict)
    # is_gc: bool = False # is it a grand canonical calculation
    # trajectory_positions: list[list[list[float]]] = None
    # trajectory_lattice: list[list[list[float]]] = None
    # trajectory_forces: list[list[list[float]]] = None
    # trajectory_ecomponents: list[dict] = None
    # # is_converged: bool = None #TODO implement this

    @property
    def t_s(self) -> float:
        '''
        Returns the total time in seconds for the calculation

        Returns:
        -------
        t_s: float
            The total time in seconds for the calculation
        '''
        t_s = None
        if self.jstrucs:
            t_s = self.jstrucs.t_s
        return t_s
    

    @property
    def is_converged(self) -> bool:
        '''
        Returns True if the electronic and geometric optimization have converged
        (or only the former if a single-point calculation)
        '''
        converged = self.jstrucs.elec_converged
        if self.geom_opt:
            converged = converged and self.jstrucs.geom_converged
        return converged


    @property
    def trajectory(self) -> Trajectory:
        '''
        Returns a pymatgen trajectory object
        '''
        constant_lattice = self.jsettings_lattice.nIterations == 0
        traj = Trajectory.from_structures(
            structures=self.jstrucs,
            constant_lattice=constant_lattice
        )
        return traj

    
    @property
    def electronic_output(self) -> dict:
        '''
        Return a dictionary with all relevant electronic information.
        Returns values corresponding to these keys in _electronic_output
        field.
        '''
        dct = {}
        for field in self.__dataclass_fields__:
            if field in self._electronic_output:
                value = getattr(self, field)
                dct[field] = value
        return dct
    
    
    @property
    def structure(self) -> Structure:
        structure = self.jstrucs[-1]
        return structure
    
    
    @classmethod
    def from_out_slice(cls, text: list[str]):
        '''
        Read slice of out file into a JDFTXOutfileSlice instance

        Parameters:
        ----------
        text: list[str]
            file to read
        '''
        instance = cls()

        instance._set_min_settings(text)
        instance._set_geomopt_vars(text)
        instance._set_jstrucs(text)
        instance.prefix = instance._get_prefix(text)
        spintype, Nspin = instance._get_spinvars(text)
        instance.xc_func =instance._get_xc_func(text)
        instance.spintype = spintype
        instance.Nspin = Nspin
        broadening_type, broadening = instance._get_broadeningvars(text)
        instance.broadening_type = broadening_type
        instance.broadening = broadening
        instance.kgrid = instance._get_kgrid(text)
        truncation_type, truncation_radius = instance._get_truncationvars(text)
        instance.truncation_type = truncation_type
        instance.truncation_radius = truncation_radius
        instance.pwcut = instance._get_pw_cutoff(text)
        instance.rhocut = instance._get_rho_cutoff(text)
        instance.fftgrid = instance._get_fftgrid(text)
        instance._set_eigvars(text)
        instance._set_orb_fillings()
        instance.is_metal = instance._determine_is_metal()
        instance._set_fluid(text)
        instance._set_total_electrons(text)
        instance._set_Nbands(text)
        instance._set_atom_vars(text)
        instance._set_pseudo_vars(text)
        instance._set_lattice_vars(text)
        instance.has_solvation = instance.check_solvation()

        

        #@ Cooper added @#
        instance.is_gc = key_exists('target-mu', text)
        instance._set_ecomponents(text)
        # instance._build_trajectory(templines)

        return instance
    

    def _get_xc_func(self, text: list[str]) -> str:
        '''
        Get the exchange-correlation functional used in the calculation
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file
            
        Returns
        -------
        xc_func: str
            exchange-correlation functional used
        '''
        line = find_key('elec-ex-corr', text)
        xc_func = text[line].strip().split()[-1].strip()
        return xc_func
        
    

    def _get_prefix(self, text: list[str]) -> str:
        '''
        Get output prefix from the out file

        Parameters
        ----------
            text: list[str]
                output of read_file for out file

        Returns
        -------
            prefix: str
                prefix of dump files for JDFTx calculation
        '''
        prefix = None
        line = find_key('dump-name', text)
        dumpname = text[line].split()[1]
        if "." in dumpname:
            prefix = dumpname.split('.')[0]
        return prefix
    
    def _get_spinvars(self, text: list[str]) -> tuple[str, int]:
        '''
        Set spintype and Nspin from out file text for instance

        Parameters
        ----------
        text: list[str]
            output of read_file for out file

        Returns
        -------
        spintype: str
            type of spin in calculation
        Nspin: int
            number of spin types in calculation
        '''
        line = find_key('spintype ', text)
        spintype = text[line].split()[1]
        if spintype == 'no-spin':
            spintype = None
            Nspin = 1
        elif spintype == 'z-spin':
            Nspin = 2
        else:
            raise NotImplementedError('have not considered this spin yet')
        return spintype, Nspin
    
    def _get_broadeningvars(self, text:list[str]) -> tuple[str, float]:
        '''
        Get broadening type and value from out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file

        Returns
        -------
        broadening_type: str
            type of electronic smearing
        broadening: float
            parameter for electronic smearing
        '''
        line = find_key('elec-smearing ', text)
        if not line is None:
            broadening_type = text[line].split()[1]
            broadening = float(text[line].split()[2])
        else:
            broadening_type = None
            broadening = 0
        return broadening_type, broadening
    
    def _get_truncationvars(self, text:list[str]) -> tuple[str, float]:
        '''
        Get truncation type and value from out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file

        Returns
        -------
        truncation_type: str
            type of coulomb truncation
        truncation_radius: float | None
            radius of truncation (if truncation_type is spherical)
        '''
        maptypes = {'Periodic': None, 'Slab': 'slab', 'Cylindrical': 'wire', 'Wire': 'wire',
                    'Spherical': 'spherical', 'Isolated': 'box'}
        line = find_key('coulomb-interaction', text)
        truncation_type = None
        truncation_radius = None
        if not line is None:
            truncation_type = text[line].split()[1]
            truncation_type = maptypes[truncation_type]
            direc = None
            if len(text[line].split()) == 3:
                direc = text[line].split()[2]
            if truncation_type == 'slab' and direc != '001':
                raise ValueError('BGW slab Coulomb truncation must be along z!')
            if truncation_type == 'wire' and direc != '001':
                raise ValueError('BGW wire Coulomb truncation must be periodic in z!')
            if truncation_type == 'error':
                raise ValueError('Problem with this truncation!')
            if truncation_type == 'spherical':
                line = find_key('Initialized spherical truncation of radius', text)
                truncation_radius = float(text[line].split()[5]) / ang_to_bohr
        return truncation_type, truncation_radius
    
    
    def _get_pw_cutoff(self, text:list[str]) -> float:
        '''
        Get the electron cutoff from the out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        line = find_key('elec-cutoff ', text)
        pwcut = float(text[line].split()[1]) * Ha_to_eV
        return pwcut
    
    
    def _get_rho_cutoff(self, text:list[str]) -> float:
        '''
        Get the electron cutoff from the out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        line = find_key('elec-cutoff ', text)
        lsplit = text[line].split()
        if len(lsplit) == 3:
            rhocut = float(lsplit[2]) * Ha_to_eV
        else:
            pwcut = self.pwcut
            if self.pwcut is None:
                pwcut = self._get_pw_cutoff(text)
            rhocut = float(pwcut * 4)
        return rhocut
    

    def _get_fftgrid(self, text:list[str]) -> list[int]:
        '''
        Get the FFT grid from the out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        line = find_key_first('Chosen fftbox size', text)
        fftgrid = [int(x) for x in text[line].split()[6:9]]
        return fftgrid
    

    def _get_kgrid(self, text:list[str]) -> list[int]:
        '''
        Get the kpoint grid from the out file text

        Parameters
        ----------
            text: list[str]
                output of read_file for out file
        '''
        line = find_key('kpoint-folding ', text)
        kgrid = [int(x) for x in text[line].split()[1:4]]
        return kgrid
    
    
    def _get_eigstats_varsdict(self, text:list[str], prefix:str | None) -> dict[str, float]:
        '''
        Get the eigenvalue statistics from the out file text
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        prefix: str
            prefix for the eigStats section in the out file
        
        Returns
        -------
        varsdict: dict[str, float]
            dictionary of eigenvalue statistics
        '''
        varsdict = {}
        _prefix = ""
        if not prefix is None:
            _prefix = f"{prefix}."
        line = find_key(f'Dumping \'{_prefix}eigStats\' ...', text)
        if line is None:
            raise ValueError('Must run DFT job with "dump End EigStats" to get summary gap information!')
        varsdict["Emin"] = float(text[line+1].split()[1]) * Ha_to_eV
        varsdict["HOMO"] = float(text[line+2].split()[1]) * Ha_to_eV
        varsdict["EFermi"] = float(text[line+3].split()[2]) * Ha_to_eV
        varsdict["LUMO"] = float(text[line+4].split()[1]) * Ha_to_eV
        varsdict["Emax"] = float(text[line+5].split()[1]) * Ha_to_eV
        varsdict["Egap"] = float(text[line+6].split()[2]) * Ha_to_eV
        return varsdict
    
    
    def _set_eigvars(self, text:list[str]) -> None:
        '''
        Set the eigenvalue statistics variables
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        eigstats = self._get_eigstats_varsdict(text, self.prefix)
        self.Emin = eigstats["Emin"]
        self.HOMO = eigstats["HOMO"]
        self.EFermi = eigstats["EFermi"]
        self.LUMO = eigstats["LUMO"]
        self.Emax = eigstats["Emax"]
        self.Egap = eigstats["Egap"]
    

    def _get_pp_type(self, text:list[str]) -> str:
        '''
        Get the pseudopotential type used in calculation

        Parameters
        ----------
        text: list[str]
            output of read_file for out file

        Returns
        ----------
        pptype: str
            Pseudopotential library used
        '''
        skey = "Reading pseudopotential file"
        line = find_key(skey, text)
        ppfile_example = text[line].split(skey)[1].split(":")[0].strip("'")
        pptype = None
        readable = ["GBRV", "SG15"]
        for _pptype in readable:
            if _pptype in ppfile_example:
                if not pptype is None:
                    if ppfile_example.index(pptype) < ppfile_example.index(_pptype):
                        pptype = _pptype
                    else:
                        pass
                else:
                    pptype = _pptype
        if pptype is None:
            raise ValueError(f"Could not determine pseudopotential type from file name {ppfile_example}")
        return pptype
    
    
    def _set_pseudo_vars(self, text:list[str]) -> None:
        '''
        Set the pseudopotential variables   

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        self.pp_type = self._get_pp_type(text)
        if self.pp_type == "SG15":
            self._set_pseudo_vars_SG15(text)
        elif self.pp_type == "GBRV":
            self._set_pseudo_vars_GBRV(text)
    
    
    def _set_pseudo_vars_SG15(self, text:list[str]) -> None:
        '''
        Set the pseudopotential variables for SG15 pseudopotentials

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        startline = find_key('---------- Setting up pseudopotentials ----------', text)
        endline = find_first_range_key('Initialized ', text, startline = startline)[0]
        lines = find_all_key('valence electrons', text)
        lines = [x for x in lines if x < endline and x > startline]
        atom_total_elec = [int(float(text[x].split()[0])) for x in lines]
        total_elec_dict = dict(zip(self.atom_types, atom_total_elec))
        element_total_electrons = np.array([total_elec_dict[x] for x in self.atom_elements])
        element_valence_electrons = np.array([atom_valence_electrons[x] for x in self.atom_elements])
        element_semicore_electrons = element_total_electrons - element_valence_electrons
        self.total_electrons_uncharged = np.sum(element_total_electrons)
        self.valence_electrons_uncharged = np.sum(element_valence_electrons)
        self.semicore_electrons_uncharged = np.sum(element_semicore_electrons)
        self.semicore_electrons = self.semicore_electrons_uncharged
        self.valence_electrons = self.total_electrons - self.semicore_electrons  #accounts for if system is charged


    def _set_pseudo_vars_GBRV(self, text:list[str]) -> None:
        ''' TODO: implement this method
        '''
        self.total_electrons_uncharged = None
        self.valence_electrons_uncharged = None
        self.semicore_electrons_uncharged = None
        self.semicore_electrons = None
        self.valence_electrons = None


    def _collect_settings_lines(self, text:list[str], start_flag:str) -> list[int]:
        '''
        Collect the lines of settings from the out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        start_flag: str
            key to start collecting settings lines

        Returns
        -------
        lines: list[int]
            list of line numbers where settings occur
        '''
        started = False
        lines = []
        for i, line in enumerate(text):
            if started:
                if line.strip().split()[-1].strip() == "\\":
                    lines.append(i)
                else:
                    started = False
            elif start_flag in line:
                started = True
                #lines.append(i) # we DONT want to do this
            elif len(lines):
                break
        return lines
    

    def _create_settings_dict(self, text:list[str], start_flag:str) -> dict:
        '''
        Create a dictionary of settings from the out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        start_flag: str
            key to start collecting settings lines

        Returns
        -------
        settings_dict: dict
            dictionary of settings
        '''
        lines = self._collect_settings_lines(text, start_flag)
        settings_dict = {}
        for line in lines:
            line_text_list = text[line].strip().split()
            key = line_text_list[0]
            value = line_text_list[1]
            settings_dict[key] = value
        return settings_dict
    
    
    def _get_settings_object(self, text:list[str], settings_class: JMinSettings) -> JMinSettings:
        '''
        Get the settings object from the out file text
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        settings_class: JMinSettings
            settings class to create object from
            
        Returns
        -------
        settings_obj: JMinSettings
            settings object
        '''
        settings_dict = self._create_settings_dict(text, settings_class.start_flag)
        if len(settings_dict):
            settings_obj = settings_class(**settings_dict)
        else:
            settings_obj = None
        return settings_obj
    

    def _set_min_settings(self, text:list[str]) -> None:
        '''
        Set the settings objects from the out file text

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        self.jsettings_fluid = self._get_settings_object(text, JMinSettingsFluid)
        self.jsettings_electronic = self._get_settings_object(text, JMinSettingsElectronic)
        self.jsettings_lattice = self._get_settings_object(text, JMinSettingsLattice)
        self.jsettings_ionic = self._get_settings_object(text, JMinSettingsIonic)
    

    def _set_geomopt_vars(self, text:list[str]) -> None:
        ''' 
        Set vars geom_opt and geom_opt_type for initializing self.jstrucs

        Parameters
        ----------
            text: list[str]
                output of read_file for out file
        '''
        if self.jsettings_ionic is None or self.jsettings_lattice is None:
            self._set_min_settings(text)
        #
        if self.jsettings_ionic is None or self.jsettings_lattice is None:
            raise ValueError("Unknown issue in setting settings objects")
        else:
            if self.jsettings_lattice.nIterations > 0:
                self.geom_opt = True
                self.geom_opt_type = "lattice"
            elif self.jsettings_ionic.nIterations > 0:
                self.geom_opt = True
                self.geom_opt_type = "ionic"
            else:
                self.geom_opt = False
                self.geom_opt_type = "single point"


    def _set_jstrucs(self, text:list[str]) -> None:
        '''
        Set the JStructures object from the out file text

        Parameters
        ----------
            text: list[str]
                output of read_file for out file
        '''
        self.jstrucs = JOutStructures.from_out_slice(text, iter_type=self.geom_opt_type)


    def _set_orb_fillings(self) -> None:
        '''
        Calculate and set HOMO and LUMO fillings
        '''
        if self.broadening_type is not None:
            self.HOMO_filling = (2 / self.Nspin) * self.calculate_filling(self.broadening_type, self.broadening, self.HOMO, self.EFermi)
            self.LUMO_filling = (2 / self.Nspin) * self.calculate_filling(self.broadening_type, self.broadening, self.LUMO, self.EFermi)
        else:
            self.HOMO_filling = (2 / self.Nspin)
            self.LUMO_filling = 0


    def _set_fluid(self, text: list[str]) -> None: # Is this redundant to the fluid settings?
        '''
        Set the fluid class variable
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        line = find_first_range_key('fluid ', text)
        self.fluid = text[line[0]].split()[1]
        if self.fluid == 'None':
            self.fluid = None


    def _set_total_electrons(self, text:str) -> None:
        '''
        Set the total_Electrons class variable

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        total_electrons = self.jstrucs[-1].elecMinData[-1].nElectrons
        self.total_electrons = total_electrons
        # lines = find_all_key('nElectrons', text)
        # if len(lines) > 1:
        #     idx = 4
        # else:
        #     idx = 1  #nElectrons was not printed in scf iterations then
        # self.total_electrons = float(text[lines[-1]].split()[idx])
    

    def _set_Nbands(self, text: list[str]) -> None:
        '''
        Set the Nbands class variable

        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        lines = find_all_key('elec-n-bands', text)
        line = lines[0]
        nbands = int(text[line].strip().split()[-1].strip())
        self.Nbands = nbands
    

    def _set_atom_vars(self, text: list[str]) -> None:
        '''
        Set the atom variables
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file'''
        startline = find_key('Input parsed successfully', text)
        endline = find_key('---------- Initializing the Grid ----------', text)
        lines = find_first_range_key('ion ', text, startline = startline, endline = endline)
        atom_elements = [text[x].split()[1] for x in lines]
        self.Nat = len(atom_elements)
        atom_coords = [text[x].split()[2:5] for x in lines]
        self.atom_coords_initial = np.array(atom_coords, dtype = float)
        atom_types = []
        for x in atom_elements:
            if not x in atom_types:
                atom_types.append(x)
        self.atom_elements = atom_elements
        mapping_dict = dict(zip(atom_types, range(1, len(atom_types) + 1)))
        self.atom_elements_int = [mapping_dict[x] for x in self.atom_elements]
        self.atom_types = atom_types
        line = find_key('# Ionic positions in', text) + 1
        coords = np.array([text[i].split()[2:5] for i in range(line, line + self.Nat)], dtype = float)
        self.atom_coords_final = coords
        self.atom_coords = self.atom_coords_final.copy()
    

    def _set_lattice_vars(self, text: list[str]) -> None:
        '''
        Set the lattice variables
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        self.lattice_initial = self.jstrucs[0].lattice.matrix
        self.lattice_final = self.jstrucs[-1].lattice.matrix
        self.lattice = self.lattice_final.copy()
        # This block was throwing errors
        # lines = find_all_key('R =', text)
        # line = lines[0]
        # lattice_initial = np.array([x.split()[1:4] for x in text[(line + 1):(line + 4)]], dtype = float).T / ang_to_bohr
        # self.lattice_initial = lattice_initial.copy()
        # templines = find_all_key('LatticeMinimize', text)
        # if len(templines) > 0:
        #     line = templines[-1]
        #     lattice_final = np.array([x.split()[1:4] for x in text[(line + 1):(line + 4)]], dtype = float).T / ang_to_bohr
        #     self.lattice_final = lattice_final.copy()
        #     self.lattice = lattice_final.copy()
        # else:
        #     self.lattice = lattice_initial.copy()
        self.a, self.b, self.c = np.sum(self.lattice**2, axis = 1)**0.5


    def _set_ecomponents(self, text: list[str]) -> None:
        '''
        Set the energy components dictionary
        
        Parameters
        ----------
        text: list[str]
            output of read_file for out file
        '''
        line = find_key("# Energy components:", text)
        self.Ecomponents = self._read_ecomponents(line, text)


    def calculate_filling(self, broadening_type, broadening, eig, EFermi):
        #most broadening implementations do not have the denominator factor of 2, but JDFTx does currently
        #   remove if use this for other code outfile reading
        x = (eig - EFermi) / (2.0 * broadening)
        if broadening_type == 'Fermi':
            filling = 0.5 * (1 - np.tanh(x))
        elif broadening_type == 'Gauss':
            filling = 0.5 * (1 - math.erf(x))
        elif broadening_type == 'MP1':
            filling = 0.5 * (1 - math.erf(x)) - x * np.exp(-1 * x**2) / (2 * np.pi**0.5)
        elif broadening_type == 'Cold':
            filling = 0.5* (1 - math.erf(x + 0.5**0.5)) + np.exp(-1 * (x + 0.5**0.5)**2) / (2 * np.pi)**0.5
        else:
            raise NotImplementedError('Have not added other broadening types')

        return filling
    

    def _determine_is_metal(self) -> bool:
        '''
        Determine if the system is a metal based on the fillings of HOMO and LUMO

        Returns
        --------
        is_metal: bool
            True if system is metallic
        '''
        TOL_PARTIAL = 0.01
        is_metal = True
        if self.HOMO_filling / (2 / self.Nspin) > (1 - TOL_PARTIAL) and self.LUMO_filling / (2 / self.Nspin) < TOL_PARTIAL:
            is_metal = False
        return is_metal
    

    def check_solvation(self) -> bool:
        '''
        Check if calculation used implicit solvation
        
        Returns
        --------
        has_solvation: bool
            True if calculation used implicit solvation
        '''
        has_solvation = self.fluid is not None
        return has_solvation
    

    def write():
        #don't need a write method since will never do that
        return NotImplementedError('There is no need to write a JDFTx out file')
    

    def _build_trajectory(self, text):
        '''
        Builds the trajectory lists and sets the instance attributes.
        
        '''
        # Needs to handle LatticeMinimize and IonicMinimize steps in one run
        # can do this by checking if lattice vectors block is present and if
        # so adding it to the lists. If it isn't present, copy the last 
        # lattice from the list.
        # initialize lattice list with starting lattice and remove it
        # from the list after iterating through all the optimization steps
        trajectory_positions = []
        trajectory_lattice = [self.lattice_initial]
        trajectory_forces = []
        trajectory_ecomponents = []

        ion_lines = find_first_range_key('# Ionic positions in', text)
        force_lines = find_first_range_key('# Forces in', text)
        ecomp_lines = find_first_range_key('# Energy components:', text)
        # print(ion_lines, force_lines, ecomp_lines)
        for iline, ion_line, force_line, ecomp_line in enumerate(zip(ion_lines, force_lines, ecomp_lines)):
            coords = np.array([text[i].split()[2:5] for i in range(ion_line + 1, ion_line + self.Nat + 1)], dtype = float)
            forces = np.array([text[i].split()[2:5] for i in range(force_line + 1, force_line + self.Nat + 1)], dtype = float)
            ecomp = self._read_ecomponents(ecomp_line, text)
            lattice_lines = find_first_range_key('# Lattice vectors:', text, startline=ion_line, endline=ion_lines[iline-1])
            if len(lattice_lines) == 0: # if no lattice lines found, append last lattice
                trajectory_lattice.append(trajectory_lattice[-1])
            else:
                line = lattice_lines[0]
                trajectory_lattice.append(np.array([x.split()[1:4] for x in text[(line + 1):(line + 4)]], dtype = float).T / ang_to_bohr)
            trajectory_positions.append(coords)
            trajectory_forces.append(forces)
            trajectory_ecomponents.append(ecomp)
        trajectory_lattice = trajectory_lattice[1:] # remove starting lattice

        self.trajectory_positions = trajectory_positions
        self.trajectory_lattice = trajectory_lattice
        self.trajectory_forces = trajectory_forces
        self.trajectory_ecomponents = trajectory_ecomponents
    

    def _read_ecomponents(self, line:int, text:str) -> dict:
        '''
        Read the energy components from the out file text
        
        Parameters
        ----------
        line: int
            line number where energy components are found
        text: list[str]
            output of read_file for out file
        
        Returns
        -------
        Ecomponents: dict
            dictionary of energy components
        '''
        Ecomponents = {}
        if self.is_gc == True:
            final_E_type = "G"
        else:
            final_E_type = "F"
        for tmp_line in text[line+1:]:
            chars = tmp_line.strip().split()
            if tmp_line.startswith("--"):
                continue
            E_type = chars[0]
            Energy = float(chars[-1]) * Ha_to_eV
            Ecomponents.update({E_type:Energy})
            if E_type == final_E_type:
                return Ecomponents
        

    

    def to_dict(self) -> dict:
        # convert dataclass to dictionary representation
        dct = {}
        for field in self.__dataclass_fields__:
            value = getattr(self, field)
            dct[field] = value
        return dct


In [66]:
import math
import os
from dataclasses import dataclass
from functools import wraps
from dataclasses import dataclass
from typing import List, Optional


class ClassPrintFormatter():

    def __str__(self) -> str:
        """Generic means of printing class to command line in readable format"""
        return (
            str(self.__class__)
            + "\n"
            + "\n".join(
                str(item) + " = " + str(self.__dict__[item])
                for item in sorted(self.__dict__)
            )
        )


def check_file_exists(func):
    """Check if file exists (and continue normally) or raise an exception if it does not"""

    @wraps(func)
    def wrapper(filename):
        if not os.path.isfile(filename):
            raise OSError("'" + filename + "' file doesn't exist!")
        return func(filename)

    return wrapper


@check_file_exists
def read_file(file_name: str) -> list[str]:
    '''
    Read file into a list of str

    Parameters
    ----------
    filename: Path or str
        name of file to read

    Returns
    -------
    text: list[str]
        list of strings from file
    '''
    with open(file_name, 'r') as f:
        text = f.readlines()
    return text


def get_start_lines(text: list[str], start_key: Optional[str]="*************** JDFTx", add_end: Optional[bool]=False) -> list[int]:
    '''
    Get the line numbers corresponding to the beginning of seperate JDFTx calculations
    (in case of multiple calculations appending the same out file)

    Args:
        text: output of read_file for out file
    '''
    start_lines = []
    for i, line in enumerate(text):
        if start_key in line:
            start_lines.append(i)
    if add_end:
        start_lines.append(i)
    return start_lines

def read_outfile_slices(file_name: str) -> list[list[str]]:
    '''
    Read slice of out file into a list of str

    Parameters
    ----------
    filename: Path or str
        name of file to read
    out_slice_idx: int
        index of slice to read from file

    Returns
    -------
    texts: list[list[str]]
        list of out file slices (individual calls of JDFTx)
    '''
    _text = read_file(file_name)
    start_lines = get_start_lines(_text, add_end=True)
    texts = []
    for i in range(len(start_lines)-1):
        text = _text[start_lines[i]:start_lines[i+1]]
        texts.append(text)
    return texts


@dataclass
class JDFTXOutfile(List[JDFTXOutfileSlice], ClassPrintFormatter):
    '''
    A class to read and process a JDFTx out file
    '''

    @classmethod
    def from_file(cls, file_path: str):
        texts = read_outfile_slices(file_path)
        instance = cls()
        for text in texts:
            instance.append(JDFTXOutfileSlice.from_out_slice(text))
        return instance

    def __getattr__(self, name):
        if len(self):
            return getattr(self[-1], name)
        else:
            try:
                return super().__getattr__(name)
            except AttributeError:
                if self:
                    return getattr(self[-1], name)
                raise AttributeError(f"'JDFTXOutfile' object has no attribute '{name}'")

    def __setattr__(self, name, value):
        # Do we want this? I don't imagine this class object should be modified
        if name in self.__annotations__:
            super().__setattr__(name, value)
        elif self:
            setattr(self[-1], name, value)
        else:
            raise AttributeError(f"'JDFTXOutfile' object has no attribute '{name}'")


In [70]:
test = JDFTXOutfile.from_file(Path(os.getcwd()) / "example_files" /  "example_sp.out")
print(test[-1].jstrucs.iter_type)
print(test[-1].t_s)
    

single point
None


In [61]:
from atomate2.jdftx.io.JDFTXInfile_master_format import *

In [62]:
tag_ex = "fluid-anion"

instance = JDFTXInfile()
lines = ["fluid-anion Cl- 0.5",
         "lattice \\",
         "1.0 0.0 0.0 \\",
         "1.0 0.0 0.0 \\",
         "1.0 0.0 0.0",
         "latt-move-scale 0 0 0",
         "ion Cl 0.0 0.0 0.0 1",
         "ion-species GBRV_v1.5/$ID_pbe_v1.uspp"]
# lines = instance._gather_tags(lines)

test = JDFTXInfile.from_str("\n".join(lines))
