# Imports

In [None]:
import os
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import boto3


def create_s3_writer(bucket_name, access_key, secret_key, endpoint_url=None):
    s3_client = boto3.client(
        "s3",
        aws_access_key_id=access_key,
        aws_secret_access_key=secret_key,
        endpoint_url=endpoint_url,
    )

    def write_to_s3(content, file_key):
        try:
            s3_client.put_object(Bucket=bucket_name, Key=file_key, Body=content)
            return f"s3://{bucket_name}/{file_key}"
        except Exception as e:
            return f"Error: {str(e)}"

    return write_to_s3


# Create the S3 writer function
s3_writer = create_s3_writer(
    bucket_name="your-bucket-name",
    access_key="your-access-key",
    secret_key="your-secret-key",
    endpoint_url="your-endpoint-url",  # Optional, remove if using standard S3
)


# Create a UDF that uses the S3 writer
@udf(returnType=StringType())
def write_content_to_s3(content, id):
    if content is None:
        return None
    file_key = f"path/to/files/{id}.txt"
    return s3_writer(content, file_key)


# Apply the UDF to your DataFrame
df = df.withColumn("file_path", write_content_to_s3(df.content_column, df.id))


loader.access_key = os.getenv("SPARK_ID")
loader.access_secret = os.getenv("SPARK_KEY")
loader.endpoint = os.getenv("SPARK_ENDPOINT")

In [None]:
from pyspark.sql import DataFrame
from colabfit.tools.schema import *
from pyspark import StorageLevel
from ibis import _
from collections import defaultdict
from tqdm import tqdm
import pyspark.sql.functions as sf
from colabfit.tools.utilities import unstring_df_val


def get_config_ds_data(self, dataset_id, table_name):
    elements_counts = defaultdict(lambda x: 0)
    total_elements = 0
    elements = set()
    dimension_types = set()
    nperiodic_dimensions = set()
    nsites = 0
    predicate = _.dataset_ids.contains(dataset_id)
    bucket_name, schema_name, table_n = self._get_table_split(table_name)
    with self.session.transaction() as tx:
        table = tx.bucket(bucket_name).schema(schema_name).table(table_n)
        rec_batch_reader = table.select(predicate=predicate, internal_row_id=False)
        # spark_df: DataFrame = None
        nconfigurations = 0
        df_schema = config_schema
        unstr_schema = config_df_schema
        schema_type_dict = {f.name: f.dataType for f in unstr_schema}
        string_cols = [f.name for f in unstr_schema if f.dataType.typeName() == "array"]
        for rec_batch in tqdm(rec_batch_reader):
            print("reading next batch")
            if rec_batch is None or rec_batch.num_rows == 0:
                break
            print(f"Read {rec_batch.num_rows} rows")
            total_rows = rec_batch.num_rows
            nconfigurations += total_rows
            print(f"Total rows read: {nconfigurations}")
            chunk_size = 10000
            for i in range(0, total_rows, chunk_size):
                batch_chunk = rec_batch.slice(i, min(chunk_size, total_rows - i))
                pandas_chunk_df = batch_chunk.to_pandas()
                chunk_spark_df = self.spark.createDataFrame(
                    pandas_chunk_df, schema=df_schema
                )
                for col in string_cols:
                    string_col_udf = sf.udf(unstring_df_val, schema_type_dict[col])
                    chunk_spark_df = chunk_spark_df.withColumn(
                        col, string_col_udf(sf.col(col))
                    )
                nsites += chunk_spark_df.agg({"nsites": "sum"}).first()[0]
                batch_elements = sorted(
                    chunk_spark_df.withColumn(
                        "exploded_elements", sf.explode("elements")
                    )
                    .agg(sf.collect_set("exploded_elements").alias("exploded_elements"))
                    .select("exploded_elements")
                    .take(1)[0][0]
                )
                elements.update(batch_elements)
                atomic_ratios_df = chunk_spark_df.select("atomic_numbers").withColumn(
                    "single_element", sf.explode("atomic_numbers")
                )
                total_elements += atomic_ratios_df.count()
                atomic_ratios_df = atomic_ratios_df.groupBy("single_element").count()
                dimension_types.update(
                    chunk_spark_df.agg(sf.collect_set("dimension_types")).collect()[0][0]
                )
                nperiodic_dimensions.update(
                    chunk_spark_df.agg(sf.collect_set("nperiodic_dimensions")).collect()[
                        0
                    ][0]
                )
                for row in atomic_ratios_df.collect():
                    elements_counts[row["single_element"]] += row["count"]

    nelements = len(list(elements))

    config_data = {
        "elements": None,
        "elements_ratios": None,
    }

In [None]:
from colabfit.tools.utilities import *
from colabfit.tools.schema import *
import pyspark.sql.functions as sf

cos = spark.table("ndb.colabfit.dev.co_oodcat")
cos = cos.filter(sf.col("dataset_ids").contains("DS_wmgdq06mzdys_0"))
unstr_schema = config_df_schema
schema_type_dict = {f.name: f.dataType for f in unstr_schema}
array_cols = [f.name for f in unstr_schema if f.dataType.typeName() == "array"]
df_schema = config_schema
for col in array_cols:
    string_col_udf = sf.udf(unstring_df_val, schema_type_dict[col])
    cos = cos.withColumn(col, string_col_udf(sf.col(col)))

