# Imports

In [None]:
import os
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import boto3


def create_s3_writer(bucket_name, access_key, secret_key, endpoint_url=None):
    s3_client = boto3.client(
        "s3",
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key,
        endpoint_url=endpoint_url,
    )

    def write_to_s3(content, file_key):
        try:
            s3_client.put_object(Bucket=bucket_name, Key=file_key, Body=content)
            return f"s3://{bucket_name}/{file_key}"
        except Exception as e:
            return f"Error: {str(e)}"

    return write_to_s3


# Create the S3 writer function
s3_writer = create_s3_writer(
    bucket_name="your-bucket-name",
    access_key="your-access-key",
    secret_key="your-secret-key",
    endpoint_url="your-endpoint-url",  # Optional, remove if using standard S3
)


# Create a UDF that uses the S3 writer
@udf(returnType=StringType())
def write_content_to_s3(content, id):
    if content is None:
        return None
    file_key = f"path/to/files/{id}.txt"
    return s3_writer(content, file_key)


# Apply the UDF to your DataFrame
df = df.withColumn("file_path", write_content_to_s3(df.content_column, df.id))

In [None]:
import boto3
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType


class S3Writer:
    def __init__(self, bucket_name, access_key, secret_key, endpoint_url=None):
        self.bucket_name = bucket_name
        self.access_key = access_key
        self.secret_key = secret_key
        self.endpoint_url = endpoint_url
        self.s3_client = None

    def __getstate__(self):
        # Don't pickle the client
        state = self.__dict__.copy()
        del state["s3_client"]
        return state

    def __setstate__(self, state):
        # Reconstruct the client on unpickling
        self.__dict__.update(state)
        self.s3_client = None

    def get_client(self):
        if self.s3_client is None:
            self.s3_client = boto3.client(
                "s3",
                aws_access_key_id=self.access_key,
                aws_secret_access_key=self.secret_key,
                endpoint_url=self.endpoint_url,
            )
        return self.s3_client

    def write_file(self, content, file_key):
        try:
            client = self.get_client()
            client.put_object(Bucket=self.bucket_name, Key=file_key, Body=content)
            return f"s3://{self.bucket_name}/{file_key}"
        except Exception as e:
            return f"Error: {str(e)}"


from dotenv import load_dotenv

load_dotenv()

s3_writer = S3Writer(
    bucket_name=os.getenv("S3_BUCKET_NAME"),
    access_key=os.getenv("AWS_ACCESS_KEY_ID"),
    secret_key=os.getenv("AWS_SECRET"),
    endpoint_url=os.getenv("S3_ENDPOINT_URL"),
)

In [1]:
import datetime
import json
from time import time
import os
from collections import defaultdict
import pickle
from tqdm import tqdm

# from functools import partial
# from itertools import chain, islice
# from multiprocessing import Pool, cpu_count
from pathlib import Path

# from pprint import pprint

import dateutil.parser
import findspark
import lmdb
import numpy as np
import psycopg
import pyspark.sql.functions as sf
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)
from colabfit.tools.schema import (
    property_object_schema,
    config_df_schema,
    config_schema,
    property_object_df_schema,
)
from colabfit.tools.configuration import AtomicConfiguration, config_schema
from colabfit.tools.database import DataManager, PGDataLoader
from colabfit.tools.dataset import Dataset, dataset_schema
from colabfit.tools.property import Property, property_object_schema
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)
from colabfit.tools.schema import configuration_set_schema
import pyarrow as pa

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)
findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

# Connect to DB and run loader

In [23]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/21 14:49:53 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/21 14:49:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/21 14:49:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/21 14:49:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
import os
import pickle
import sys
import time
from pathlib import Path

from ase.io import iread
from dotenv import load_dotenv
from tqdm import tqdm

from colabfit.tools.configuration import AtomicConfiguration
from colabfit.tools.database import DataManager, SparkDataLoader
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    free_energy_pd,
    potential_energy_pd,
)

load_dotenv()
loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
# loader.set_vastdb_session(
PKL_FP = Path("data/oc20_data_mapping.pkl")
with open(PKL_FP, "rb") as f:
    OC20_MAP = pickle.load(f)

