In [1]:
from collections import namedtuple
from ase.units import create_units
from collections import namedtuple
import numpy as np
import pyspark.sql.functions as sf
from pyspark.sql import Row
import itertools
from ast import literal_eval
from pyspark.sql.types import StructField, FloatType

In [3]:
def trim_prefix(metadata_path):
    if metadata_path is None:
        return None
    if metadata_path.startswith("/vdev/colabfit-data/MD"):
        return metadata_path.replace("/vdev/colabfit-data", "data")
    elif metadata_path.startswith("/vdev/colabfit-data/data"):
        return metadata_path.replace("/vdev/colabfit-data/data", "data")
    else:
        return metadata_path


trim_prefix("/vdev/colabfit-data/MD/1989/MD_1206814276940281681221989.json")

'MD/1989/MD_1206814276940281681221989.json'

In [2]:
trim_prefix("/vdev/colabfit-data/data/MD/9668/MD_8737377704434833080489668.json")

'data/MD/9668/MD_8737377704434833080489668.json'

In [None]:
def trim_prefix(metadata_path):
    if metadata_path is None:
        return None
    if metadata_path.startswith("/vdev/colabfit-data/MD"):
        return metadata_path.replace("/vdev/colabfit-data", "data")
    elif metadata_path.startswith("/vdev/colabfit-data/data"):
        return metadata_path.replace("/vdev/colabfit-data/data", "data")

    return metadata_path

In [None]:
UNITS = create_units("2014")

UNITS["angstrom"] = UNITS["Ang"]
UNITS["bohr"] = UNITS["Bohr"]
UNITS["hartree"] = UNITS["Hartree"]
UNITS["rydberg"] = UNITS["Rydberg"]
UNITS["debye"] = UNITS["Debye"]
UNITS["kbar"] = UNITS["bar"] * 1000


prop_info = namedtuple("prop_info", ["unit", "dtype"])
energy_info = prop_info(["eV"], float)
force_info = prop_info(["eV/angstrom", "eV/angstrom^3"], list)
stress_info = prop_info(["eV/angstrom^3"], list)
MAIN_KEY_MAP = {
    "potential_energy": energy_info,
    "atomic_forces_00": force_info,
    "cauchy_stress": stress_info,
    "atomization_energy": energy_info,
    "formation_energy": energy_info,
    "band_gap": energy_info,
    "free_energy": energy_info,
}


def standardize_energy(row: Row):
    """
    For each key in :attr:`self.property_map`, convert :attr:`self.edn[key]`
    from its original units to the expected ColabFit-compliant units.
    """
    rowdict = row.asDict()
    for prop_name, val in rowdict.items():
        if prop_name not in MAIN_KEY_MAP.keys():
            continue
        if val is None:
            continue
        p_info = MAIN_KEY_MAP[prop_name]
        unit_col = f"{prop_name}_unit"
        if prop_name[-2:] == "00":
            unit_col = f"{prop_name[:-3]}_unit"
        units = rowdict[unit_col]
        if p_info.dtype == list:
            val = literal_eval(val)
            prop_val = np.array(val, dtype=np.float64)
        else:
            prop_val = val
        ref_en_col = f"{prop_name}_reference"
        if ref_en_col in rowdict and rowdict[ref_en_col] is not None:
            if rowdict[f"{ref_en_col}_unit"] != units:
                raise RuntimeError(
                    "Units of the reference energy and energy must be the same"
                )
            else:
                prop_val += rowdict[ref_en_col]
        per_atom_col = f"{prop_name}_per_atom"
        if per_atom_col in rowdict:
            if rowdict[per_atom_col] is True:
                if rowdict["nsites"] is None:
                    raise RuntimeError("nsites must be provided to convert per-atom")
                prop_val *= rowdict["nsites"]
        if units not in p_info.unit:
            split_units = list(
                itertools.chain.from_iterable(
                    [
                        sp.split("^")
                        for sp in itertools.chain.from_iterable(
                            [sp.split("/") for sp in units.split("*")]
                        )
                    ]
                )
            )
            prop_val *= float(UNITS[split_units[0]])
            for u in split_units[1:]:
                if units[units.find(u) - 1] == "*":
                    prop_val *= UNITS[u]
                elif units[units.find(u) - 1] == "/":
                    prop_val /= UNITS[u]
                elif units[units.find(u) - 1] == "^":
                    try:
                        prop_val = np.power(prop_val, int(u))
                    except Exception:
                        raise RuntimeError(
                            f"There may be something wrong with the units: {u}"
                        )
                else:
                    raise RuntimeError(
                        f"There may be something wrong with the units: {u}"
                    )
        if p_info.dtype == list:
            prop_val = prop_val.tolist()
        rowdict[prop_name] = prop_val
        rowdict[unit_col] = p_info.unit[0]
    return Row(**rowdict)

In [None]:
from pyspark.sql.types import DoubleType, ArrayType
import numpy as np
import itertools
from ast import literal_eval
import pyspark.sql.functions as sf


