In [None]:
import os
from os.path import exists, join
import glob
import shutil

import matplotlib
import matplotlib.pyplot as plt
import numpy as np

from rhocalc.aims import aims_calc, aims_fields, aims_parser
from rholearn import utils

import chemiscope
from dft_settings import *

chemiscope.show(STRUCTURE, mode="structure")

In [None]:
from rholearn import io
import pandas as pd
from settings import ALL_IDXS, ALL_STRUCTURES

# Parse some calculation metrics from aims output files into a dataframe
metric_names = [
    "curr_experiment",
    "category",
    # "geometry_from",
    "size",
    "z_depth_max",
    "num_atoms",
    "supercell",
    # "num_abfs",
    # "df_error_percent",
    # "ri_coeffs_size_mb",
    # "ri_ovlp_size_mb",
    # "lsoap_size_mb",
    # "linalg_loss",
    # "can_overfit",
    # "MAE(ML v RI) %",
]
metrics = []
for A, frame in zip(ALL_IDXS, ALL_STRUCTURES):
    struct_row =  [
        "*" if A in STRUCTURE_ID else "",
        frame.info["category"],
        # frame.info["geometry_from"],
        frame.info["size"],
        np.abs(np.min(frame.positions[:, 2]) - np.max(frame.positions[:, 2])),
        frame.get_global_number_of_atoms(),
        frame.info.get("supercell"),
    ]
    # if A in STRUCTURE_ID:
    #     calc_info = io.unpickle_dict(join(PROCESSED_DIR(A), "calc_info.pickle"))
    #     struct_row += [
    # #         calc_info["num_abfs"],
    #         calc_info["df_error_percent"]["total"],
    # #         os.path.getsize(join(PROCESSED_DIR(A), "ri_coeffs.npz")) * 1e-6,
    # #         os.path.getsize(join(PROCESSED_DIR(A), "ri_ovlp.npz")) * 1e-6,
    # #         # os.path.getsize(join(PROCESSED_DIR(A), "lsoap.npz")) * 1e-6,
    # #         # linalg_losses[A].detach().numpy(),
    # #         # "NO!" if linalg_losses[A].detach().numpy() > 1e-10 else "",
    #     ]

    #     if A in [0, 1, 2, 3]:
    #         struct_row += [
    #             np.load(
    #                 f"/home/abbott/march-24/si_dimers_extrapolate/ml/ildos+1V/{A}/evaluation/epoch_eval/{A}/mae_percent.npy"
    #             )
    #         ]
    #     elif A in [16, 17, 18, 19]:
    #         struct_row += [
    #             np.load(
    #                 f"/home/abbott/march-24/si_dimers_extrapolate/ml/ildos+1V/{A - 16}/evaluation/epoch_eval/{A}/mae_percent.npy"
    #             )
    #         ]
    # else:
    #     struct_row += [""] * 5
    metrics.append(struct_row)

metrics = pd.DataFrame(metrics, columns=metric_names, index=ALL_IDXS)
# display(metrics)
display(metrics.iloc[STRUCTURE_ID])

## 1. SCF

In [None]:
write_geom = True  # False to use optimized geometry instead

# And the general settings for all calcs
aims_kwargs = BASE_AIMS_KWARGS.copy()
aims_kwargs.update(SCF_KWARGS)

# Define paths to the aims.out files for RI calcs
all_aims_outs = [join(SCF_DIR(A), "aims.out") for A in STRUCTURE_ID]

In [None]:
# calcs = {
#     A: {"atoms": structure, "run_dir": SCF_DIR(A)}
#     for A, structure in zip(STRUCTURE_ID, STRUCTURE)
# }
calcs = {A: {"atoms": structure} for A, structure in zip(STRUCTURE_ID, STRUCTURE)}

for aims_out in all_aims_outs:
    if exists(aims_out):
        shutil.copy(aims_out, aims_out + ".previous")
        os.remove(aims_out)
for A in STRUCTURE_ID:
    if exists(join(SCF_DIR(A), "geometry.in.next_step")):
        shutil.copy(
            join(SCF_DIR(A), "geometry.in"),
            join(SCF_DIR(A), "geometry.in.previous"),
        )
        shutil.copy(
            join(SCF_DIR(A), "geometry.in.next_step"),
            join(SCF_DIR(A), "geometry.in"),
        )

    # Define cube settings for Hartree Potential
    com = ALL_STRUCTURES[A].get_center_of_mass()
    calcs[A]["aims_kwargs"] = {
        "cubes": (
            f"cube origin {com[0]} {com[1]} {com[2]} \n"
            f"cube edge 100 0.1 0.0 0.0 \n"
            f"cube edge 100 0.0 0.1 0.0 \n"
            f"cube edge 100 0.0 0.0 0.1 \n"
        )
    }

# Run the SCF in AIMS
aims_calc.run_aims_array(
    calcs=calcs,
    aims_path=AIMS_PATH,
    aims_kwargs=aims_kwargs,
    sbatch_kwargs=SBATCH_KWARGS,
    run_dir=SCF_DIR,
    load_modules=HPC_KWARGS["load_modules"],
    export_vars=HPC_KWARGS["export_vars"],
    run_command="srun",
    write_geom=write_geom,
)

