# Imports

In [1]:
import datetime
import json
import os
import pickle

# from functools import partial
# from itertools import chain, islice
# from multiprocessing import Pool, cpu_count
from pathlib import Path

# from pprint import pprint

import dateutil.parser
import findspark
import lmdb
import numpy as np
import psycopg
import pyspark.sql.functions as sf
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)

from colabfit.tools.configuration import AtomicConfiguration, config_schema
from colabfit.tools.database import DataManager, PGDataLoader
from colabfit.tools.dataset import Dataset, dataset_schema
from colabfit.tools.property import Property, property_object_schema
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)
from colabfit.tools.schema import configuration_set_schema
import pyarrow as pa

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)
findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

# Set up MTPU and Carolina Materials readers and data

In [2]:
# MTPU data


def convert_stress(keys, stress):
    stresses = {k: s for k, s in zip(keys, stress)}
    return [
        [stresses["xx"], stresses["xy"], stresses["xz"]],
        [stresses["xy"], stresses["yy"], stresses["yz"]],
        [stresses["xz"], stresses["yz"], stresses["zz"]],
    ]


SYMBOL_DICT = {"0": "Si", "1": "O"}


def mtpu_reader(filepath):
    with open(filepath, "rt") as f:
        energy = None
        forces = None
        coords = []
        cell = []
        symbols = []
        config_count = 0
        for line in f:
            if line.strip().startswith("Size"):
                size = int(f.readline().strip())
            elif line.strip().lower().startswith("supercell"):
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
                cell.append([float(x) for x in f.readline().strip().split()])
            elif line.strip().startswith("Energy"):
                energy = float(f.readline().strip())
            elif line.strip().startswith("PlusStress"):
                stress_keys = line.strip().split()[-6:]
                stress = [float(x) for x in f.readline().strip().split()]
                stress = convert_stress(stress_keys, stress)
            elif line.strip().startswith("AtomData:"):
                keys = line.strip().split()[1:]
                if "fx" in keys:
                    forces = []
                for i in range(size):
                    li = {
                        key: val for key, val in zip(keys, f.readline().strip().split())
                    }
                    symbols.append(SYMBOL_DICT[li["type"]])
                    if "cartes_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["cartes_x"],
                                    li["cartes_y"],
                                    li["cartes_z"],
                                ]
                            ]
                        )
                    elif "direct_x" in keys:
                        coords.append(
                            [
                                float(c)
                                for c in [
                                    li["direct_x"],
                                    li["direct_y"],
                                    li["direct_z"],
                                ]
                            ]
                        )

                    if "fx" in keys:
                        forces.append(
                            [float(f) for f in [li["fx"], li["fy"], li["fz"]]]
                        )

            elif line.startswith("END_CFG"):
                if "cartes_x" in keys:
                    config = AtomicConfiguration(
                        positions=coords, symbols=symbols, cell=cell
                    )
                elif "direct_x" in keys:
                    config = AtomicConfiguration(
                        scaled_positions=coords, symbols=symbols, cell=cell
                    )
                config.info["energy"] = energy
                if forces:
                    config.info["forces"] = forces
                config.info["stress"] = stress

                if "Si" in symbols and "O" in symbols:
                    config.info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "11x11x11",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    config.info["_name"] = f"{filepath.stem}_SiO2_{config_count}"
                elif "Si" in symbols:
                    config.info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "8x8x8",
                        "kinetic-energy-cutoff": {
                            "val": 884,
                            "units": "eV",
                        },
                    }
                    config.info["_name"] = f"{filepath.stem}_Si_{config_count}"
                elif "O" in symbols:
                    config.info["input"] = {
                        "kpoint-scheme": "Monkhorst-Pack",
                        "kpoints": "gamma-point",
                        "kinetic-energy-cutoff": {
                            "val": 1224,
                            "units": "eV",
                        },
                    }
                    config.info["_name"] = f"{filepath.stem}_O_{config_count}"
                config_count += 1
                yield config
                forces = None
                stress = []
                coords = []
                cell = []
                symbols = []
                energy = None

