In [33]:
from __future__ import annotations
from collections import defaultdict
from monty.serialization import loadfn

from typing import Literal, TYPE_CHECKING
if TYPE_CHECKING:
    from pathlib import Path
    from typing import Sequence

In [16]:
def get_geometry_for_single_dataset(
    dataset : Literal["BEGDB_H2O","WATER27","H2O_alkali_clusters","H2O_halide_clusters"],
    geometries_filename : str | Path = "geometries.json.gz",
    return_type : Literal["pmg","ase"] = "pmg"
) -> dict:
    """
    Get geometry files for single dataset.

    Includes charge and spin multiplicity info.

    Args:
        dataset : Literal["BEGDB_H2O", "WATER27", "H2O_alkali_clusters", or "H2O_halide_clusters"]
            Name of the dataset
        geometries_filename : str | Path
            Name of the file containing the geometries. Don't recommend changing this.
        return_type : Literal["pmg", "ase"] = "pmg"
            Whether to return a dict of pymatgen.core.molecule objects or ase.Atoms objects
    Returns:
        dict, a dict of molecules with charge and spin info.
    """
    geometries = loadfn(geometries_filename)[dataset]
    if return_type == "ase":
        geometries = {k: v.to_ase_atoms() for k, v in geometries.items()}
    return geometries

In [55]:
def get_total_energies_by_dataset(
    dataset : Literal["BEGDB_H2O","WATER27","H2O_alkali_clusters","H2O_halide_clusters"],
    functional : str | Sequence[str] | None = None,
    energies_filename : str | Path = "total_energies.json.gz",
) -> dict:
    """
        Get total energies for a single dataset.

        Optionally filter by functional or a list of functionals.
        Accepts input such as r2SCAN@HF or SCAN-FLOSIC.

        The energies dict is structured as:
        
        ```
        {
            dataset : {
                functional_for_energy : {
                    functional_for_density : {
                        molecules in that dataset
                    } 
                }
            }
        }
        ```

        Thus requesting dataset=WATER27 and functional = r2SCAN@HF corresponds to
            functional_for_energy = r2SCAN
            functional_for_density = HF

        In the BEGDB dataset, the "*dmono*" entries correspond to the distorted monomers contained within
        each oligomer.

        Args:
            dataset : Literal["BEGDB_H2O", "WATER27", "H2O_alkali_clusters", or "H2O_halide_clusters"]
                Name of the dataset
            functional : str | Sequence[str] | None = None,
                If None, returns all entries in a given dataset.
                If a str, returns a single functional.
                If a Sequence (list, tuple, etc.) of str's, returns that subset of functionals
            energies_filename : str | Path
                Name of the file containing the total energies. Don't recommend changing this.
        Returns:
            dict, a dict of energies corresponding to the systems in the geometry file.
    """
    _energies = loadfn(energies_filename)[dataset]

    if functional is None:
        functionals_to_return = []
        for dfa, at_dfa_d in _energies.items():
            functionals_to_return += [
                f"{dfa}" if dfa == at_dfa else f"{dfa}@{at_dfa}"
                for at_dfa in at_dfa_d
            ]
    elif isinstance(functional,str):
        functionals_to_return = [functional]
    else:
        functionals_to_return = [f for f in functional]
    
    energies = defaultdict(dict)
    for f in functionals_to_return:
        func = f.split("@")[0]
        at_f = f.split("@")[-1]
        energies[f] = _energies[func][at_f]

    return dict(energies)


In [62]:
get_total_energies_by_dataset("BEGDB_H2O",functional=["r2SCAN@HF","r2SCAN25"])["r2SCAN25"]

{'total_energies': {'10PP1': -764.3157173934909,
  '10PP2': -764.3156873483101,
  '2Cs': -152.8403721197716,
  '3UUD': -229.27444429578347,
  '3UUU': -229.27322966943854,
  '4Ci': -305.70896950626025,
  '4Py': -305.7041465128648,
  '4S4': -305.710445966687,
  '5CAA': -382.13791052033264,
  '5CAB': -382.1367778009257,
  '5CAC': -382.13823725931405,
  '5CYC': -382.1409738594434,
  '5FRA': -382.1357579848993,
  '5FRB': -382.1386795728674,
  '5FRC': -382.1346494787938,
  '6BAG': -458.57096543021794,
  '6BK1': -458.57261340852466,
  '6BK2': -458.57203992638546,
  '6CA': -458.5730616579142,
  '6CB1': -458.5696289504374,
  '6CB2': -458.56951898255863,
  '6CC': -458.57138784983914,
  '6PR': -458.57323983125724,
  '7BI1': -535.0019531617615,
  '7BI2': -535.001832798326,
  '7CA1': -535.005390910201,
  '7CA2': -535.0040987082122,
  '7CH1': -535.003646810475,
  '7CH2': -535.0026423778413,
  '7CH3': -535.0013308557448,
  '7HM1': -535.0000722941754,
  '7PR1': -535.0081266192365,
  '7PR2': -535.00768

In [60]:
get_geometry_for_single_dataset("BEGDB_H2O").keys()

dict_keys(['5FRC', '5FRB', '4S4', '4Py', '5CYC', '5FRA', '5CAA', '7CA2', '5CAC', '5CAB', '7CA1', '10PP1', '6CB2', 'H2O', '10PP2', '6CB1', '4Ci', '8D2d', '7CH1', '7CH3', '9S4DA', '7CH2', '6BK2', '7HM1', '7PR1', '6CC', '6BK1', '7PR2', '6CA', '7PR3', '2Cs', '6PR', '3UUD', '9D2dDD', '7BI1', '8S4', '6BAG', '7BI2', '3UUU', '9D2dDD_dmono_8', '9D2dDD_dmono_1', '9D2dDD_dmono_6', '9D2dDD_dmono_7', '9D2dDD_dmono_9', '9D2dDD_dmono_5', '9D2dDD_dmono_2', '9D2dDD_dmono_3', '9D2dDD_dmono_4', '7CA2_dmono_1', '7CA2_dmono_6', '7CA2_dmono_7', '7CA2_dmono_5', '7CA2_dmono_2', '7CA2_dmono_3', '7CA2_dmono_4', '6CB1_dmono_1', '6CB1_dmono_6', '6CB1_dmono_5', '6CB1_dmono_2', '6CB1_dmono_3', '6CB1_dmono_4', '8D2d_dmono_8', '8D2d_dmono_1', '8D2d_dmono_6', '8D2d_dmono_7', '8D2d_dmono_5', '8D2d_dmono_2', '8D2d_dmono_3', '8D2d_dmono_4', '2Cs_dmono_1', '2Cs_dmono_2', '6BK1_dmono_1', '6BK1_dmono_6', '6BK1_dmono_5', '6BK1_dmono_2', '6BK1_dmono_3', '6BK1_dmono_4', '7PR2_dmono_1', '7PR2_dmono_6', '7PR2_dmono_7', '7PR2_dmo