# Imports

In [2]:
import datetime
import json
from time import time
import os
from collections import defaultdict
import pickle
from tqdm import tqdm

# from functools import partial
# from itertools import chain, islice
# from multiprocessing import Pool, cpu_count
from pathlib import Path

# from pprint import pprint

import dateutil.parser
import findspark
import lmdb
import numpy as np
import psycopg
import pyspark.sql.functions as sf
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)
from colabfit.tools.schema import (
    property_object_schema,
    config_df_schema,
    config_schema,
    property_object_df_schema,
)
from colabfit.tools.configuration import AtomicConfiguration, config_schema
from colabfit.tools.database import DataManager, PGDataLoader
from colabfit.tools.dataset import Dataset, dataset_schema
from colabfit.tools.property import Property, property_object_schema
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)
from colabfit.tools.schema import configuration_set_schema
import pyarrow as pa

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)
findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

In [19]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
data = list(mtpu_configs)
# data = [x for x in mtpu_configs]
# data[0].configuration_summary()

In [20]:
import colabfit.tools.configuration
from importlib import reload

reload(colabfit.tools.configuration)
AtomicConfiguration = colabfit.tools.configuration.AtomicConfiguration

# Connect to DB and run loader

In [23]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/21 14:49:53 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/21 14:49:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/21 14:49:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/21 14:49:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
import os
import pickle
import sys
import time
from pathlib import Path

from ase.io import iread
from dotenv import load_dotenv
from tqdm import tqdm

from colabfit.tools.configuration import AtomicConfiguration
from colabfit.tools.database import DataManager, SparkDataLoader
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    free_energy_pd,
    potential_energy_pd,
)

load_dotenv()
loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
# loader.set_vastdb_session(
PKL_FP = Path("data/oc20_data_mapping.pkl")
with open(PKL_FP, "rb") as f:
    OC20_MAP = pickle.load(f)

In [181]:
#     endpoint=endpoint, access_key=access_key, access_secret=access_secret
# )

DATASET_FP = Path("data/s2ef_train_200K/s2ef_train_200K")
DATASET_NAME = "OC20_S2EF_train_200K"
LICENSE = "https://creativecommons.org/licenses/by/4.0/legalcode"
PUBLICATION = "https://doi.org/10.1021/acscatal.0c04525"
DATA_LINK = (
    "https://github.com/Open-Catalyst-Project/ocp/blob"
    "/main/DATASET.md#open-catalyst-2020-oc20"
)
AUTHORS = [
    "Lowik Chanussot",
    "Abhishek Das",
    "Siddharth Goyal",
    "Thibaut Lavril",
    "Muhammed Shuaibi",
    "Morgane Riviere",
    "Kevin Tran",
    "Javier Heras-Domingo",
    "Caleb Ho",
    "Weihua Hu",
    "Aini Palizhati",
    "Anuroop Sriram",
    "Brandon Wood",
    "Junwoong Yoon",
    "Devi Parikh",
    "C. Lawrence Zitnick",
    "Zachary Ulissi",
]
DATASET_DESC = (
    "OC20_S2EF_train_200K is the 200K subset of the OC20 Structure to Energy and "
    "Forces dataset. "
)
ELEMENTS = None


GLOB_STR = "*.extxyz"
PI_METADATA = {
    "software": {"value": "VASP"},
    "method": {"value": "DFT-rPBE"},
    "basis_set": {"value": "def2-TZVPP"},
    "input": {
        "value": {
            "EDIFFG": "1E-3",
        },
    },
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "reference-energy": {"field": "reference_energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            "_metadata": PI_METADATA,
        }
    ],
    "free-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            "_metadata": PI_METADATA,
        },
    ],
}
CO_METADATA = {
    key: {"field": key}
    for key in [
        "constraints",
        "bulk_id",
        "ads_id",
        "bulk_mpid",
        "bulk_symbols",
        "ads_symbols",
        "miller_index",
        "shift",
        "top",
        "adsorption_site",
        "class",
        "anomaly",
        "system_id",
        "frame_number",
    ]
}


