In [1]:
import datetime
import json
import os
from pathlib import Path
from pprint import pprint

import dateutil.parser
import findspark
import numpy as np
import pyspark
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)

from colabfit.tools.configuration import AtomicConfiguration
from colabfit.tools.property import Property
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)

findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

In [2]:
def convert_stress(keys, stress):
    stresses = {k: s for k, s in zip(keys, stress)}
    return [
        [stresses["xx"], stresses["xy"], stresses["xz"]],
        [stresses["xy"], stresses["yy"], stresses["yz"]],
        [stresses["xz"], stresses["yz"], stresses["zz"]],
    ]


SYMBOL_DICT = {"0": "Si", "1": "O"}


def reader(filepath):
    with open(filepath, "rt") as f:
        energy = None
        forces = None
        coords = []
        cell = []
        symbols = []
        config_count = 0
        for line in f:
            if line.strip().startswith("Size"):
                size = int(f.readline().strip())
            elif line.strip().lower().startswith("supercell"):
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
            elif line.strip().startswith("Energy"):
                energy = float(f.readline().strip())
            elif line.strip().startswith("PlusStress"):
                stress_keys = line.strip().split()[-6:]
                stress = [float(x) for x in f.readline().strip().split()]
                stress = convert_stress(stress_keys, stress)
            elif line.strip().startswith("AtomData:"):
                keys = line.strip().split()[1:]
                if "fx" in keys:
                    forces = []
                for i in range(size):
                    li = {
                        key: val for key, val in zip(keys, f.readline().strip().split())
                    }
                    symbols.append(SYMBOL_DICT[li["type"]])
                    if "cartes_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["cartes_x"],
                                    li["cartes_y"],
                                    li["cartes_z"],
                                ]
                            ]
                        )
                    elif "direct_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["direct_x"],
                                    li["direct_y"],
                                    li["direct_z"],
                                ]
                            ]
                        )

                    if "fx" in keys:
                        forces.append(
                            [float(f) for f in [li["fx"], li["fy"], li["fz"]]]
                        )

            elif line.startswith("END_CFG"):
                if "cartes_x" in keys:
                    config = AtomicConfiguration(
                        positions=coords, symbols=symbols, cell=cell
                    )
                elif "direct_x" in keys:
                    config = AtomicConfiguration(
                        scaled_positions=coords, symbols=symbols, cell=cell
                    )
                config.info["energy"] = energy
                if forces:
                    config.info["forces"] = forces
                config.info["stress"] = stress

                if "Si" in symbols and "O" in symbols:
                    config.info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "11x11x11",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    config.info["names"] = f"{filepath.stem}_SiO2_{config_count}"
                elif "Si" in symbols:
                    config.info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "8x8x8",
                        "kinetic-energy-cutoff": {
                            "val": 884,
                            "units": "eV",
                        },
                    }
                    config.info["names"] = f"{filepath.stem}_Si_{config_count}"
                elif "O" in symbols:
                    config.info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "gamma-point",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    config.info["names"] = f"{filepath.stem}_O_{config_count}"
                config_count += 1
                yield config
                forces = None
                stress = []
                coords = []
                cell = []
                symbols = []
                energy = None

In [3]:
mtpu_configs = reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
data = [x for x in mtpu_configs]
data[0].configuration_summary()

{'nsites': 4,
 'elements': ['Si'],
 'nelements': 1,
 'elements_ratios': [1.0],
 'chemical_formula_anonymous': 'A',
 'chemical_formula_reduced': 'Si',
 'chemical_formula_hill': 'Si4',
 'dimension_types': [0, 0, 0],
 'nperiodic_dimensions': 0}

In [4]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}