In [None]:
# Wait until all AIMS calcs have finished
all_finished = False
while len(all_aims_outs) > 0:
    for aims_out in all_aims_outs:
        if exists(aims_out):
            with open(aims_out, "r") as f:
                # Basic check to see if AIMS calc has finished
                if "Leaving FHI-aims." in f.read():
                    all_aims_outs.remove(aims_out)

In [None]:
converged = []
for A, frame in zip(STRUCTURE_ID, STRUCTURE):
    # Parse the calculation info
    calc_info = aims_parser.parse_aims_out(SCF_DIR(A))
    converged.append(calc_info["scf"]["converged"])

    # Get the Fermi energy as the VBM
    kso_info = aims_parser.get_ks_orbital_info(join(SCF_DIR(A), "ks_orbital_info.out"))
    homo_idx = aims_fields.get_homo_kso_idx(kso_info)
    fermi_vbm = kso_info[homo_idx - 1]["energy_eV"]  # 1-indexing

    # Calculate the Fermi energy by integration
    fermi_integrated = aims_fields.calculate_fermi_energy(
        kso_info_path=join(SCF_DIR(A), "ks_orbital_info.out"),
        n_electrons=frame.get_atomic_numbers().sum(),
        gaussian_width=LDOS["gaussian_width"],
        interpolation_truncation=0.5,
    )
    print(f"Fermi energy for {A}: ChemPot: {calc_info['fermi_eV']}, VBM: {fermi_vbm}, integrated: {fermi_integrated}")
    calc_info["vbm_eV"] = fermi_vbm
    calc_info["fermi_integrated_eV"] = fermi_integrated
    utils.pickle_dict(join(SCF_DIR(A), "calc_info.pickle"), calc_info)
    
assert all(converged)

In [None]:
# Plot DOS with different alignments
gaussian_width = 0.3
e_grid = np.linspace(-15, 5, 1000)

fig, axes = plt.subplots(3, 1, figsize=(10, 5), sharex=True, sharey=True)
for A, frame in zip(STRUCTURE_ID, STRUCTURE):
    calc_info = utils.unpickle_dict(join(SCF_DIR(A), "calc_info.pickle"))
    kso_info_path = f"{DATA_DIR}/{A}/ks_orbital_info.out"
    _, dos = aims_fields.calculate_dos(
        kso_info_path, gaussian_width=gaussian_width, e_grid=e_grid
    )

    for ax, target_energy in zip(axes, ["fermi_eV", "vbm_eV", "fermi_integrated_eV"]):
        ax.plot(
            e_grid - calc_info[target_energy],
            dos / frame.get_global_number_of_atoms(),
            c="green" if A == 6 else "gray",
        )
        ax.set_xlim(-10, 10)
        ax.set_ylabel("Total DOS")
ax.set_xlabel("Energy (eV)")

## 2. RI

In [None]:
# Define paths to the aims.out files for RI calcs
all_aims_outs = [join(RI_DIR(A), "aims.out") for A in STRUCTURE_ID]
# for aims_out in all_aims_outs:
#     if exists(aims_out):
#         shutil.copy(aims_out, aims_out + ".copy." + utils.timestamp())
#         os.remove(aims_out)

In [None]:
calcs = {}
for A, frame in zip(STRUCTURE_ID, STRUCTURE):
    if not exists(RI_DIR(A)):  # make RI dir
        os.makedirs(RI_DIR(A))
    calcs[A] = {"atoms": frame}

    # Get SCF calculation info and path to KS-orbital info
    calc_info = utils.unpickle_dict(join(SCF_DIR(A), "calc_info.pickle"))
    kso_info_path = join(SCF_DIR(A), "ks_orbital_info.out")

    if FIELD_NAME == "ildos":  # define KSO weights and write to file

        # Save LDOS settings
        ldos_kwargs = {k: v for k, v in LDOS.items()}
        ldos_kwargs["target_energy"] = calc_info[ldos_kwargs["target_energy"]]
        utils.pickle_dict(join(RI_DIR(A), "ldos_settings.pkl"), ldos_kwargs)
        print(f"Structure {A}, target_energy: {ldos_kwargs['target_energy']}")

        # Write KS-orbital weight vector
        kso_weights = aims_fields.get_kso_weight_vector_for_named_field(
            field_name=FIELD_NAME, kso_info_path=kso_info_path, **ldos_kwargs
        )
        np.savetxt(join(RI_DIR(A), "ks_orbital_weights.in"), kso_weights)

    elif FIELD_NAME == "edensity":
        assert RI.get("ri_fit_total_density") is not None

    # Specify tailored cube edges
    if RI.get("output") == ["cube ri_fit"] and CUBE["slab"] is True:
        calcs[A]["aims_kwargs"] = aims_calc.get_aims_cube_edges_slab(
            frame, CUBE.get("n_points")
        )

    # Copy density matrix restart
    for density_matrix in glob.glob(join(SCF_DIR(A), "D*.csc")):
        shutil.copy(density_matrix, RI_DIR(A))