df.rdd  
you can only parallelize one time so don't try to do a dataframe select from an rdd  
updating to sdk 5.1 in a couple weeks  
boto3 and s3 are the amazon file system interactions, mostly for adding metadata TO FILES (not to the database) and interacting with the files as FileExistsError. 
Make sure to spark.stop() at end of  python file.

In [29]:
def update_co_rows_cs_id(self, co_ids: list[str], cs_id: str):
    with psycopg.connect(
        """dbname=colabfit user=%s password=%s host=localhost port=5432"""
        % (
            user,
            password,
        )
    ) as conn:
        # dbname=self.database_name,
        # user=self.properties["user"],
        # password=self.properties["password"],
        # host="localhost",
        # port="5432",
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = 
            """
        )
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = concat(%s::text, 
                rtrim(ltrim(replace(configuration_set_ids,%s,''), 
                
                '['),']') || ', ', %s::text)
            WHERE id = ANY(%s)""",
            ("[", f"{cs_id}", f"{cs_id}]", co_ids),
            # ("[", f", {cs_id}", f", {cs_id}]"),
        )
        # cur.fetchall()
        conn.commit()

In [None]:
# You were trying to get  postgresql to recognize the WHERE id = ANY() array syntax

In [112]:
with psycopg.connect(
    dbname="colabfit",
    user=os.environ.get("PGS_USER"),
    password=os.environ.get("PGS_PASS"),
    host="localhost",
) as conn:
    with conn.cursor() as cur:

        # cur.execute(
        #     "UPDATE configurations SET configuration_set_ids = configuration_set_ids || %(cs_id)s WHERE id = ANY(%(co_ids)s)",
        #     {"cs_id": cs["id"], "co_ids": co_ids},
        # )
        # data = cur.fetchall()
        cur.execute(
            "SELECT * FROM public.configurations WHERE id = ANY(%s)",
            [co_ids],
        )
        data2 = cur.fetchall()
    conn.commit()

# In progress

