# Imports

In [2]:
import datetime
import json
from time import time
import os
from collections import defaultdict
import pickle
from tqdm import tqdm

# from functools import partial
# from itertools import chain, islice
# from multiprocessing import Pool, cpu_count
from pathlib import Path

# from pprint import pprint

import dateutil.parser
import findspark
import lmdb
import numpy as np
import psycopg
import pyspark.sql.functions as sf
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)
from colabfit.tools.schema import (
    property_object_schema,
    config_df_schema,
    config_schema,
    property_object_df_schema,
)
from colabfit.tools.configuration import AtomicConfiguration, config_schema
from colabfit.tools.database import DataManager, PGDataLoader
from colabfit.tools.dataset import Dataset, dataset_schema
from colabfit.tools.property import Property, property_object_schema
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)
from colabfit.tools.schema import configuration_set_schema
import pyarrow as pa

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)
findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

# Set up MTPU and Carolina Materials readers and data

In [18]:
# MTPU data


def convert_stress(keys, stress):
    stresses = {k: s for k, s in zip(keys, stress)}
    return [
        [stresses["xx"], stresses["xy"], stresses["xz"]],
        [stresses["xy"], stresses["yy"], stresses["yz"]],
        [stresses["xz"], stresses["yz"], stresses["zz"]],
    ]


SYMBOL_DICT = {"0": "Si", "1": "O"}


def mtpu_reader(filepath):
    with open(filepath, "rt") as f:
        energy = None
        forces = None
        coords = []
        cell = []
        symbols = []
        config_count = 0
        info = dict()
        for line in f:
            if line.strip().startswith("Size"):
                size = int(f.readline().strip())
            elif line.strip().lower().startswith("supercell"):
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
            elif line.strip().startswith("Energy"):
                energy = float(f.readline().strip())
            elif line.strip().startswith("PlusStress"):
                stress_keys = line.strip().split()[-6:]
                stress = [float(x) for x in f.readline().strip().split()]
                stress = convert_stress(stress_keys, stress)
            elif line.strip().startswith("AtomData:"):
                keys = line.strip().split()[1:]
                if "fx" in keys:
                    forces = []
                for i in range(size):
                    li = {
                        key: val for key, val in zip(keys, f.readline().strip().split())
                    }
                    symbols.append(SYMBOL_DICT[li["type"]])
                    if "cartes_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["cartes_x"],
                                    li["cartes_y"],
                                    li["cartes_z"],
                                ]
                            ]
                        )
                    elif "direct_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["direct_x"],
                                    li["direct_y"],
                                    li["direct_z"],
                                ]
                            ]
                        )

                    if "fx" in keys:
                        forces.append(
                            [float(f) for f in [li["fx"], li["fy"], li["fz"]]]
                        )

            elif line.startswith("END_CFG"):

                info["energy"] = energy
                if forces:
                    info["forces"] = forces
                info["stress"] = stress

                if "Si" in symbols and "O" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "11x11x11",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_SiO2_{config_count}"
                elif "Si" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "8x8x8",
                        "kinetic-energy-cutoff": {
                            "val": 884,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_Si_{config_count}"
                elif "O" in symbols:
                    info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "gamma-point",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    info["_name"] = f"{filepath.stem}_O_{config_count}"
                if "cartes_x" in keys:
                    config = AtomicConfiguration(
                        positions=coords, symbols=symbols, cell=cell, info=info
                    )
                elif "direct_x" in keys:
                    config = AtomicConfiguration(
                        scaled_positions=coords, symbols=symbols, cell=cell, info=info
                    )
                config_count += 1
                yield config
                forces = None
                stress = []
                coords = []
                cell = []
                symbols = []
                energy = None

In [19]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
data = list(mtpu_configs)
# data = [x for x in mtpu_configs]
# data[0].configuration_summary()

In [20]:
import colabfit.tools.configuration
from importlib import reload

reload(colabfit.tools.configuration)
AtomicConfiguration = colabfit.tools.configuration.AtomicConfiguration

In [23]:
# Carolina Materials data

SOFTWARE = "VASP"
METHODS = "DFT-PBE"
CM_PI_METADATA = {
    "software": {"value": SOFTWARE},
    "method": {"value": METHODS},
    "input": {"value": {"IBRION": 6, "NFREE": 4}},
}

CM_PROPERTY_MAP = {
    "formation-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
        }
    ],
    "_metadata": CM_PI_METADATA,
}
CO_MD = {
    key: {"field": key}
    for key in [
        "_symmetry_space_group_name_H-M",
        "_symmetry_Int_Tables_number",
        "_chemical_formula_structural",
        "_chemical_formula_sum",
        "_cell_volume",
        "_cell_formula_units_Z",
        "symmetry_dict",
        "formula_pretty",
    ]
}