24/04/11 14:31:09 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/04/11 14:31:09 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/04/11 14:31:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [5]:
config_schema = StructType(
    [
        StructField("id", StringType(), False),
        StructField("hash", StringType(), False),
        StructField("last_modified", TimestampType(), False),
        StructField("dataset_ids", StringType(), True),  # ArrayType(StringType())
        StructField("chemical_formula_hill", StringType(), True),
        StructField("chemical_formula_reduced", StringType(), True),
        StructField("chemical_formula_anonymous", StringType(), True),
        StructField("elements", StringType(), True),  # ArrayType(StringType())
        StructField("elements_ratios", StringType(), True),  # ArrayType(IntegerType())
        StructField("atomic_numbers", StringType(), True),  # ArrayType(IntegerType())
        StructField("nsites", IntegerType(), True),
        StructField("nelements", IntegerType(), True),
        StructField("nperiodic_dimensions", IntegerType(), True),
        StructField("cell", StringType(), True),  # ArrayType(ArrayType(DoubleType()))
        StructField("dimension_types", StringType(), True),  # ArrayType(IntegerType())
        StructField("pbc", StringType(), True),  # ArrayType(IntegerType())
        StructField(
            "positions", StringType(), True
        ),  # ArrayType(ArrayType(DoubleType()))
        StructField("names", StringType(), True),  # ArrayType(StringType())
    ]
)
property_object_schema = StructType(
    [
        StructField("id", StringType(), False),
        StructField("hash", StringType(), False),
        StructField("last_modified", TimestampType(), False),
        StructField("configuration_ids", StringType(), True),  # ArrayType(StringType())
        StructField("dataset_ids", StringType(), True),  # ArrayType(StringType())
        StructField("metadata", StringType(), True),
        StructField("chemical_formula_hill", StringType(), True),
        StructField("potential_energy", DoubleType(), True),
        StructField("potential_energy_unit", StringType(), True),
        StructField("potential_energy_per_atom", BooleanType(), True),
        StructField("potential_energy_reference", DoubleType(), True),
        StructField("potential_energy_reference_unit", StringType(), True),
        StructField("potential_energy_property_id", StringType(), True),
        StructField(
            "atomic_forces", StringType(), True
        ),  # ArrayType(ArrayType(DoubleType()))
        StructField("atomic_forces_unit", StringType(), True),
        StructField("atomic_forces_property_id", StringType(), True),
        StructField(
            "cauchy_stress", StringType(), True
        ),  # ArrayType(ArrayType(DoubleType()))
        StructField("cauchy_stress_unit", StringType(), True),
        StructField("cauchy_stress_volume_normalized", BooleanType(), True),
        StructField("cauchy_stress_property_id", StringType(), True),
        StructField("free_energy", DoubleType(), True),
        StructField("free_energy_unit", StringType(), True),
        StructField("free_energy_per_atom", BooleanType(), True),
        StructField("free_energy_reference", DoubleType(), True),
        StructField("free_energy_reference_unit", StringType(), True),
        StructField("free_energy_property_id", StringType(), True),
        StructField("band_gap", DoubleType(), True),
        StructField("band_gap_unit", StringType(), True),
        StructField("band_gap_property_id", StringType(), True),
        StructField("formation_energy", DoubleType(), True),
        StructField("formation_energy_unit", StringType(), True),
        StructField("formation_energy_per_atom", BooleanType(), True),
        StructField("formation_energy_reference", DoubleType(), True),
        StructField("formation_energy_reference_unit", StringType(), True),
        StructField("formation_energy_property_id", StringType(), True),
        StructField("adsorption_energy", DoubleType(), True),
        StructField("adsorption_energy_unit", StringType(), True),
        StructField("adsorption_energy_per_atom", BooleanType(), True),
        StructField("adsorption_energy_reference", DoubleType(), True),
        StructField("adsorption_energy_reference_unit", StringType(), True),
        StructField("adsorption_energy_property_id", StringType(), True),
        StructField("atomization_energy", DoubleType(), True),
        StructField("atomization_energy_unit", StringType(), True),
        StructField("atomization_energy_per_atom", BooleanType(), True),
        StructField("atomization_energy_reference", DoubleType(), True),
        StructField("atomization_energy_reference_unit", StringType(), True),
        StructField("atomization_energy_property_id", StringType(), True),
    ]
)

In [6]:
def stringify_lists(row_dict):
    """
    Replace list/tuple fields with comma-separated strings.
    Spark and Vast both support array columns, but the connector does not,
    so keeping cell values in list format crashes the table.
    TODO: Remove when no longer necessary
    """
    for key, val in row_dict.items():
        if (
            isinstance(val, np.ndarray)
            or isinstance["potential-energy"](val, list)
            or isinstance(val, tuple)
            or isinstance(val, dict)
        ):
            row_dict[key] = str(val)
    return row_dict

In [7]:
def _empty_dict_from_schema(schema):
    empty_dict = {}
    for field in schema:
        empty_dict[field.name] = None
    return empty_dict

In [8]:
SHORT_ID_STRING_NAME = "colabfit-id"
ATOMS_NAME_FIELD = "name"

In [9]:
PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}

PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}