Upsert appears to be this for postgres:
```
update the_table
    set id = id || array[5,6]
where id = 4;
```
* ~~Check for upsert function from pyspark to concatenate lists of relationships instead of primary key id collision~~
* There is no pyspark-upsert function. Will have to manage this possibly through a different sql-based library
* Written: find duplicates, but convert to access database, not download full dataframe
* I see this being used with batches of hashes during upload: something like
    ``` for batch in batches:
            hash_duplicates = find_duplicates(batch, loader/database)
            hash_duplicates.make_change_to_append_dataset-ids
            hash_duplicates.write-to-database
* Where would be the best place to catch duplicates? Keeping in mind that this might be a bulk operation (i.e. on the order of millions, like with ANI1/ANI2x variations)

In [9]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/30 09:52:06 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/30 09:52:06 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/30 09:52:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/30 09:52:08 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

In [None]:
from colabfit.tools.schema import *
import pyspark.sql.functions as sf

rows = dm._gather_co_po_rows(dm.prop_defs, dm.prop_map, dm.dataset_id, dm.configs)
co_rows, po_rows = list(zip(*rows))
spark_df = loader.spark.createDataFrame(po_rows, schema=property_object_df_schema)
table_name = loader.prop_object_table


def write_table(
    self,
    spark_df,
    table_name: str,
    ids_filter: list[str] = None,
    check_length_col: str = None,
):
    """Include self.table_prefix in the table name when passed to this function"""


if ids_filter is not None:
    spark_df = spark_df.filter(sf.col("id").isin(ids_filter))
ids = [x["id"] for x in spark_df.select("id").collect()]
all_unique = self.check_unique_ids(table_name, ids)
if not all_unique:
    raise ValueError("Duplicate IDs found in table. Not writing.")
table_split = table_name.split(".")
string_cols = [f.name for f in spark_df.schema if f.dataType.typeName() == "array"]
string_col_udf = sf.udf(stringify_df_val, StringType())
for col in string_cols:
    spark_df = spark_df.withColumn(col, string_col_udf(sf.col(col)))
arrow_schema = spark_schema_to_arrow_schema(spark_df.schema)
if not self.spark.catalog.tableExists(table_name):
    print(f"Creating table {table_name}")
    with self.session.transaction() as tx:
        schema = tx.bucket(table_split[1]).schema(table_split[2])
        schema.create_table(table_split[3], arrow_schema)
arrow_rec_batch = pa.table(
    [pa.array(col) for col in zip(*spark_df.collect())],
    schema=arrow_schema,
).to_batches()[0]
with self.session.transaction() as tx:
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])
    table.insert(arrow_rec_batch)


write_table(loader, spark_df, table_name)

In [None]:
self = loader

dataset_id = "DS_y7nrdsjtuwom_0"


spark_schema = StructType(
    [
        StructField("id", StringType(), False),
        StructField("multiplicity", IntegerType(), True),
        StructField("last_modified", TimestampType(), False),
        StructField("$row_id", IntegerType(), False),
    ]
)
with self.session.transaction() as tx:
    table_name = self.prop_object_table
    table_path = table_name.split(".")
    table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
    rec_batches = table.select(
        predicate=table["dataset_id"] == dataset_id,
        columns=["id", "multiplicity", "last_modified"],
        internal_row_id=True,
    )
    for rec_batch in rec_batches:
        df = self.spark.createDataFrame(
            rec_batch.to_struct_array().to_pandas(), schema=spark_schema
        )
        print(f"length of df: {df.count()}")
        df = df.withColumn("multiplicity", sf.lit(0))
        update_time = dateutil.parser.parse(
            datetime.datetime.now(tz=datetime.timezone.utc).strftime(
                "%Y-%m-%dT%H:%M:%SZ"
            )
        )
        df = df.withColumn("last_modified", sf.lit(update_time).cast("timestamp"))
        arrow_schema = pa.schema(
            [
                pa.field("id", pa.string()),
                pa.field("multiplicity", pa.int32()),
                pa.field("last_modified", pa.timestamp("us")),
                pa.field("$row_id", pa.int32()),
            ]
        )
        update_table = pa.table(
            [pa.array(col) for col in zip(*df.collect())], schema=arrow_schema
        )
        table.update(
            rows=update_table,
            columns=["multiplicity", "last_modified"],
        )

In [None]:
with self.session.transaction() as tx:
    table_path = "ndb.colabfit.dev.test_co".split(".")
    table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
    rec_batch = table.select()

    rec_batch = table.select(
        predicate=table["id"].isin(id_batch),
        columns=update_cols + ["id"],
        internal_row_id=True,
    )

    rec_batch = rec_batch.read_all()
    duplicate_df = self.spark.createDataFrame(
        rec_batch.to_struct_array().to_pandas(), schema=spark_schema
    )

In [None]:
import dateutil.parser
import datetime
import pyarrow as pa
from colabfit.tools.utilities import *
from pyspark.sql.types import TimestampType

datetimes = [datetime.datetime.now(tz=datetime.timezone.utc) for x in range(10)]

datetimes = list(zip(*[datetimes, datetimes]))
dtdf = spark.createDataFrame(
    datetimes,
    schema=StructType(
        [
            StructField("datetime", TimestampType()),
            StructField("datetime2", TimestampType()),
        ]
    ),
)
arrschema = pa.schema(
    [
        pa.field("datetime", pa.timestamp("us")),
        pa.field("datetime2", pa.timestamp("ns")),
    ]
)

dtdf.write.mode("append").saveAsTable("ndb.colabfit.dev.time_test")
dtdf_read = spark.read.table("ndb.colabfit.dev.time_test")
dtdf_read.show(truncate=False)


from vastdb.session import Session
from dotenv import load_dotenv

load_dotenv()

access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
session = Session(endpoint=endpoint, access=access_key, secret=access_secret)
with session.transaction() as tx:
    sch = tx.bucket("colabfit").schema("dev")
    sch.create_table("time_test3", arrschema)
    table = tx.bucket("colabfit").schema("dev").table("time_test3")
    rec_batch = pa.table(
        [pa.array(col) for col in zip(*dtdf.collect())],
        arrschema,
    ).to_batches()[0]
    print(rec_batch)
    table.insert(rec_batch)

with session.transaction() as tx:
    sch = tx.bucket("colabfit").schema("dev")
    table = tx.bucket("colabfit").schema("dev").table("time_test3")
    print(table.select().read_all())


dtdf2 = spark.read.table("ndb.colabfit.dev.time_test2")
dtdf2.show(truncate=False)

In [None]:
from pyspark.sql import Row


def write_value_to_file(path_prefix, extension, BUCKET_DIR, write_column, row):
    """i.e.: partial(_write_value(
    'CO/positions',
    'txt',
    '/save/here'
    'positions',
    )
    """
    id = row["id"]
    value = row[write_column]
    row_dict = row.copy()
    split = id[-4:]
    filename = f"{id}.{extension}"
    full_path = Path(BUCKET_DIR) / path_prefix / split / filename
    full_path.parent.mkdir(parents=True, exist_ok=True)
    full_path.write_text(str(value))
    # row_dict = row.asDict()
    row_dict[write_column] = str(full_path)
    return Row(**row_dict)


from functools import partial

part_write = partial(
    write_value_to_file,
    "CO/positions",
    "txt",
    "/scratch/gw2338/vast/data-lake-main/spark/scripts",
    "positions",
)

In [None]:
configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
co_rows = [x.spark_row for x in configs]
rdd = sc.parallelize(co_rows)
rdd.foreachPartition(part_write)

In [None]:
config_list = list(mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg")))
dm2.configs = config_list[:50]
dm2.load_co_po_to_vastdb(loader)
dm2.configs = config_list[25:]
dm2.load_co_po_to_vastdb(loader)

In [None]:
from importlib import reload

import colabfit.tools.utilities
import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set
import colabfit.tools.schema

reload(colabfit.tools.utilities)
reload(colabfit.tools.schema)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset
property_object_df_schema = colabfit.tools.schema.property_object_df_schema
property_object_schema = colabfit.tools.schema.property_object_schema
##############

import json
import lmdb
import pickle
from colabfit.tools.database import DataManager, SparkDataLoader

loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
load_dotenv()
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
loader.set_vastdb_session(
    endpoint=endpoint, access_key=access_key, access_secret=access_secret
)

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)

carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"


dm = DataManager(
    nprocs=1,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
dm.configs = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))

match = [
    (r".*3.*", None, "3_configurations", "Carmat with 3"),
    (r".*4.*", None, "4_configurations", "Carmat with 4"),
]
# dm.load_co_po_to_vastdb(loader)

In [None]:
def _gather_co_po_rows(
    prop_defs: list[dict],
    prop_map: dict,
    dataset_id,
    configs: list[AtomicConfiguration],
):
    """Convert COs and DOs to Spark rows."""
    co_po_rows = []
    for config in configs:
        config.set_dataset_id(dataset_id)
        property = Property.from_definition(
            definitions=prop_defs,
            configuration=config,
            property_map=prop_map,
            standardize_energy=True,
        )
        yield property


from colabfit.tools.property import Property
from colabfit.tools.property_definitions import *


propgen = _gather_co_po_rows(dm.prop_defs, dm.prop_map, dm.dataset_id, dm.configs)
p = next(propgen)

In [2]:
from importlib import reload

import colabfit.tools.utilities
import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration
import colabfit.tools.configuration_set
import colabfit.tools.property
import colabfit.tools.schema

reload(colabfit.tools.utilities)
reload(colabfit.tools.dataset)
reload(colabfit.tools.configuration)
reload(colabfit.tools.database)
reload(colabfit.tools.property)
reload(colabfit.tools.schema)
AtomicConfiguration = colabfit.tools.configuration.AtomicConfiguration
DataManager = colabfit.tools.database.DataManager
SparkDataLoader = colabfit.tools.database.SparkDataLoader
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset
Property = colabfit.tools.property.Property
dataset_df_schema = colabfit.tools.schema.dataset_df_schema
dataset_schema = colabfit.tools.schema.dataset_schema

#################################################

batches = dm.gather_co_po_in_batches_no_pool()
batch = next(batches)
cos, pos = zip(*batch)
from colabfit.tools.schema import *
import pyspark.sql.functions as sf

podf = spark.createDataFrame(pos, schema=property_object_df_schema).limit(100)
from colabfit.tools.utilities import stringify_df_val

schema = property_object_df_schema
string_cols = [f.name for f in schema if f.dataType.typeName() == "array"]
string_col_udf = sf.udf(stringify_df_val, StringType())
for col in string_cols:
    podf = podf.withColumn(col, string_col_udf(sf.col(col)))
podf2 = split_long_string(podf)

codf = spark.createDataFrame(cos, schema=config_df_schema)
combindf = codf.withColumnRenamed("id", "configuration_id").join(
    podf, on="configuration_id", how="inner"
)
combindf.select("id", "configuration_id", "nsites", "atomic_forces_00").limit(1).show()

NameError: name 'dm' is not defined

In [None]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as sf
from pyspark.sql.types import StringType

NSITES_COL_SPLITS = 20
from pyspark.sql import DataFrame
from pyspark.sql.functions import substring, length, col, lit

In [None]:
spark.sql("drop table if exists ndb.colabfit.dev.test_co3")
spark.sql("drop table if exists ndb.colabfit.dev.test_po3")
spark.sql("drop table if exists ndb.colabfit.dev.test_cs3")
spark.sql("drop table if exists ndb.colabfit.dev.test_ds3")

loader.read_table(loader.dataset_table).show(truncate=False)

In [None]:
with loader.session.transaction() as tx:
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])
    sch = table.arrow_schema

In [None]:
ids.select(pyspark.sql.functions.collect_list("id")).first()[0]

In [None]:
from colabfit.tools.schema import *

batches = dm.gather_co_po_in_batches_no_pool()
batch = next(batches)
cos, pos = zip(*batch)
co_df = spark.createDataFrame(cos, schema=config_df_schema)
po_df = spark.createDataFrame(pos, schema=property_object_df_schema)
loader.write_table(po_df, loader.prop_object_table)

po = loader.read_table(loader.prop_object_table)
po.count()
ids = [x["id"] for x in po_df.select("id").collect()]
# new, old = loader.find_existing_po_rows_append_elem(ids, podf)

In [None]:
from colabfit.tools.database import batched
from colabfit.tools.utilities import *
from colabfit.tools.schema import *
from colabfit.tools.database import *
from colabfit.tools.configuration import *

import pyarrow as pa

self = loader
import dateutil.parser
import datetime
import pyspark.sql.functions as sf


cols = ["dataset_ids"]
elems = ["DS_y7nrdsjtuw0g_1"]


if isinstance(cols, str):
    cols = [cols]


if isinstance(elems, str):
    elems = [elems]


col_types = {
    "id": StringType(),
    "last_modified": TimestampType(),
    "$row_id": IntegerType(),
}


arr_cols = []
for col in cols:
    col_types[col] = get_spark_field_type(config_schema, col)
    is_arr = get_spark_field_type(config_df_schema, col)
    if is_arr.typeName() == "array":
        arr_cols.append(col)


update_cols = [col for col in col_types if col not in ["id", "$row_id"]]


total_write_cols = update_cols + ["$row_id"]
ids = [x["id"] for x in co_df.select("id").collect()]
batched_ids = batched(ids, 10000)
new_ids = []
existing_ids = []


id_batch = next(batched_ids)


id_batch = list(set(id_batch))


# We only have to use vastdb-sdk here bc we need the '$row_id' column
with self.session.transaction() as tx:
    table_path = self.config_table.split(".")
    table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
    rec_batch = table.select(
        predicate=table["id"].isin(id_batch),
        columns=update_cols + ["id"],
        internal_row_id=True,
    )
    spark_schema = StructType(
        [StructField(col, col_types[col], True) for i, col in enumerate(update_cols)]
        + [
            StructField("id", StringType(), False),
            StructField("$row_id", IntegerType(), False),
        ]
    )
    rec_batch = rec_batch.read_all()
    duplicate_df = self.spark.createDataFrame(
        rec_batch.to_pylist(), schema=spark_schema
    )
    print(f"length of df: {duplicate_df.count()}")


unstring_udf = sf.udf(unstring_df_val, ArrayType(StringType()))


for col_name, col_type in col_types.items():
    if col_name in arr_cols:
        duplicate_df = duplicate_df.withColumn(col_name, unstring_udf(sf.col(col_name)))


for col, elem in zip(cols, elems):
    if col == "labels":
        co_df_labels = co_df.select("id", "labels").collect()
        duplicate_df.withColumnRenamed("labels", "labels_dup").join(
            co_df_labels.withColumnRenamed("labels", "labels_co_df"),
            on="id",
        ).withColumn(
            "labels",
            sf.array_distinct(sf.array_union("labels_dup", "labels_co_df")),
        )
    else:
        duplicate_df = duplicate_df.withColumn(
            col,
            sf.array_distinct(sf.array_union(sf.col(col), sf.array(sf.lit(elem)))),
        )


existing_ids_batch = [x["id"] for x in duplicate_df.select("id").collect()]

new_ids_batch = [id for id in id_batch if id not in existing_ids_batch]
string_udf = sf.udf(stringify_df_val, StringType())
print(arr_cols)


for col_name in duplicate_df.columns:
    print("stringifying column: ", col_name)
    if col_name in arr_cols:
        duplicate_df = duplicate_df.withColumn(col_name, string_udf(sf.col(col_name)))


update_time = dateutil.parser.parse(
    datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
)

duplicate_df = duplicate_df.withColumn(
    "last_modified", sf.lit(update_time).cast("timestamp")
)

update_schema = StructType(
    [StructField(col, col_types[col], False) for col in total_write_cols]
)

arrow_schema = spark_schema_to_arrow_schema(update_schema)
update_table = pa.table(
    [pa.array(col) for col in zip(*duplicate_df.select(total_write_cols).collect())],
    schema=arrow_schema,
)

with self.session.transaction() as tx:
    table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
    table.update(
        rows=update_table,
        columns=update_cols,
    )
new_ids.extend(new_ids_batch)
existing_ids.extend(existing_ids_batch)

In [None]:
spark.sql("show tables in ndb.colabfit.dev").show()
spark.sql("drop table ndb.colabfit.dev.gpw_test_configs")
spark.sql("drop table ndb.colabfit.dev.gpw_test_prop_objects")
spark.sql("drop table ndb.colabfit.dev.gpw_test_config_sets")
spark.sql("drop table ndb.colabfit.dev.gpw_test_datasets")

In [None]:
from colabfit.tools.utilities import *
from colabfit.tools.database import *
from colabfit.tools.schema import *
from functools import partial

self = loader
batches = dm.gather_co_po_in_batches_no_pool()
batch = next(batches)

co_rows, po_rows = list(zip(*batch))
codf = loader.spark.createDataFrame(co_rows, schema=config_df_schema)
podf = loader.spark.createDataFrame(po_rows, schema=property_object_df_schema)
df = podf
string_cols = [f.name for f in df.schema if f.dataType.typeName() == "array"]
string_col_udf = sf.udf(stringify_df_val, StringType())
for col in string_cols:
    df = df.withColumn(col, string_col_udf(sf.col(col)))

column_name = "atomic_forces_00"


####################################################################################
def get_max_string_length(df, column_name):
    return (
        df.select(column_name)
        .select(sf.length(column_name).alias("string_length"))
        .agg(sf.max("string_length"))
        .collect()[0][0]
    )


def split_long_string_cols(df, column_name: str, max_string_length: int = 60000):
    """
    Splits a long string column into multiple columns based on a maximum string length.
    :param df: Input DataFrame with array cols already stringified
    :param column_name: Name of the column containing the long string
    :param max_string_length: Maximum length for each split string
    :return: DataFrame with the long string split across multiple columns
    """
    if get_max_string_length(df, column_name) <= max_string_length:
        print("no columns truncated")
        return df
    print("columns truncated")
    overflow_columns = [
        f"{'_'.join(column_name.split('_')[:-1])}_{i+1:02}" for i in range(19)
    ]
    if not all([col in df.columns for col in overflow_columns]):
        raise ValueError("Overflow columns not found in target DataFrame schema")
    all_columns = [column_name] + overflow_columns
    tmp_columns = [f"{col_name}_tmp" for col_name in all_columns]
    df = df.withColumn("total_length", sf.length(sf.col(column_name)))
    substring_exprs = [
        sf.when(
            sf.length(sf.col(column_name)) - (i * max_string_length) > 0,
            sf.substring(
                sf.col(column_name), (i * max_string_length + 1), max_string_length
            ),
        )
        .otherwise(sf.lit(None))
        .alias(col_name)
        for i, col_name in enumerate(tmp_columns)
    ]
    df = df.select("*", *substring_exprs)
    for tmp_col, col in zip(tmp_columns, all_columns):
        df = df.drop(col).withColumnRenamed(f"{tmp_col}", col)
    df = df.drop("total_length")
    return df


# Make this to replace the columns, not just add duplicate names
def split_long_string(df, col_name, thresh):
    columns = [sf.col(c) for c in df.columns]
    num_splits = NSITES_COL_SPLITS
    split_exprs = [
        sf.when(
            sf.col(col_name).substr(i * thresh + 1, thresh) != "",
            sf.col(col_name).substr(i * thresh + 1, thresh),
        ).otherwise(sf.lit(""))
        for i in range(num_splits)
    ]
    for i, expr in enumerate(split_exprs):
        columns.append(expr.alias(f"{'_'.join(col_name.split('_')[:-1])}_{i:02}"))
    return df.select(columns)


df1 = split_long_string(podf, "atomic_forces_00", 500 // 19)

In [None]:
ids_filter = None
check_length_col = "atomic_forces_00"
_MAX_STRING_LEN = 2500
if ids_filter is not None:
    spark_df = spark_df.filter(sf.col("id").isin(ids_filter))
all_unique = self.check_unique_ids(table_name, spark_df)
if not all_unique:
    raise ValueError("Duplicate IDs found in table. Not writing.")
table_split = table_name.split(".")
string_cols = [f.name for f in spark_df.schema if f.dataType.typeName() == "array"]
string_col_udf = sf.udf(stringify_df_val, StringType())
for col in string_cols:
    spark_df = spark_df.withColumn(col, string_col_udf(sf.col(col)))
if check_length_col is not None:
    spark_df = split_long_string_cols(spark_df, check_length_col, _MAX_STRING_LEN)
# arrow_schema = spark_schema_to_arrow_schema(spark_df.schema)
# for field in arrow_schema:
#     field = field.with_nullable(True)
# if not self.spark.catalog.tableExists(table_name):
#     print(f"Creating table {table_name}")
#     with self.session.transaction() as tx:
#         schema = tx.bucket(table_split[1]).schema(table_split[2])
#         schema.create_table(table_split[3], arrow_schema)
with self.session.transaction() as tx:
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])
    arrow_schema = table.arrow_schema

arrow_rec_batch = pa.table(
    [pa.array(col) for col in zip(*spark_df.collect())],
    schema=arrow_schema,
).to_batches()

with self.session.transaction() as tx:
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])

    for rec_batch in arrow_rec_batch:
        table.insert(rec_batch)

In [None]:
# This one leaves empty strings
# substring_exprs2 = [
#     sf.substring(
#         df[column_name],
#         (i * max_string_length + 1),
#         max_string_length,
#     ).alias(col_name)
#     for i, col_name in enumerate(tmp_columns)
# ]

In [None]:
import numpy as np
from functools import partial


def generate_random_float_array(n):
    # Generate an n*3 array with random floats
    random_array = np.random.rand(n, 3)
    return random_array


def split_string(s, max_length=60000):
    if s is None:
        return [None]
    return [s[i : i + max_length] for i in range(0, len(s), max_length)]


np.set_printoptions(threshold=np.inf)
x = generate_random_float_array(10000)
xstr = "".join(np.array2string(np.arr(x), separator=",").replace("\n", ""))
splx = split_string(xstr)

In [None]:
from colabfit.tools.utilities import _write_value
from functools import partial

In [None]:
part_write = partial(
    _write_value,
    "CO/positions",
    "txt",
    "/scratch/gw2338/vast/data-lake-main/spark/scripts",
    "positions",
)

In [None]:
timestamp_struct = StructType(
    [
        StructField("datetime", TimestampType(), False),
        StructField("datetime2", TimestampType(), False),
    ]
)
import pyarrow as pa
import datetime

timestamp_pyarrow_schema = pa.schema(
    [
        pa.field("datetime", pa.timestamp("us")),
        pa.field("datetime2", pa.timestamp("ns")),
    ]
)
datetimes = [datetime.datetime.now(tz=datetime.timezone.utc) for x in range(10)]
dtdf = spark.createDataFrame(
    datetimes,
    schema=timestamp_struct,
)

In [None]:
import pyspark.sql.functions as sf
from ast import literal_eval
from colabfit.tools.schema import *


def create_join_udf():
    def join_cols(*cols):
        return literal_eval(
            "".join(
                sf.col(c) for c in cols if sf.col(c) is not None and sf.col(c) != "[]"
            )
        )

    return sf.udf(join_cols, ArrayType(DoubleType()))

In [None]:
join_udf = create_join_udf()
# udf_join = sf.udf(join_cols, ArrayType(DoubleType()))
join_cols = [f"atomic_forces_{i:02}" for i in range(19)]
spark.table(loader.prop_object_table).withColumn(
    "joined_cols", join_udf(*join_cols)
).show()

In [None]:
def get_pos_cos_by_filter(
    self,
    po_filter_conditions: list[tuple[str, str, str | int | float | list]] = None,
    co_filter_conditions: list[tuple[str, str, str | int | float | list | None]] = None,
):
    """
    example filter conditions:
    po_filter_conditions = [("dataset_id", "=", "ds_id1"),
                            ("method", "like", "DFT%")]
    co_filter_conditions = [("nsites", ">", 15),
                            ('labels', 'array_contains', 'label1')]
    """
    po_filter_conditions = [("dataset_id", "==", "DS_rf10ovxd13ne_0")]
    co_filter_conditions = [("nsites", ">", "10")]
    po_df = self.read_table(self.prop_object_table, unstring=True)
    po_df = get_filtered_table(self, po_df, po_filter_conditions)
    po_df = po_df.drop("chemical_formula_hill")

    co_df = self.read_table(self.config_table, unstring=True)
    overlap_cols = [col for col in po_df.columns if col in co_df.columns]
    po_df = po_df.select(
        [
            col if col not in overlap_cols else sf.col(col).alias(f"prop_object_{col}")
            for col in po_df.columns
        ]
    )
    co_df = co_df.select(
        [
            (
                col
                if col not in overlap_cols
                else sf.col(col).alias(f"configuration_{col}")
            )
            for col in co_df.columns
        ]
    )
    co_df = get_filtered_table(self, co_df, co_filter_conditions)
    co_po_df = co_df.join(po_df, on="configuration_id", how="inner")
    return co_po_df


def get_filtered_table(
    self,
    df,
    filter_conditions: list[tuple[str, str, str | int | float | list]] | None = None,
):
    if filter_conditions is None:
        return df
    for i, (column, operand, condition) in enumerate(filter_conditions):
        if operand == "in":
            df = df.filter(sf.col(column).isin(condition))
        elif operand == "like":
            df = df.filter(sf.col(column).like(condition))
        elif operand == "rlike":
            df = df.filter(sf.col(column).rlike(condition))
        elif operand == "==":
            df = df.filter(sf.col(column) == condition)
        elif operand == "array_contains":
            df = df.filter(sf.array_contains(sf.col(column), condition))
        elif operand == ">":
            df = df.filter(sf.col(column) > condition)
        elif operand == "<":
            df = df.filter(sf.col(column) < condition)
        elif operand == ">=":
            df = df.filter(sf.col(column) >= condition)
        elif operand == "<=":
            df = df.filter(sf.col(column) <= condition)
        else:
            raise ValueError(f"Operand {operand} not implemented in get_pos_cos_filter")
    return df


get_pos_cos_by_filter(
    loader, [("dataset_id", "==", "DS_rf10ovxd13ne_0")], [("nsites", "<", "10")]
)

In [None]:
import pyspark.sql.functions as sf

loader.config_table = "ndb.colabfit.ingest.co"
co = loader.read_table(loader.config_table)
ds_ids = [
    x["id"]
    for x in spark.read("ndb.colabfit.ingest.ds").limit(10).select("id").collect()
]
id_regexes = "(" + ")|(".join([f".*{ds_id}.*" for ds_id in ds_ids]) + ")"
co1.filter(sf.col("dataset_ids").rlike(id_regexes)).count()