In [None]:
from pyspark.sql.types import *
from delta.tables import *
from pyspark.sql.functions import *
from datetime import datetime
from transformers import AutoTokenizer, AutoModel
import numpy as np
import pandas as pd
import logging
import pyspark
import torch

# Set Spark Config

def trim_whitespace(df: DataFrame) -> DataFrame:
    for column in df.columns:
        # Check if column type is String and trim if True
        df = df.withColumn(column, trim(col(column))) if df.schema[column].dataType == "StringType" else df
    return df
    
def load(url: str) -> DataFrame:
    return trim_whitespace(spark.read.parquet(url))

DT = load("YOUR_DATA_PATH")

# Configure Logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Model and Tokenizer Initialization (Singleton for PySpark workers)
model_name = "intfloat/multilingual-e5-large"
device = torch.device("cuda")

class ModelWrapper:
    """Singleton class to load the model and tokenizer."""
    model = None
    tokenizer = None

    @staticmethod
    def get_model_and_tokenizer():
        if ModelWrapper.model is None or ModelWrapper.tokenizer is None:
            logger.info("Loading model and tokenizer...")
            ModelWrapper.tokenizer = AutoTokenizer.from_pretrained(model_name)
            ModelWrapper.model = AutoModel.from_pretrained(model_name).to(device).half()
            logger.info("Model and tokenizer loaded successfully.")
        return ModelWrapper.tokenizer, ModelWrapper.model

# UDF to Generate Embeddings with progress monitoring using logging
@pandas_udf("array<float>")
def generate_embeddings_udf(full_name_series):
    tokenizer, model = ModelWrapper.get_model_and_tokenizer()

    embeddings_list = []
    batch_size = 512  # Adjust for optimal GPU utilization
    full_name_list = full_name_series.tolist()
    total_batches = len(full_name_list) // batch_size + (1 if len(full_name_list) % batch_size > 0 else 0)

    start_time = datetime.now()
    logger.info(f"Starting embedding generation for {len(full_name_list)} texts in {total_batches} batches.")

    for batch_idx in range(0, len(full_name_list), batch_size):
        batch_texts = full_name_list[batch_idx:batch_idx + batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)

        # Perform inference and compute embeddings
        with torch.no_grad():
            outputs = model(**inputs)
            embeddings = outputs.last_hidden_state.mean(dim=1).cpu().numpy()

        embeddings_list.extend(embeddings)

        # Log progress
        current_batch = batch_idx // batch_size + 1
        elapsed_time = (datetime.now() - start_time).total_seconds()
        estimated_time_remaining = (elapsed_time / current_batch) * (total_batches - current_batch)

        logger.info(f"Processed batch {current_batch}/{total_batches}. "
                    f"Elapsed time: {elapsed_time:.2f}s, Estimated remaining: {estimated_time_remaining:.2f}s.")

    logger.info("Completed embedding generation.")

    # Convert the list of arrays to a Pandas Series for compatibility
    return pd.Series([np.array(embedding, dtype=np.float32).tolist() for embedding in embeddings_list])

# Add Embeddings Column with Monitoring
logger.info("Adding embeddings column to the DataFrame...")
df = DT.withColumn("embedding", generate_embeddings_udf(col("full_name")))
logger.info("Embedding generation completed. DataFrame updated.")

# Show the updated DataFrame
df.show()