In [3]:
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
# data = [x for x in mtpu_configs]
# data[0].configuration_summary()

In [4]:
# Carolina Materials data

SOFTWARE = "VASP"
METHODS = "DFT-PBE"
CM_PI_METADATA = {
    "software": {"value": SOFTWARE},
    "method": {"value": METHODS},
    "input": {"value": {"IBRION": 6, "NFREE": 4}},
}

CM_PROPERTY_MAP = {
    "formation-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
        }
    ],
    "_metadata": CM_PI_METADATA,
}
CO_MD = {
    key: {"field": key}
    for key in [
        "_symmetry_space_group_name_H-M",
        "_symmetry_Int_Tables_number",
        "_chemical_formula_structural",
        "_chemical_formula_sum",
        "_cell_volume",
        "_cell_formula_units_Z",
        "symmetry_dict",
        "formula_pretty",
    ]
}


def load_row(txn, row):
    try:
        data = pickle.loads(txn.get(f"{row}".encode("ascii")))
        return data
    except TypeError:
        return False


def config_from_row(row: dict, row_num: int):
    coords = row.pop("cart_coords")
    a_num = row.pop("atomic_numbers")
    cell = [
        row.pop(x)
        for x in [
            "_cell_length_a",
            "_cell_length_b",
            "_cell_length_c",
            "_cell_angle_alpha",
            "_cell_angle_beta",
            "_cell_angle_gamma",
        ]
    ]
    symmetry_dict = {str(key): val for key, val in row.pop("symmetry_dict").items()}
    for key in symmetry_dict:
        key = str(key)
    info = {}
    info = row
    info["symmetry_dict"] = symmetry_dict
    info["_name"] = f"carolina_materials_{row_num}"
    if row_num % 10 == 0:
        info["_labels"] = [row_num % 10, "bcc"]
    else:
        info["_labels"] = [row_num % 10, "fcc"]
    config = AtomicConfiguration(
        scaled_positions=coords,
        numbers=a_num,
        cell=cell,
        info=info,
    )
    return config
    # return AtomicConfiguration.from_ase(config)


def carmat_reader(fp: Path):
    parent = fp.parent
    env = lmdb.open(str(parent))
    txn = env.begin()
    row_num = 0
    rows = []
    while row_num <= 10000:
        row = load_row(txn, row_num)
        if row is False:
            env.close()
            break
        rows.append(row)
        yield config_from_row(row, row_num)
        row_num += 1
    env.close()
    return False
    # return rows

In [5]:
PI_METADATA = {
    "software": {"value": "Quantum ESPRESSO"},
    "method": {"value": "DFT-PBE"},
    "input": {"field": "input"},
}
PROPERTY_MAP = {
    "potential-energy": [
        {
            "energy": {"field": "energy", "units": "eV"},
            "per-atom": {"value": False, "units": None},
            # "_metadata": PI_METADATA,
        }
    ],
    "atomic-forces": [
        {
            "forces": {"field": "forces", "units": "eV/angstrom"},
            # "_metadata": PI_METADATA,
        },
    ],
    "cauchy-stress": [
        {
            "stress": {"field": "stress", "units": "GPa"},
            "volume-normalized": {"value": True, "units": None},
        }
    ],
    "_metadata": PI_METADATA,
}

# Connect to DB and run loader

