In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder \
    .appName("pyspark_alto_text_chunking") \
    .master("local[4]") \
    .config("spark.driver.memory", "32g") \
    .config("spark.executor.memory", "32g") \
    .config("spark.sql.orc.enableVectorizedReader", "false") \
    .config("spark.sql.parquet.columnarReaderBatchSize", "256") \
    .config("spark.sql.orc.columnarReaderBatchSize", "256") \
    .config("spark.sql.shuffle.partitions", "1024") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

25/06/26 17:26:15 WARN Utils: Your hostname, WTDDXK0DJDFN resolves to a loopback address: 127.0.0.1; using 10.21.25.239 instead (on interface en0)
25/06/26 17:26:15 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/06/26 17:26:16 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [2]:
spark.sql("SHOW TABLES").show(truncate=False)

+---------+-----------------------------------+-----------+
|namespace|tableName                          |isTemporary|
+---------+-----------------------------------+-----------+
|default  |alto_sentence                      |false      |
|default  |alto_sentence_en_doc_vectors       |false      |
|default  |alto_sentence_en_entities          |false      |
|default  |alto_sentence_en_sample_005pct     |false      |
|default  |alto_sentence_sample_001pct        |false      |
|default  |alto_sentence_with_fasttext_lang   |false      |
|default  |alto_sentence_with_lang            |false      |
|default  |df_with_spellcheck_sampled_5pct    |false      |
|default  |df_with_spellcheck_sampled_pt_05pct|false      |
|default  |iiif_manifests                     |false      |
|default  |iiif_manifests_old                 |false      |
|default  |images                             |false      |
|default  |manifests_with_images              |false      |
|default  |manifests_with_images_and_tex

In [11]:
spark.table("plain_text_renderings").printSchema()

root
 |-- id: string (nullable = true)
 |-- raw_text_url: string (nullable = true)
 |-- text: string (nullable = true)
 |-- download_status: string (nullable = true)



In [12]:
spark.table("plain_text_renderings").count()

226145

In [4]:
# Install required packages if not already installed
# !pip install langchain langchain-text-splitters

import pandas as pd
import pyspark.sql.functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType
from langchain.text_splitter import RecursiveCharacterTextSplitter
from typing import Iterator

# Define the chunking function that will be applied to each partition
def chunk_text_partition(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    """
    Function to chunk text using LangChain's RecursiveCharacterTextSplitter.
    This function will be applied to each partition of the DataFrame.
    """
    # Initialize the text splitter
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=1000,  # Maximum chunk size
        chunk_overlap=200,  # Overlap between chunks
    )
    
    for df in iterator:
        # Process each row in the partition
        chunked_rows = []
        
        for _, row in df.iterrows():
            if pd.isna(row['text']) or row['text'] is None or row['text'].strip() == '':
                # Skip empty or null text
                continue
                
            # Split the text into chunks
            chunks = text_splitter.split_text(row['text'])
            
            # Create a row for each chunk
            for chunk_idx, chunk in enumerate(chunks):
                chunked_rows.append({
                    'id': row['id'],
                    'chunk_text': chunk,
                    'chunk_index': chunk_idx,
                    'total_chunks': len(chunks),
                })
        
        # Return the chunked data as a DataFrame
        if chunked_rows:
            yield pd.DataFrame(chunked_rows)
        else:
            # Return empty DataFrame with correct schema if no chunks
            yield pd.DataFrame(columns=['id', 'chunk_text', 
                                      'chunk_index', 'total_chunks'])

# Define the return schema for mapInPandas
chunk_schema = StructType([
    StructField("id", StringType(), True),
    StructField("chunk_text", StringType(), True),
    StructField("chunk_index", IntegerType(), True),
    StructField("total_chunks", IntegerType(), True),
])

print("Text chunking functions defined successfully!")

Text chunking functions defined successfully!


In [5]:
# Apply text chunking to a sample of the data first (for testing)
# Let's start with a small sample to test the chunking
filtered_df = spark.table("plain_text_renderings").filter(
    F.col("text").isNotNull() & 
    (F.col("text") != "") & 
    (F.col("download_status") == "success")
)