In [None]:
self = loader
predicate = _.dataset_id == dm.dataset_id

read_batches_to_spark(
    loader, "colabfit", "dev", "po_oodcat", predicate, False, property_object_schema
)

In [None]:
import boto3
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
import sys

from dotenv import load_dotenv

load_dotenv()
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
spark_conf = {
    "access_key": access_key,
    "access_secret": access_secret,
    "endpoint": endpoint,
}


def get_s3_client(spark_conf):
    return boto3.client(
        "s3",
        use_ssl=False,
        endpoint_url=spark_conf["endpoint"],
        aws_access_key_id=spark_conf["access_key"],
        aws_secret_access_key=spark_conf["access_secret"],
        region_name="fake-region",
        config=boto3.session.Config(
            signature_version="s3v4", s3={"addressing_style": "path"}
        ),
    )


bucket_name = "colabfit-data"
key = "gpw_METADATA/test.txt"
content = "Hello, world!"
file_path = f"/vdev/{bucket_name}/{key}"
s3_client = get_s3_client(spark_conf)
s3_client.put_object(Bucket=bucket_name, Key=key, Body=content)
response = s3_client.head_object(Bucket=bucket_name, Key=key)


class S3FileManager:
    def __init__(self, bucket_name, access_id, secret_key, endpoint_url=None):
        self.bucket_name = bucket_name
        self.access_id = access_id
        self.secret_key = secret_key
        self.endpoint_url = endpoint_url
        self.s3_client = None

    def __getstate__(self):
        # Don't pickle the client
        state = self.__dict__.copy()
        del state["s3_client"]
        return state

    def __setstate__(self, state):
        # Reconstruct the client on unpickling
        self.__dict__.update(state)
        self.s3_client = None

    def get_client(self):
        if self.s3_client is None:
            self.s3_client = boto3.client(
                "s3",
                use_ssl=False,
                endpoint_url=self.endpoint_url,
                aws_access_key_id=self.access_id,
                aws_secret_access_key=self.secret_key,
                region_name="fake-region",
                config=boto3.session.Config(
                    signature_version="s3v4", s3={"addressing_style": "path"}
                ),
            )
        return self.s3_client

    def write_file(self, content, file_key):
        try:
            client = self.get_client()
            client.put_object(Bucket=self.bucket_name, Key=file_key, Body=content)
            return (f"/vdev/{self.bucket_name}/{file_key}", sys.getsizeof(content))
        except Exception as e:
            return f"Error: {str(e)}"

    def read_file(self, file_key):
        try:
            client = self.get_client()
            response = client.get_object(Bucket=self.bucket_name, Key=file_key)
            return response["Body"].read().decode("utf-8")
        except Exception as e:
            return f"Error: {str(e)}"


s3m = S3FileManager(
    bucket_name=bucket_name,
    access_key=access_key,
    access_secret=access_secret,
    endpoint_url=endpoint,
)
s3m.get_client()
s3m.write_file(content=content, file_key=key)

In [1]:
import datetime
import json
from time import time
import os
from collections import defaultdict
import pickle
from tqdm import tqdm

# from functools import partial
# from itertools import chain, islice
# from multiprocessing import Pool, cpu_count
from pathlib import Path

# from pprint import pprint

import dateutil.parser
import findspark
import lmdb
import numpy as np
import psycopg
import pyspark.sql.functions as sf
from ase.atoms import Atoms
from ase.io.cfg import read_cfg
from dotenv import load_dotenv
from pyspark.sql import SparkSession
from pyspark.sql.types import (
    ArrayType,
    BooleanType,
    DoubleType,
    FloatType,
    IntegerType,
    LongType,
    StringType,
    StructField,
    StructType,
    TimestampType,
)
from colabfit.tools.schema import (
    property_object_schema,
    config_df_schema,
    config_schema,
    property_object_df_schema,
)
from colabfit.tools.configuration import AtomicConfiguration, config_schema
from colabfit.tools.database import DataManager, PGDataLoader
from colabfit.tools.dataset import Dataset, dataset_schema
from colabfit.tools.property import Property, property_object_schema
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    cauchy_stress_pd,
    potential_energy_pd,
)
from colabfit.tools.schema import configuration_set_schema
import pyarrow as pa

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)
findspark.init()
format = "jdbc"
load_dotenv("./.env")

True

# Connect to DB and run loader

In [23]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/21 14:49:53 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/21 14:49:53 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/21 14:49:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/21 14:49:55 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
import os
import pickle
import sys
import time
from pathlib import Path

from ase.io import iread
from dotenv import load_dotenv
from tqdm import tqdm

from colabfit.tools.configuration import AtomicConfiguration
from colabfit.tools.database import DataManager, SparkDataLoader
from colabfit.tools.property_definitions import (
    atomic_forces_pd,
    free_energy_pd,
    potential_energy_pd,
)

load_dotenv()
loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
# loader.set_vastdb_session(
PKL_FP = Path("data/oc20_data_mapping.pkl")
with open(PKL_FP, "rb") as f:
    OC20_MAP = pickle.load(f)

df.rdd  
you can only parallelize one time so don't try to do a dataframe select from an rdd  
updating to sdk 5.1 in a couple weeks  
boto3 and s3 are the amazon file system interactions, mostly for adding metadata TO FILES (not to the database) and interacting with the files as FileExistsError. 
Make sure to spark.stop() at end of  python file.