def load_row(txn, row):
    try:
        data = pickle.loads(txn.get(f"{row}".encode("ascii")))
        return data
    except TypeError:
        return False


def config_from_row(row: dict, row_num: int):
    coords = row.pop("cart_coords")
    a_num = row.pop("atomic_numbers")
    cell = [
        row.pop(x)
        for x in [
            "_cell_length_a",
            "_cell_length_b",
            "_cell_length_c",
            "_cell_angle_alpha",
            "_cell_angle_beta",
            "_cell_angle_gamma",
        ]
    ]
    symmetry_dict = {str(key): val for key, val in row.pop("symmetry_dict").items()}
    for key in symmetry_dict:
        key = str(key)
    info = {}
    info = row
    info["symmetry_dict"] = symmetry_dict
    info["_name"] = f"carolina_materials_{row_num}"
    if row_num % 10 == 0:
        info["_labels"] = [row_num % 10, "bcc"]
    else:
        info["_labels"] = [row_num % 10, "fcc"]
    config = AtomicConfiguration(
        scaled_positions=coords,
        numbers=a_num,
        cell=cell,
        info=info,
    )
    return config
    # return AtomicConfiguration.from_ase(config)


def carmat_reader(fp: Path):
    parent = fp.parent
    env = lmdb.open(str(parent))
    txn = env.begin()
    row_num = 0
    rows = []
    while row_num <= 100000:
        row = load_row(txn, row_num)
        if row is False:
            env.close()
            break
        rows.append(row)
        yield config_from_row(row, row_num)
        row_num += 1
    env.close()
    return False
    # return rows

In [24]:
PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}

# Connect to DB and run loader

In [23]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/21 14:49:53 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/21 14:49:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/21 14:49:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/21 14:49:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))

PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}
spark = SparkSession.builder.appName("ColabfitIngestData").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
# print(loader.spark)
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = list(mtpu_configs)
print(mtpu_configs[0])
co_po_rows = []
for config in tqdm(mtpu_configs):
    config.set_dataset_id(mtpu_ds_id)
    co_po_rows.append(
        (
            config.spark_row,
            Property.from_definition(
                [potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
                configuration=config,
                property_map=PROPERTY_MAP,
            ).spark_row,
        )
    )

24/05/20 17:14:41 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


AtomicConfiguration(name=Unified_training_set_SiO2_1061, symbols='Si4', pbc=False, cell=[[3.85085, 0.0, 0.077017], [-1.925425, 3.334933, -0.038508], [0.127258, 0.0, 6.362934]])


100%|██████████| 1062/1062 [00:00<00:00, 2495.36it/s]


In [24]:
config_df = loader.spark.read.jdbc(
    url=url, table="configurations", properties=properties
)

In [None]:
from colabfit.tools.database import batched


def gather_co_po_in_batches_no_pool(self):
    chunk_size = 10000
    config_chunks = batched(self.configs, chunk_size)
    for chunk in config_chunks:
        yield list(
            self._gather_co_po_rows(
                self.prop_defs, self.prop_map, self.dataset_id, chunk
            )
        )

In [None]:
from colabfit.tools.utilities import (
    add_elem_to_row_dict,
    unstringify_row_dict,
    stringify_row_dict,
    get_spark_field_type,
    spark_schema_to_arrow_schema,
    arrow_record_batch_to_rdd,
)
from functools import partial


def append_ith_element_to_rdd(row_elem):
    """
    row_elem: tuple created by joining two RDD.zipWithIndex
    new_co_ids: list of configuration ids
    """
    (index, (po_row, new_co_ids)) = row_elem
    val = po_row.get("configuration_ids")
    if val is None:
        val = new_co_ids
    else:
        val.extend(new_co_ids)
        val = list(set(val))
    po_row["configuration_ids"] = val
    return po_row


def find_existing_pos_append_elems(
    self,
    rdd,
    table_name: str,
    ids: list[str],
    cols: list[str],
    elems: list[str],
    write_schema: StructType,
):
    if isinstance(cols, str):
        cols = [cols]
    if isinstance(elems, str):
        elems = [elems]
    col_types = {"id": StringType(), "$row_id": IntegerType()}
    for col in cols:
        col_types[col] = get_spark_field_type(write_schema, col)
    update_cols = [col for col in col_types if col != "id"]
    query_schema = StructType(
        [
            StructField(col, col_types[col], False)
            for i, col in enumerate(cols + ["id", "$row_id"])
        ]
    )
    partial_batch_to_rdd = partial(arrow_record_batch_to_rdd, query_schema)
    batched_ids = batched(ids, 10000)
    new_ids = []
    existing_ids = []
    for id_batch in batched_ids:
        id_batch = list(set(id_batch))
        # We only have to use vastdb-sdk here bc we need the '$row_id' column
        with self.session.transaction() as tx:
            # string would be 'ndb.colabfit.dev.[table name]'
            table_path = table_name.split(".")
            table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
            rec_batch = table.select(
                predicate=table["id"].isin(id_batch),
                columns=cols + ["id"],
                internal_row_id=True,
            )
            rec_batch = rec_batch.read_all()
            rdd = self.spark.sparkContext.parallelize(
                list(partial_batch_to_rdd(rec_batch))
            )
            print(f"length of rdd: {rdd.count()}")
        rdd = rdd.map(unstringify_row_dict)
        for col, elem in zip(cols, elems):
            # Add 'labels' to this?
            if col == "configuration_ids":
                po_co_id_map = rdd.map(lambda x: {x["id"]: x[col][0]}).collect()
                rdd = rdd.zipWithIndex()
                co_ids = rdd.map(lambda x: po_co_id_map[x["id"]]).zipWithIndex()
                rdd = rdd.join(co_ids).map(append_ith_element_to_rdd)

            else:
                partial_add = partial(add_elem_to_row_dict, col, elem)
                rdd = rdd.map(partial_add)
        existing_ids_batch = rdd.map(lambda x: x["id"]).collect()
        new_ids_batch = [id for id in id_batch if id not in existing_ids_batch]
        rdd = rdd.map(stringify_row_dict)
        rdd_collect = rdd.map(lambda x: [x[col] for col in update_cols]).collect()
        update_schema = StructType(
            [StructField(col, col_types[col], False) for col in update_cols]
        )
        arrow_schema = spark_schema_to_arrow_schema(update_schema)
        update_table = pa.table(
            [pa.array(col) for col in zip(*rdd_collect)], schema=arrow_schema
        )
        with self.session.transaction() as tx:
            table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
            table.update(rows=update_table)
        new_ids.extend(new_ids_batch)
        existing_ids.extend(existing_ids_batch)

    return (new_ids, list(set(existing_ids)))

In [None]:
def reduce_po_rdd(po_rdd):
    po_co_ids = (
        po_rdd.map(lambda x: (x["id"], x["configuration_ids"][0]))
        .groupByKey()
        .mapValues(list)
    )
    po_id_map = po_co_ids.collectAsMap()
    broadcast_map = spark.sparkContext.broadcast(po_id_map)

    def replace_id_val(row):
        row["configuration_ids"] = broadcast_map.value[row["id"]]
        return row

    po_rdd = po_rdd.map(replace_id_val)
    po_rdd = (
        po_rdd.map(lambda x: (x["id"], x))
        .reduceByKey(lambda a, b: a)
        .map(lambda x: x[1])
    )
    return po_rdd

In [None]:
atomic_ratios_df = (
    config_df.select("atomic_numbers")
    .withColumn("exploded_atom", sf.explode("atomic_numbers"))
    .groupBy(sf.col("exploded_atom").alias("atomic_number"))
    .show()
    .count()
    .withColumn("ratio", sf.col("count") / row_dict["nsites"])
    .select("ratio", "atomic_number")
    .withColumn(
        "element",
        sf.udf(lambda x: ELEMENT_MAP[x], StringType())(sf.col("atomic_number")),
    )
    .select("element", "ratio")
    .collect()
)

In [None]:
from colabfit.tools.utilities import unstringify

batches = dm.gather_co_po_in_batches_no_pool()
batch1 = next(batches)
cos, pos = zip(*batch1)
pos_rdd = spark.sparkContext.parallelize(pos)
print(pos_rdd.count())
pos_rdd_reduced = reduce_po_rdd(pos_rdd)
print(pos_rdd_reduced.count())

po = loader.read_table(loader.prop_object_table)
po_rdd = po.rdd.map(unstringify).map(lambda x: x.asDict())
po_co_id_map = (
    po_rdd.map(lambda x: (x["id"], x["configuration_ids"][0]))
    .groupByKey()
    .mapValues(list)
    .collect()
)
po_co_id_map = dict(po_co_id_map)
co_ids = po_rdd.map(lambda x: po_co_id_map[x["id"]]).zipWithIndex()
co_ids = co_ids.map(lambda x: (x[1], x[0]))
rdd = po_rdd.zipWithIndex()
rdd = rdd.map(lambda x: (x[1], x[0]))
joined_rdd = rdd.join(co_ids).map(append_ith_element_of_list_to_spark_rdd_column)

In [None]:
from importlib import reload

import colabfit.tools.utilities
import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set

reload(colabfit.tools.utilities)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset

df.rdd  
you can only parallelize one time so don't try to do a dataframe select from an rdd  
updating to sdk 5.1 in a couple weeks  
boto3 and s3 are the amazon file system interactions, mostly for adding metadata TO FILES (not to the database) and interacting with the files as FileExistsError. 
Make sure to spark.stop() at end of  python file.

In [16]:
carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"
dm = DataManager(
    nprocs=4,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)

NameError: name 'carmat_reader' is not defined

In [26]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

Dataset ID: DS_y7nrdsjtuwom_0


In [12]:
batch = next(dm.gather_co_po_in_batches())

In [29]:
def update_co_rows_cs_id(self, co_ids: list[str], cs_id: str):
    with psycopg.connect(
        """dbname=colabfit user=%s password=%s host=localhost port=5432"""
        % (
            user,
            password,
        )
    ) as conn:
        # dbname=self.database_name,
        # user=self.properties["user"],
        # password=self.properties["password"],
        # host="localhost",
        # port="5432",
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = 
            """
        )
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = concat(%s::text, 
                rtrim(ltrim(replace(configuration_set_ids,%s,''), 
                
                '['),']') || ', ', %s::text)
            WHERE id = ANY(%s)""",
            ("[", f"{cs_id}", f"{cs_id}]", co_ids),
            # ("[", f", {cs_id}", f", {cs_id}]"),
        )
        # cur.fetchall()
        conn.commit()