@sf.udf(returnType=DoubleType())
def standardize_energy_udf(en_col, unit_col):
    val = en_col
    if val is None:
        return None
    units = unit_col
    prop_val = val
    if units != "eV":
        split_units = list(
            itertools.chain.from_iterable(
                [
                    sp.split("^")
                    for sp in itertools.chain.from_iterable(
                        [sp.split("/") for sp in units.split("*")]
                    )
                ]
            )
        )
        prop_val *= float(UNITS[split_units[0]])
        for u in split_units[1:]:
            if units[units.find(u) - 1] == "*":
                prop_val *= UNITS[u]
            elif units[units.find(u) - 1] == "/":
                prop_val /= UNITS[u]
            elif units[units.find(u) - 1] == "^":
                try:
                    prop_val = np.power(prop_val, int(u))
                except Exception:
                    raise RuntimeError(
                        f"There may be something wrong with the units: {u}"
                    )
            else:
                raise RuntimeError(f"There may be something wrong with the units: {u}")
    return prop_val


@sf.udf(returnType=ArrayType(ArrayType(DoubleType())))
def standardize_array_udf(af_col, unit_col, unit):
    val = af_col
    if val is None or val == "[]":
        return "[]"
    units = unit_col
    prop_val = val
    val = literal_eval(val)
    prop_val = np.array(val, dtype=np.float64)
    if units not in unit:
        split_units = list(
            itertools.chain.from_iterable(
                [
                    sp.split("^")
                    for sp in itertools.chain.from_iterable(
                        [sp.split("/") for sp in units.split("*")]
                    )
                ]
            )
        )
        prop_val *= float(UNITS[split_units[0]])
        for u in split_units[1:]:
            if units[units.find(u) - 1] == "*":
                prop_val *= UNITS[u]
            elif units[units.find(u) - 1] == "/":
                prop_val /= UNITS[u]
            elif units[units.find(u) - 1] == "^":
                try:
                    prop_val = np.power(prop_val, int(u))
                except Exception:
                    raise RuntimeError(
                        f"There may be something wrong with the units: {u}"
                    )
            else:
                raise RuntimeError(f"There may be something wrong with the units: {u}")
    prop_val = prop_val.tolist()
    print(prop_val)
    return prop_val

In [None]:
import pyspark.sql.functions as sf

pos = spark.table("ndb.colabfit.dev.pos_convert_v3")

po_forces = (
    pos.filter(sf.col("atomic_forces_unit") != "eV/angstrom")
    .filter(sf.col("atomic_forces_unit") != "eV/angstrom^3")
    .limit(100)
)

for col in [
    "potential_energy",
    "free_energy",
    "electronic_band_gap",
    "adsorption_energy",
    "atomization_energy",
    "formation_energy",
]:
    po_forces3 = po_forces.withColumn(
        col,
        standardize_energy_udf(sf.col(col), sf.col(f"{col}_unit")),
    )

po_forces3 = po_forces.withColumn(
    "atomic_forces_00",
    standardize_array_udf(
        sf.col("atomic_forces_00"),
        sf.col("atomic_forces_unit"),
        sf.lit(force_info.unit),
    ),
)

po_forces3 = po_forces3.withColumn(
    "cauchy_stress",
    standardize_array_udf(
        sf.col("cauchy_stress"), sf.col("cauchy_stress_unit"), sf.lit(stress_info.unit)
    ),
)

po_forces3.first()

In [None]:
pos = spark.table("ndb.colabfit.dev.pos_convert_v3")

co_nsites = (
    spark.table("ndb.colabfit.dev.co_convert")
    .select("id", "nsites")
    .withColumnRenamed("id", "configuration_id")
).limit(100)

po_sites = pos.join(co_nsites, on="configuration_id", how="inner")

In [None]:
@sf.udf("float")
def convert_energy(val, unit):
    if unit != ref_unit:
        raise RuntimeError("Units of the reference energy and energy must be the same")
    return val


df = spark.table("ndb.colabfit.dev.pos_convert")
unit = "eV/angstrom"
unit_col = "atomic_forces_unit"
col = "atomic_forces_00"
df = df.withColumn(
    col,
    sf.when(
        sf.col(unit_col) != unit,
        convert_energy(
            col,
            unit_col,
        ),
    ).otherwise(sf.col(col)),
)

In [None]:
from pyspark.sql import Row


def alter_schema(row: Row):
    rowdict = row.asDict()
    if rowdict["free_energy"] is not None:
        rowdict["target_energy"] = rowdict["free_energy"]
    elif rowdict["potential_energy"] is not None:
        rowdict["target_energy"] = rowdict["potential_energy"]
    else:
        rowdict["target_energy"] = None
    return Row(**rowdict)

In [None]:
targen_schema = po_forces2.schema
targen_schema.add(StructField("target_energy", FloatType(), True))


po_targen = po_forces2.rdd.map(alter_schema).toDF(schema=targen_schema)
po_targen2 = po_targen.select(
    [
        c
        for c in po_targen.columns
        if (
            "unit" not in c
            and "property_id" not in c
            and "reference" not in c
            and "potential_energy" not in c
            and "free_energy" not in c
        )
    ]
)
po_targen2.first()
po_targen2.join(
    co_nsites, po_targen2.configuration_id == co_nsites.configuration_id, how="inner"
).first()