def oc_reader(fp: Path):
    fp_num = f"{int(fp.stem)}"
    prop_fp = fp.with_suffix(".txt")
    with prop_fp.open("r") as prop_f:
        prop_lines = [x.strip() for x in prop_f.readlines()]

        iter_configs = iread(fp, format="extxyz")
        for i, config in tqdm(enumerate(iter_configs)):
            system_id, frame_number, reference_energy = prop_lines[i].split(",")
            reference_energy = float(reference_energy)
            config.info["constraints-fix-atoms"] = config.constraints[0].index
            config_data = OC20_MAP[system_id]
            config.info.update(config_data)
            config.info["reference_energy"] = reference_energy
            config.info["system_id"] = system_id
            config.info["frame_number"] = frame_number
            config.info["_name"] = f"{DATASET_NAME}__file_{fp_num}_config_{i}"
            yield AtomicConfiguration.from_ase(config, CO_METADATA)


def oc_wrapper(dir_path: str):
    dir_path = Path(dir_path)
    if not dir_path.exists():
        print(f"Path {dir_path} does not exist")
        return
    xyz_paths = sorted(list(dir_path.rglob("*.extxyz")))
    print(xyz_paths)
    for xyz_path in xyz_paths:
        print(f"Reading {xyz_path}")
        reader = oc_reader(xyz_path)
        for config in reader:
            yield config

In [153]:
AtomicConfiguration(co_md_map=CO_METADATA, info=info, **cdict)

AtomicConfiguration(symbols='C2H3Ir24ORe24Ti48', pbc=True, cell=[[12.46844256, 0.0, -0.0], [-6.23422128, 13.89198995, -1.15766583], [0.0, 0.0, 33.57230904]], tags=...)

In [158]:
from importlib import reload
import colabfit.tools.configuration

reload(colabfit.tools.configuration)
AtomicConfiguration = colabfit.tools.configuration.AtomicConfiguration

In [None]:
batches = dm.gather_co_po_in_batches()
batch = next(batches)
cos, pos = zip(*batch)
po_rdd = sc.parallelize(pos[:100])

In [None]:
from importlib import reload

import colabfit.tools.utilities
import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set
import colabfit.tools.property

reload(colabfit.tools.utilities)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
reload(colabfit.tools.property)
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset
Property = colabfit.tools.property.Property

df.rdd  
you can only parallelize one time so don't try to do a dataframe select from an rdd  
updating to sdk 5.1 in a couple weeks  
boto3 and s3 are the amazon file system interactions, mostly for adding metadata TO FILES (not to the database) and interacting with the files as FileExistsError. 
Make sure to spark.stop() at end of  python file.

In [29]:
def update_co_rows_cs_id(self, co_ids: list[str], cs_id: str):
    with psycopg.connect(
        """dbname=colabfit user=%s password=%s host=localhost port=5432"""
        % (
            user,
            password,
        )
    ) as conn:
        # dbname=self.database_name,
        # user=self.properties["user"],
        # password=self.properties["password"],
        # host="localhost",
        # port="5432",
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = 
            """
        )
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = concat(%s::text, 
                rtrim(ltrim(replace(configuration_set_ids,%s,''), 
                
                '['),']') || ', ', %s::text)
            WHERE id = ANY(%s)""",
            ("[", f"{cs_id}", f"{cs_id}]", co_ids),
            # ("[", f", {cs_id}", f", {cs_id}]"),
        )
        # cur.fetchall()
        conn.commit()

In [None]:
# You were trying to get  postgresql to recognize the WHERE id = ANY() array syntax

In [112]:
with psycopg.connect(
    dbname="colabfit",
    user=os.environ.get("PGS_USER"),
    password=os.environ.get("PGS_PASS"),
    host="localhost",
) as conn:
    with conn.cursor() as cur:

        # cur.execute(
        #     "UPDATE configurations SET configuration_set_ids = configuration_set_ids || %(cs_id)s WHERE id = ANY(%(co_ids)s)",
        #     {"cs_id": cs["id"], "co_ids": co_ids},
        # )
        # data = cur.fetchall()
        cur.execute(
            "SELECT * FROM public.configurations WHERE id = ANY(%s)",
            [co_ids],
        )
        data2 = cur.fetchall()
    conn.commit()