In [7]:
chunks_df = filtered_df.limit(10).mapInPandas(
    chunk_text_partition,
    schema=chunk_schema
    )


chunks_df.show(truncate=False)

[Stage 5:>                                                          (0 + 1) / 1]

+--------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

In [6]:
!hadoop fs -ls /user/hive/warehouse/

Found 20 items
drwxr-xr-x   - ubuntu supergroup          0 2025-06-16 16:47 /user/hive/warehouse/alto_sentence
drwxr-xr-x   - ubuntu supergroup          0 2025-06-18 15:16 /user/hive/warehouse/alto_sentence_en_doc_vectors
drwxr-xr-x   - ubuntu supergroup          0 2025-06-20 10:37 /user/hive/warehouse/alto_sentence_en_entities
drwxr-xr-x   - ubuntu supergroup          0 2025-06-18 10:15 /user/hive/warehouse/alto_sentence_en_sample_005pct
drwxr-xr-x   - ubuntu supergroup          0 2025-06-17 08:47 /user/hive/warehouse/alto_sentence_sample_001pct
drwxr-xr-x   - ubuntu supergroup          0 2025-06-17 15:06 /user/hive/warehouse/alto_sentence_with_fasttext_lang
drwxr-xr-x   - ubuntu supergroup          0 2025-06-17 15:28 /user/hive/warehouse/alto_sentence_with_lang
drwxr-xr-x   - ubuntu supergroup          0 2025-06-16 16:08 /user/hive/warehouse/alto_sentences
drwxr-xr-x   - ubuntu supergroup          0 2025-06-02 14:17 /user/hive/warehouse/df_with_spellcheck_sampled_5pct
drwxr-xr-x   - 

In [5]:
!hadoop fs -rm -r /user/hive/warehouse/plain_text_chunks