In [10]:
prop = Property.from_definition(
    [potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    data[0],
    property_map=PROPERTY_MAP,
)

In [11]:
instance = prop.instance
pprint(instance)
pprint(prop.chemical_formula_hill)
pprint(prop.to_spark_row())

{'atomic-forces': {'forces': {'source-unit': 'eV/angstrom',
                              'source-value': [[-0.000211, -2.8e-05, 0.000336],
                                               [-9.1e-05, 2.8e-05, -0.000302],
                                               [9.1e-05, 2.8e-05, 0.000304],
                                               [0.000211,
                                                -2.8e-05,
                                                -0.000339]]},
                   'instance-id': 1,
                   'property-id': 'tag:staff@noreply.colabfit.org,2022-05-30:property/atomic-forces'},
 'cauchy-stress': {'instance-id': 1,
                   'property-id': 'tag:staff@noreply.colabfit.org,2022-05-30:property/cauchy-stress',
                   'stress': {'source-unit': 'GPa',
                              'source-value': [[-0.02752, 0.0, -0.99821],
                                               [0.0, 0.01635, -0.0],
                                               [-0.9

In [12]:
def md_from_map(pmap_md, config: AtomicConfiguration):
    gathered_fields = {}
    for md_field in pmap_md.keys():
        if "value" in pmap_md[md_field]:
            v = pmap_md[md_field]["value"]
        elif "field" in pmap_md[md_field]:
            field_key = pmap_md[md_field]["field"]

            if field_key in config.info:
                v = config.info[field_key]
            elif field_key in config.arrays:
                v = config.arrays[field_key]
            else:
                # No keys are required; ignored if missing
                continue
        else:
            # No keys are required; ignored if missing
            continue

        if "units" in pmap_md[md_field]:
            gathered_fields[md_field] = {
                "source-value": v,
                "source-unit": pmap_md[md_field]["units"],
            }
        else:
            gathered_fields[md_field] = {"source-value": v}
    return json.dumps(gathered_fields)
    # return gathered_fields

In [13]:
# in-progress config parser
def co_to_spark_row(config: AtomicConfiguration):
    co_dict = _empty_dict_from_schema(config_schema)
    co_dict["hash"] = config._hash
    co_dict["id"] = f"CO_{config._hash}"
    co_dict.update(config.configuration_summary())
    co_dict["cell"] = config.cell.tolist()
    co_dict["positions"] = config.positions
    co_dict["names"] = config.info[ATOMS_NAME_FIELD]
    co_dict["pbc"] = config.pbc
    co_dict["last_modified"] = datetime.datetime.now(tz=datetime.timezone.utc).strftime(
        "%Y-%m-%dT%H:%M:%SZ"
    )
    co_dict["atomic_numbers"] = config.numbers
    pprint(config.configuration_summary())

    return co_dict

In [17]:
data[0].info["name"]

'Unified_training_set_Si_0'

In [14]:
data[0].to_spark_row()

{'id': 'CO_47706510123393079',
 'hash': 47706510123393079,
 'last_modified': '2024-04-11T18:31:10Z',
 'dataset_ids': None,
 'chemical_formula_hill': 'Si4',
 'chemical_formula_reduced': 'Si',
 'chemical_formula_anonymous': 'A',
 'elements': ['Si'],
 'elements_ratios': [1.0],
 'atomic_numbers': array([14, 14, 14, 14]),
 'nsites': 4,
 'nelements': 1,
 'nperiodic_dimensions': 0,
 'cell': [[3.85085, 0.0, 0.077017],
  [-1.925425, 3.334933, -0.038508],
  [0.127258, 0.0, 6.362934]],
 'dimension_types': [0, 0, 0],
 'pbc': array([False, False, False]),
 'positions': array([[ 1.892001,  1.11132 ,  0.400465],
        [ 1.955509, -1.11132 ,  3.581973],
        [ 1.895339, -1.11132 , -0.400508],
        [ 1.958847,  1.11132 ,  2.781   ]]),
 'names': set()}

In [16]:
md = parse_do(data[0], PROPERTY_MAP["potential-energy"][0])
pprint(md)

{'input': {'source-value': {'kinetic-energy-cutoff': {'units': 'eV',
                                                      'val': 884},
                            'kpoint-scheme': 'Monkhorst-Pack',
                            'kpoints': '8x8x8'}},
 'method': {'source-value': 'DFT-PBE'},
 'software': {'source-value': 'Quantum ESPRESSO'}}


In [18]:
json.dumps(md)

'{"software": {"source-value": "Quantum ESPRESSO"}, "method": {"source-value": "DFT-PBE"}, "input": {"source-value": {"kpoint-scheme": "Monkhorst-Pack", "kpoints": "8x8x8", "kinetic-energy-cutoff": {"val": 884, "units": "eV"}}}}'

In [None]:
def _parse_cf_config(config: AtomicConfiguration):
    summary = config.configuration_summary()
    co = _empty_dict_from_schema(config_schema)
    co["hash"] = str(hash(config))  # String
    co["id"] = f"CO_{co['hash']}"  # String
    last_modified = datetime.datetime.now(tz=datetime.timezone.utc).strftime(
        "%Y-%m-%dT%H:%M:%SZ"
    )
    co["last_modified"] = dateutil.parser.parse(last_modified)  # timestamp

    co["nsites"] = summary["nsites"]  # int
    co["elements"] = summary["elements"]  # Array[string]
    co["nelements"] = summary["nelements"]  # int
    co["elements_ratios"] = summary["elements_ratios"]  # Array[string]
    co["chemical_formula_anonymous"] = summary["chemical_formula_anonymous"]  # String
    co["chemical_formula_reduced"] = summary["chemical_formula_reduced"]  # String
    co["chemical_formula_hill"] = summary["chemical_formula_hill"]  # String
    co["dimension_types"] = summary["dimension_types"]  # Array
    co["nperiodic_dimensions"] = summary["nperiodic_dimensions"]  # int

    co["atomic_numbers"] = config.arrays["numbers"]  # Array[int]
    co["cell"] = config.cell  # Array 3x3
    co["pbc"] = config.pbc  # Array len 3
    co["names"] = config.info[ATOMS_NAME_FIELD]  # Array

    # TODO: handle oversize positions lists
    co["positions"] = config.arrays["positions"]  # Array nsitesx3
    # co["dataset_ids"] = []  # Array[string]
    # for relationship in config["relationships"]:
    #     co["dataset_ids"].append(relationship["dataset"])

    return co

In [None]:
type(data[0].arrays["numbers"])

In [None]:
cos_parallel = spark.sparkContext.parallelize(data)

In [None]:
configs_to_add = cos_parallel.map(_parse_cf_config).map(stringify_lists)

In [None]:
configs_to_add.take(1)

In [None]:
"""
Can we make the configuration and the property instance/data object at the same time?
In this way, we would only have to pass through the data one time.

Workflow:
create database access object
create data reader as function? of the database access object
reader returns ase.Atoms-style objects (AtomicConfiguration)
DOs and PIs are now one object
These DOs point to a configuration
The configuration may already exist in the database, so we keep track of the hash added to the DO


"""

In [None]:
def _build_co_update_doc(configuration):
    processed_fields = configuration.configuration_summary()
    co_hash = str(hash(configuration))
    config_dict = configuration.todict()
    config_dict.update(
        {
            "hash": co_hash,
            SHORT_ID_STRING_NAME: "CO_" + co_hash,
            "last_modified": datetime.datetime.now(tz=datetime.timezone.utc).strftime(
                "%Y-%m-%dT%H:%M:%SZ"
            ),
        }
    )
    names = configuration.info.pop([ATOMS_NAME_FIELD])
    if isinstance(names, str):
        names = [names]
    config_dict["names"] = names
    config_dict.update(
        {k: str(v.tolist()) for k, v in configuration.unique_identifiers.items()}
    )
    config_dict.update({k: v for k, v in processed_fields.items()})
    # 'nsites': 4,
    # 'elements': ['Si'],
    # 'nelements': 1,
    # 'elements_ratios': [1.0],
    # 'chemical_formula_anonymous': 'A',
    # 'chemical_formula_reduced': 'Si',
    # 'chemical_formula_hill': 'Si4',
    # 'dimension_types': [0, 0, 0],
    # 'nperiodic_dimensions': 0
    return config_dict, co_hash

In [None]:
stringify_lists(_build_co_update_doc(data[0])[0])

In [None]:
cos = json.load(Path("sample_db/co_ds1.json").open("r"))

In [None]:
with open(Path("sample_db/co_ds1.json"), "r") as f:
    co_json = spark.sparkContext.parallelize(json.load(f))

In [None]:
co = co_json.map(_parse_config).map(stringify_lists)
co_df = spark.createDataFrame(co, config_schema)

In [None]:
def parse_configs(co_path, spark):
    with open(co_path, "r") as f:
        co_json = spark.sparkContext.parallelize(json.load(f))
    co = co_json.map(_parse_config).map(stringify_lists)
    co_df = spark.createDataFrame(co, config_schema)
    return co_df

In [None]:
parse_configs("sample_db/co_ds1.json", spark).show()

In [None]:
table_name = "co"

mode = "append"
url = "jdbc:postgresql://localhost:5432/colabfit"
properties = {"user": user, "password": password, "driver": "org.postgresql.Driver"}
co_df.write.jdbc(url=url, table=table_name, mode=mode, properties=properties)

In [None]:
# co_df.write.