# In progress

Upsert appears to be this for postgres:
```
update the_table
    set id = id || array[5,6]
where id = 4;
```
* ~~Check for upsert function from pyspark to concatenate lists of relationships instead of primary key id collision~~
* There is no pyspark-upsert function. Will have to manage this possibly through a different sql-based library
* Written: find duplicates, but convert to access database, not download full dataframe
* I see this being used with batches of hashes during upload: something like
    ``` for batch in batches:
            hash_duplicates = find_duplicates(batch, loader/database)
            hash_duplicates.make_change_to_append_dataset-ids
            hash_duplicates.write-to-database
* Where would be the best place to catch duplicates? Keeping in mind that this might be a bulk operation (i.e. on the order of millions, like with ANI1/ANI2x variations)

In [9]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/30 09:52:06 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/30 09:52:06 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/30 09:52:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/30 09:52:08 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

In [None]:
from pyspark.sql import Row


def write_value_to_file(path_prefix, extension, BUCKET_DIR, write_column, row):
    """i.e.: partial(_write_value(
    'CO/positions',
    'txt',
    '/save/here'
    'positions',
    )
    """
    id = row["id"]
    value = row[write_column]
    row_dict = row.copy()
    split = id[-4:]
    filename = f"{id}.{extension}"
    full_path = Path(BUCKET_DIR) / path_prefix / split / filename
    full_path.parent.mkdir(parents=True, exist_ok=True)
    full_path.write_text(str(value))
    # row_dict = row.asDict()
    row_dict[write_column] = str(full_path)
    return Row(**row_dict)


from functools import partial

part_write = partial(
    write_value_to_file,
    "CO/positions",
    "txt",
    "/scratch/gw2338/vast/data-lake-main/spark/scripts",
    "positions",
)

In [None]:
configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
co_rows = [x.spark_row for x in configs]
rdd = sc.parallelize(co_rows)
rdd.foreachPartition(part_write)

In [None]:
config_list = list(mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg")))
dm2.configs = config_list[:50]
dm2.load_co_po_to_vastdb(loader)
dm2.configs = config_list[25:]
dm2.load_co_po_to_vastdb(loader)

In [None]:
from importlib import reload

import colabfit.tools.utilities
import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set
import colabfit.tools.schema

reload(colabfit.tools.utilities)
reload(colabfit.tools.schema)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset
property_object_df_schema = colabfit.tools.schema.property_object_df_schema
property_object_schema = colabfit.tools.schema.property_object_schema
##############

import json
import lmdb
import pickle
from colabfit.tools.database import DataManager, SparkDataLoader

loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
load_dotenv()
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
loader.set_vastdb_session(
    endpoint=endpoint, access_key=access_key, access_secret=access_secret
)

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)

carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"


dm = DataManager(
    nprocs=1,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
dm.configs = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))

match = [
    (r".*3.*", None, "3_configurations", "Carmat with 3"),
    (r".*4.*", None, "4_configurations", "Carmat with 4"),
]
# dm.load_co_po_to_vastdb(loader)

In [None]:
batches = dm.gather_co_po_in_batches()
batch = next(batches)
cos, pos = zip(*batch)
po_rdd = sc.parallelize(pos)
po_rdd = reduce_po_rdd(po_rdd, match)
kv_rdd = po_rdd.map(lambda x: (x["id"], x))
sum_rdd = kv_rdd.mapValues(lambda x: x["multiplicity"]).reduceByKey(lambda a, b: a + b)
join_rdd = kv_rdd.join(sum_rdd)
result = join_rdd.map(lambda x: {**x[1][0], "multiplicity": x[1][1]})


rdd = loader.spark.sparkContext.parallelize(cos)
ids_coll = rdd.map(lambda x: x["id"]).collect()
loader.spark.read.table(loader.config_table).select(sf.col("id")).filter(
    sf.col("id").isin(broadcast_ids.value)
).count()

In [None]:
def reduce_po_rdd(po_rdd):
    rdd = po_rdd.map(lambda x: (x["id"], x))
    rdd_grouped = rdd.groupByKey()

    def sum_multiplicity(dicts):
        summed_dict = None
        for d in dicts:
            if summed_dict is None:
                summed_dict = d.copy()
                summed_dict["multiplicity"] = d["multiplicity"]
            else:
                summed_dict["multiplicity"] += d["multiplicity"]
        return summed_dict

    summed_rdd = rdd_grouped.mapValues(lambda rows: sum_multiplicity(rows))
    result_rdd = summed_rdd.map(lambda x: x[1])
    return result_rdd

In [None]:
import re
from pyspark.sql.functions import col, sum as spark_sum

In [1]:
from colabfit.tools.schema import *

set(list(property_object_schema.fieldNames())) - set(
    property_object_df_schema.fieldNames()
)

set()

In [None]:
atomic_re = re.compile(r"atomic_forces_(\d+)")
atomic_columns = [col for col in prop_df.columns if atomic_re.match(col)]
atomic_df = prop_df.select(*atomic_columns)
non_null_counts = atomic_df.select(
    [spark_sum(col(c).isNotNull().cast("int")).alias(c) for c in atomic_df.columns]
)
total_populated_cells = non_null_counts.select(
    sum([col(c) for c in non_null_counts.columns]).alias("total_populated_cells")
).collect()[0][0]

In [None]:
from types import GeneratorType
from itertools import islice
import pyspark.sql.functions as sf
import dateutil.parser
import datetime


def batched(configs, n):
    "Batch data into tuples of length n. The last batch may be shorter."
    if not isinstance(configs, GeneratorType):
        configs = iter(configs)
    while True:
        batch = list(islice(configs, n))
        if len(batch) == 0:
            break
        yield batch

In [None]:
def zero_multiplicity(self, dataset_id):
    ids = (
        loader.spark.read.table(loader.prop_object_table)
        .filter(sf.col("dataset_id") == dataset_id)
        .select("id")
        .collect()
    )
    ids = [x["id"] for x in ids]
    batched_ids = batched(ids, 1000)
    for id_batch in batched_ids:
        id_batch = list(set(id_batch))
        with self.session.transaction() as tx:
            table_name = "ndb.colabfit.dev.gpw_test_propobjects"
            table_path = table_name.split(".")
            table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
            rec_batch = table.select(
                predicate=table["id"].isin(id_batch),
                columns=["id", "multiplicity", "last_modified"],
                internal_row_id=True,
            )
            rec_batch = rec_batch.read_all()
            df = self.spark.createDataFrame(rec_batch.to_pylist())
            print(f"length of df: {df.count()}")
            df = df.withColumn("multiplicity", sf.lit(0))
            update_time = dateutil.parser.parse(
                datetime.datetime.now(tz=datetime.timezone.utc).strftime(
                    "%Y-%m-%dT%H:%M:%SZ"
                )
            )
            df = df.withColumn("last_modified", update_time)
            arrow_schema = StructType(
                [
                    StructField("id", pa.string, False),
                    StructField("multiplicity", pa.integer, False),
                    StructField("last_modified", pa.timestamp("ns"), False),
                ]
            )
            update_table = pa.RecordBatch.from_pylist(df, arrow_schema)
            table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
            table.update(
                rows=update_table,
                columns=["multiplicity", "last_modified"],
            )

In [None]:
zero_multiplicity(loader, ds_id)

In [None]:
spark.sql("show tables in ndb.colabfit.dev").show()
spark.sql("drop table ndb.colabfit.dev.sample_configs")
spark.sql("drop table ndb.colabfit.dev.sample_prop_objects")
spark.sql("drop table ndb.colabfit.dev.sample_config_sets")

In [None]:
from colabfit.tools.utilities import *
from colabfit.tools.database import *
from colabfit.tools.schema import *
from functools import partial

self = loader
batches = dm.gather_co_po_in_batches()
batch = next(batches)

co_rows, po_rows = list(zip(*batch))
co_rdd = loader.spark.sparkContext.parallelize(co_rows)
po_rdd = loader.spark.sparkContext.parallelize(po_rows)
co_ids = co_rdd.map(lambda x: x["id"]).collect()
co_rdd2 = (
    co_rdd.map(lambda x: (x["id"], x)).reduceByKey(lambda a, b: a).map(lambda x: x[1])
)


po_ids = po_rdd.map(lambda x: x["id"]).collect()
if len(set(po_ids)) < len(po_ids):
    print(f"{len(po_ids) - len(set(po_ids))} duplicates found in PO RDD")
    po_rdd = loader.reduce_po_rdd(po_rdd)
co_ids = set(co_ids)
all_unique_co = loader.check_unique_ids(loader.config_table, co_ids)
all_unique_po = loader.check_unique_ids(loader.prop_object_table, po_ids)

new_po_ids, update_po_ids = loader.find_existing_po_rows_append_elem(
    po_rdd=po_rdd,
    ids=po_ids,
)
loader.write_table(
    po_rdd,
    loader.prop_object_table,
    property_object_schema,
    ids_filter=new_po_ids,
)

cos = list(cos)
cos.extend([cos[0], cos[1]])
co_rdd = sc.parallelize(cos)

In [None]:
import pyspark.sql.functions as sf
from colabfit.tools.schema import *
from colabfit.tools.utilities import *
from colabfit.tools.utilities import _empty_dict_from_schema
import dateutil.parser
import datetime

prop_df = loader.read_table(loader.prop_object_table, unstring=True)
prop_df = prop_df.filter(sf.col("dataset_id") == "DS_y7nrdsjtuwom_0")
config_df = loader.read_table(loader.config_table, unstring=True)
config_df = config_df.filter(
    sf.array_contains(sf.col("dataset_ids"), "DS_y7nrdsjtuwom_0")
)

In [None]:
from collections import Counter

idc = Counter(ids)
ids = prop_df.select("configuration_id").collect()
ids = [x["configuration_id"] for x in ids]


ids = [x["configuration_id"] for x in ids]
co_po_df.filter(sf.col("configuration_id") == "CO_1077499488010994550").select(
    "potential_energy", "potential_energy_reference", "atomic_forces", "cauchy_stress"
).show()

In [33]:
import numpy as np
from functools import partial


def generate_random_float_array(n):
    # Generate an n*3 array with random floats
    random_array = np.random.rand(n, 3)
    return random_array


def split_string(s, max_length=60000):
    if s is None:
        return [None]
    return [s[i : i + max_length] for i in range(0, len(s), max_length)]


np.set_printoptions(threshold=np.inf)
x = generate_random_float_array(10000)
xstr = "".join(np.array2string(np.arr(x), separator=",").replace("\n", ""))
splx = split_string(xstr)

In [34]:
batches = dm.gather_co_po_in_batches()
batch = next(batches)
cos, pos = zip(*batch)
po_rdd = sc.parallelize(pos[:100])
po_rdd2 = check_size_n_arrs_and_split(po_rdd, "atomic_forces")
keys = list(po_rdd2.map(lambda x: x.keys()).take(1)[0])
schema = property_object_schema
extra_keys = [x for x in keys if x not in schema.fieldNames()]
for key in extra_keys:
    schema.add(StructField(key, StringType(), True))

In [41]:
f"{(1+2):02d}"

'03'

In [38]:
def split_size_n_arrs(arr, n):
    if arr is None:
        return [None]
    return [arr[i : i + n] for i in range(0, len(arr), n)]


def split_arr_map(column, row_dict):
    col_string = "".join(
        np.array2string(
            np.array(row_dict[column]), threshold=np.inf, separator=","
        ).replace("\n", "")
    )
    row_dict[column] = split_size_n_arrs(col_string, 60000)
    return row_dict


def flatten_arr(column, max_chunks, row_dict):
    empty = max_chunks - len(row_dict[column])
    full = len(row_dict[column])
    for i, force_arr in enumerate(row_dict[column]):
        row_dict[f"{column}_{i+1}"] = force_arr
    row_dict[column] = row_dict.pop(f"{column}_1")
    for i in range(empty):
        row_dict[f"{column}_{full+i+1}"] = None
    return row_dict


def split_size_n_arrs_to_cols(rdd, column):
    rdd = rdd.map(partial(split_arr_map, column))
    max_chunks = rdd.map(lambda x: len(x[column])).max()
    rdd = rdd.map(partial(flatten_arr, column, max_chunks))
    return rdd


def update_schema(rdd, schema):
    keys = list(rdd.map(lambda x: x.keys()).take(1)[0])
    extra_keys = [x for x in keys if x not in schema.fieldNames()]
    for key in extra_keys:
        schema.add(StructField(key, StringType(), True))
    return schema

60000

In [None]:
import os
import pickle
from pathlib import Path
from time import time
import pyspark.sql.functions as sf

from ase.io import iread
from dotenv import load_dotenv
from colabfit.tools.utilities import *

from colabfit.tools.configuration import AtomicConfiguration
from colabfit.tools.database import DataManager, SparkDataLoader
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    free_energy_pd,
    potential_energy_pd,
)


def unstringify(row):
    """Should be mapped as DataFrame.rdd.map(unstringify)"""
    row_dict = row.asDict()
    for key, val in row_dict.items():
        if key == "metadata":
            continue
        elif isinstance(val, str) and len(val) > 0 and val[0] in ["{", "["]:
            dval = literal_eval(row[key])
            row_dict[key] = dval
    new_row = Row(**row_dict)
    return row_dict


load_dotenv()
loader = SparkDataLoader(
    table_prefix="ndb.colabfit.dev",
    # check_ids_batch_size=5000,
)
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
loader.set_vastdb_session(
    endpoint=endpoint,
    access_key=access_key,
    access_secret=access_secret,
)
dataset_id = "DS_y7nrdsjtuwom_0"
dataset_id = "DS_zdy2xz6y88nl_0"

if loader.spark.catalog.tableExists(loader.config_set_table):
    cs_ids = (
        loader.read_table(loader.config_set_table)
        .filter(sf.col("dataset_id") == dataset_id)
        .select("id")
        .collect()
    )
    if len(cs_ids) == 0:
        cs_ids = None
    else:
        cs_ids = [x["id"] for x in cs_ids]
else:
    cs_ids = None
config_df = loader.read_table(loader.config_table, unstring=True)
config_df = config_df.filter(sf.array_contains(sf.col("dataset_ids"), dataset_id))
prop_df = loader.read_table(loader.prop_object_table, unstring=False)
prop_df = prop_df.filter(sf.col("dataset_id") == dataset_id).limit(80000)
prop_df = prop_df.rdd.map(unstringify)
prop_df.take(1)
select(
    "id",
    "dataset_id",
    "configuration_id",
    "atomic_forces",
    "potential_energy",
    "cauchy_stress",
    "formation_energy",
    "free_energy",
    "multiplicity",
)
rdd = prop_df.rdd.map(unstringify)
rdd.take(1)
ds = Dataset(
    name=name,
    authors=authors,
    config_df=config_df,
    prop_df=prop_df,
    publication_link=publication_link,
    data_link=data_link,
    description=description,
    other_links=other_links,
    dataset_id=dataset_id,
    labels=labels,
    data_license=data_license,
    configuration_set_ids=cs_ids,
)
ds_rdd = loader.spark.sparkContext.parallelize([ds.spark_row])
loader.write_table(ds_rdd, loader.dataset_table, schema=dataset_schema)

In [None]:
def create_configuration_sets(
    self,
    loader,
    # below args in order:
    # [config-name-regex-pattern], [config-label-regex-pattern], \
    # [config-set-name], [config-set-description]
    name_label_match: list[tuple],
):
    config_set_rows = []
    # Load unstrung dataframe of configs, filter for just includes ds-id
    config_df = loader.read_table(table_name=loader.config_table, unstring=True)
    config_df = config_df.filter(
        sf.array_contains(sf.col("dataset_ids"), self.dataset_id)
    )
    for i, (names_match, label_match, cs_name, cs_desc) in tqdm(
        enumerate(name_label_match), desc="Creating Configuration Sets"
    ):
        print(
            f"names match: {names_match}, label {label_match}, cs_name {cs_name}, cs_desc {cs_desc}"
        )
        if names_match:
            config_set_query = config_df.withColumn(
                "names_exploded", sf.explode(sf.col("names"))
            ).filter(sf.col("names_exploded").rlike(names_match))
        # Currently an AND operation on labels: labels col contains x AND y
        if label_match is not None:
            if isinstance(label_match, str):
                label_match = [label_match]
            for label in label_match:
                config_set_query = config_set_query.filter(
                    sf.array_contains(sf.col("labels"), label)
                )
        co_ids = [x["id"] for x in config_set_query.select("id").distinct().collect()]
        loader.find_existing_rows_append_elem(
            table_name=loader.config_table,
            ids=co_ids,
            cols="configuration_set_ids",
            elems=cs_name,
            edit_schema=config_df_schema,
            write_schema=config_schema,
        )
        config_set = ConfigurationSet(
            name=cs_name,
            description=cs_desc,
            config_df=config_set_query,
            dataset_id=self.dataset_id,
        )
        config_set_rows.append(config_set.spark_row)
    loader.write_table(
        config_set_rows, loader.config_set_table, schema=configuration_set_schema
    )

In [None]:
from colabfit.tools.utilities import _write_value
from functools import partial

In [None]:
part_write = partial(
    _write_value,
    "CO/positions",
    "txt",
    "/scratch/gw2338/vast/data-lake-main/spark/scripts",
    "positions",
)

In [None]:
from schema import dataset_schema
from colabfit.tools.dataset import Dataset


def create_dataset(
    self,
    loader,
    name: str,
    authors: list[str],
    publication_link: str,
    data_link: str,
    description: str,
    other_links: list[str] = None,
    dataset_id: str = None,
    labels: list[str] = None,
    data_license: str = "CC-BY-ND-4.0",
):
    cs_ids = loader.read_table(loader.config_set_table).select("id").collect()
    if len(cs_ids) == 0:
        cs_ids = None
    else:
        cs_ids = [x["id"] for x in cs_ids]
    config_df = loader.read_table(loader.config_table, unstring=True)
    config_df = config_df.filter(sf.array_contains(sf.col("dataset_ids"), dataset_id))
    prop_df = loader.read_table(loader.prop_object_table, unstring=True)
    prop_df = prop_df.filter(sf.array_contains(sf.col("dataset_ids"), dataset_id))
    ds = Dataset(
        name=name,
        authors=authors,
        config_df=config_df,
        prop_df=prop_df,
        publication_link=publication_link,
        data_link=data_link,
        description=description,
        other_links=other_links,
        dataset_id=dataset_id,
        labels=labels,
        data_license=data_license,
        configuration_set_ids=cs_ids,
    )
    loader.write_table([ds.spark_row], loader.dataset_table, schema=dataset_schema)

In [None]:
t0 = time()
create_dataset(
    dm,
    loader,
    "carolina_materials",
    ["author one", "author two"],
    "https://www.carolina_materials.com",
    "https://www.carolina_materials.com/data",
    "Carolina Materials is a ... description",
    dataset_id=dm.dataset_id,
)
print(f"Time elapsed: {time() - t0}")

In [None]:
from importlib import reload

import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set
import colabfit.tools.schema

reload(colabfit.tools.configuration_set)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
reload(colabfit.tools.schema)
configuration_set_schema = colabfit.tools.schema.configuration_set_schema
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset

In [None]:
def find_duplicate_hash(spark_rows: dict, loader):
    # hashes = loader.spark.createDataFrame([x["hash"] for x in spark_rows])
    hashes = [x["hash"] for x in spark_rows]
    duplicates = loader.spark.read.jdbc(
        url=url,
        table="configurations",
        properties=properties,
    ).filter(sf.col("hash").isin(hashes))
    # dupl_hashes = df.filter(df.hash.isin(hashes)).select("hash").collect()
    return duplicates