In [29]:
def update_co_rows_cs_id(self, co_ids: list[str], cs_id: str):
    with psycopg.connect(
        """dbname=colabfit user=%s password=%s host=localhost port=5432"""
        % (
            user,
            password,
        )
    ) as conn:
        # dbname=self.database_name,
        # user=self.properties["user"],
        # password=self.properties["password"],
        # host="localhost",
        # port="5432",
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = 
            """
        )
        cur = conn.execute(
            """UPDATE configurations
                SET configuration_set_ids = concat(%s::text, 
                rtrim(ltrim(replace(configuration_set_ids,%s,''), 
                
                '['),']') || ', ', %s::text)
            WHERE id = ANY(%s)""",
            ("[", f"{cs_id}", f"{cs_id}]", co_ids),
            # ("[", f", {cs_id}", f", {cs_id}]"),
        )
        # cur.fetchall()
        conn.commit()

In [None]:
# You were trying to get  postgresql to recognize the WHERE id = ANY() array syntax

In [112]:
with psycopg.connect(
    dbname="colabfit",
    user=os.environ.get("PGS_USER"),
    password=os.environ.get("PGS_PASS"),
    host="localhost",
) as conn:
    with conn.cursor() as cur:

        # cur.execute(
        #     "UPDATE configurations SET configuration_set_ids = configuration_set_ids || %(cs_id)s WHERE id = ANY(%(co_ids)s)",
        #     {"cs_id": cs["id"], "co_ids": co_ids},
        # )
        # data = cur.fetchall()
        cur.execute(
            "SELECT * FROM public.configurations WHERE id = ANY(%s)",
            [co_ids],
        )
        data2 = cur.fetchall()
    conn.commit()

# In progress

Upsert appears to be this for postgres:
```
update the_table
    set id = id || array[5,6]
where id = 4;
```
* ~~Check for upsert function from pyspark to concatenate lists of relationships instead of primary key id collision~~
* There is no pyspark-upsert function. Will have to manage this possibly through a different sql-based library
* Written: find duplicates, but convert to access database, not download full dataframe
* I see this being used with batches of hashes during upload: something like
    ``` for batch in batches:
            hash_duplicates = find_duplicates(batch, loader/database)
            hash_duplicates.make_change_to_append_dataset-ids
            hash_duplicates.write-to-database
* Where would be the best place to catch duplicates? Keeping in mind that this might be a bulk operation (i.e. on the order of millions, like with ANI1/ANI2x variations)

In [9]:
JARFILE = os.environ.get("CLASSPATH")
spark = (
    SparkSession.builder.appName("PostgreSQL Connection with PySpark")
    .config("spark.jars", JARFILE)
    .getOrCreate()
)
url = "jdbc:postgresql://localhost:5432/colabfit"
user = os.environ.get("PGS_USER")
password = os.environ.get("PGS_PASS")
properties = {
    "user": user,
    "password": password,
    "driver": "org.postgresql.Driver",
}
loader = PGDataLoader(appname="colabfit", env="./.env")

24/05/30 09:52:06 WARN Utils: Your hostname, arktos resolves to a loopback address: 127.0.1.1; using 172.24.21.25 instead (on interface enp5s0)
24/05/30 09:52:06 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
24/05/30 09:52:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/05/30 09:52:08 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [None]:
mtpu_ds_id = "DS_y7nrdsjtuwom_0"
mtpu_configs = mtpu_reader(Path("data/mtpu_2023/Unified_training_set.cfg"))
dm2 = DataManager(
    nprocs=4,
    configs=mtpu_configs,
    prop_defs=[potential_energy_pd, atomic_forces_pd, cauchy_stress_pd],
    prop_map=PROPERTY_MAP,
    dataset_id=mtpu_ds_id,
)

In [None]:
from colabfit.tools.schema import *
import pyspark.sql.functions as sf

rows = dm._gather_co_po_rows(dm.prop_defs, dm.prop_map, dm.dataset_id, dm.configs)
co_rows, po_rows = list(zip(*rows))
spark_df = loader.spark.createDataFrame(po_rows, schema=property_object_df_schema)
table_name = loader.prop_object_table


def write_table(
    self,
    spark_df,
    table_name: str,
    ids_filter: list[str] = None,
    check_length_col: str = None,
):
    """Include self.table_prefix in the table name when passed to this function"""


if ids_filter is not None:
    spark_df = spark_df.filter(sf.col("id").isin(ids_filter))
ids = [x["id"] for x in spark_df.select("id").collect()]
all_unique = self.check_unique_ids(table_name, ids)
if not all_unique:
    raise ValueError("Duplicate IDs found in table. Not writing.")
table_split = table_name.split(".")
string_cols = [f.name for f in spark_df.schema if f.dataType.typeName() == "array"]
string_col_udf = sf.udf(stringify_df_val, StringType())
for col in string_cols:
    spark_df = spark_df.withColumn(col, string_col_udf(sf.col(col)))
arrow_schema = spark_schema_to_arrow_schema(spark_df.schema)
if not self.spark.catalog.tableExists(table_name):
    print(f"Creating table {table_name}")
    with self.session.transaction() as tx:
        schema = tx.bucket(table_split[1]).schema(table_split[2])
        schema.create_table(table_split[3], arrow_schema)
arrow_rec_batch = pa.table(
    [pa.array(col) for col in zip(*spark_df.collect())],
    schema=arrow_schema,
).to_batches()[0]
with self.session.transaction() as tx:
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])
    table.insert(arrow_rec_batch)


write_table(loader, spark_df, table_name)

In [None]:
self = loader

dataset_id = "DS_y7nrdsjtuwom_0"


spark_schema = StructType(
    [
        StructField("id", StringType(), False),
        StructField("multiplicity", IntegerType(), True),
        StructField("last_modified", TimestampType(), False),
        StructField("$row_id", IntegerType(), False),
    ]
)
with self.session.transaction() as tx:
    table_name = self.prop_object_table
    table_path = table_name.split(".")
    table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
    rec_batches = table.select(
        predicate=table["dataset_id"] == dataset_id,
        columns=["id", "multiplicity", "last_modified"],
        internal_row_id=True,
    )
    for rec_batch in rec_batches:
        df = self.spark.createDataFrame(
            rec_batch.to_struct_array().to_pandas(), schema=spark_schema
        )
        print(f"length of df: {df.count()}")
        df = df.withColumn("multiplicity", sf.lit(0))
        update_time = dateutil.parser.parse(
            datetime.datetime.now(tz=datetime.timezone.utc).strftime(
                "%Y-%m-%dT%H:%M:%SZ"
            )
        )
        df = df.withColumn("last_modified", sf.lit(update_time).cast("timestamp"))
        arrow_schema = pa.schema(
            [
                pa.field("id", pa.string()),
                pa.field("multiplicity", pa.int32()),
                pa.field("last_modified", pa.timestamp("us")),
                pa.field("$row_id", pa.int32()),
            ]
        )
        update_table = pa.table(
            [pa.array(col) for col in zip(*df.collect())], schema=arrow_schema
        )
        table.update(
            rows=update_table,
            columns=["multiplicity", "last_modified"],
        )

In [None]:
with self.session.transaction() as tx:
    table_path = "ndb.colabfit.dev.test_co".split(".")
    table = tx.bucket(table_path[1]).schema(table_path[2]).table(table_path[3])
    rec_batch = table.select()

    rec_batch = table.select(
        predicate=table["id"].isin(id_batch),
        columns=update_cols + ["id"],
        internal_row_id=True,
    )

    rec_batch = rec_batch.read_all()
    duplicate_df = self.spark.createDataFrame(
        rec_batch.to_struct_array().to_pandas(), schema=spark_schema
    )

In [None]:
from importlib import reload

import colabfit.tools.utilities
import colabfit.tools.dataset
import colabfit.tools.database
import colabfit.tools.configuration_set
import colabfit.tools.schema

reload(colabfit.tools.utilities)
reload(colabfit.tools.schema)
reload(colabfit.tools.dataset)
reload(colabfit.tools.database)
DataManager = colabfit.tools.database.DataManager
ConfigurationSet = colabfit.tools.configuration_set.ConfigurationSet
Dataset = colabfit.tools.dataset.Dataset
property_object_df_schema = colabfit.tools.schema.property_object_df_schema
property_object_schema = colabfit.tools.schema.property_object_schema
##############

import json
import lmdb
import pickle
from colabfit.tools.database import DataManager, SparkDataLoader

loader = SparkDataLoader(table_prefix="ndb.colabfit.dev")
load_dotenv()
access_key = os.getenv("SPARK_ID")
access_secret = os.getenv("SPARK_KEY")
endpoint = os.getenv("SPARK_ENDPOINT")
loader.set_vastdb_session(
    endpoint=endpoint, access_key=access_key, access_secret=access_secret
)

with open("formation_energy.json", "r") as f:
    formation_energy_pd = json.load(f)

carmat_config_gen = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))
carmat_ds_id = "DS_y7nrdsjtuw0g_0"


dm = DataManager(
    nprocs=1,
    configs=carmat_config_gen,
    prop_defs=[formation_energy_pd],
    prop_map=CM_PROPERTY_MAP,
    dataset_id=carmat_ds_id,
)
dm.configs = carmat_reader(Path("data/carolina_matdb/base/all/data.mdb"))

match = [
    (r".*3.*", None, "3_configurations", "Carmat with 3"),
    (r".*4.*", None, "4_configurations", "Carmat with 4"),
]
# dm.load_co_po_to_vastdb(loader)

In [None]:
from vastdb.session import Session
from dotenv import load_dotenv

load_dotenv()
vast_db_access = os.getenv("VAST_DB_ACCESS")
vast_db_secret = os.getenv("VAST_DB_SECRET")
endpoint = os.getenv("VAST_DB_ENDPOINT")
session = Session(access=vast_db_access, secret=vast_db_secret, endpoint=endpoint)

with session.transaction() as tx:
    table_split = "ndb.colabfit-prod.prod.co_20240820".split(".")
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])
    print(table.projections())

    sch = table.arrow_schema

In [None]:
from colabfit.tools.utilities import *
from colabfit.tools.database import *
from colabfit.tools.schema import *
from functools import partial

self = loader
batches = dm.gather_co_po_in_batches_no_pool()
batch = next(batches)

co_rows, po_rows = list(zip(*batch))
codf = loader.spark.createDataFrame(co_rows, schema=config_df_schema)
podf = loader.spark.createDataFrame(po_rows, schema=property_object_df_schema)
df = podf
string_cols = [f.name for f in df.schema if f.dataType.typeName() == "array"]
string_col_udf = sf.udf(stringify_df_val, StringType())
for col in string_cols:
    df = df.withColumn(col, string_col_udf(sf.col(col)))

column_name = "atomic_forces_00"


####################################################################################
def get_max_string_length(df, column_name):
    return (
        df.select(column_name)
        .select(sf.length(column_name).alias("string_length"))
        .agg(sf.max("string_length"))
        .collect()[0][0]
    )


def split_long_string_cols(df, column_name: str, max_string_length: int = 60000):
    """
    Splits a long string column into multiple columns based on a maximum string length.
    :param df: Input DataFrame with array cols already stringified
    :param column_name: Name of the column containing the long string
    :param max_string_length: Maximum length for each split string
    :return: DataFrame with the long string split across multiple columns
    """
    if get_max_string_length(df, column_name) <= max_string_length:
        print("no columns truncated")
        return df
    print("columns truncated")
    overflow_columns = [
        f"{'_'.join(column_name.split('_')[:-1])}_{i+1:02}" for i in range(19)
    ]
    if not all([col in df.columns for col in overflow_columns]):
        raise ValueError("Overflow columns not found in target DataFrame schema")
    all_columns = [column_name] + overflow_columns
    tmp_columns = [f"{col_name}_tmp" for col_name in all_columns]
    df = df.withColumn("total_length", sf.length(sf.col(column_name)))
    substring_exprs = [
        sf.when(
            sf.length(sf.col(column_name)) - (i * max_string_length) > 0,
            sf.substring(
                sf.col(column_name), (i * max_string_length + 1), max_string_length
            ),
        )
        .otherwise(sf.lit(None))
        .alias(col_name)
        for i, col_name in enumerate(tmp_columns)
    ]
    df = df.select("*", *substring_exprs)
    for tmp_col, col in zip(tmp_columns, all_columns):
        df = df.drop(col).withColumnRenamed(f"{tmp_col}", col)
    df = df.drop("total_length")
    return df


# Make this to replace the columns, not just add duplicate names
def split_long_string(df, col_name, thresh):
    columns = [sf.col(c) for c in df.columns]
    num_splits = NSITES_COL_SPLITS
    split_exprs = [
        sf.when(
            sf.col(col_name).substr(i * thresh + 1, thresh) != "",
            sf.col(col_name).substr(i * thresh + 1, thresh),
        ).otherwise(sf.lit(""))
        for i in range(num_splits)
    ]
    for i, expr in enumerate(split_exprs):
        columns.append(expr.alias(f"{'_'.join(col_name.split('_')[:-1])}_{i:02}"))
    return df.select(columns)


df1 = split_long_string(podf, "atomic_forces_00", 500 // 19)

In [None]:
ids_filter = None
check_length_col = "atomic_forces_00"
_MAX_STRING_LEN = 2500
if ids_filter is not None:
    spark_df = spark_df.filter(sf.col("id").isin(ids_filter))
all_unique = self.check_unique_ids(table_name, spark_df)
if not all_unique:
    raise ValueError("Duplicate IDs found in table. Not writing.")
table_split = table_name.split(".")
string_cols = [f.name for f in spark_df.schema if f.dataType.typeName() == "array"]
string_col_udf = sf.udf(stringify_df_val, StringType())
for col in string_cols:
    spark_df = spark_df.withColumn(col, string_col_udf(sf.col(col)))
if check_length_col is not None:
    spark_df = split_long_string_cols(spark_df, check_length_col, _MAX_STRING_LEN)
# arrow_schema = spark_schema_to_arrow_schema(spark_df.schema)
# for field in arrow_schema:
#     field = field.with_nullable(True)
# if not self.spark.catalog.tableExists(table_name):
#     print(f"Creating table {table_name}")
#     with self.session.transaction() as tx:
#         schema = tx.bucket(table_split[1]).schema(table_split[2])
#         schema.create_table(table_split[3], arrow_schema)
with self.session.transaction() as tx:
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])
    arrow_schema = table.arrow_schema

arrow_rec_batch = pa.table(
    [pa.array(col) for col in zip(*spark_df.collect())],
    schema=arrow_schema,
).to_batches()

with self.session.transaction() as tx:
    table = tx.bucket(table_split[1]).schema(table_split[2]).table(table_split[3])

    for rec_batch in arrow_rec_batch:
        table.insert(rec_batch)

In [None]:
# This one leaves empty strings
# substring_exprs2 = [
#     sf.substring(
#         df[column_name],
#         (i * max_string_length + 1),
#         max_string_length,
#     ).alias(col_name)
#     for i, col_name in enumerate(tmp_columns)
# ]

In [None]:
import numpy as np
from functools import partial


def generate_random_float_array(n):
    # Generate an n*3 array with random floats
    random_array = np.random.rand(n, 3)
    return random_array


def split_string(s, max_length=60000):
    if s is None:
        return [None]
    return [s[i : i + max_length] for i in range(0, len(s), max_length)]


np.set_printoptions(threshold=np.inf)
x = generate_random_float_array(10000)
xstr = "".join(np.array2string(np.arr(x), separator=",").replace("\n", ""))
splx = split_string(xstr)

In [None]:
from colabfit.tools.utilities import _write_value
from functools import partial

In [None]:
part_write = partial(
    _write_value,
    "CO/positions",
    "txt",
    "/scratch/gw2338/vast/data-lake-main/spark/scripts",
    "positions",
)

In [None]:
timestamp_struct = StructType(
    [
        StructField("datetime", TimestampType(), False),
        StructField("datetime2", TimestampType(), False),
    ]
)
import pyarrow as pa
import datetime

timestamp_pyarrow_schema = pa.schema(
    [
        pa.field("datetime", pa.timestamp("us")),
        pa.field("datetime2", pa.timestamp("ns")),
    ]
)
datetimes = [datetime.datetime.now(tz=datetime.timezone.utc) for x in range(10)]
dtdf = spark.createDataFrame(
    datetimes,
    schema=timestamp_struct,
)

In [None]:
import pyspark.sql.functions as sf
from ast import literal_eval
from colabfit.tools.schema import *


def create_join_udf():
    def join_cols(*cols):
        return literal_eval(
            "".join(
                sf.col(c) for c in cols if sf.col(c) is not None and sf.col(c) != "[]"
            )
        )

    return sf.udf(join_cols, ArrayType(DoubleType()))

In [None]:
from colabfit.tools.schema import *

old_ds = spark.table("ndb.`colabfit-prod`.prod.ds_20240820")
# new_ds = spark.table('ndb.`colabfit-prod`.prod.ds_20240903')
old_schema = dataset_df_schema
new_schema = dataset_df_schema.add(StructField("doi", StringType(), True)).add(
    StructField("publication_year", StringType(), True)
)
# old_ds = old_ds.withColumn('doi')

In [None]:
import pandas as pd

with open("dois_pub_years.csv", "r") as f:
    df = pd.read_csv(f)

dois_df = spark.createDataFrame(
    df,
    schema=StructType(
        [
            StructField("id", StringType(), True),
            StructField("doi", StringType(), True),
            StructField("publication-year", StringType(), True),
        ]
    ),
)
dois_df = dois_df.withColumnRenamed("colabfit-id", "id")
dois_df = dois_df.withColumnRenamed("publication-year", "publication_year")
dois_df.show()


new_ds = old_ds.join(dois_df, on="id", how="left")
old_ds.count()
new_ds.count()
import pyspark.sql.functions as sf

new_ds.filter(sf.col("doi").isNotNull()).count()
new_ds.printSchema()
new_ds.write.mode("overwrite").saveAsTable(
    "ndb.`colabfit-prod`.prod.ds_20240903", schema=new_schema
)
# new_ds.write.mode("overwrite").saveAsTable(
#     "ndb.colabfit.dev.ds_20240903", schema=new_schema
# )
cos = spark.table("ndb.`colabfit-prod`.prod.co_20240903")
cos.write.mode("overwrite").saveAsTable("ndb.colabfit.dev.co_20240903")

In [None]:
dss = spark.table("ndb.`colabfit-prod`.prod.ds_20240903")
dss = dss.filter(sf.col("id") != "DS_otx1qc9f3pm4_0")
dss.write.mode("overwrite").saveAsTable(
    "ndb.`colabfit-prod`.prod.ds_20240903", schema=new_schema
)

In [None]:
old_cos = spark.table("ndb.`colabfit-prod`.prod.co_20240820")
new_pos = spark.table("ndb.`colabfit-prod`.prod.po_20240820")
new_pos_oc20 = new_pos.filter(sf.col("dataset_id") == "DS_otx1qc9f3pm4_0")
new_pos_oc20 = new_pos_oc20.select("configuration_id").withColumnRenamed(
    "configuration_id", "id"
)
cos_not_in_new_pos = old_cos.filter(sf.col("dataset_ids").contains("DS_otx1qc9f3pm4_0"))
cos_not_in_new_pos = cos_not_in_new_pos.join(new_pos_oc20, on="id", how="left_anti")
cos_not_in_new_pos.count()
new_cos = old_cos.join(cos_not_in_new_pos, on="id", how="left_anti")

In [None]:
# For updating single rows in database: example of updating publication year in dataset table
from vastdb.session import Session
from dotenv import load_dotenv

load_dotenv()
vast_db_access = os.getenv("VAST_DB_ACCESS")
vast_db_secret = os.getenv("VAST_DB_SECRET")
endpoint = os.getenv("VAST_DB_ENDPOINT")
session = Session(access=vast_db_access, secret=vast_db_secret, endpoint=endpoint)

table_split = "ndb.`colabfit-prod`.prod.ds".split(".")
bucket_name = table_split[1].replace("`", "")
schema_name = table_split[2].replace("`", "")
table_name = table_split[3]


with session.transaction() as tx:
    table = tx.bucket(bucket_name).schema(schema_name).table(table_name)
    # rec_batch = table.select(
    #     predicate=table["id"] == "DS_otx1qc9f3pm4_0",
    #     columns=["id", "publication_year"],
    #     internal_row_id=True,
    # )
    # rec_batch = rec_batch.read_all()
    pa_table = pa.table(
        {"id": ["DS_otx1qc9f3pm4_0"], "publication_year": ["2024"], "$row_id": [524288]}
    )
    table.update(
        rows=pa_table,
        columns=["publication_year"],
    )

In [None]:
from colabfit.tools.utilities import _empty_dict_from_schema


def to_spark_row(self, config_df, prop_df):
    """"""
    row_dict = _empty_dict_from_schema(dataset_schema)
    row_dict["last_modified"] = dateutil.parser.parse(
        datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
    )
    row_dict["nconfiguration_sets"] = len(self.configuration_set_ids)
    config_df = config_df.select(
        "id",
        "elements",
        "atomic_numbers",
        "nsites",
        "nperiodic_dimensions",
        "dimension_types",
        # "labels",
    )
    prop_df = prop_df.select(
        "atomization_energy",
        "atomic_forces_00",
        "adsorption_energy",
        "electronic_band_gap",
        "cauchy_stress",
        "formation_energy",
        "energy",
    )
    carray_cols = ["atomic_numbers", "elements", "dimension_types"]
    carray_types = {
        col.name: col.dataType for col in config_df_schema if col.name in carray_cols
    }
    for col in carray_cols:
        unstr_udf = sf.udf(unstring_df_val, carray_types[col])
        config_df = config_df.withColumn(col, unstr_udf(sf.col(col)))
    row_dict["nsites"] = config_df.agg({"nsites": "sum"}).first()[0]
    row_dict["elements"] = sorted(
        config_df.select("elements")
        .withColumn("exploded_elements", sf.explode("elements"))
        .agg(sf.collect_set("exploded_elements").alias("exploded_elements"))
        .select("exploded_elements")
        .take(1)[0][0]
    )
    row_dict["nelements"] = len(row_dict["elements"])
    atomic_ratios_df = config_df.select("atomic_numbers").withColumn(
        "single_element", sf.explode("atomic_numbers")
    )
    total_elements = atomic_ratios_df.count()
    print(total_elements, row_dict["nsites"])
    assert total_elements == row_dict["nsites"]
    atomic_ratios_df = atomic_ratios_df.groupBy("single_element").count()
    atomic_ratios_df = atomic_ratios_df.withColumn(
        "ratio", sf.col("count") / total_elements
    )
    atomic_ratios_coll = (
        atomic_ratios_df.withColumn(
            "element",
            sf.udf(lambda x: ELEMENT_MAP[x], StringType())(sf.col("single_element")),
        )
        .select("element", "ratio")
        .collect()
    )
    row_dict["total_elements_ratios"] = [
        x[1] for x in sorted(atomic_ratios_coll, key=lambda x: x["element"])
    ]
    row_dict["nperiodic_dimensions"] = config_df.agg(
        sf.collect_set("nperiodic_dimensions")
    ).collect()[0][0]
    row_dict["dimension_types"] = (
        config_df.select("dimension_types")
        .agg(sf.collect_set("dimension_types"))
        .collect()[0][0]
    )
    nproperty_objects = prop_df.count()
    row_dict["nproperty_objects"] = nproperty_objects
    for prop in [
        "atomization_energy",
        "adsorption_energy",
        "electronic_band_gap",
        "cauchy_stress",
        "formation_energy",
        "energy",
    ]:
        row_dict[f"{prop}_count"] = (
            prop_df.select(prop).where(f"{prop} is not null").count()
        )
    row_dict[f"atomic_forces_count"] = (
        prop_df.select("atomic_forces_00")
        .filter(sf.col("atomic_forces_00") != "[]")
        .count()
    )
    prop = "energy"
    row_dict[f"{prop}_variance"] = (
        prop_df.select(prop).where(f"{prop} is not null").agg(sf.variance(prop))
    ).first()[0]
    row_dict[f"{prop}_mean"] = (
        prop_df.select(prop).where(f"{prop} is not null").agg(sf.mean(prop))
    ).first()[0]
    return row_dict

In [None]:
from collections import namedtuple
from pathlib import Path

DATASET_FP = Path(
    "/scratch/gw2338/vast/data-lake-main/spark/scripts/gw_scripts/data/s2ef_val_ood_both/"
)

DATASET_NAME = "OC20_S2EF_val_ood_both"
DATASET_ID = "DS_889euoe7akyy_0"
DOI = None

PUBLICATION_YEAR = "2024"
AUTHORS = [
    "Lowik Chanussot",
    "Abhishek Das",
    "Siddharth Goyal",
    "Thibaut Lavril",
    "Muhammed Shuaibi",
    "Morgane Riviere",
    "Kevin Tran",
    "Javier Heras-Domingo",
    "Caleb Ho",
    "Weihua Hu",
    "Aini Palizhati",
    "Anuroop Sriram",
    "Brandon Wood",
    "Junwoong Yoon",
    "Devi Parikh",
    "C. Lawrence Zitnick",
    "Zachary Ulissi",
]

LICENSE = "CC-BY-4.0"
PUBLICATION = "https://doi.org/10.1021/acscatal.0c04525"
DATA_LINK = "https://fair-chem.github.io/core/datasets/oc20.html"
DESCRIPTION = (
    "OC20_S2EF_val_ood_both is the out-of-domain validation set of the OC20 "
    "Structure to Energy and Forces (S2EF) dataset featuring both unseen catalyst "
    "composition and unseen adsorbate. Features include energy, "
    "atomic forces and data from the OC20 mappings file, including "
    "adsorbate id, materials project bulk id and miller index."
)

self_arguments = [
    "configuration_set_ids",
    "authors",
    "description",
    "data_license",
    "publication_link",
    "data_link",
    "other_links",
    "name",
    "publication_year",
    "doi",
]

# Creating the namedtuple
dataset = namedtuple("dataset", self_arguments)

# Example usage
ds1 = dataset(
    configuration_set_ids=[],
    authors=AUTHORS,
    description=DESCRIPTION,
    data_license=LICENSE,
    publication_link=PUBLICATION,
    data_link=DATA_LINK,
    other_links=None,
    name=DATASET_NAME,
    publication_year=PUBLICATION_YEAR,
    doi=None,
)

print(ds1)
import pyspark.sql.functions as sf
from colabfit.tools.utilities import unstring_df_val
from colabfit.tools.dataset import *
from colabfit.tools.schema import *

cos = spark.table("ndb.colabfit.dev.co_wip").filter(
    sf.col("dataset_ids").contains("DS_889euoe7akyy_0")
)
pos = spark.table("ndb.colabfit.dev.po_wip").filter(
    sf.col("dataset_id") == "DS_889euoe7akyy_0"
)

# carray_cols = [f.name for f in config_df_schema if f.dataType.typeName() == "array"]

# cschema_type_dict = {f.name: f.dataType for f in config_df_schema}
# cstring_cols = [f.name for f in config_df_schema if f.dataType.typeName() == "array"]

# for col in cstring_cols:
#     string_col_udf = sf.udf(unstring_df_val, cschema_type_dict[col])
#     cos = cos.withColumn(col, string_col_udf(sf.col(col)))


ro = to_spark_row(ds1, cos, pos)

In [None]:
prop_df = pos
config_df = cos
from colabfit.tools.utilities import _empty_dict_from_schema

self = ds1


row_dict = _empty_dict_from_schema(dataset_schema)
row_dict["last_modified"] = dateutil.parser.parse(
    datetime.datetime.now(tz=datetime.timezone.utc).strftime("%Y-%m-%dT%H:%M:%SZ")
)
row_dict["nconfiguration_sets"] = len(self.configuration_set_ids)
config_df = config_df.select(
    "id",
    "elements",
    "atomic_numbers",
    "nsites",
    "nperiodic_dimensions",
    "dimension_types",
    # "labels",
)
row_dict["nsites"] = config_df.agg({"nsites": "sum"}).first()[0]
row_dict["elements"] = sorted(
    config_df.select("elements")
    .withColumn("exploded_elements", sf.explode("elements"))
    .agg(sf.collect_set("exploded_elements").alias("exploded_elements"))
    .select("exploded_elements")
    .take(1)[0][0]
)
row_dict["nelements"] = len(row_dict["elements"])
atomic_ratios_df = config_df.select("atomic_numbers").withColumn(
    "single_element", sf.explode("atomic_numbers")
)
total_elements = atomic_ratios_df.count()
print(total_elements, row_dict["nsites"])
assert total_elements == row_dict["nsites"]
atomic_ratios_df = atomic_ratios_df.groupBy("single_element").count()
atomic_ratios_df = atomic_ratios_df.withColumn("ratio", sf.col("count") / total_elements)
atomic_ratios_coll = (
    atomic_ratios_df.withColumn(
        "element",
        sf.udf(lambda x: ELEMENT_MAP[x], StringType())(sf.col("single_element")),
    )
    .select("element", "ratio")
    .collect()
)
row_dict["total_elements_ratios"] = [
    x[1] for x in sorted(atomic_ratios_coll, key=lambda x: x["element"])
]
row_dict["nperiodic_dimensions"] = config_df.agg(
    sf.collect_set("nperiodic_dimensions")
).collect()[0][0]
row_dict["dimension_types"] = (
    config_df.select("dimension_types")
    .agg(sf.collect_set("dimension_types"))
    .collect()[0][0]
)

nproperty_objects = prop_df.count()
row_dict["nproperty_objects"] = nproperty_objects
for prop in [
    "atomization_energy",
    "adsorption_energy",
    "electronic_band_gap",
    "cauchy_stress",
    "formation_energy",
    "energy",
]:
    row_dict[f"{prop}_count"] = prop_df.select(prop).where(f"{prop} is not null").count()

row_dict[f"atomic_forces_count"] = (
    prop_df.select("atomic_forces_00").filter(sf.col("atomic_forces_00") != "[]").count()
)
prop = "energy"
row_dict[f"{prop}_variance"] = (
    prop_df.select(prop).where(f"{prop} is not null").agg(sf.variance(prop))
).first()[0]
row_dict[f"{prop}_mean"] = (
    prop_df.select(prop).where(f"{prop} is not null").agg(sf.mean(prop))
).first()[0]
row_dict["nconfigurations"] = config_df.count()
row_dict["authors"] = self.authors
row_dict["description"] = self.description
row_dict["license"] = self.data_license
row_dict["links"] = str(
    {
        "source-publication": self.publication_link,
        "source-data": self.data_link,
        "other": self.other_links,
    }
)
row_dict["name"] = self.name
row_dict["publication_year"] = self.publication_year
row_dict["doi"] = self.doi