# Imports

In [16]:
import datetime
import json
import os
from collections import defaultdict
import pickle
from tqdm import tqdm

# from functools import partial
# from itertools import chain, islice
# from multiprocessing import Pool, cpu_count
from pathlib import Path

# from pprint import pprint

import dateutil.parser
import findspark
import lmdb
import numpy as np
import psycopg
import pyspark.sql.functions as sf
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)

from colabfit.tools.configuration import AtomicConfiguration, config_schema
from colabfit.tools.database import DataManager, PGDataLoader
from colabfit.tools.dataset import Dataset, dataset_schema
from colabfit.tools.property import Property, property_object_schema
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)
from colabfit.tools.schema import configuration_set_schema
import pyarrow as pa

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)
findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

# Set up MTPU and Carolina Materials readers and data

In [17]:
# MTPU data


def convert_stress(keys, stress):
    stresses = {k: s for k, s in zip(keys, stress)}
    return [
        [stresses["xx"], stresses["xy"], stresses["xz"]],
        [stresses["xy"], stresses["yy"], stresses["yz"]],
        [stresses["xz"], stresses["yz"], stresses["zz"]],
    ]


SYMBOL_DICT = {"0": "Si", "1": "O"}


def mtpu_reader(filepath):
    with open(filepath, "rt") as f:
        energy = None
        forces = None
        coords = []
        cell = []
        symbols = []
        config_count = 0
        info = dict()
        for line in f:
            if line.strip().startswith("Size"):
                size = int(f.readline().strip())
            elif line.strip().lower().startswith("supercell"):
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
            elif line.strip().startswith("Energy"):
                energy = float(f.readline().strip())
            elif line.strip().startswith("PlusStress"):
                stress_keys = line.strip().split()[-6:]
                stress = [float(x) for x in f.readline().strip().split()]
                stress = convert_stress(stress_keys, stress)
            elif line.strip().startswith("AtomData:"):
                keys = line.strip().split()[1:]
                if "fx" in keys:
                    forces = []
                for i in range(size):
                    li = {
                        key: val for key, val in zip(keys, f.readline().strip().split())
                    }
                    symbols.append(SYMBOL_DICT[li["type"]])
                    if "cartes_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["cartes_x"],
                                    li["cartes_y"],
                                    li["cartes_z"],
                                ]
                            ]
                        )
                    elif "direct_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["direct_x"],
                                    li["direct_y"],
                                    li["direct_z"],
                                ]
                            ]
                        )

                    if "fx" in keys:
                        forces.append(
                            [float(f) for f in [li["fx"], li["fy"], li["fz"]]]
                        )

            elif line.startswith("END_CFG"):

                info["energy"] = energy
                if forces:
                    info["forces"] = forces
                info["stress"] = stress

                if "Si" in symbols and "O" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "11x11x11",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_SiO2_{config_count}"
                elif "Si" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "8x8x8",
                        "kinetic-energy-cutoff": {
                            "val": 884,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_Si_{config_count}"
                elif "O" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "gamma-point",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_O_{config_count}"
                if "cartes_x" in keys:
                    config = AtomicConfiguration(
                        positions=coords, symbols=symbols, cell=cell, info=info
                    )
                elif "direct_x" in keys:
                    config = AtomicConfiguration(
                        scaled_positions=coords, symbols=symbols, cell=cell, info=info
                    )
                config_count += 1
                yield config
                forces = None
                stress = []
                coords = []
                cell = []
                symbols = []
                energy = None

In [18]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
data = list(mtpu_configs)
# data = [x for x in mtpu_configs]
# data[0].configuration_summary()

In [19]:
import colabfit.tools.configuration
from importlib import reload

reload(colabfit.tools.configuration)
AtomicConfiguration = colabfit.tools.configuration.AtomicConfiguration

In [20]:
print(data[0])

AtomicConfiguration(name=Unified_training_set_SiO2_1061, symbols='Si4', pbc=False, cell=[[3.85085, 0.0, 0.077017], [-1.925425, 3.334933, -0.038508], [0.127258, 0.0, 6.362934]])


In [21]:
loader.spark.serialize()

NameError: name 'loader' is not defined

In [None]:
data[0].spark_row

{'id': 'CO_47706510123393079',
 'hash': 47706510123393079,
 'last_modified': datetime.datetime(2024, 5, 20, 17, 14, 38),
 'dataset_ids': None,
 'metadata': None,
 'chemical_formula_hill': 'Si4',
 'chemical_formula_reduced': 'Si',
 'chemical_formula_anonymous': 'A',
 'elements': "['Si']",
 'elements_ratios': '[1.0]',
 'atomic_numbers': '[14, 14, 14, 14]',
 'nsites': 4,
 'nelements': 1,
 'nperiodic_dimensions': 0,
 'cell': '[[3.85085, 0.0, 0.077017], [-1.925425, 3.334933, -0.038508], [0.127258, 0.0, 6.362934]]',
 'dimension_types': '[0, 0, 0]',
 'pbc': '[False, False, False]',
 'positions': '[[1.892001, 1.11132, 0.400465], [1.955509, -1.11132, 3.581973], [1.895339, -1.11132, -0.400508], [1.958847, 1.11132, 2.781]]',
 'names': "['Unified_training_set_Si_0']",
 'labels': None,
 'configuration_set_ids': None}

In [None]:
# Carolina Materials data

SOFTWARE = "VASP"
METHODS = "DFT-PBE"
CM_PI_METADATA = {
    "software": {"value": SOFTWARE},
    "method": {"value": METHODS},
    "input": {"value": {"IBRION": 6, "NFREE": 4}},
}

CM_PROPERTY_MAP = {
    "formation-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
        }
    ],
    "_metadata": CM_PI_METADATA,
}
CO_MD = {
    key: {"field": key}
    for key in [
        "_symmetry_space_group_name_H-M",
        "_symmetry_Int_Tables_number",
        "_chemical_formula_structural",
        "_chemical_formula_sum",
        "_cell_volume",
        "_cell_formula_units_Z",
        "symmetry_dict",
        "formula_pretty",
    ]
}


def load_row(txn, row):
    try:
        data = pickle.loads(txn.get(f"{row}".encode("ascii")))
        return data
    except TypeError:
        return False


def config_from_row(row: dict, row_num: int):
    coords = row.pop("cart_coords")
    a_num = row.pop("atomic_numbers")
    cell = [
        row.pop(x)
        for x in [
            "_cell_length_a",
            "_cell_length_b",
            "_cell_length_c",
            "_cell_angle_alpha",
            "_cell_angle_beta",
            "_cell_angle_gamma",
        ]
    ]
    symmetry_dict = {str(key): val for key, val in row.pop("symmetry_dict").items()}
    for key in symmetry_dict:
        key = str(key)
    info = {}
    info = row
    info["symmetry_dict"] = symmetry_dict
    info["_name"] = f"carolina_materials_{row_num}"
    if row_num % 10 == 0:
        info["_labels"] = [row_num % 10, "bcc"]
    else:
        info["_labels"] = [row_num % 10, "fcc"]
    config = AtomicConfiguration(
        scaled_positions=coords,
        numbers=a_num,
        cell=cell,
        info=info,
    )
    return config
    # return AtomicConfiguration.from_ase(config)