rm: `/user/hive/warehouse/plain_text_chunks': No such file or directory


In [7]:
chunks_df = filtered_df.mapInPandas(
    chunk_text_partition,
    schema=chunk_schema
)
chunks_df.write.mode("overwrite").saveAsTable("plain_text_chunks")

                                                                                

In [14]:
chunks_df = spark.table("plain_text_chunks")

In [13]:
chunks_df.count()

74321464

In [17]:
chunks_df.sample(False, 0.1).show(truncate=False)

+--------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [35]:
chunks_df.sample(False, 0.001, seed=2432).show(truncate=False)

+--------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [19]:
chunks_df.select(
    F.col("id")).distinct().count()

                                                                                

222016

In [75]:
!curl http://ec2-18-134-162-140.eu-west-2.compute.amazonaws.com:8080/embed \
    -X POST \
    -d '{"inputs": ["Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?", "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: Explain gravity"]}' \
    -H "Content-Type: application/json"


[[0.0016940979,-0.0059597637,0.0750649,-0.006030069,-0.006592515,-0.005500072,0.029896164,0.07186328,0.050576866,0.0044211494,-0.033076145,0.00533242,0.013747476,-0.030263918,-0.003593705,-0.009107297,-0.06286415,0.010302495,-0.07398327,-0.0013540615,-0.013639313,-0.025569657,0.022021921,0.045428324,-0.0015967515,-0.0130985,-0.0068412893,-0.15030286,0.019285405,-0.00508635,-0.026824344,-0.04590424,-0.015748486,-0.008193322,0.010005048,0.00092681893,0.023060283,0.0034530933,0.03054514,0.03787857,0.012027689,-0.024228439,-0.0036261538,0.002674322,-0.028576579,-0.04711566,0.012125036,0.027451687,0.00024184499,-0.010437698,0.005416246,-0.0030366671,0.03456879,0.0018360615,0.017727863,-0.0037667651,0.094490916,-0.005792111,-0.018560715,-0.046639744,-0.014115229,-0.00695486,-0.0019645046,-0.027797807,-0.012049322,-0.014677675,0.045211997,-0.028771272,0.017370926,0.011670752,-0.023536198,-0.051009517,0.008712503,0.030155754,0.0025296547,-0.01408278,-0.005262114,0.0064789443,-0.05187482,-0.018

In [74]:
!curl http://ec2-18-134-162-140.eu-west-2.compute.amazonaws.com:8080/embed \
    -X POST \
    -d '{"inputs": ["Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?"]}' \
    -H "Content-Type: application/json"


[[0.001696309,-0.005943154,0.075010054,-0.0061644707,-0.006531532,-0.0055167153,0.030077435,0.07241903,0.050697643,0.004612557,-0.032992333,0.0053520775,0.013721615,-0.030552454,-0.0035950416,-0.009111756,-0.06283226,0.010245329,-0.07414638,-0.001352189,-0.013635247,-0.025499964,0.022304371,0.045429233,-0.0018622963,-0.01307386,-0.0068500116,-0.14984737,0.019389473,-0.0050848783,-0.02686025,-0.04599062,-0.015610901,-0.008291267,0.009975431,0.00087851804,0.023103269,0.003389919,0.030638821,0.03800164,0.012134614,-0.024226045,-0.0037380874,0.0026746893,-0.02873874,-0.04702703,0.01219939,0.027594373,0.00022367797,-0.010450451,0.0053817662,-0.00312272,0.034633312,0.0018515004,0.01794282,-0.0038838324,0.09457226,-0.0058891745,-0.018547392,-0.046854295,-0.014121064,-0.0068823993,-0.0019473141,-0.027745515,-0.012080635,-0.014650064,0.04512695,-0.028760333,0.017349044,0.011681186,-0.023427147,-0.05099993,0.008744695,0.030099027,0.0025586332,-0.014121064,-0.0052765063,0.0064289705,-0.051993154,

In [48]:
from langchain_huggingface.embeddings import HuggingFaceEndpointEmbeddings

embeddings = HuggingFaceEndpointEmbeddings(
    model="http://ec2-18-134-162-140.eu-west-2.compute.amazonaws.com:8080")

In [62]:
# text = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: What is the capital of China?"
text = "What is the capital of China?"
query_result = embeddings.embed_query(text)
query_result[:3]

[0.0021538420114666224, 0.02161426842212677, 0.07863157987594604]

In [2]:
import requests
import json
import time
from typing import List, Optional, Union

class Qwen3Embedding():

    def __init__(self, endpoint: str = "http://ec2-18-134-162-140.eu-west-2.compute.amazonaws.com:8080"):
        """
        Initialize the Qwen3Embedding class with the model endpoint.

        Args:
            endpoint (str): The URL of the embedding service endpoint.
        """
        self.endpoint = endpoint

    def get_detailed_instruct(self, task_description: str, query: str) -> str:
        if task_description is None:
            task_description = self.instruction
        return f'Instruct: {task_description}\nQuery: {query}'

    def get_embeddings(self,
            sentences: Union[List[str], str], is_query: bool = False, instruction=None,
            max_retries: int = 5) -> Optional[dict]:
        """
        Get embeddings from the embedding service endpoint with retry logic.

        Args:
            sentences (Union[List[str], str]): The input sentences or a single sentence to embed.
            is_query (bool): Whether the input is a query. If True, uses a different instruction.
            instruction (str): Custom instruction for the embedding service.
            max_retries (int): Maximum number of retries for failed requests.

        Returns:
            dict: JSON response from the embedding service, or None if all retries failed
        """

        if isinstance(sentences, str):
            sentences = [sentences]
        if is_query:
            sentences = [
                self.get_detailed_instruct(instruction, sent)
                for sent in sentences]

        # Prepare the request payload
        payload = {
            "inputs": sentences
        }

        # Set headers
        headers = {
            "Content-Type": "application/json"
        }

        # Retryable status codes and exceptions
        # Server errors and rate limiting
        retryable_status_codes = {500, 502, 503, 504, 429}

        for attempt in range(
                max_retries + 1):  # +1 because we want max_retries actual retries
            try:
                # Make the POST request
                response = requests.post(
                    self.endpoint, json=payload, headers=headers, timeout=30)

                # If successful, return the result
                if response.status_code == 200:
                    return response.json()

                # Check if error is retryable
                if response.status_code in retryable_status_codes:
                    if attempt < max_retries:
                        wait_time = 8 ** attempt  # Exponential backoff: 1s, 2s, 4s
                        print(
                            f"Attempt {attempt + 1} failed with status {response.status_code}. Retrying in {wait_time}s...")
                        time.sleep(wait_time)
                        continue
                    else:
                        print(
                            f"All {max_retries} retries failed. Last status: {response.status_code}")
                        return None
                else:
                    # Non-retryable error (e.g., 400, 401, 404)
                    print(
                        f"Non-retryable error: {response.status_code} - {response.text}")
                    return None

            except (requests.exceptions.ConnectionError,
                    requests.exceptions.Timeout,
                    requests.exceptions.RequestException) as e:
                if attempt < max_retries:
                    wait_time = 8 ** attempt  # Exponential backoff
                    print(
                        f"Attempt {attempt + 1} failed with error: {e}. Retrying in {wait_time}s...")
                    time.sleep(wait_time)
                else:
                    print(f"All {max_retries} retries failed. Last error: {e}")
                    return None

        return None

In [77]:
embeddings = Qwen3Embedding(
    endpoint="http://ec2-18-134-162-140.eu-west-2.compute.amazonaws.com:8080")

query_outputs = embeddings.get_embeddings(
    ["What is the capital of China?", "Explain gravity"],
    is_query=True,
    instruction="Given a web search query, retrieve relevant passages that answer the query")

query_outputs

[[0.0016940979,
  -0.0059597637,
  0.0750649,
  -0.006030069,
  -0.006592515,
  -0.005500072,
  0.029896164,
  0.07186328,
  0.050576866,
  0.0044211494,
  -0.033076145,
  0.00533242,
  0.013747476,
  -0.030263918,
  -0.003593705,
  -0.009107297,
  -0.06286415,
  0.010302495,
  -0.07398327,
  -0.0013540615,
  -0.013639313,
  -0.025569657,
  0.022021921,
  0.045428324,
  -0.0015967515,
  -0.0130985,
  -0.0068412893,
  -0.15030286,
  0.019285405,
  -0.00508635,
  -0.026824344,
  -0.04590424,
  -0.015748486,
  -0.008193322,
  0.010005048,
  0.00092681893,
  0.023060283,
  0.0034530933,
  0.03054514,
  0.03787857,
  0.012027689,
  -0.024228439,
  -0.0036261538,
  0.002674322,
  -0.028576579,
  -0.04711566,
  0.012125036,
  0.027451687,
  0.00024184499,
  -0.010437698,
  0.005416246,
  -0.0030366671,
  0.03456879,
  0.0018360615,
  0.017727863,
  -0.0037667651,
  0.094490916,
  -0.005792111,
  -0.018560715,
  -0.046639744,
  -0.014115229,
  -0.00695486,
  -0.0019645046,
  -0.027797807,
  -0

In [78]:
documents = [
    "The capital of China is Beijing.",
    "Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]

doc_outputs = embeddings.get_embeddings(
    documents,
    is_query=False)

doc_outputs

[[0.0020910008,
  0.008498907,
  0.07896339,
  0.019156266,
  -0.005545087,
  0.013524055,
  0.03732324,
  0.04033608,
  0.046361763,
  -0.04056092,
  -0.00509822,
  0.013703926,
  0.0030549972,
  -0.029139109,
  0.0003597421,
  -0.009285843,
  -0.062954865,
  0.019043846,
  -0.079143256,
  -0.005840188,
  -0.009611859,
  0.0044237035,
  0.010758537,
  0.02990356,
  0.0020769485,
  -0.01726762,
  -0.004612006,
  -0.13463348,
  0.027205495,
  -0.013040651,
  -0.056749314,
  -0.039189402,
  -0.059672218,
  0.0068688253,
  0.0065371883,
  -0.005457962,
  0.0075545837,
  -0.00095697015,
  -0.0141648445,
  0.00060003856,
  0.012838296,
  -0.008319036,
  -0.0035861789,
  0.0142322965,
  -0.08292055,
  -0.0315224,
  -0.0017649846,
  0.002780975,
  -0.006087511,
  -0.011101416,
  0.0012963362,
  0.015536361,
  0.010106504,
  -0.008240342,
  0.010117746,
  0.020853799,
  0.11322882,
  -0.0042101066,
  -0.0057783574,
  0.028869303,
  -0.0138725545,
  -0.008813681,
  -0.017795991,
  -0.015266554,

In [3]:
import numpy as np
from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType, IntegerType
from typing import Iterator
import pandas as pd

# Define the embedding function that will be applied to each partition
def embed_text_partition(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    """
    Function to generate embeddings for text chunks using Qwen3Embedding.
    This function will be applied to each partition of the DataFrame.
    
    Args:
        iterator: Iterator of pandas DataFrames (one per partition)
        
    Returns:
        Iterator of pandas DataFrames with embeddings added
    """
    # Initialize the embedding model
    embeddings_model = Qwen3Embedding(
        endpoint="http://ec2-18-134-162-140.eu-west-2.compute.amazonaws.com:8080"
    )
    
    # Process each DataFrame in the partition
    for df in iterator:
        # Skip empty DataFrames
        if df.empty:
            yield df
            continue
            
        # Process in batches of 32 for efficiency
        batch_size = 32
        result_dfs = []
        
        for i in range(0, len(df), batch_size):
            batch = df.iloc[i:i+batch_size]
            texts = batch['chunk_text'].tolist()
            
            # Get embeddings for the batch
            embedding_results = embeddings_model.get_embeddings(texts, is_query=False)
            
            # Add embeddings to the batch
            embedding_vectors = embedding_results
            batch = batch.copy()
            batch['embedding'] = embedding_vectors
            result_dfs.append(batch)
        
        # Combine all batches and yield
        if result_dfs:
            yield pd.concat(result_dfs)
        else:
            # Return empty DataFrame with correct schema if all embedding failed
            empty_df = df.copy()
            empty_df['embedding'] = None
            yield empty_df

# Define the output schema including embeddings
embedding_schema = StructType([
    StructField("id", StringType(), True),
    StructField("chunk_text", StringType(), True),
    StructField("chunk_index", IntegerType(), True),
    StructField("total_chunks", IntegerType(), True),
    StructField("embedding", ArrayType(FloatType()), True)
])

In [4]:
# Test the embedding function on a small sample
chunks_df = spark.table("plain_text_chunks")
sample_chunks_df = chunks_df.sample(False, 0.0001, seed=42)
sample_chunks_df.count()

                                                                                

7293

In [9]:
import pyspark.sql.functions as F

# Apply the embedding function to the sample
sample_embedded_df = sample_chunks_df.mapInPandas(
    embed_text_partition,
    schema=embedding_schema
)
sample_embedded_df

DataFrame[id: string, chunk_text: string, chunk_index: int, total_chunks: int, embedding: array<float>]

In [6]:
sample_embedded_df.write.mode("overwrite").saveAsTable("plain_text_chunks_with_embeddings")

                                                                                

In [7]:
sample_embedded_df = spark.table("plain_text_chunks_with_embeddings")

In [8]:
sample_embedded_df.show(truncate=False)


+--------+------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [9]:
sample_embedded_df_pd = sample_embedded_df.toPandas()

In [45]:
# See src/wc_simd/embed.py for embedding code

# Index Embeddings

This is for experimentation. Actual indexing code in `src/wc_simd/index.py`

## Peak at embeddings

In [6]:
i = 6

table_name = f"plain_text_chunks_with_embeddings_{i}"

spark.table(table_name).limit(1000).toPandas().to_csv(f"../data/tmp/{table_name}.csv", index=False)

## Prep metadata

In [49]:
spark.sql("SELECT contributors FROM works").printSchema()

root
 |-- contributors: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- agent: struct (nullable = true)
 |    |    |    |-- id: string (nullable = true)
 |    |    |    |-- identifiers: array (nullable = true)
 |    |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |    |-- identifierType: struct (nullable = true)
 |    |    |    |    |    |    |-- id: string (nullable = true)
 |    |    |    |    |    |    |-- label: string (nullable = true)
 |    |    |    |    |    |    |-- type: string (nullable = true)
 |    |    |    |    |    |-- type: string (nullable = true)
 |    |    |    |    |    |-- value: string (nullable = true)
 |    |    |    |-- label: string (nullable = true)
 |    |    |    |-- type: string (nullable = true)
 |    |    |-- primary: boolean (nullable = true)
 |    |    |-- roles: array (nullable = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- label: string (

In [None]:
import pyspark.sql.functions as F

(
    spark .table("works") .select(
        "id",
        F.explode("contributors").alias("contributors"),
        F.col("production").getItem(0).getField("dates").getItem(0).getField("label").alias("production_date"),
    ) .select(
        "id",
        "contributors.primary",
        F.col("contributors.agent.label").alias("contributor"),
        "production_date") .where(
        F.col("primary") == True) .drop("primary") .show(
        50,
        truncate=False))

+--------+-------------------------------------------------------------------------------------------------+-----------------------+
|id      |contributor                                                                                      |production_date        |
+--------+-------------------------------------------------------------------------------------------------+-----------------------+
|hg7q3wz8|Scot, Michael, approximately 1175-approximately 1234                                             |1549                   |
|na3nt32r|Smith, Alexander                                                                                 |NULL                   |
|fs2hg2wx|Lackington, James                                                                                |NULL                   |
|svxuqzm7|South Australia. Central Board of Health.                                                        |[1945]                 |
|z5he9j2w|Helvétius, Jean-Adrien, 1662-1727                          

## Indexing Function

In [9]:
import numpy as np
from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType, IntegerType
from typing import Iterator
import pandas as pd
from langchain_elasticsearch import ElasticsearchStore, DenseVectorStrategy
import dotenv
import os
import time
import logging

# Load environment variables from .env file at the driver level
dotenv.load_dotenv()

# Get environment variables on the driver node
ES_CLOUD_ID = os.environ.get("ES_CLOUD_ID")
ES_USERNAME = os.environ.get("ES_USERNAME") 
ES_PASSWORD = os.environ.get("ES_PASSWORD")

# Define the embedding function that will be applied to each partition
def index_embeddings_partition(
        iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:

    # Initialize ElasticsearchStore using the pre-loaded environment variables
    db = ElasticsearchStore(
        # replace with your cloud ID
        es_cloud_id=ES_CLOUD_ID,
        index_name="vectorsearch",
        embedding=None,
        es_user=ES_USERNAME,
        es_password=ES_PASSWORD,
        # replace with your password
        strategy=DenseVectorStrategy(),  # strategy for dense vector search
    )

    # Process each DataFrame in the partition
    for df in iterator:
        # Skip empty DataFrames
        if df.empty:
            yield df
            continue

        # Combine "id" and "chunk_idx" to create a unique document ID
        df['es_doc_id'] = df['id'] + "_" + df['chunk_index'].astype(str)

        # Create list of texts
        texts = df['chunk_text'].tolist()
        # Create lists of embeddings from embedding column
        embeddings = df['embedding'].tolist()
        
        # Add embeddings with retry logic
        max_retries = 9999
        for attempt in range(max_retries + 1):
            try:
                db.add_embeddings(text_embeddings=zip(texts, embeddings),
                                  metadatas=df[['contributor', 'date']].to_dict(orient='records'),
                                  ids=df['es_doc_id'].tolist())
                break  # Success, exit retry loop
            except Exception as e:
                if attempt < max_retries:
                    wait_time = min(2 ** min(attempt, 10), 60)  # Exponential backoff, max 1 minutes
                    print(f"Attempt {attempt + 1} failed with error: {e}. Retrying in {wait_time}s...")
                    time.sleep(wait_time)
                else:
                    print(f"All {max_retries} retries failed. Last error: {e}")
                    raise e  # Re-raise the exception after all retries

        # Return just the document ID
        yield df[['es_doc_id']]


        
# Define the output schema including embeddings
index_embedding_schema = StructType([
    StructField("es_doc_id", StringType(), True),
])

In [4]:
import pyspark.sql.functions as F

work_contributor_w_date = spark.table("works").select(
    "id", F.explode("contributors").alias("contributors"),
    F.col("production").getItem(0).getField("dates").getItem(0).
    getField("label").alias("date"),).select(
    "id", "contributors.primary", F.col("contributors.agent.label").alias(
        "contributor"),
    "date").where(
    F.col("primary") == True).drop("primary")
work_contributor_w_date

DataFrame[id: string, contributor: string, date: string]

In [None]:
# df_idx = 0
# table_name = f"plain_text_chunks_with_embeddings_{df_idx}"
# df_to_index = spark.table(table_name).join(
#     work_contributor_w_date, on="id", how="left")
# df_to_index

DataFrame[id: string, chunk_text: string, chunk_index: int, total_chunks: int, embedding: array<float>, contributor: string, date: string]

In [None]:
# df_to_index.count()

                                                                                

269315

In [None]:
# for df_idx in range(1, 287):
#     table_name = f"plain_text_chunks_with_embeddings_{df_idx}"
#     df_to_index = spark.table(table_name).join(
#         work_contributor_w_date, on="id", how="left")

#     indexed_df = (
#         df_to_index
#         # .sample(False, 0.001, seed=42)
#         .mapInPandas(
#             index_embeddings_partition,
#             schema=index_embedding_schema
#         )
#     )
#     indexed_df.write.mode("overwrite").saveAsTable(
#         # f"{table_name}_indexed_sample"
#         table_name + "_indexed"
#     )

Attempt 1 failed with error: Connection timed out. Retrying in 1s... + 16) / 23]
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
Attempt 1 failed with error: Connection timed out. Retrying in 1s... + 16) / 23]
Attempt 1 failed with error: Connection timed out. Retrying in 1s...
          

In [91]:
spark.table(f"{table_name}_indexed_sample").show(truncate=False)

+-------------+
|es_doc_id    |
+-------------+
|es58wtq4_1136|
|es58wtq4_225 |
|es5vqmyy_886 |
|es5vqmyy_452 |
|eujegnnd_164 |
|ewyrtbke_17  |
|f4k6mvt4_119 |
|f53zme34_41  |
|f5594z6t_601 |
|f5594z6t_318 |
|f5594z6t_429 |
|f5594z6t_115 |
|f8yd4388_108 |
|fabujtup_15  |
|fcmrnpf2_1265|
|fdpva6s3_324 |
|fdtqn3ww_872 |
|erxhfqpn_369 |
|eskx4s96_618 |
|eskx4s96_579 |
+-------------+
only showing top 20 rows



## Completion Tables

In [1]:
indexed_tables_df = spark.sql("show tables").where(
    "tableName RLIKE 'plain_text_chunks_with_embeddings_[0-9]+_indexed$'")
indexed_tables_df.show(
    500, truncate=False)
indexed_tables_df.count()

NameError: name 'spark' is not defined

In [None]:
for df_idx in range(0, 229):
    table_name = f"plain_text_chunks_with_embeddings_{df_idx}"
    indexed_table_name = table_name + "_indexed"
    # result = spark.sql(f"""DROP TABLE IF EXISTS {indexed_table_name}""").collect()
    print(f"Dropped table {indexed_table_name}: {result}")

Dropped table plain_text_chunks_with_embeddings_0_indexed: []
Dropped table plain_text_chunks_with_embeddings_1_indexed: []
Dropped table plain_text_chunks_with_embeddings_2_indexed: []
Dropped table plain_text_chunks_with_embeddings_3_indexed: []
Dropped table plain_text_chunks_with_embeddings_4_indexed: []
Dropped table plain_text_chunks_with_embeddings_5_indexed: []
Dropped table plain_text_chunks_with_embeddings_6_indexed: []
Dropped table plain_text_chunks_with_embeddings_7_indexed: []
Dropped table plain_text_chunks_with_embeddings_8_indexed: []
Dropped table plain_text_chunks_with_embeddings_9_indexed: []
Dropped table plain_text_chunks_with_embeddings_10_indexed: []
Dropped table plain_text_chunks_with_embeddings_11_indexed: []
Dropped table plain_text_chunks_with_embeddings_12_indexed: []
Dropped table plain_text_chunks_with_embeddings_13_indexed: []
Dropped table plain_text_chunks_with_embeddings_14_indexed: []
Dropped table plain_text_chunks_with_embeddings_15_indexed: []
Dr

## Test Index

In [9]:
from sentence_transformers import SentenceTransformer

emeddings_model = SentenceTransformer("Qwen/Qwen3-Embedding-0.6B")

In [3]:
import numpy as np
from pyspark.sql.types import ArrayType, FloatType, StructType, StructField, StringType, IntegerType
from typing import Iterator
import pandas as pd
from langchain_elasticsearch import ElasticsearchStore, DenseVectorStrategy
import dotenv
import os
import time
import logging

# Load environment variables from .env file at the driver level
dotenv.load_dotenv()

# Get environment variables on the driver node
ES_CLOUD_ID = os.environ.get("ES_CLOUD_ID")
ES_USERNAME = os.environ.get("ES_USERNAME")
ES_PASSWORD = os.environ.get("ES_PASSWORD")

# Define the embedding function that will be applied to each partition

# Initialize ElasticsearchStore using the pre-loaded environment variables
db = ElasticsearchStore(
    # replace with your cloud ID
    es_cloud_id=ES_CLOUD_ID,
    index_name="vectorsearch_sharded",
    embedding=None,
    es_user=ES_USERNAME,
    es_password=ES_PASSWORD,
    # replace with your password
    strategy=DenseVectorStrategy(),  # strategy for dense vector search
)


In [None]:

# Now you can perform similarity search
query = "Flu remedies?"
query_embedding = emeddings_model.encode(query, prompt_name="query")
results = db.similarity_search_by_vector_with_relevance_scores(query_embedding, k=3)

for doc in results:
    print(doc)

In [17]:
es = db._store.client

doc_ids = ["bs3v5gcz_37", "n2dy86uj_55", "yhjnmwn3_221"]

response = es.mget(
    index="vectorsearch_sharded",
    body={
        "ids": doc_ids
    }
)

# 4. Inspect the returned docs
for doc in response["docs"]:
    if doc.get("found", False):
        _id, chunk_index = doc["_id"].split("_")
        url = f"https://wellcomecollection.org/works/{_id}"
        print(_id)
        print(chunk_index)
        print(url)
        print(doc["_source"]["text"])
        print(doc["_source"]["metadata"])

bs3v5gcz
37
https://wellcomecollection.org/works/bs3v5gcz
most important step of completely isolating the first few casesâno matter what their agesâwhenever the health and interest of the district are again threatened. This provision is a very positive gain to the District for it almost furnishes a guarantee that we shall be able to successfully cope with the danger of Small-pox spread when it once again presents itself. The actual cost of the outbreak was between jQ60 and Â£70. I made an unsuccessful attempt to induce a few neighbouring Authorities to combine, with the view of providing a small permanent structure for Small-pox isolation. The necessity for such provision is growing and will continue to do so ; for with the increased facilities of communication with the Metropolis, and with the growth of the number of residents who daily visit the City, the risks of imported infection to these outlying districts grows greater year by year. These 9 cases of Small-pox, like each of t

In [12]:
response

ObjectApiResponse({'docs': [{'_index': 'vectorsearch_sharded', '_id': 'bs3v5gcz_37', '_version': 1, '_seq_no': 118027, '_primary_term': 1, 'found': True, '_source': {'text': 'most important step of completely isolating the first few casesâ\x80\x94no matter what their agesâ\x80\x94whenever the health and interest of the district are again threatened. This provision is a very positive gain to the District for it almost furnishes a guarantee that we shall be able to successfully cope with the danger of Small-pox spread when it once again presents itself. The actual cost of the outbreak was between jQ60 and Â£70. I made an unsuccessful attempt to induce a few neighbouring Authorities to combine, with the view of providing a small permanent structure for Small-pox isolation. The necessity for such provision is growing and will continue to do so ; for with the increased facilities of communication with the Metropolis, and with the growth of the number of residents who daily visit the City, t