In [57]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/07 15:01:14 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.
24/05/07 15:01:14 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [58]:
carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"
# carmat_ds_id2 = "duplicate_ds_id"
dm = DataManager(
    nprocs=4,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
# dm_dup = DataManager(
#     nprocs=4,
#     configs=carmat_config_gen,
#     prop_defs=[formation_energy_pd],
#     prop_map=CM_PROPERTY_MAP,
#     dataset_id=carmat_ds_id2,
# )

Dataset ID: DS_y7nrdsjtuw0g_0


In [8]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

Dataset ID: DS_y7nrdsjtuwom_0


In [59]:
batch = next(dm.gather_co_po_in_batches())

In [60]:
cos = [x[0] for x in batch]
cos_dataframe = loader.spark.createDataFrame(cos, schema=config_schema)

In [61]:
rows = loader.spark.createDataFrame(cos, schema=config_schema).collect()

In [53]:
from collections import defaultdict

In [54]:
row_dict = defaultdict(list)
for i, row in enumerate(rows):
    if i > 10:
        break
    for key, val in row.asDict().items():
        row_dict[key].append(val)
print(row_dict)

defaultdict(<class 'list'>, {'id': ['CO_2035392092515548233', 'CO_395357802651642160', 'CO_1353971176978508405', 'CO_673544646621950284', 'CO_1180491453360028556', 'CO_1235470615381239321', 'CO_1125763852340156480', 'CO_2185718778256801767', 'CO_1547613879869181226', 'CO_2262213674274439455', 'CO_2212061042677459331'], 'hash': ['2035392092515548233', '395357802651642160', '1353971176978508405', '673544646621950284', '1180491453360028556', '1235470615381239321', '1125763852340156480', '2185718778256801767', '1547613879869181226', '2262213674274439455', '2212061042677459331'], 'last_modified': [datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.datetime(2024, 5, 7, 12, 13, 15), datetime.date

In [30]:
import pyarrow as pa

config_schema = StructType(
    [
        StructField("id", StringType(), False),
        StructField("hash", StringType(), False),
        StructField("last_modified", TimestampType(), False),
        StructField("dataset_ids", StringType(), True),  # ArrayType(StringType())
        StructField("metadata", StringType(), True),
        StructField("chemical_formula_hill", StringType(), True),
        StructField("chemical_formula_reduced", StringType(), True),
        StructField("chemical_formula_anonymous", StringType(), True),
        StructField("elements", StringType(), True),  # ArrayType(StringType())
        StructField("elements_ratios", StringType(), True),  # ArrayType(IntegerType())
        StructField("atomic_numbers", StringType(), True),  # ArrayType(IntegerType())
        StructField("nsites", IntegerType(), True),
        StructField("nelements", IntegerType(), True),
        StructField("nperiodic_dimensions", IntegerType(), True),
        StructField("cell", StringType(), True),  # ArrayType(ArrayType(DoubleType()))
        StructField("dimension_types", StringType(), True),  # ArrayType(IntegerType())
        StructField("pbc", StringType(), True),  # ArrayType(IntegerType())
        StructField(
            "positions", StringType(), True
        ),  # ArrayType(ArrayType(DoubleType()))
        StructField("names", StringType(), True),  # ArrayType(StringType()),
        StructField("labels", StringType(), True),  # ArrayType(StringType())
        StructField(
            "configuration_set_ids", StringType(), True
        ),  # ArrayType(StringType())
    ]
)


def convert_type(spark_type):
    type_name = spark_type.dataType.typeName()
    if type_name in data_type_map:
        print(data_type_map[type_name])
        return pa.field(spark_type.name, data_type_map[type_name])
    else:
        raise ValueError(f"Unsupported PySpark data type: {spark_type}")


# data_type_map = {
#     StringType: pa.string(),
#     IntegerType: pa.int32(),
#     FloatType: pa.float64(),
#     DoubleType: pa.float64(),
#     BooleanType: pa.bool_(),
#     TimestampType: pa.timestamp("ns"),
#     # DateType: pa.date32(),
#     ArrayType: lambda dt: pa.list_(convert_type(dt)),
#     StructType: lambda st: pa.struct([convert_type(f) for f in st.fields]),
# }
data_type_map = {
    "string": pa.string(),
    "integer": pa.int32(),
    "float": pa.float64(),
    # DoubleType: pa.float64(),
    # BooleanType: pa.bool_(),
    "timestamp": pa.timestamp("ns"),
    # DateType: pa.date32(),
    # ArrayType: lambda dt: pa.list_(convert_type(dt)),
    "struct": lambda st: pa.struct([convert_type(f) for f in st.fields]),
}

In [62]:
config_set_query = (
    cos_dataframe.withColumn(
        "names_unstrung", sf.from_json(sf.col("names"), ArrayType(StringType()))
    )
    .withColumn(
        "labels_unstrung",
        sf.from_json(sf.col("labels"), ArrayType(StringType())),
    )
    .withColumn(
        "dataset_ids_unstrung", sf.from_json("dataset_ids", ArrayType(StringType()))
    )
    .drop("names", "labels", "dataset_ids")
    .withColumnRenamed("names_unstrung", "names")
    .withColumnRenamed("labels_unstrung", "labels")
    .withColumnRenamed("dataset_ids_unstrung", "dataset_ids")
    .filter(sf.array_contains(sf.col("dataset_ids"), carmat_ds_id))
)

In [63]:
config_set_query.show(1, False)

+----------------------+-------------------+-------------------+--------+---------------------+------------------------+--------------------------+-----------------------+--------------------+----------------------------------+------+---------+--------------------+-----------------------------------------------------------------------------------------------------------------------------------+---------------+---------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------

In [64]:
label_match = "fcc|1"
names_match = "materials_10"
if names_match:
    config_set_query = (
        config_set_query.withColumn("labels_exploded", sf.explode(sf.col("labels")))
        .withColumn("names_exploded", sf.explode(sf.col("names")))
        .drop("names")
        .filter(sf.regexp_like(sf.col("names_exploded"), sf.lit(rf"{names_match}")))
    )
if label_match:
    config_set_query = config_set_query.filter(
        sf.regexp_like(sf.col("labels_exploded"), sf.lit(rf"{label_match}"))
    )
config_set_query.show(1, False)

+---------------------+------------------+-------------------+--------+---------------------+------------------------+--------------------------+------------------------+--------------------+----------------------------------------+------+---------+--------------------+--------------------------------------------------------------------------------------------------------------------------------+---------------+---------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+---------------------+--------+-------------------+--------------

In [None]:
loader.spark.sql(
    "UPDATE config_table SET configuration_set_ids = array_append(configuration_set_ids, %s) WHERE id = %s"
) % (config_set.id, co_ids)

In [13]:
from colabfit.tools.configuration_set import ConfigurationSet

In [56]:
from importlib import reload
import colabfit.tools.configuration_set

reload(colabfit.tools.configuration_set)
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
reload(colabfit.tools.database)
from colabfit.tools.database import DataManager, PGDataLoader

In [65]:
co_ids = [x[0] for x in config_set_query.select("id").distinct().collect()]

In [66]:
CS = ConfigurationSet(
    name="test", config_df=config_set_query, description="test description"
)

<class 'list'>
[Row(element='Ga', ratio=0.009410288582183186), Row(element='Pt', ratio=0.006900878293601004), Row(element='Se', ratio=0.02070263488080301), Row(element='Tl', ratio=0.006900878293601004), Row(element='Fe', ratio=0.006273525721455458), Row(element='Co', ratio=0.014429109159347553), Row(element='Mg', ratio=0.00878293601003764), Row(element='Ti', ratio=0.0370138017565872), Row(element='H', ratio=0.05144291091593475), Row(element='Te', ratio=0.00439146800501882), Row(element='C', ratio=0.012547051442910916), Row(element='S', ratio=0.029485570890840654), Row(element='Cd', ratio=0.03450439146800502), Row(element='Tc', ratio=0.01944792973651192), Row(element='Hf', ratio=0.00878293601003764), Row(element='Cs', ratio=0.0075282308657465494), Row(element='Be', ratio=0.014429109159347553), Row(element='In', ratio=0.0056461731493099125), Row(element='Sb', ratio=0.005018820577164366), Row(element='Sn', ratio=0.006273525721455458), Row(element='Hg', ratio=0.0056461731493099125), Row(el

In [101]:
def update_co_rows_cs_id(self, co_ids: list[str], cs_id: str):
    with psycopg.connect(
        """dbname=colabfit user=%s password=%s host=localhost port=5432"""
        % (
            user,
            password,
        )
    ) as conn:
        # dbname=self.database_name,
        # user=self.properties["user"],
        # password=self.properties["password"],
        # host="localhost",
        # port="5432",

        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = concat(%s::text, 
                rtrim(ltrim(replace(configuration_set_ids,%s,''), 
                
                '['),']'), %s::text)
                WHERE id = ANY(%s::array)""",
            ("[", f", {cs_id}", f", {cs_id}]", co_ids),
            # WHERE id = ANY(%s)""",
            # (cs_id, co_ids),
        )
        # cur.fetchall()
        conn.commit()

In [None]:
# You were trying to get  postgresql to recognize the WHERE id = ANY() array syntax

In [102]:
update_co_rows_cs_id(loader, co_ids, CS.spark_row["id"])

SyntaxError: syntax error at or near "array"
LINE 6:                 WHERE id = ANY($4::array)
                                           ^

In [67]:
CS.spark_row

{'id': 'CS_258512216210800474',
 'hash': 258512216210800474,
 'last_modified': None,
 'nconfigurations': 110,
 'nsites': 1594,
 'nelements': 62,
 'elements': ['Ag',
  'Al',
  'As',
  'Au',
  'B',
  'Ba',
  'Be',
  'Bi',
  'Br',
  'C',
  'Ca',
  'Cd',
  'Cl',
  'Co',
  'Cr',
  'Cs',
  'Cu',
  'F',
  'Fe',
  'Ga',
  'Ge',
  'H',
  'Hf',
  'Hg',
  'I',
  'In',
  'Ir',
  'K',
  'Li',
  'Mg',
  'Mn',
  'Mo',
  'N',
  'Na',
  'Nb',
  'Ni',
  'O',
  'Os',
  'P',
  'Pb',
  'Pd',
  'Pt',
  'Rb',
  'Re',
  'Rh',
  'Ru',
  'S',
  'Sb',
  'Sc',
  'Se',
  'Si',
  'Sn',
  'Ta',
  'Tc',
  'Te',
  'Ti',
  'Tl',
  'V',
  'W',
  'Y',
  'Zn',
  'Zr'],
 'dataset_id': None,
 'name': 'test',
 'description': 'test description',
 'total_elements_ratios': [0.013801756587202008,
  0.006273525721455458,
  0.009410288582183186,
  0.010037641154328732,
  0.016938519447929738,
  0.006273525721455458,
  0.014429109159347553,
  0.01819322459222083,
  0.027603513174404015,
  0.012547051442910916,
  0.00313676286072772

In [78]:
from colabfit.tools.utilities import _empty_dict_from_schema
from colabfit.tools.schema import configuration_set_schema
import dateutil.parser

cs = _empty_dict_from_schema(configuration_set_schema)
cs["nconfigurations"] = 200
cs["dataset_id"] = carmat_ds_id
cs["name"] = "test"
cs["description"] = "test description for test"
cs["nelements"] = 25
cs["last_modified"] = dateutil.parser.parse(
    datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
)
cs["id"] = "CS_y7nrdsjtuw0g_0"
cs["hash"] = hash(cs["name"])

In [79]:
loader.write_table([cs], "configuration_sets", configuration_set_schema)

In [111]:
co_ids = [x[0] for x in config_set_query.select("id").distinct().collect()]

In [84]:
dm.load_data_to_pg_in_batches(loader)

Loading data to PostgreSQL: : 0batch [00:00, ?batch/s]24/05/03 15:10:44 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
Loading data to PostgreSQL: : 2batch [00:04,  2.20s/batch]


In [112]:
with psycopg.connect(
    dbname="colabfit",
    user=os.environ.get("PGS_USER"),
    password=os.environ.get("PGS_PASS"),
    host="localhost",
) as conn:
    with conn.cursor() as cur:

        # cur.execute(
        #     "UPDATE configurations SET configuration_set_ids = configuration_set_ids || %(cs_id)s WHERE id = ANY(%(co_ids)s)",
        #     {"cs_id": cs["id"], "co_ids": co_ids},
        # )
        # data = cur.fetchall()
        cur.execute(
            "SELECT * FROM public.configurations WHERE id = ANY(%s)",
            [co_ids],
        )
        data2 = cur.fetchall()
    conn.commit()

In [109]:
co_ids.append("CO_215290934057753943")

In [113]:
data2

[]

In [None]:
dm2.load_data_to_pg_in_batches(loader)

# In progress

Upsert appears to be this for postgres:
```
update the_table
    set id = id || array[5,6]
where id = 4;
```
* ~~Check for upsert function from pyspark to concatenate lists of relationships instead of primary key id collision~~
* There is no pyspark-upsert function. Will have to manage this possibly through a different sql-based library
* Written: find duplicates, but convert to access database, not download full dataframe
* I see this being used with batches of hashes during upload: something like
    ``` for batch in batches:
            hash_duplicates = find_duplicates(batch, loader/database)
            hash_duplicates.make_change_to_append_dataset-ids
            hash_duplicates.write-to-database
* Where would be the best place to catch duplicates? Keeping in mind that this might be a bulk operation (i.e. on the order of millions, like with ANI1/ANI2x variations)

In [None]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

In [None]:
carmat_ds_id = "DS_y7nrdsjtuw0g_0"

In [None]:
from importlib import reload

import colabfit.tools.dataset
import colabfit.tools.database

reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager

Dataset = colabfit.tools.dataset.Dataset

In [None]:
def find_duplicate_hash(spark_rows: dict, loader):
    # hashes = loader.spark.createDataFrame([x["hash"] for x in spark_rows])
    hashes = [x["hash"] for x in spark_rows]
    duplicates = loader.spark.read.jdbc(
        url=url,
        table="configurations",
        properties=properties,
    ).filter(sf.col("hash").isin(hashes))
    # dupl_hashes = df.filter(df.hash.isin(hashes)).select("hash").collect()
    return duplicates

In [None]:
dupl_rows = [x.spark_row for x in dm_dup.configs]

In [None]:
config_df = (
    loader.spark.read.jdbc(
        url=loader.url, table=loader.config_table, properties=loader.properties
    )
    .withColumn(
        "ds_ids_unstrung",
        sf.from_json(sf.col("dataset_ids"), sf.ArrayType(sf.StringType())),
    )
    .filter(sf.array_contains("ds_ids_unstrung", dm.dataset_id))
    .drop("ds_ids_unstrung")
)

In [None]:
dup_hashes = find_duplicate_hash(dupl_rows, loader)

In [None]:
dup_hashes.show(10, False)

In [None]:
[x["hash"] for x in dup_hashes]

# outer

In [None]:
"""
Can we make the configuration and the property instance/data object at the same time?
In this way, we would only have to pass through the data one time.

Workflow:
create database access object
create data reader as function? of the database access object
reader returns ase.Atoms-style objects (AtomicConfiguration)
DOs and PIs are now one object
These DOs point to a configuration
The configuration may already exist in the database, so we keep track of the hash added to the DO


"""

In [None]:
cos = json.load(Path("sample_db/co_ds1.json").open("r"))

In [None]:
with open(Path("sample_db/co_ds1.json"), "r") as f:
    co_json = spark.sparkContext.parallelize(json.load(f))

In [None]:
co = co_json.map(_parse_config).map(stringify_lists)
co_df = spark.createDataFrame(co, config_schema)

In [None]:
def parse_configs(co_path, spark):
    with open(co_path, "r") as f:
        co_json = spark.sparkContext.parallelize(json.load(f))
    co = co_json.map(_parse_config).map(stringify_lists)
    co_df = spark.createDataFrame(co, config_schema)
    return co_df

In [None]:
parse_configs("sample_db/co_ds1.json", spark).show()