# And the general settings for all calcs
aims_kwargs = BASE_AIMS.copy()
aims_kwargs.update(RI)

# Run the RI fitting procedure in AIMS
aims_calc.run_aims_array(
    calcs=calcs,
    aims_path=AIMS_PATH,
    aims_kwargs=aims_kwargs,
    sbatch_kwargs=SBATCH,
    run_dir=RI_DIR,
    load_modules=HPC["load_modules"],
    export_vars=HPC["export_vars"],
    run_command="srun",
)

In [None]:
# Wait until all AIMS calcs have finished
all_finished = False
while len(all_aims_outs) > 0:
    for aims_out in all_aims_outs:
        if exists(aims_out):
            with open(aims_out, "r") as f:
                # Basic check to see if AIMS calc has finished
                if "Leaving FHI-aims." in f.read():
                    all_aims_outs.remove(aims_out)

# Remove the density matrix restart files
for A in STRUCTURE_ID:
    for density_matrix in glob.glob(join(RI_DIR(A), "D*.csc")):
        os.remove(density_matrix)

## 3. Process -> metatensor

In [None]:
aims_calc.process_aims_results_sbatch_array(
    "run-process-aims.sh",
    structure_idxs=STRUCTURE_ID,
    run_dir=RI_DIR,
    process_what=["coeffs", "ovlp"],
    **SBATCH_KWARGS,
)

In [None]:
# Remove ri_ovlp.out from the RI dir if now processed into a TensorMap
for A in STRUCTURE_ID:
    if exists(join(PROCESSED_DIR(A), "ri_ovlp.npz")):
        try:
            os.remove(join(RI_DIR(A), "ri_ovlp.out"))
        except FileNotFoundError:
            print(f"ri_ovlp.out already removed for structure {A}")
    else:
        print(f"Structure {A} not yet processed")

In [None]:
maes = []
for A in STRUCTURE_ID:
    mae = io.unpickle_dict(join(PROCESSED_DIR(A), 'calc_info.pickle'))['df_error_percent']['total']
    print(f"A = {A}, %MAE = {mae}")
    maes.append(mae)

print(f"\nAverage %MAE = {np.mean(maes)}")

## 4. Rebuild

In [None]:
# Try a rebuild of the density from the target RI coefficients
import shutil
from settings import CUBE_KWARGS, LDOS_KWARGS, REBUILD_KWARGS, RI_KWARGS

# And the general settings for all calcs
aims_kwargs = BASE_AIMS_KWARGS.copy()
aims_kwargs.update(REBUILD_KWARGS)

# Define paths to the aims.out files for RI calcs
all_aims_outs = [join(REBUILD_DIR(A), "aims.out") for A in STRUCTURE_ID]
# for aims_out in all_aims_outs:
#     if exists(aims_out):
#         os.remove(aims_out)

In [None]:
# calcs = {
#     A: {"atoms": structure, "run_dir": SCF_DIR(A)}
#     for A, structure in zip(STRUCTURE_ID, STRUCTURE)
# }
calcs = {A: {"atoms": structure} for A, structure in zip(STRUCTURE_ID, STRUCTURE)}


for A, frame in zip(STRUCTURE_ID, STRUCTURE):
    if not exists(REBUILD_DIR(A)):
        os.makedirs(REBUILD_DIR(A))

    # Copy coefficients from RI dir to rebuild dir
    shutil.copy(
        join(RI_DIR(A), "ri_coeffs.out"),
        join(REBUILD_DIR(A), "ri_coeffs.in"),
    )

    # Specify tailored cube edges
    if RI_KWARGS.get("output") == ["cube ri_fit"]:
        if frame.info["category"] == "bulk":
            print("Bulk structure, not assigning cube slab edges")
        else:
            if CUBE_KWARGS.get("slab") is True:
                calcs[A]["aims_kwargs"] = aims_calc.get_aims_cube_edges_slab(
                    frame, CUBE_KWARGS.get("n_points")
                )
            else:
                calcs[A]["aims_kwargs"] = aims_calc.get_aims_cube_edges(
                    frame, CUBE_KWARGS.get("n_points")
                )

# Run the RI fitting procedure in AIMS
aims_calc.run_aims_array(
    calcs=calcs,
    aims_path=AIMS_PATH,
    aims_kwargs=aims_kwargs,
    sbatch_kwargs=SBATCH_KWARGS,
    run_dir=REBUILD_DIR,
    load_modules=HPC_KWARGS["load_modules"],
    export_vars=HPC_KWARGS["export_vars"],
    run_command="srun",
)

In [None]:
# Wait until all AIMS calcs have finished
all_finished = False
while len(all_aims_outs) > 0:
    for aims_out in all_aims_outs:
        if exists(aims_out):
            with open(aims_out, "r") as f:
                # Basic check to see if AIMS calc has finished
                if "Leaving FHI-aims." in f.read():
                    all_aims_outs.remove(aims_out)