# Imports

In [2]:
import datetime
import json
import os
from collections import defaultdict
import pickle
from tqdm import tqdm

# from functools import partial
# from itertools import chain, islice
# from multiprocessing import Pool, cpu_count
from pathlib import Path

# from pprint import pprint

import dateutil.parser
import findspark
import lmdb
import numpy as np
import psycopg
import pyspark.sql.functions as sf
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)
from colabfit.tools.schema import (
    property_object_schema,
    config_df_schema,
    config_schema,
    property_object_df_schema,
)
from colabfit.tools.configuration import AtomicConfiguration, config_schema
from colabfit.tools.database import DataManager, PGDataLoader
from colabfit.tools.dataset import Dataset, dataset_schema
from colabfit.tools.property import Property, property_object_schema
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)
from colabfit.tools.schema import configuration_set_schema
import pyarrow as pa

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)
findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

# Set up MTPU and Carolina Materials readers and data

In [18]:
# MTPU data


def convert_stress(keys, stress):
    stresses = {k: s for k, s in zip(keys, stress)}
    return [
        [stresses["xx"], stresses["xy"], stresses["xz"]],
        [stresses["xy"], stresses["yy"], stresses["yz"]],
        [stresses["xz"], stresses["yz"], stresses["zz"]],
    ]


SYMBOL_DICT = {"0": "Si", "1": "O"}


def mtpu_reader(filepath):
    with open(filepath, "rt") as f:
        energy = None
        forces = None
        coords = []
        cell = []
        symbols = []
        config_count = 0
        info = dict()
        for line in f:
            if line.strip().startswith("Size"):
                size = int(f.readline().strip())
            elif line.strip().lower().startswith("supercell"):
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
            elif line.strip().startswith("Energy"):
                energy = float(f.readline().strip())
            elif line.strip().startswith("PlusStress"):
                stress_keys = line.strip().split()[-6:]
                stress = [float(x) for x in f.readline().strip().split()]
                stress = convert_stress(stress_keys, stress)
            elif line.strip().startswith("AtomData:"):
                keys = line.strip().split()[1:]
                if "fx" in keys:
                    forces = []
                for i in range(size):
                    li = {
                        key: val for key, val in zip(keys, f.readline().strip().split())
                    }
                    symbols.append(SYMBOL_DICT[li["type"]])
                    if "cartes_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["cartes_x"],
                                    li["cartes_y"],
                                    li["cartes_z"],
                                ]
                            ]
                        )
                    elif "direct_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["direct_x"],
                                    li["direct_y"],
                                    li["direct_z"],
                                ]
                            ]
                        )

                    if "fx" in keys:
                        forces.append(
                            [float(f) for f in [li["fx"], li["fy"], li["fz"]]]
                        )

            elif line.startswith("END_CFG"):

                info["energy"] = energy
                if forces:
                    info["forces"] = forces
                info["stress"] = stress

                if "Si" in symbols and "O" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "11x11x11",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_SiO2_{config_count}"
                elif "Si" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "8x8x8",
                        "kinetic-energy-cutoff": {
                            "val": 884,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_Si_{config_count}"
                elif "O" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "gamma-point",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_O_{config_count}"
                if "cartes_x" in keys:
                    config = AtomicConfiguration(
                        positions=coords, symbols=symbols, cell=cell, info=info
                    )
                elif "direct_x" in keys:
                    config = AtomicConfiguration(
                        scaled_positions=coords, symbols=symbols, cell=cell, info=info
                    )
                config_count += 1
                yield config
                forces = None
                stress = []
                coords = []
                cell = []
                symbols = []
                energy = None

In [19]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
data = list(mtpu_configs)
# data = [x for x in mtpu_configs]
# data[0].configuration_summary()

In [20]:
import colabfit.tools.configuration
from importlib import reload

reload(colabfit.tools.configuration)
AtomicConfiguration = colabfit.tools.configuration.AtomicConfiguration

In [23]:
# Carolina Materials data

SOFTWARE = "VASP"
METHODS = "DFT-PBE"
CM_PI_METADATA = {
    "software": {"value": SOFTWARE},
    "method": {"value": METHODS},
    "input": {"value": {"IBRION": 6, "NFREE": 4}},
}

CM_PROPERTY_MAP = {
    "formation-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
        }
    ],
    "_metadata": CM_PI_METADATA,
}
CO_MD = {
    key: {"field": key}
    for key in [
        "_symmetry_space_group_name_H-M",
        "_symmetry_Int_Tables_number",
        "_chemical_formula_structural",
        "_chemical_formula_sum",
        "_cell_volume",
        "_cell_formula_units_Z",
        "symmetry_dict",
        "formula_pretty",
    ]
}