def carmat_reader(fp: Path):
    parent = fp.parent
    env = lmdb.open(str(parent))
    txn = env.begin()
    row_num = 0
    rows = []
    while row_num <= 10000:
        row = load_row(txn, row_num)
        if row is False:
            env.close()
            break
        rows.append(row)
        yield config_from_row(row, row_num)
        row_num += 1
    env.close()
    return False
    # return rows

In [None]:
PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}

# Connect to DB and run loader

In [23]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/21 14:49:53 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/21 14:49:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/21 14:49:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/21 14:49:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))

PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}
spark = SparkSession.builder.appName("ColabfitIngestData").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
# print(loader.spark)
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = list(mtpu_configs)
print(mtpu_configs[0])
co_po_rows = []
for config in tqdm(mtpu_configs):
    config.set_dataset_id(mtpu_ds_id)
    co_po_rows.append(
        (
            config.spark_row,
            Property.from_definition(
                [potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
                configuration=config,
                property_map=PROPERTY_MAP,
            ).spark_row,
        )
    )

24/05/20 17:14:41 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


AtomicConfiguration(name=Unified_training_set_SiO2_1061, symbols='Si4', pbc=False, cell=[[3.85085, 0.0, 0.077017], [-1.925425, 3.334933, -0.038508], [0.127258, 0.0, 6.362934]])


100%|██████████| 1062/1062 [00:00<00:00, 2495.36it/s]


In [None]:
co_po_rows[0][0]

{'id': 'CO_47706510123393079',
 'hash': 47706510123393079,
 'last_modified': datetime.datetime(2024, 5, 20, 17, 14, 41),
 'dataset_ids': "['DS_y7nrdsjtuwom_0']",
 'metadata': None,
 'chemical_formula_hill': 'Si4',
 'chemical_formula_reduced': 'Si',
 'chemical_formula_anonymous': 'A',
 'elements': "['Si']",
 'elements_ratios': '[1.0]',
 'atomic_numbers': '[14, 14, 14, 14]',
 'nsites': 4,
 'nelements': 1,
 'nperiodic_dimensions': 0,
 'cell': '[[3.85085, 0.0, 0.077017], [-1.925425, 3.334933, -0.038508], [0.127258, 0.0, 6.362934]]',
 'dimension_types': '[0, 0, 0]',
 'pbc': '[False, False, False]',
 'positions': '[[1.892001, 1.11132, 0.400465], [1.955509, -1.11132, 3.581973], [1.895339, -1.11132, -0.400508], [1.958847, 1.11132, 2.781]]',
 'names': "['Unified_training_set_Si_0']",
 'labels': None,
 'configuration_set_ids': None}

In [None]:
import datetime

In [None]:
dateutil.parser.parse(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

datetime.datetime(2024, 5, 20, 17, 10, 24)

In [None]:
try:
    print("making co po rows")
    co_rows, po_rows = list(zip(*co_po_rows))
    print("making cos dataframes...")
    cos_dataframe = spark.createDataFrame(co_rows, schema=config_schema)
    print("Done!")
    print(cos_dataframe.show(1, False))
    print("making pos dataframes...")
    pos_dataframe = spark.createDataFrame(po_rows, schema=property_object_schema)
    print("Done!")
    pos_dataframe.show(1, False)
    try:
        # loader.write_table(
        #     co_rows,
        #     loader.config_table,
        #     config_schema,
        # )
        # loader.write_table(
        #     po_rows,
        #     loader.prop_object_table,
        #     property_object_schema,
        # )
        print(loader.config_table)
        cos_dataframe.write.mode("append").saveAsTable(loader.config_table)
        print(loader.prop_object_table)
        pos_dataframe.write.mode("append").saveAsTable(loader.prop_object_table)
    except:
        print("loader write failed")
except:
    print("error getting df")

making co po rows
making cos dataframes...
Done!


                                                                                

+--------------------+-----------------+-------------------+---------------------+--------+---------------------+------------------------+--------------------------+--------+---------------+----------------+------+---------+--------------------+---------------------------------------------------------------------------------------+---------------+---------------------+----------------------------------------------------------------------------------------------------------------------------+-----------------------------+------+---------------------+
|id                  |hash             |last_modified      |dataset_ids          |metadata|chemical_formula_hill|chemical_formula_reduced|chemical_formula_anonymous|elements|elements_ratios|atomic_numbers  |nsites|nelements|nperiodic_dimensions|cell                                                                                   |dimension_types|pbc                  |positions                                                                

24/05/20 14:52:03 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+----------------------+-------------------+-------------------+------------------------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------+----------------+-------+---------------------+------------------+---------------------+-------------------------+--------------------------+-------------------------------+-------------------------------------------------------------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

24/05/20 14:52:04 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/05/20 14:52:04 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
24/05/20 14:52:04 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
24/05/20 14:52:04 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
24/05/20 14:52:04 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 63.33% for 12 writers
24/05/20 14:52:04 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 58.46% for 13 writers
24/05/20 14:52:04 WARN MemoryManager: Total allocation exceeds 95.

gpw_test_propobjects


24/05/20 14:52:05 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 95.00% for 8 writers
24/05/20 14:52:05 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 84.44% for 9 writers
24/05/20 14:52:05 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 76.00% for 10 writers
24/05/20 14:52:05 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 69.09% for 11 writers
24/05/20 14:52:05 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 63.33% for 12 writers
24/05/20 14:52:05 WARN MemoryManager: Total allocation exceeds 95.00% (1,020,054,720 bytes) of heap memory
Scaling row group sizes to 58.46% for 13 writers
24/05/20 14:52:05 WARN MemoryManager: Total allocation exceeds 95.

In [24]:
config_df = loader.spark.read.jdbc(
    url=url, table="configurations", properties=properties
)

In [101]:
def append_elem(col_array, elem):
    print(col_array)
    unstrung = eval(col_array)
    unstrung.append(elem)
    unstrung = list(set(unstrung))
    return str(unstrung)

In [25]:
config_df.show(1, False)

+----------------------+-------------------+-------------------+---------------------+--------+---------------------+------------------------+--------------------------+----------------+-----------------+----------------------------+------+---------+--------------------+-------------------------------------------------------------------------------------------------+---------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------+----------+---------------------+
|id                    |hash               |last_modified      |dataset_ids          |metadata|chemical_formula_hill|chemical

In [None]:
# NEXT: find a way to check wehther cs-ids is null and handle the column =+ cs_id as array
# OR find a way to use the lambda function to use the user-defined function (append element)

In [118]:
labels = ["fcc", 6]
config_set_id = "test_config_set_id"
config_df.withColumn("filter_labels", sf.lit(labels)).withColumn(
    "labels_unstrung", sf.from_json(sf.col("labels"), ArrayType(StringType()))
).withColumn(
    "has_labels",
    sf.forall(
        "filter_labels",
        lambda x: sf.array_contains(col=sf.col("labels_unstrung"), value=x),
    ),
).withColumn(
    "new_cs_id", sf.lit([config_set_id])
).withColumn(
    "configuration_set_ids",
    sf.when(
        condition=sf.col("has_labels") == True,
        value=sf.array_union(config_df["configuration_set_ids"], sf.col("dataset_ids")),
    ),
    # .otherwise(config_df["configuration_set_ids"]),
)
# .withColumn(
#     "configuration_set_ids",
#     sf.transform_values(
#         "configuration_set_ids", lambda k, cs_ids: append_elem(cs_ids, config_set_id)
#     ),
# )
# .show(
#     10, False
# )

AnalysisException: [DATATYPE_MISMATCH.BINARY_ARRAY_DIFF_TYPES] Cannot resolve "array_union(configuration_set_ids, dataset_ids)" due to data type mismatch: Input to function `array_union` should have been two "ARRAY" with same element type, but it's ["STRING", "STRING"].;
'Project [id#42, hash#43, last_modified#44, dataset_ids#45, metadata#46, chemical_formula_hill#47, chemical_formula_reduced#48, chemical_formula_anonymous#49, elements#50, elements_ratios#51, atomic_numbers#52, nsites#53, nelements#54, nperiodic_dimensions#55, cell#56, dimension_types#57, pbc#58, positions#59, names#60, labels#61, CASE WHEN (has_labels#2158 = true) THEN array_union(configuration_set_ids#62, dataset_ids#45) END AS configuration_set_ids#2210, filter_labels#2111, labels_unstrung#2134, has_labels#2158, new_cs_id#2184]
+- Project [id#42, hash#43, last_modified#44, dataset_ids#45, metadata#46, chemical_formula_hill#47, chemical_formula_reduced#48, chemical_formula_anonymous#49, elements#50, elements_ratios#51, atomic_numbers#52, nsites#53, nelements#54, nperiodic_dimensions#55, cell#56, dimension_types#57, pbc#58, positions#59, names#60, labels#61, configuration_set_ids#62, filter_labels#2111, labels_unstrung#2134, has_labels#2158, test_config_set_id AS new_cs_id#2184]
   +- Project [id#42, hash#43, last_modified#44, dataset_ids#45, metadata#46, chemical_formula_hill#47, chemical_formula_reduced#48, chemical_formula_anonymous#49, elements#50, elements_ratios#51, atomic_numbers#52, nsites#53, nelements#54, nperiodic_dimensions#55, cell#56, dimension_types#57, pbc#58, positions#59, names#60, labels#61, configuration_set_ids#62, filter_labels#2111, labels_unstrung#2134, forall(filter_labels#2111, lambdafunction(array_contains(labels_unstrung#2134, lambda x_28#2159), lambda x_28#2159, false)) AS has_labels#2158]
      +- Project [id#42, hash#43, last_modified#44, dataset_ids#45, metadata#46, chemical_formula_hill#47, chemical_formula_reduced#48, chemical_formula_anonymous#49, elements#50, elements_ratios#51, atomic_numbers#52, nsites#53, nelements#54, nperiodic_dimensions#55, cell#56, dimension_types#57, pbc#58, positions#59, names#60, labels#61, configuration_set_ids#62, filter_labels#2111, from_json(ArrayType(StringType,true), labels#61, Some(America/New_York)) AS labels_unstrung#2134]
         +- Project [id#42, hash#43, last_modified#44, dataset_ids#45, metadata#46, chemical_formula_hill#47, chemical_formula_reduced#48, chemical_formula_anonymous#49, elements#50, elements_ratios#51, atomic_numbers#52, nsites#53, nelements#54, nperiodic_dimensions#55, cell#56, dimension_types#57, pbc#58, positions#59, names#60, labels#61, configuration_set_ids#62, array(fcc, cast(6 as string)) AS filter_labels#2111]
            +- Relation [id#42,hash#43,last_modified#44,dataset_ids#45,metadata#46,chemical_formula_hill#47,chemical_formula_reduced#48,chemical_formula_anonymous#49,elements#50,elements_ratios#51,atomic_numbers#52,nsites#53,nelements#54,nperiodic_dimensions#55,cell#56,dimension_types#57,pbc#58,positions#59,names#60,labels#61,configuration_set_ids#62] JDBCRelation(configurations) [numPartitions=1]


In [10]:
first = config_df.first()

In [83]:
first

Row(id='CO_1000272601413209564', hash='1000272601413209564', last_modified=datetime.datetime(2024, 5, 3, 15, 10, 43), dataset_ids="['DS_y7nrdsjtuw0g_0']", metadata=None, chemical_formula_hill='In2N4Y2', chemical_formula_reduced='InN2Y', chemical_formula_anonymous='A2BC', elements="['In', 'N', 'Y']", elements_ratios='[0.25, 0.5, 0.25]', atomic_numbers='[39, 39, 49, 49, 7, 7, 7, 7]', nsites=8, nelements=3, nperiodic_dimensions=0, cell='[[3.42184625, 0.0, 0.0], [-1.7109231249999994, 2.9634057803445177, 0.0], [0.0, 0.0, 11.13673602]]', dimension_types='[0, 0, 0]', pbc='[False, False, False]', positions='[[0.0, 0.0, 0.0], [0.0, 0.0, 5.56836801], [-1.710923098524749e-08, 1.9756038634410311, 2.784184005], [1.7109231421092315, 0.9878019169034866, 8.352552015], [-1.710923098524749e-08, 1.9756038634410311, 6.919232818060203], [1.7109231421092315, 0.9878019169034866, 4.217503201939798], [1.7109231421092315, 0.9878019169034866, 1.350864808060202], [-1.710923098524749e-08, 1.9756038634410311, 9.785

In [26]:
config_rdd = config_df.rdd

In [31]:
first_rdd = config_rdd.take(1)

In [36]:
for x in first_rdd[0]:
    print(x)

CO_1000272601413209564
1000272601413209564
2024-05-03 15:10:43
['DS_y7nrdsjtuw0g_0']
None
In2N4Y2
InN2Y
A2BC
['In', 'N', 'Y']
[0.25, 0.5, 0.25]
[39, 39, 49, 49, 7, 7, 7, 7]
8
3
0
[[3.42184625, 0.0, 0.0], [-1.7109231249999994, 2.9634057803445177, 0.0], [0.0, 0.0, 11.13673602]]
[0, 0, 0]
[False, False, False]
[[0.0, 0.0, 0.0], [0.0, 0.0, 5.56836801], [-1.710923098524749e-08, 1.9756038634410311, 2.784184005], [1.7109231421092315, 0.9878019169034866, 8.352552015], [-1.710923098524749e-08, 1.9756038634410311, 6.919232818060203], [1.7109231421092315, 0.9878019169034866, 4.217503201939798], [1.7109231421092315, 0.9878019169034866, 1.350864808060202], [-1.710923098524749e-08, 1.9756038634410311, 9.7858712119398]]
['carolina_materials_4263']
[3, 'fcc']
None


In [40]:
from functools import partial
from pyspark.sql import Row

In [42]:
Row(first_rdd[0].asDict())

<Row({'id': 'CO_1000272601413209564', 'hash': '1000272601413209564', 'last_modified': datetime.datetime(2024, 5, 3, 15, 10, 43), 'dataset_ids': "['DS_y7nrdsjtuw0g_0']", 'metadata': None, 'chemical_formula_hill': 'In2N4Y2', 'chemical_formula_reduced': 'InN2Y', 'chemical_formula_anonymous': 'A2BC', 'elements': "['In', 'N', 'Y']", 'elements_ratios': '[0.25, 0.5, 0.25]', 'atomic_numbers': '[39, 39, 49, 49, 7, 7, 7, 7]', 'nsites': 8, 'nelements': 3, 'nperiodic_dimensions': 0, 'cell': '[[3.42184625, 0.0, 0.0], [-1.7109231249999994, 2.9634057803445177, 0.0], [0.0, 0.0, 11.13673602]]', 'dimension_types': '[0, 0, 0]', 'pbc': '[False, False, False]', 'positions': '[[0.0, 0.0, 0.0], [0.0, 0.0, 5.56836801], [-1.710923098524749e-08, 1.9756038634410311, 2.784184005], [1.7109231421092315, 0.9878019169034866, 8.352552015], [-1.710923098524749e-08, 1.9756038634410311, 6.919232818060203], [1.7109231421092315, 0.9878019169034866, 4.217503201939798], [1.7109231421092315, 0.9878019169034866, 1.350864808060

In [50]:
config_rdd.map(unstringify).take(1)

[Row(id='CO_1000272601413209564', hash='1000272601413209564', last_modified=datetime.datetime(2024, 5, 3, 15, 10, 43), dataset_ids=['DS_y7nrdsjtuw0g_0'], metadata=None, chemical_formula_hill='In2N4Y2', chemical_formula_reduced='InN2Y', chemical_formula_anonymous='A2BC', elements=['In', 'N', 'Y'], elements_ratios=[0.25, 0.5, 0.25], atomic_numbers=[39, 39, 49, 49, 7, 7, 7, 7], nsites=8, nelements=3, nperiodic_dimensions=0, cell=[[3.42184625, 0.0, 0.0], [-1.7109231249999994, 2.9634057803445177, 0.0], [0.0, 0.0, 11.13673602]], dimension_types=[0, 0, 0], pbc=[False, False, False], positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 5.56836801], [-1.710923098524749e-08, 1.9756038634410311, 2.784184005], [1.7109231421092315, 0.9878019169034866, 8.352552015], [-1.710923098524749e-08, 1.9756038634410311, 6.919232818060203], [1.7109231421092315, 0.9878019169034866, 4.217503201939798], [1.7109231421092315, 0.9878019169034866, 1.350864808060202], [-1.710923098524749e-08, 1.9756038634410311, 9.7858712119398]], 

In [None]:
test_rdd.take(10)

In [1]:
config_rdd.take(10)

NameError: name 'config_rdd' is not defined

In [49]:
def unstringify(row):
    row_dict = row.asDict()
    for key, val in row_dict.items():
        if isinstance(val, str) and len(val) > 0 and val[0] in ["{", "["]:
            dval = eval(row[key])
            row_dict[key] = dval
    new_row = Row(**row_dict)
    return new_row

df.rdd  
you can only parallelize one time so don't try to do a dataframe select from an rdd  
updating to sdk 5.1 in a couple weeks  
boto3 and s3 are the amazon file system interactions, mostly for adding metadata TO FILES (not to the database) and interacting with the files as FileExistsError. 
Make sure to spark.stop() at end of  python file.

In [7]:
carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"
# carmat_ds_id2 = "duplicate_ds_id"
dm = DataManager(
    nprocs=4,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
# dm_dup = DataManager(
#     nprocs=4,
#     configs=carmat_config_gen,
#     prop_defs=[formation_energy_pd],
#     prop_map=CM_PROPERTY_MAP,
#     dataset_id=carmat_ds_id2,
# )

Dataset ID: DS_y7nrdsjtuw0g_0


In [8]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

Dataset ID: DS_y7nrdsjtuwom_0


In [12]:
batch = next(dm.gather_co_po_in_batches())

In [13]:
cos = [x[0] for x in batch]
cos_dataframe = loader.spark.createDataFrame(cos, schema=config_schema)

In [11]:
rows = loader.spark.createDataFrame(cos, schema=config_schema).collect()

                                                                                

In [13]:
row_dict = defaultdict(list)
for i, row in enumerate(rows):
    if i > 10:
        break
    for key, val in row.asDict().items():
        row_dict[key].append(val)
print(row_dict)

defaultdict(<class 'list'>, {'id': ['CO_2035392092515548233', 'CO_395357802651642160', 'CO_1353971176978508405', 'CO_673544646621950284', 'CO_1180491453360028556', 'CO_1235470615381239321', 'CO_1125763852340156480', 'CO_2185718778256801767', 'CO_1547613879869181226', 'CO_2262213674274439455', 'CO_2212061042677459331'], 'hash': ['2035392092515548233', '395357802651642160', '1353971176978508405', '673544646621950284', '1180491453360028556', '1235470615381239321', '1125763852340156480', '2185718778256801767', '1547613879869181226', '2262213674274439455', '2212061042677459331'], 'last_modified': [datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.datetime(2024, 5, 8, 15, 17, 24), datetime.date

In [30]:
import pyarrow as pa

config_schema = StructType(
    [
        StructField("id", StringType(), False),
        StructField("hash", StringType(), False),
        StructField("last_modified", TimestampType(), False),
        StructField("dataset_ids", StringType(), True),  # ArrayType(StringType())
        StructField("metadata", StringType(), True),
        StructField("chemical_formula_hill", StringType(), True),
        StructField("chemical_formula_reduced", StringType(), True),
        StructField("chemical_formula_anonymous", StringType(), True),
        StructField("elements", StringType(), True),  # ArrayType(StringType())
        StructField("elements_ratios", StringType(), True),  # ArrayType(IntegerType())
        StructField("atomic_numbers", StringType(), True),  # ArrayType(IntegerType())
        StructField("nsites", IntegerType(), True),
        StructField("nelements", IntegerType(), True),
        StructField("nperiodic_dimensions", IntegerType(), True),
        StructField("cell", StringType(), True),  # ArrayType(ArrayType(DoubleType()))
        StructField("dimension_types", StringType(), True),  # ArrayType(IntegerType())
        StructField("pbc", StringType(), True),  # ArrayType(IntegerType())
        StructField(
            "positions", StringType(), True
        ),  # ArrayType(ArrayType(DoubleType()))
        StructField("names", StringType(), True),  # ArrayType(StringType()),
        StructField("labels", StringType(), True),  # ArrayType(StringType())
        StructField(
            "configuration_set_ids", StringType(), True
        ),  # ArrayType(StringType())
    ]
)


def convert_type(spark_type):
    type_name = spark_type.dataType.typeName()
    if type_name in data_type_map:
        print(data_type_map[type_name])
        return pa.field(spark_type.name, data_type_map[type_name])
    else:
        raise ValueError(f"Unsupported PySpark data type: {spark_type}")


# data_type_map = {
#     StringType: pa.string(),
#     IntegerType: pa.int32(),
#     FloatType: pa.float64(),
#     DoubleType: pa.float64(),
#     BooleanType: pa.bool_(),
#     TimestampType: pa.timestamp("ns"),
#     # DateType: pa.date32(),
#     ArrayType: lambda dt: pa.list_(convert_type(dt)),
#     StructType: lambda st: pa.struct([convert_type(f) for f in st.fields]),
# }
data_type_map = {
    "string": pa.string(),
    "integer": pa.int32(),
    "float": pa.float64(),
    # DoubleType: pa.float64(),
    # BooleanType: pa.bool_(),
    "timestamp": pa.timestamp("ns"),
    # DateType: pa.date32(),
    # ArrayType: lambda dt: pa.list_(convert_type(dt)),
    "struct": lambda st: pa.struct([convert_type(f) for f in st.fields]),
}

In [14]:
config_set_query = (
    cos_dataframe.withColumn(
        "names_unstrung", sf.from_json(sf.col("names"), ArrayType(StringType()))
    )
    .withColumn(
        "labels_unstrung",
        sf.from_json(sf.col("labels"), ArrayType(StringType())),
    )
    .withColumn(
        "dataset_ids_unstrung", sf.from_json("dataset_ids", ArrayType(StringType()))
    )
    .drop("names", "labels", "dataset_ids")
    .withColumnRenamed("names_unstrung", "names")
    .withColumnRenamed("labels_unstrung", "labels")
    .withColumnRenamed("dataset_ids_unstrung", "dataset_ids")
    .filter(sf.array_contains(sf.col("dataset_ids"), carmat_ds_id))
)

In [15]:
config_set_query.show(1, False)

+----------------------+-------------------+-------------------+--------+---------------------+------------------------+--------------------------+-----------------------+--------------------+----------------------------------+------+---------+--------------------+-----------------------------------------------------------------------------------------------------------------------------------+---------------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------

In [17]:
label_match = "fcc|1"
names_match = "materials_10"
if names_match:
    config_set_query = (
        config_set_query.withColumn("labels_exploded", sf.explode(sf.col("labels")))
        .withColumn("names_exploded", sf.explode(sf.col("names")))
        .drop("names")
        .filter(sf.regexp_like(sf.col("names_exploded"), sf.lit(rf"{names_match}")))
    )
if label_match:
    config_set_query = config_set_query.filter(
        sf.regexp_like(sf.col("labels_exploded"), sf.lit(rf"{label_match}"))
    )
config_set_query.show(1, False)

+---------------------+------------------+-------------------+--------+---------------------+------------------------+--------------------------+------------------------+--------------------+----------------------------------------+------+---------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+---------------+---------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+--------+-------------------+--------------

In [18]:
loader.spark.sql(
    "UPDATE config_table SET configuration_set_ids = array_append(configuration_set_ids, %s) WHERE id = %s"
) % (config_set.id, co_ids)

ParseException: 
[PARSE_SYNTAX_ERROR] Syntax error at or near '('.(line 1, pos 60)

== SQL ==
UPDATE config_table SET configuration_set_ids = array_append(configuration_set_ids, %s) WHERE id = %s
------------------------------------------------------------^^^


In [21]:
from colabfit.tools.configuration_set import ConfigurationSet

In [20]:
from importlib import reload
import colabfit.tools.configuration_set

reload(colabfit.tools.configuration_set)
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
reload(colabfit.tools.database)
from colabfit.tools.database import DataManager, PGDataLoader

In [16]:
co_ids = [x[0] for x in config_set_query.select("id").distinct().collect()]

In [17]:
co_ids

['CO_1180491453360028556',
 'CO_879568364406847005',
 'CO_881342345556419709',
 'CO_395357802651642160',
 'CO_107450829658643358',
 'CO_2096618530829203889',
 'CO_991509346356160256',
 'CO_1491805601289580349',
 'CO_917048308752776257',
 'CO_1690862290738836150',
 'CO_874500111687326151',
 'CO_1125763852340156480',
 'CO_797754777732167492',
 'CO_1964436062852611062',
 'CO_1559941325392132777',
 'CO_1353971176978508405',
 'CO_286992457615631475',
 'CO_1552121326771475428',
 'CO_1325933560876578330',
 'CO_938770350513000754',
 'CO_387863327590048542',
 'CO_37932453843969185',
 'CO_1638546733562907466',
 'CO_1425676531260078729',
 'CO_1105782185213068538',
 'CO_326107300816987243',
 'CO_1615978422284641163',
 'CO_1047317511485860216',
 'CO_48167453662657190',
 'CO_1586766357518215377',
 'CO_856131307896983346',
 'CO_1554352833473561771',
 'CO_2035392092515548233',
 'CO_2162096324626582854',
 'CO_127707683762503658',
 'CO_617556824345540542',
 'CO_59433596177554941',
 'CO_14603548386437399

In [22]:
CS = ConfigurationSet(
    name="test", config_df=config_set_query, description="test description"
)

<class 'list'>
[Row(element='Ga', ratio=0.012887595268940222), Row(element='I', ratio=0.016706765146289347), Row(element='Pt', ratio=0.012332385551087706), Row(element='Se', ratio=0.016555344314147753), Row(element='Tl', ratio=0.006023184211854569), Row(element='Ni', ratio=0.017144203105809514), Row(element='Os', ratio=0.011339737873715026), Row(element='Co', ratio=0.014014839241549877), Row(element='Fe', ratio=0.016269327186769184), Row(element='Ru', ratio=0.012214613792755354), Row(element='Mg', ratio=0.010919124451099483), Row(element='Ti', ratio=0.02444605212241533), Row(element='Ag', ratio=0.008883355485640258), Row(element='H', ratio=0.07577771421841614), Row(element='Te', ratio=0.012988542490367953), Row(element='Al', ratio=0.018389218836751518), Row(element='C', ratio=0.01552904756296583), Row(element='S', ratio=0.027171627100964046), Row(element='Li', ratio=0.021552231774820397), Row(element='Ca', ratio=0.007991655029695307), Row(element='Zr', ratio=0.010077897605868398), Row(

In [29]:
def update_co_rows_cs_id(self, co_ids: list[str], cs_id: str):
    with psycopg.connect(
        """dbname=colabfit user=%s password=%s host=localhost port=5432"""
        % (
            user,
            password,
        )
    ) as conn:
        # dbname=self.database_name,
        # user=self.properties["user"],
        # password=self.properties["password"],
        # host="localhost",
        # port="5432",
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = 
            """
        )
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = concat(%s::text, 
                rtrim(ltrim(replace(configuration_set_ids,%s,''), 
                
                '['),']') || ', ', %s::text)
            WHERE id = ANY(%s)""",
            ("[", f"{cs_id}", f"{cs_id}]", co_ids),
            # ("[", f", {cs_id}", f", {cs_id}]"),
        )
        # cur.fetchall()
        conn.commit()

In [None]:
# You were trying to get  postgresql to recognize the WHERE id = ANY() array syntax

In [28]:
update_co_rows_cs_id(loader, co_ids, CS.spark_row["id"])

In [67]:
CS.spark_row

{'id': 'CS_258512216210800474',
 'hash': 258512216210800474,
 'last_modified': None,
 'nconfigurations': 110,
 'nsites': 1594,
 'nelements': 62,
 'elements': ['Ag',
  'Al',
  'As',
  'Au',
  'B',
  'Ba',
  'Be',
  'Bi',
  'Br',
  'C',
  'Ca',
  'Cd',
  'Cl',
  'Co',
  'Cr',
  'Cs',
  'Cu',
  'F',
  'Fe',
  'Ga',
  'Ge',
  'H',
  'Hf',
  'Hg',
  'I',
  'In',
  'Ir',
  'K',
  'Li',
  'Mg',
  'Mn',
  'Mo',
  'N',
  'Na',
  'Nb',
  'Ni',
  'O',
  'Os',
  'P',
  'Pb',
  'Pd',
  'Pt',
  'Rb',
  'Re',
  'Rh',
  'Ru',
  'S',
  'Sb',
  'Sc',
  'Se',
  'Si',
  'Sn',
  'Ta',
  'Tc',
  'Te',
  'Ti',
  'Tl',
  'V',
  'W',
  'Y',
  'Zn',
  'Zr'],
 'dataset_id': None,
 'name': 'test',
 'description': 'test description',
 'total_elements_ratios': [0.013801756587202008,
  0.006273525721455458,
  0.009410288582183186,
  0.010037641154328732,
  0.016938519447929738,
  0.006273525721455458,
  0.014429109159347553,
  0.01819322459222083,
  0.027603513174404015,
  0.012547051442910916,
  0.00313676286072772

In [78]:
from colabfit.tools.utilities import _empty_dict_from_schema
from colabfit.tools.schema import configuration_set_schema
import dateutil.parser

cs = _empty_dict_from_schema(configuration_set_schema)
cs["nconfigurations"] = 200
cs["dataset_id"] = carmat_ds_id
cs["name"] = "test"
cs["description"] = "test description for test"
cs["nelements"] = 25
cs["last_modified"] = dateutil.parser.parse(
    datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
)
cs["id"] = "CS_y7nrdsjtuw0g_0"
cs["hash"] = hash(cs["name"])

In [79]:
loader.write_table([cs], "configuration_sets", configuration_set_schema)

In [111]:
co_ids = [x[0] for x in config_set_query.select("id").distinct().collect()]

In [9]:
dm.load_data_to_pg_in_batches(loader)

Loading data to PostgreSQL: : 0batch [00:00, ?batch/s]24/05/14 18:12:58 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
24/05/14 18:13:01 ERROR Executor: Exception in task 13.0 in stage 0.0 (TID 13)2]
java.sql.BatchUpdateException: Batch entry 0 INSERT INTO configurations ("id","hash","last_modified","dataset_ids","metadata","chemical_formula_hill","chemical_formula_reduced","chemical_formula_anonymous","elements","elements_ratios","atomic_numbers","nsites","nelements","nperiodic_dimensions","cell","dimension_types","pbc","positions","names","labels","configuration_set_ids") VALUES (('CO_425343146072841686'),('425343146072841686'),('2024-05-14 18:12:56-04'::timestamp),('[''DS_y7nrdsjtuw0g_0'']'),(NULL),('Ge2LiRh6Zr'),('Ge2LiRh6Zr'),('A6B2CD'),('[''Ge'', ''Li'', ''Rh'', ''Zr'']'),('[0.2, 0.1, 0.6, 0.1]'),('[32, 32, 3, 45, 45, 45, 45, 45, 45, 40]'),('10'::int4),('4':

Py4JJavaError: An error occurred while calling o46.jdbc.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 0.0 failed 1 times, most recent failure: Lost task 0.0 in stage 0.0 (TID 0) (mart4.physics.nyu.edu executor driver): java.sql.BatchUpdateException: Batch entry 0 INSERT INTO configurations ("id","hash","last_modified","dataset_ids","metadata","chemical_formula_hill","chemical_formula_reduced","chemical_formula_anonymous","elements","elements_ratios","atomic_numbers","nsites","nelements","nperiodic_dimensions","cell","dimension_types","pbc","positions","names","labels","configuration_set_ids") VALUES (('CO_2035392092515548233'),('2035392092515548233'),('2024-05-14 18:12:55-04'::timestamp),('[''DS_y7nrdsjtuw0g_0'']'),(NULL),('H6BrCaRh2'),('BrCaH6Rh2'),('A6B2CD'),('[''Br'', ''Ca'', ''H'', ''Rh'']'),('[0.1, 0.1, 0.6, 0.2]'),('[35, 20, 1, 1, 1, 1, 1, 1, 45, 45]'),('10'::int4),('4'::int4),('0'::int4),('[[5.39874426, 0.0, 0.0], [2.6993721300000004, 4.67544967769542, 0.0], [2.6993721300000004, 1.5584832258984738, 4.4080562295931855]]'),('[0, 0, 0]'),('[False, False, False]'),('[[0.0, 0.0, 0.0], [5.398744260000001, 3.116966451796947, 2.2040281147965928], [5.398744260000001, 4.84493341938597, 3.4258852752451454], [5.39874426, 1.3889994842079243, 0.9821709543480398], [6.895207550832455, 3.980949935591458, 0.9821709543480398], [3.902280969167545, 2.252982968002436, 3.4258852752451454], [6.895207550832455, 2.252982968002436, 3.4258852752451454], [3.902280969167545, 3.980949935591458, 0.9821709543480398], [8.098116390000001, 4.67544967769542, 3.306042172194889], [2.6993721300000004, 1.5584832258984735, 1.1020140573982964]]'),('[''carolina_materials_0'']'),('[0, ''bcc'']'),(NULL)) was aborted: ERROR: duplicate key value violates unique constraint "idkey"
  Detail: Key (id)=(CO_2035392092515548233) already exists.  Call getNextException to see other errors in the batch.
	at org.postgresql.jdbc.BatchResultHandler.handleError(BatchResultHandler.java:165)
	at org.postgresql.core.v3.QueryExecutorImpl.processResults(QueryExecutorImpl.java:2413)
	at org.postgresql.core.v3.QueryExecutorImpl.execute(QueryExecutorImpl.java:579)
	at org.postgresql.jdbc.PgStatement.internalExecuteBatch(PgStatement.java:912)
	at org.postgresql.jdbc.PgStatement.executeBatch(PgStatement.java:936)
	at org.postgresql.jdbc.PgPreparedStatement.executeBatch(PgPreparedStatement.java:1733)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.savePartition(JdbcUtils.scala:751)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.$anonfun$saveTable$1(JdbcUtils.scala:902)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.$anonfun$saveTable$1$adapted(JdbcUtils.scala:901)
	at org.apache.spark.rdd.RDD.$anonfun$foreachPartition$2(RDD.scala:1039)
	at org.apache.spark.rdd.RDD.$anonfun$foreachPartition$2$adapted(RDD.scala:1039)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: org.postgresql.util.PSQLException: ERROR: duplicate key value violates unique constraint "idkey"
  Detail: Key (id)=(CO_2035392092515548233) already exists.
	at org.postgresql.core.v3.QueryExecutorImpl.receiveErrorResponse(QueryExecutorImpl.java:2725)
	at org.postgresql.core.v3.QueryExecutorImpl.processResults(QueryExecutorImpl.java:2412)
	... 21 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:989)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2398)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2419)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2438)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2463)
	at org.apache.spark.rdd.RDD.$anonfun$foreachPartition$1(RDD.scala:1039)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:151)
	at org.apache.spark.rdd.RDDOperationScope$.withScope(RDDOperationScope.scala:112)
	at org.apache.spark.rdd.RDD.withScope(RDD.scala:410)
	at org.apache.spark.rdd.RDD.foreachPartition(RDD.scala:1037)
	at org.apache.spark.sql.Dataset.$anonfun$foreachPartition$1(Dataset.scala:3514)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.sql.Dataset.$anonfun$withNewRDDExecutionId$1(Dataset.scala:4309)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.Dataset.withNewRDDExecutionId(Dataset.scala:4307)
	at org.apache.spark.sql.Dataset.foreachPartition(Dataset.scala:3514)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.saveTable(JdbcUtils.scala:901)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcRelationProvider.createRelation(JdbcRelationProvider.scala:70)
	at org.apache.spark.sql.execution.datasources.SaveIntoDataSourceCommand.run(SaveIntoDataSourceCommand.scala:48)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult$lzycompute(commands.scala:75)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.sideEffectResult(commands.scala:73)
	at org.apache.spark.sql.execution.command.ExecutedCommandExec.executeCollect(commands.scala:84)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.$anonfun$applyOrElse$1(QueryExecution.scala:107)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:66)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:107)
	at org.apache.spark.sql.execution.QueryExecution$$anonfun$eagerlyExecuteCommands$1.applyOrElse(QueryExecution.scala:98)
	at org.apache.spark.sql.catalyst.trees.TreeNode.$anonfun$transformDownWithPruning$1(TreeNode.scala:461)
	at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(origin.scala:76)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDownWithPruning(TreeNode.scala:461)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.org$apache$spark$sql$catalyst$plans$logical$AnalysisHelper$$super$transformDownWithPruning(LogicalPlan.scala:32)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning(AnalysisHelper.scala:267)
	at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper.transformDownWithPruning$(AnalysisHelper.scala:263)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:32)
	at org.apache.spark.sql.catalyst.plans.logical.LogicalPlan.transformDownWithPruning(LogicalPlan.scala:32)
	at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:437)
	at org.apache.spark.sql.execution.QueryExecution.eagerlyExecuteCommands(QueryExecution.scala:98)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted$lzycompute(QueryExecution.scala:85)
	at org.apache.spark.sql.execution.QueryExecution.commandExecuted(QueryExecution.scala:83)
	at org.apache.spark.sql.execution.QueryExecution.assertCommandExecuted(QueryExecution.scala:142)
	at org.apache.spark.sql.DataFrameWriter.runCommand(DataFrameWriter.scala:859)
	at org.apache.spark.sql.DataFrameWriter.saveToV1Source(DataFrameWriter.scala:388)
	at org.apache.spark.sql.DataFrameWriter.saveInternal(DataFrameWriter.scala:361)
	at org.apache.spark.sql.DataFrameWriter.save(DataFrameWriter.scala:248)
	at org.apache.spark.sql.DataFrameWriter.jdbc(DataFrameWriter.scala:756)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:829)
Caused by: java.sql.BatchUpdateException: Batch entry 0 INSERT INTO configurations ("id","hash","last_modified","dataset_ids","metadata","chemical_formula_hill","chemical_formula_reduced","chemical_formula_anonymous","elements","elements_ratios","atomic_numbers","nsites","nelements","nperiodic_dimensions","cell","dimension_types","pbc","positions","names","labels","configuration_set_ids") VALUES (('CO_2035392092515548233'),('2035392092515548233'),('2024-05-14 18:12:55-04'::timestamp),('[''DS_y7nrdsjtuw0g_0'']'),(NULL),('H6BrCaRh2'),('BrCaH6Rh2'),('A6B2CD'),('[''Br'', ''Ca'', ''H'', ''Rh'']'),('[0.1, 0.1, 0.6, 0.2]'),('[35, 20, 1, 1, 1, 1, 1, 1, 45, 45]'),('10'::int4),('4'::int4),('0'::int4),('[[5.39874426, 0.0, 0.0], [2.6993721300000004, 4.67544967769542, 0.0], [2.6993721300000004, 1.5584832258984738, 4.4080562295931855]]'),('[0, 0, 0]'),('[False, False, False]'),('[[0.0, 0.0, 0.0], [5.398744260000001, 3.116966451796947, 2.2040281147965928], [5.398744260000001, 4.84493341938597, 3.4258852752451454], [5.39874426, 1.3889994842079243, 0.9821709543480398], [6.895207550832455, 3.980949935591458, 0.9821709543480398], [3.902280969167545, 2.252982968002436, 3.4258852752451454], [6.895207550832455, 2.252982968002436, 3.4258852752451454], [3.902280969167545, 3.980949935591458, 0.9821709543480398], [8.098116390000001, 4.67544967769542, 3.306042172194889], [2.6993721300000004, 1.5584832258984735, 1.1020140573982964]]'),('[''carolina_materials_0'']'),('[0, ''bcc'']'),(NULL)) was aborted: ERROR: duplicate key value violates unique constraint "idkey"
  Detail: Key (id)=(CO_2035392092515548233) already exists.  Call getNextException to see other errors in the batch.
	at org.postgresql.jdbc.BatchResultHandler.handleError(BatchResultHandler.java:165)
	at org.postgresql.core.v3.QueryExecutorImpl.processResults(QueryExecutorImpl.java:2413)
	at org.postgresql.core.v3.QueryExecutorImpl.execute(QueryExecutorImpl.java:579)
	at org.postgresql.jdbc.PgStatement.internalExecuteBatch(PgStatement.java:912)
	at org.postgresql.jdbc.PgStatement.executeBatch(PgStatement.java:936)
	at org.postgresql.jdbc.PgPreparedStatement.executeBatch(PgPreparedStatement.java:1733)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.savePartition(JdbcUtils.scala:751)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.$anonfun$saveTable$1(JdbcUtils.scala:902)
	at org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils$.$anonfun$saveTable$1$adapted(JdbcUtils.scala:901)
	at org.apache.spark.rdd.RDD.$anonfun$foreachPartition$2(RDD.scala:1039)
	at org.apache.spark.rdd.RDD.$anonfun$foreachPartition$2$adapted(RDD.scala:1039)
	at org.apache.spark.SparkContext.$anonfun$runJob$5(SparkContext.scala:2438)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:93)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:166)
	at org.apache.spark.scheduler.Task.run(Task.scala:141)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$4(Executor.scala:620)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:623)
	at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1128)
	at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:628)
	... 1 more
Caused by: org.postgresql.util.PSQLException: ERROR: duplicate key value violates unique constraint "idkey"
  Detail: Key (id)=(CO_2035392092515548233) already exists.
	at org.postgresql.core.v3.QueryExecutorImpl.receiveErrorResponse(QueryExecutorImpl.java:2725)
	at org.postgresql.core.v3.QueryExecutorImpl.processResults(QueryExecutorImpl.java:2412)
	... 21 more


In [112]:
with psycopg.connect(
    dbname="colabfit",
    user=os.environ.get("PGS_USER"),
    password=os.environ.get("PGS_PASS"),
    host="localhost",
) as conn:
    with conn.cursor() as cur:

        # cur.execute(
        #     "UPDATE configurations SET configuration_set_ids = configuration_set_ids || %(cs_id)s WHERE id = ANY(%(co_ids)s)",
        #     {"cs_id": cs["id"], "co_ids": co_ids},
        # )
        # data = cur.fetchall()
        cur.execute(
            "SELECT * FROM public.configurations WHERE id = ANY(%s)",
            [co_ids],
        )
        data2 = cur.fetchall()
    conn.commit()

In [109]:
co_ids.append("CO_215290934057753943")

In [113]:
data2

[]

In [None]:
dm2.load_data_to_pg_in_batches(loader)

# In progress

Upsert appears to be this for postgres:
```
update the_table
    set id = id || array[5,6]
where id = 4;
```
* ~~Check for upsert function from pyspark to concatenate lists of relationships instead of primary key id collision~~
* There is no pyspark-upsert function. Will have to manage this possibly through a different sql-based library
* Written: find duplicates, but convert to access database, not download full dataframe
* I see this being used with batches of hashes during upload: something like
    ``` for batch in batches:
            hash_duplicates = find_duplicates(batch, loader/database)
            hash_duplicates.make_change_to_append_dataset-ids
            hash_duplicates.write-to-database
* Where would be the best place to catch duplicates? Keeping in mind that this might be a bulk operation (i.e. on the order of millions, like with ANI1/ANI2x variations)

In [None]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

In [None]:
carmat_ds_id = "DS_y7nrdsjtuw0g_0"

In [None]:
from importlib import reload

import colabfit.tools.dataset
import colabfit.tools.database

reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager

Dataset = colabfit.tools.dataset.Dataset

In [None]:
def find_duplicate_hash(spark_rows: dict, loader):
    # hashes = loader.spark.createDataFrame([x["hash"] for x in spark_rows])
    hashes = [x["hash"] for x in spark_rows]
    duplicates = loader.spark.read.jdbc(
        url=url,
        table="configurations",
        properties=properties,
    ).filter(sf.col("hash").isin(hashes))
    # dupl_hashes = df.filter(df.hash.isin(hashes)).select("hash").collect()
    return duplicates

In [None]:
dupl_rows = [x.spark_row for x in dm_dup.configs]

In [None]:
config_df = (
    loader.spark.read.jdbc(
        url=loader.url, table=loader.config_table, properties=loader.properties
    )
    .withColumn(
        "ds_ids_unstrung",
        sf.from_json(sf.col("dataset_ids"), sf.ArrayType(sf.StringType())),
    )
    .filter(sf.array_contains("ds_ids_unstrung", dm.dataset_id))
    .drop("ds_ids_unstrung")
)

In [None]:
dup_hashes = find_duplicate_hash(dupl_rows, loader)

In [None]:
dup_hashes.show(10, False)

In [None]:
[x["hash"] for x in dup_hashes]

# outer

In [None]:
"""
Can we make the configuration and the property instance/data object at the same time?
In this way, we would only have to pass through the data one time.

Workflow:
create database access object
create data reader as function? of the database access object
reader returns ase.Atoms-style objects (AtomicConfiguration)
DOs and PIs are now one object
These DOs point to a configuration
The configuration may already exist in the database, so we keep track of the hash added to the DO


"""

In [None]:
cos = json.load(Path("sample_db/co_ds1.json").open("r"))

In [None]:
with open(Path("sample_db/co_ds1.json"), "r") as f:
    co_json = spark.sparkContext.parallelize(json.load(f))

In [None]:
co = co_json.map(_parse_config).map(stringify_lists)
co_df = spark.createDataFrame(co, config_schema)

In [None]:
def parse_configs(co_path, spark):
    with open(co_path, "r") as f:
        co_json = spark.sparkContext.parallelize(json.load(f))
    co = co_json.map(_parse_config).map(stringify_lists)
    co_df = spark.createDataFrame(co, config_schema)
    return co_df

In [None]:
parse_configs("sample_db/co_ds1.json", spark).show()