In [None]:
# You were trying to get  postgresql to recognize the WHERE id = ANY() array syntax

In [112]:
with psycopg.connect(
    dbname="colabfit",
    user=os.environ.get("PGS_USER"),
    password=os.environ.get("PGS_PASS"),
    host="localhost",
) as conn:
    with conn.cursor() as cur:

        # cur.execute(
        #     "UPDATE configurations SET configuration_set_ids = configuration_set_ids || %(cs_id)s WHERE id = ANY(%(co_ids)s)",
        #     {"cs_id": cs["id"], "co_ids": co_ids},
        # )
        # data = cur.fetchall()
        cur.execute(
            "SELECT * FROM public.configurations WHERE id = ANY(%s)",
            [co_ids],
        )
        data2 = cur.fetchall()
    conn.commit()

# In progress

Upsert appears to be this for postgres:
```
update the_table
    set id = id || array[5,6]
where id = 4;
```
* ~~Check for upsert function from pyspark to concatenate lists of relationships instead of primary key id collision~~
* There is no pyspark-upsert function. Will have to manage this possibly through a different sql-based library
* Written: find duplicates, but convert to access database, not download full dataframe
* I see this being used with batches of hashes during upload: something like
    ``` for batch in batches:
            hash_duplicates = find_duplicates(batch, loader/database)
            hash_duplicates.make_change_to_append_dataset-ids
            hash_duplicates.write-to-database
* Where would be the best place to catch duplicates? Keeping in mind that this might be a bulk operation (i.e. on the order of millions, like with ANI1/ANI2x variations)

In [9]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/30 09:52:06 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/30 09:52:06 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/30 09:52:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/30 09:52:08 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

In [None]:
from pyspark.sql import Row


def write_value_to_file(path_prefix, extension, BUCKET_DIR, write_column, row):
    """i.e.: partial(_write_value(
    'CO/positions',
    'txt',
    '/save/here'
    'positions',
    )
    """
    id = row["id"]
    value = row[write_column]
    row_dict = row.copy()
    split = id[-4:]
    filename = f"{id}.{extension}"
    full_path = Path(BUCKET_DIR) / path_prefix / split / filename
    full_path.parent.mkdir(parents=True, exist_ok=True)
    full_path.write_text(str(value))
    # row_dict = row.asDict()
    row_dict[write_column] = str(full_path)
    return Row(**row_dict)


from functools import partial

part_write = partial(
    write_value_to_file,
    "CO/positions",
    "txt",
    "/scratch/gw2338/vast/data-lake-main/spark/scripts",
    "positions",
)

In [None]:
configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
co_rows = [x.spark_row for x in configs]
rdd = sc.parallelize(co_rows)
rdd.foreachPartition(part_write)

In [None]:
config_list = list(mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg")))
dm2.configs = config_list[:50]
dm2.load_co_po_to_vastdb(loader)
dm2.configs = config_list[25:]
dm2.load_co_po_to_vastdb(loader)

In [None]:
from importlib import reload

import colabfit.tools.utilities
import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set

reload(colabfit.tools.utilities)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset
##############

import json
import lmdb
import pickle
from colabfit.tools.database import DataManager, SparkDataLoader

loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
load_dotenv()
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
loader.set_vastdb_session(
    endpoint=endpoint, access_key=access_key, access_secret=access_secret
)

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)

carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"


dm = DataManager(
    nprocs=1,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
dm.configs = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))

match = [
    (r".*3.*", None, "3_configurations", "Carmat with 3"),
    (r".*4.*", None, "4_configurations", "Carmat with 4"),
]
# dm.load_co_po_to_vastdb(loader)

In [None]:
batches = dm2.gather_co_po_in_batches()
batch = next(batches)
cos, pos = zip(*batch)
rdd = loader.spark.sparkContext.parallelize(cos)
ids_coll = rdd.map(lambda x: x["id"]).collect()
loader.spark.read.table(loader.config_table).select(sf.col("id")).filter(
    sf.col("id").isin(broadcast_ids.value)
).count()

In [None]:
from colabfit.tools.utilities import unstringify
from colabfit.tools.schema import (
    configuration_set_df_schema,
    dataset_df_schema,
    property_object_df_schema,
    config_df_schema,
)


def read_table(self, table_name: str, unstring: bool = False):
    """
    Include self.table_prefix in the table name when passed to this function.
    Ex: loader.read_table(loader.config_table, unstring=True)
    Arguments:
        table_name {str} -- Name of the table to read from database
    Keyword Arguments:
        unstring {bool} -- Convert stringified lists to lists (default: {False})
    Returns:
        DataFrame -- Spark DataFrame
    """
    schema_dict = {
        self.config_table: config_df_schema,
        self.config_set_table: configuration_set_df_schema,
        self.dataset_table: dataset_df_schema,
        self.prop_object_table: property_object_df_schema,
    }
    if unstring:
        df = self.spark.read.table(table_name)
        return df.rdd.map(unstringify).toDF(schema_dict[table_name])
    else:
        return self.spark.read.table(table_name)


def get_pos_cos_by_filter(self, filter_conditions):
    po_df = self.read_table(self.prop_object_table, unstring=True).withColumnRenamed(
        "id", "po_id"
    )
    po_df = po_df.withColumn(
        "configuration_id", sf.explode(sf.col("configuration_ids"))
    ).drop("configuration_ids")
    co_df = self.read_table(self.config_table, unstring=True).withColumnRenamed(
        "id", "co_id"
    )
    for i, (column, operand, condition) in enumerate(filter_conditions):
        if operand == "in":
            po_df = po_df.filter(sf.col(column).isin(condition))
        elif operand == "like":
            po_df = po_df.filter(sf.col(column).like(condition))
        elif operand == "rlike":
            po_df = po_df.filter(sf.col(column).rlike(condition))
        elif operand == "==":
            po_df = po_df.filter(sf.col(column) == condition)
        elif operand == "array_contains":
            po_df = po_df.filter(sf.array_contains(sf.col(column), condition))
        elif operand == ">":
            po_df = po_df.filter(sf.col(column) > condition)
        elif operand == "<":
            po_df = po_df.filter(sf.col(column) < condition)
        elif operand == ">=":
            po_df = po_df.filter(sf.col(column) >= condition)
        elif operand == "<=":
            po_df = po_df.filter(sf.col(column) <= condition)
        else:
            raise ValueError(f"Operand {operand} not implemented in get_pos_cos_filter")
    co_po_df = co_df.join(po_df, co_df["co_id"] == po_df["configuration_id"], "inner")
    return co_po_df


get_pos_cos_by_filter(
    loader, [("dataset_ids", "array_contains", mtpu_ds_id), ("method", "like", "DFT%")]
)
df = get_pos_cos_by_filter(
    loader,
    [
        ("dataset_ids", "array_contains", mtpu_ds_id),
        ("method", "like", "DFT%"),
        ("potential_energy", "<", -56729.0),
    ],
    [("nsites", ">=", 63)],
)


def get_pos_cos_by_filter(
    self,
    po_filter_conditions: list[tuple],
    co_filter_conditions: list[tuple] = None,
):
    po_df = self.read_table(self.prop_object_table, unstring=True).withColumnRenamed(
        "id", "po_id"
    )
    po_df = po_df.withColumn(
        "configuration_id", sf.explode(sf.col("configuration_ids"))
    ).drop("configuration_ids")
    co_df = self.read_table(self.config_table, unstring=True).withColumnRenamed(
        "id", "co_id"
    )
    po_df = get_filtered_table(self, po_df, po_filter_conditions)
    if co_filter_conditions is not None:
        co_df = get_filtered_table(self, co_df, co_filter_conditions)
    co_po_df = co_df.join(po_df, co_df["co_id"] == po_df["configuration_id"], "inner")
    return co_po_df


def get_filtered_table(
    self, df: DataFrame, filter_conditions: list[tuple[str, str, str]]
):
    for i, (column, operand, condition) in enumerate(filter_conditions):
        if operand == "in":
            df = df.filter(sf.col(column).isin(condition))
        elif operand == "like":
            df = df.filter(sf.col(column).like(condition))
        elif operand == "rlike":
            df = df.filter(sf.col(column).rlike(condition))
        elif operand == "==":
            df = df.filter(sf.col(column) == condition)
        elif operand == "array_contains":
            df = df.filter(sf.array_contains(sf.col(column), condition))
        elif operand == ">":
            df = df.filter(sf.col(column) > condition)
        elif operand == "<":
            df = df.filter(sf.col(column) < condition)
        elif operand == ">=":
            df = df.filter(sf.col(column) >= condition)
        elif operand == "<=":
            df = df.filter(sf.col(column) <= condition)
        else:
            raise ValueError(f"Operand {operand} not implemented in get_pos_cos_filter")
    return df


df1 = read_filter_table(loader, [("id", "==", "CO_47706510123393079")])

In [29]:
dataset_id = mtpu_ds_id
name_label_match = [
    (".*Si.*3.*", None, "All_si_with_zero", "All Si with zero description"),
    (".*Si.*4.*", None, "All_si_with_two", "All Si with two description"),
]

In [None]:
begin = time()
dm.create_configuration_sets(loader, match)
end = time()
print(f"Time elapsed: {end - begin}")

In [None]:
def create_configuration_sets(
    self,
    loader,
    # below args in order:
    # [config-name-regex-pattern], [config-label-regex-pattern], \
    # [config-set-name], [config-set-description]
    name_label_match: list[tuple],
):
    config_set_rows = []
    # Load unstrung dataframe of configs, filter for just includes ds-id
    config_df = loader.read_table(table_name=loader.config_table, unstring=True)
    config_df = config_df.filter(
        sf.array_contains(sf.col("dataset_ids"), self.dataset_id)
    )
    for i, (names_match, label_match, cs_name, cs_desc) in tqdm(
        enumerate(name_label_match), desc="Creating Configuration Sets"
    ):
        print(
            f"names match: {names_match}, label {label_match}, cs_name {cs_name}, cs_desc {cs_desc}"
        )
        if names_match:
            config_set_query = config_df.withColumn(
                "names_exploded", sf.explode(sf.col("names"))
            ).filter(sf.col("names_exploded").rlike(names_match))
        # Currently an AND operation on labels: labels col contains x AND y
        if label_match is not None:
            if isinstance(label_match, str):
                label_match = [label_match]
            for label in label_match:
                config_set_query = config_set_query.filter(
                    sf.array_contains(sf.col("labels"), label)
                )
        co_ids = [x["id"] for x in config_set_query.select("id").distinct().collect()]
        loader.find_existing_rows_append_elem(
            table_name=loader.config_table,
            ids=co_ids,
            cols="configuration_set_ids",
            elems=cs_name,
            edit_schema=config_df_schema,
            write_schema=config_schema,
        )
        config_set = ConfigurationSet(
            name=cs_name,
            description=cs_desc,
            config_df=config_set_query,
            dataset_id=self.dataset_id,
        )
        config_set_rows.append(config_set.spark_row)
    loader.write_table(
        config_set_rows, loader.config_set_table, schema=configuration_set_schema
    )

In [None]:
from colabfit.tools.utilities import _write_value
from functools import partial

In [None]:
part_write = partial(
    _write_value,
    "CO/positions",
    "txt",
    "/scratch/gw2338/vast/data-lake-main/spark/scripts",
    "positions",
)

In [None]:
from schema import dataset_schema
from colabfit.tools.dataset import Dataset


def create_dataset(
    self,
    loader,
    name: str,
    authors: list[str],
    publication_link: str,
    data_link: str,
    description: str,
    other_links: list[str] = None,
    dataset_id: str = None,
    labels: list[str] = None,
    data_license: str = "CC-BY-ND-4.0",
):
    cs_ids = loader.read_table(loader.config_set_table).select("id").collect()
    if len(cs_ids) == 0:
        cs_ids = None
    else:
        cs_ids = [x["id"] for x in cs_ids]
    config_df = loader.read_table(loader.config_table, unstring=True)
    config_df = config_df.filter(sf.array_contains(sf.col("dataset_ids"), dataset_id))
    prop_df = loader.read_table(loader.prop_object_table, unstring=True)
    prop_df = prop_df.filter(sf.array_contains(sf.col("dataset_ids"), dataset_id))
    ds = Dataset(
        name=name,
        authors=authors,
        config_df=config_df,
        prop_df=prop_df,
        publication_link=publication_link,
        data_link=data_link,
        description=description,
        other_links=other_links,
        dataset_id=dataset_id,
        labels=labels,
        data_license=data_license,
        configuration_set_ids=cs_ids,
    )
    loader.write_table([ds.spark_row], loader.dataset_table, schema=dataset_schema)

In [None]:
t0 = time()
create_dataset(
    dm,
    loader,
    "carolina_materials",
    ["author one", "author two"],
    "https://www.carolina_materials.com",
    "https://www.carolina_materials.com/data",
    "Carolina Materials is a ... description",
    dataset_id=dm.dataset_id,
)
print(f"Time elapsed: {time() - t0}")

In [None]:
from importlib import reload

import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set
import colabfit.tools.schema

reload(colabfit.tools.configuration_set)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
reload(colabfit.tools.schema)
configuration_set_schema = colabfit.tools.schema.configuration_set_schema
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset

In [None]:
def find_duplicate_hash(spark_rows: dict, loader):
    # hashes = loader.spark.createDataFrame([x["hash"] for x in spark_rows])
    hashes = [x["hash"] for x in spark_rows]
    duplicates = loader.spark.read.jdbc(
        url=url,
        table="configurations",
        properties=properties,
    ).filter(sf.col("hash").isin(hashes))
    # dupl_hashes = df.filter(df.hash.isin(hashes)).select("hash").collect()
    return duplicates