def load_row(txn, row):
    try:
        data = pickle.loads(txn.get(f"{row}".encode("ascii")))
        return data
    except TypeError:
        return False


def config_from_row(row: dict, row_num: int):
    coords = row.pop("cart_coords")
    a_num = row.pop("atomic_numbers")
    cell = [
        row.pop(x)
        for x in [
            "_cell_length_a",
            "_cell_length_b",
            "_cell_length_c",
            "_cell_angle_alpha",
            "_cell_angle_beta",
            "_cell_angle_gamma",
        ]
    ]
    symmetry_dict = {str(key): val for key, val in row.pop("symmetry_dict").items()}
    for key in symmetry_dict:
        key = str(key)
    info = {}
    info = row
    info["symmetry_dict"] = symmetry_dict
    info["_name"] = f"carolina_materials_{row_num}"
    if row_num % 10 == 0:
        info["_labels"] = [row_num % 10, "bcc"]
    else:
        info["_labels"] = [row_num % 10, "fcc"]
    config = AtomicConfiguration(
        scaled_positions=coords,
        numbers=a_num,
        cell=cell,
        info=info,
    )
    return config
    # return AtomicConfiguration.from_ase(config)


def carmat_reader(fp: Path):
    parent = fp.parent
    env = lmdb.open(str(parent))
    txn = env.begin()
    row_num = 0
    rows = []
    while row_num <= 100000:
        row = load_row(txn, row_num)
        if row is False:
            env.close()
            break
        rows.append(row)
        yield config_from_row(row, row_num)
        row_num += 1
    env.close()
    return False
    # return rows

In [24]:
PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}

# Connect to DB and run loader

In [23]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/21 14:49:53 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/21 14:49:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/21 14:49:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/21 14:49:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))

PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}
spark = SparkSession.builder.appName("ColabfitIngestData").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
# print(loader.spark)
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = list(mtpu_configs)
print(mtpu_configs[0])
co_po_rows = []
for config in tqdm(mtpu_configs):
    config.set_dataset_id(mtpu_ds_id)
    co_po_rows.append(
        (
            config.spark_row,
            Property.from_definition(
                [potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
                configuration=config,
                property_map=PROPERTY_MAP,
            ).spark_row,
        )
    )

24/05/20 17:14:41 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


AtomicConfiguration(name=Unified_training_set_SiO2_1061, symbols='Si4', pbc=False, cell=[[3.85085, 0.0, 0.077017], [-1.925425, 3.334933, -0.038508], [0.127258, 0.0, 6.362934]])


100%|██████████| 1062/1062 [00:00<00:00, 2495.36it/s]


In [None]:
co_po_rows[0][0]

{'id': 'CO_47706510123393079',
 'hash': 47706510123393079,
 'last_modified': datetime.datetime(2024, 5, 20, 17, 14, 41),
 'dataset_ids': "['DS_y7nrdsjtuwom_0']",
 'metadata': None,
 'chemical_formula_hill': 'Si4',
 'chemical_formula_reduced': 'Si',
 'chemical_formula_anonymous': 'A',
 'elements': "['Si']",
 'elements_ratios': '[1.0]',
 'atomic_numbers': '[14, 14, 14, 14]',
 'nsites': 4,
 'nelements': 1,
 'nperiodic_dimensions': 0,
 'cell': '[[3.85085, 0.0, 0.077017], [-1.925425, 3.334933, -0.038508], [0.127258, 0.0, 6.362934]]',
 'dimension_types': '[0, 0, 0]',
 'pbc': '[False, False, False]',
 'positions': '[[1.892001, 1.11132, 0.400465], [1.955509, -1.11132, 3.581973], [1.895339, -1.11132, -0.400508], [1.958847, 1.11132, 2.781]]',
 'names': "['Unified_training_set_Si_0']",
 'labels': None,
 'configuration_set_ids': None}

In [None]:
import datetime

In [None]:
dateutil.parser.parse(datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"))

datetime.datetime(2024, 5, 20, 17, 10, 24)

In [None]:
try:
    print("making co po rows")
    co_rows, po_rows = list(zip(*co_po_rows))
    print("making cos dataframes...")
    cos_dataframe = spark.createDataFrame(co_rows, schema=config_schema)
    print("Done!")
    print(cos_dataframe.show(1, False))
    print("making pos dataframes...")
    pos_dataframe = spark.createDataFrame(po_rows, schema=property_object_schema)
    print("Done!")
    pos_dataframe.show(1, False)
    try:
        # loader.write_table(
        #     co_rows,
        #     loader.config_table,
        #     config_schema,
        # )
        # loader.write_table(
        #     po_rows,
        #     loader.prop_object_table,
        #     property_object_schema,
        # )
        print(loader.config_table)
        cos_dataframe.write.mode("append").saveAsTable(loader.config_table)
        print(loader.prop_object_table)
        pos_dataframe.write.mode("append").saveAsTable(loader.prop_object_table)
    except:
        print("loader write failed")
except:
    print("error getting df")

In [24]:
config_df = loader.spark.read.jdbc(
    url=url, table="configurations", properties=properties
)

In [101]:
def append_elem(col_array, elem):
    print(col_array)
    unstrung = eval(col_array)
    unstrung.append(elem)
    unstrung = list(set(unstrung))
    return str(unstrung)

In [25]:
config_df.show(1, False)

+----------------------+-------------------+-------------------+---------------------+--------+---------------------+------------------------+--------------------------+----------------+-----------------+----------------------------+------+---------+--------------------+-------------------------------------------------------------------------------------------------+---------------+---------------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------------+----------+---------------------+
|id                    |hash               |last_modified      |dataset_ids          |metadata|chemical_formula_hill|chemical

In [None]:
# NEXT: find a way to check wehther cs-ids is null and handle the column =+ cs_id as array
# OR find a way to use the lambda function to use the user-defined function (append element)

In [None]:
labels = ["fcc", 6]
config_set_id = "test_config_set_id"
config_df.withColumn("filter_labels", sf.lit(labels)).withColumn(
    "labels_unstrung", sf.from_json(sf.col("labels"), ArrayType(StringType()))
).withColumn(
    "has_labels",
    sf.forall(
        "filter_labels",
        lambda x: sf.array_contains(col=sf.col("labels_unstrung"), value=x),
    ),
).withColumn(
    "new_cs_id", sf.lit([config_set_id])
).withColumn(
    "configuration_set_ids",
    sf.when(
        condition=sf.col("has_labels") == True,
        value=sf.array_union(config_df["configuration_set_ids"], sf.col("dataset_ids")),
    ),
    # .otherwise(config_df["configuration_set_ids"]),
)
# .withColumn(
#     "configuration_set_ids",
#     sf.transform_values(
#         "configuration_set_ids", lambda k, cs_ids: append_elem(cs_ids, config_set_id)
#     ),
# )
# .show(
#     10, False
# )

df.rdd  
you can only parallelize one time so don't try to do a dataframe select from an rdd  
updating to sdk 5.1 in a couple weeks  
boto3 and s3 are the amazon file system interactions, mostly for adding metadata TO FILES (not to the database) and interacting with the files as FileExistsError. 
Make sure to spark.stop() at end of  python file.

In [16]:
carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"
# carmat_ds_id2 = "duplicate_ds_id"
dm = DataManager(
    nprocs=4,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
# dm_dup = DataManager(
#     nprocs=4,
#     configs=carmat_config_gen,
#     prop_defs=[formation_energy_pd],
#     prop_map=CM_PROPERTY_MAP,
#     dataset_id=carmat_ds_id2,
# )

NameError: name 'carmat_reader' is not defined

In [26]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

Dataset ID: DS_y7nrdsjtuwom_0


In [12]:
batch = next(dm.gather_co_po_in_batches())

In [13]:
cos = [x[0] for x in batch]
cos_dataframe = loader.spark.createDataFrame(cos, schema=config_schema)

In [11]:
rows = loader.spark.createDataFrame(cos, schema=config_schema).collect()

                                                                                

In [30]:
import pyarrow as pa


def spark_to_arrow_type(spark_type):
    type_name = spark_type.dataType.typeName()
    if type_name in data_type_map:
        print(data_type_map[type_name])
        return pa.field(spark_type.name, data_type_map[type_name])
    else:
        raise ValueError(f"Unsupported PySpark data type: {spark_type}")


data_type_map = {
    "string": pa.string(),
    "integer": pa.int32(),
    "float": pa.float64(),
    # DoubleType: pa.float64(),
    # BooleanType: pa.bool_(),
    "timestamp": pa.timestamp("ns"),
    # DateType: pa.date32(),
    # ArrayType: lambda dt: pa.list_(convert_type(dt)),
    "struct": lambda st: pa.struct([spark_to_arrow_type(f) for f in st.fields]),
}

In [14]:
config_set_query = (
    cos_dataframe.withColumn(
        "names_unstrung", sf.from_json(sf.col("names"), ArrayType(StringType()))
    )
    .withColumn(
        "labels_unstrung",
        sf.from_json(sf.col("labels"), ArrayType(StringType())),
    )
    .withColumn(
        "dataset_ids_unstrung", sf.from_json("dataset_ids", ArrayType(StringType()))
    )
    .drop("names", "labels", "dataset_ids")
    .withColumnRenamed("names_unstrung", "names")
    .withColumnRenamed("labels_unstrung", "labels")
    .withColumnRenamed("dataset_ids_unstrung", "dataset_ids")
    .filter(sf.array_contains(sf.col("dataset_ids"), carmat_ds_id))
)

In [15]:
config_set_query.show(1, False)

+----------------------+-------------------+-------------------+--------+---------------------+------------------------+--------------------------+-----------------------+--------------------+----------------------------------+------+---------+--------------------+-----------------------------------------------------------------------------------------------------------------------------------+---------------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------

In [17]:
label_match = "fcc|1"
names_match = "materials_10"
if names_match:
    config_set_query = (
        config_set_query.withColumn("labels_exploded", sf.explode(sf.col("labels")))
        .withColumn("names_exploded", sf.explode(sf.col("names")))
        .drop("names")
        .filter(sf.regexp_like(sf.col("names_exploded"), sf.lit(rf"{names_match}")))
    )
if label_match:
    config_set_query = config_set_query.filter(
        sf.regexp_like(sf.col("labels_exploded"), sf.lit(rf"{label_match}"))
    )
config_set_query.show(1, False)

+---------------------+------------------+-------------------+--------+---------------------+------------------------+--------------------------+------------------------+--------------------+----------------------------------------+------+---------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+---------------+---------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+--------+-------------------+--------------

In [20]:
from importlib import reload
import colabfit.tools.configuration_set

reload(colabfit.tools.configuration_set)
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
reload(colabfit.tools.database)
from colabfit.tools.database import DataManager, PGDataLoader

In [16]:
co_ids = [x[0] for x in config_set_query.select("id").distinct().collect()]

In [22]:
CS = ConfigurationSet(
    name="test", config_df=config_set_query, description="test description"
)

<class 'list'>
[Row(element='Ga', ratio=0.012887595268940222), Row(element='I', ratio=0.016706765146289347), Row(element='Pt', ratio=0.012332385551087706), Row(element='Se', ratio=0.016555344314147753), Row(element='Tl', ratio=0.006023184211854569), Row(element='Ni', ratio=0.017144203105809514), Row(element='Os', ratio=0.011339737873715026), Row(element='Co', ratio=0.014014839241549877), Row(element='Fe', ratio=0.016269327186769184), Row(element='Ru', ratio=0.012214613792755354), Row(element='Mg', ratio=0.010919124451099483), Row(element='Ti', ratio=0.02444605212241533), Row(element='Ag', ratio=0.008883355485640258), Row(element='H', ratio=0.07577771421841614), Row(element='Te', ratio=0.012988542490367953), Row(element='Al', ratio=0.018389218836751518), Row(element='C', ratio=0.01552904756296583), Row(element='S', ratio=0.027171627100964046), Row(element='Li', ratio=0.021552231774820397), Row(element='Ca', ratio=0.007991655029695307), Row(element='Zr', ratio=0.010077897605868398), Row(

In [29]:
def update_co_rows_cs_id(self, co_ids: list[str], cs_id: str):
    with psycopg.connect(
        """dbname=colabfit user=%s password=%s host=localhost port=5432"""
        % (
            user,
            password,
        )
    ) as conn:
        # dbname=self.database_name,
        # user=self.properties["user"],
        # password=self.properties["password"],
        # host="localhost",
        # port="5432",
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = 
            """
        )
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = concat(%s::text, 
                rtrim(ltrim(replace(configuration_set_ids,%s,''), 
                
                '['),']') || ', ', %s::text)
            WHERE id = ANY(%s)""",
            ("[", f"{cs_id}", f"{cs_id}]", co_ids),
            # ("[", f", {cs_id}", f", {cs_id}]"),
        )
        # cur.fetchall()
        conn.commit()

In [None]:
# You were trying to get  postgresql to recognize the WHERE id = ANY() array syntax

In [28]:
update_co_rows_cs_id(loader, co_ids, CS.spark_row["id"])

In [78]:
from colabfit.tools.utilities import _empty_dict_from_schema
from colabfit.tools.schema import configuration_set_schema
import dateutil.parser

cs = _empty_dict_from_schema(configuration_set_schema)
cs["nconfigurations"] = 200
cs["dataset_id"] = carmat_ds_id
cs["name"] = "test"
cs["description"] = "test description for test"
cs["nelements"] = 25
cs["last_modified"] = dateutil.parser.parse(
    datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
)
cs["id"] = "CS_y7nrdsjtuw0g_0"
cs["hash"] = hash(cs["name"])

In [79]:
loader.write_table([cs], "configuration_sets", configuration_set_schema)

In [111]:
co_ids = [x[0] for x in config_set_query.select("id").distinct().collect()]

In [None]:
dm.load_data_to_pg_in_batches(loader)

In [112]:
with psycopg.connect(
    dbname="colabfit",
    user=os.environ.get("PGS_USER"),
    password=os.environ.get("PGS_PASS"),
    host="localhost",
) as conn:
    with conn.cursor() as cur:

        # cur.execute(
        #     "UPDATE configurations SET configuration_set_ids = configuration_set_ids || %(cs_id)s WHERE id = ANY(%(co_ids)s)",
        #     {"cs_id": cs["id"], "co_ids": co_ids},
        # )
        # data = cur.fetchall()
        cur.execute(
            "SELECT * FROM public.configurations WHERE id = ANY(%s)",
            [co_ids],
        )
        data2 = cur.fetchall()
    conn.commit()

In [109]:
co_ids.append("CO_215290934057753943")

In [113]:
data2

[]

In [None]:
dm2.load_data_to_pg_in_batches(loader)

# In progress

Upsert appears to be this for postgres:
```
update the_table
    set id = id || array[5,6]
where id = 4;
```
* ~~Check for upsert function from pyspark to concatenate lists of relationships instead of primary key id collision~~
* There is no pyspark-upsert function. Will have to manage this possibly through a different sql-based library
* Written: find duplicates, but convert to access database, not download full dataframe
* I see this being used with batches of hashes during upload: something like
    ``` for batch in batches:
            hash_duplicates = find_duplicates(batch, loader/database)
            hash_duplicates.make_change_to_append_dataset-ids
            hash_duplicates.write-to-database
* Where would be the best place to catch duplicates? Keeping in mind that this might be a bulk operation (i.e. on the order of millions, like with ANI1/ANI2x variations)

In [9]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/30 09:52:06 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/30 09:52:06 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/30 09:52:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/30 09:52:08 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

In [27]:
self = dm2


def find_existing_rows_append_elem(
    self,
    table_name: str,
    ids: list[str],
    cols: list[str],
    elems: list[str],
    edit_schema: StructType,
    write_schema: StructType,
):
    if isinstance(cols, str):
        cols = [cols]
    if isinstance(elems, str):
        elems = [elems]
    col_types = {"id": StringType(), "$row_id": IntegerType()}
    edit_col_types = {"id": StringType(), "$row_id": IntegerType()}
    for col in cols:
        col_types[col] = get_spark_field_type(write_schema, col)
        edit_col_types[col] = get_spark_field_type(edit_schema, col)
    update_cols = [col for col in col_types if col != "id"]
    query_schema = StructType(
        [
            StructField(col, col_types[col], False)
            for i, col in enumerate(cols + ["id", "$row_id"])
        ]
    )
    edit_schema = StructType(
        [
            StructField(col, edit_col_types[col], False)
            for i, col in enumerate(cols + ["id", "$row_id"])
        ]
    )
    partial_batch_to_rdd = partial(arrow_record_batch_to_rdd, query_schema)
    with self.session.transaction() as tx:
        # string would be 'ndb.colabfit.dev.[table name]'
        table_path = table_name.split(".")
        table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
        rec_batch = table.select(
            predicate=_.id.isin(ids), columns=cols + ["id"], internal_row_id=True
        )
        rdd = self.spark.sparkContext.parallelize([])
        for batch in rec_batch:
            rdd = rdd.union(
                self.spark.sparkContext.parallelize(list(partial_batch_to_rdd(batch)))
            )
    rdd = rdd.map(unstringify_row_dict)

    def add_elem_to_row_dict(col, elem, row_dict):
        val = row_dict.get(col, [])
        row_dict[col] = list(set(val + [elem]))
        return row_dict

    for col, elem in zip(cols, elems):
        partial_add = partial(add_elem_to_row_dict, col, elem)
        rdd = rdd.map(partial_add)
    update_ids = rdd.map(lambda x: x["id"]).collect()
    new_ids = [id for id in ids if id not in update_ids]
    rdd = rdd.map(stringify_row_dict)
    rdd_collect = rdd.map(lambda x: [x[col] for col in update_cols]).collect()
    update_schema = StructType(
        [StructField(col, col_types[col], False) for col in update_cols]
    )
    arrow_schema = spark_schema_to_arrow_schema(update_schema)
    update_table = pa.table(
        [pa.array(col) for col in zip(*rdd_collect)], schema=arrow_schema
    )
    with self.session.transaction() as tx:
        table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
        table.update(rows=update_table)
    return (new_ids, update_ids)

In [None]:
loader.spark.sql(f"drop table if exists {loader.config_table}")
loader.spark.sql(f"drop table if exists {loader.prop_object_table}")

In [None]:
config_list = list(mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg")))
dm2.configs = config_list[:50]
dm2.load_co_po_to_vastdb(loader)
dm2.configs = config_list[25:]
dm2.load_co_po_to_vastdb(loader)

In [None]:
from importlib import reload

import colabfit.tools.dataset
import colabfit.tools.database

reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager

# Dataset = colabfit.tools.dataset.Dataset
##############

import json
import lmdb
import pickle
from colabfit.tools.database import DataManager, SparkDataLoader

carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"
loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
load_dotenv()
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
loader.set_vastdb_session(
    endpoint=endpoint, access_key=access_key, access_secret=access_secret
)

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)

dm = DataManager(
    nprocs=1,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
dm.configs = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
dm.load_co_po_to_vastdb(loader)

In [None]:
batches = dm2.gather_co_po_in_batches()
batch = next(batches)
cos, pos = zip(*batch)
rdd = loader.spark.sparkContext.parallelize(cos)
ids_coll = rdd.map(lambda x: x["id"]).collect()
loader.spark.read.table(loader.config_table).select(sf.col("id")).filter(
    sf.col("id").isin(broadcast_ids.value)
).count()

In [29]:
dataset_id = mtpu_ds_id
name_label_match = [
    (".*Si.*3.*", "None", "All_si_with_zero", "All Si with zero description"),
    (".*Si.*4.*", "None", "All_si_with_two", "All Si with two description"),
]
find_existing_rows_append_elem(
    loader,
    table_name=loader.config_table,
    ids=co_ids,
    cols="configuration_set_ids",
    elems="test1_cs-id",
    edit_schema=config_df_schema,
    write_schema=config_schema,
)

In [None]:
def create_configuration_set(
    self,
    loader,
    # below args in order:
    # [config-name-regex-pattern], [config-label-regex-pattern], \
    # [config-set-name], [config-set-description]
    name_label_match: list[tuple],
    dataset_id: str,
):
    # Load unstrung dataframe of configs, filter for just includes ds-id
    config_df = loader.read_table(table_name=loader.config_table, unstring=True)
    config_df = config_df.filter(sf.array_contains(sf.col("dataset_ids"), dataset_id))
    # for each set of name-label matches, filter the config_df
    # find the secret row id for the matching columns
    # use vastdb-sdk to update configuration_set_ids column with new configuration set id
    # Should be able to use current "find_dups_append_elem_sdk" function, but should rename this
    # to reflect general update usage, rather than find duplicates only,
    # or if necessary write very similar function that only updates, but assumes existence of row
    # in the table, since we're checking for that here already.
    for i, (names_match, label_match, cs_name, cs_desc) in enumerate(name_label_match):
        print(
            f"names match: {names_match}, label {label_match}, cs_name {cs_name}, cs_desc {cs_desc}"
        )
        if names_match:
            config_set_query = config_df.withColumn(
                "names_exploded", sf.explode(sf.col("names"))
            ).filter(sf.col("names_exploded").rlike(names_match))
        # Currently an AND operation on labels: labels col contains x AND y
        # if label_match:
        #     if isinstance(label_match, str):
        #         label_match = [label_match]
        #     for label in label_match:
        #         config_set_query = config_set_query.filter(
        #             sf.array_contains(sf.col("labels"), label)
        #         )
    co_ids = [x["id"] for x in config_set_query.select("id").distinct().collect()]
    find_existing_rows_append_elem(
        loader,
        table_name=loader.config_table,
        ids=co_ids,
        cols="configuration_set_ids",
        elems=cs_name,
        edit_schema=config_df_schema,
        write_schema=config_schema,
    )
    config_set = ConfigurationSet(
        name=cs_name,
        description=cs_desc,
        config_df=config_set_query,
    )
    row = config_set.spark_row
    loader.write_table([row], loader.config_set_table, schema=configuration_set_schema)

In [None]:
create_configuration_set(dm2, loader, name_label_match, dataset_id)
"""  File "<stdin>", line 37, in find_existing_rows_append_elem
  File "/ext3/miniconda3/lib/python3.12/site-packages/vastdb/table.py", line 351, in select
    query_data_request = internal_commands.build_query_data_request(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ext3/miniconda3/lib/python3.12/site-packages/vastdb/internal_commands.py", line 2084, in build_query_data_request
    filter_obj = predicate.serialize(builder)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ext3/miniconda3/lib/python3.12/site-packages/vastdb/internal_commands.py", line 245, in serialize
    raise NotImplementedError(self.expr)  # an empty OR is equivalent to a 'FALSE' literal
"""

In [None]:
from importlib import reload

import colabfit.tools.dataset
import colabfit.tools.database

reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager

Dataset = colabfit.tools.dataset.Dataset

In [None]:
def find_duplicate_hash(spark_rows: dict, loader):
    # hashes = loader.spark.createDataFrame([x["hash"] for x in spark_rows])
    hashes = [x["hash"] for x in spark_rows]
    duplicates = loader.spark.read.jdbc(
        url=url,
        table="configurations",
        properties=properties,
    ).filter(sf.col("hash").isin(hashes))
    # dupl_hashes = df.filter(df.hash.isin(hashes)).select("hash").collect()
    return duplicates