<div style="display:flex; align-items:flex-start; margin-bottom:1rem;">
  <!-- Left: Book cover -->
  <img
    src="https://adb-1376134742576436.16.azuredatabricks.net/files/Images/book_cover.JPG"
    style="width:35%; margin-right:1rem; border-radius:4px; box-shadow:0 2px 6px rgba(0,0,0,0.1);"
    alt="Book Cover"/>
  <!-- Right: Metadata -->
  <div style="flex:1;">
    <!-- O'Reilly logo above title -->
    <div style="display:flex; flex-direction:column; align-items:flex-start; margin-bottom:0.75rem;">
      <img
        src="https://cdn.oreillystatic.com/images/sitewide-headers/oreilly_logo_mark_red.svg"
        style="height:2rem; margin-bottom:0.25rem;"
        alt="O‘Reilly"/>
      <span style="font-size:1.75rem; font-weight:bold; line-height:1.2;">
        AI, ML and GenAI in the Lakehouse
      </span>
    </div>
    <!-- Details, now each on its own line -->
    <div style="font-size:0.9rem; color:#555; margin-bottom:1rem; line-height:1.4;">
      <div><strong>Name:</strong> 09-02-RAG Data Preparation</div>
      <div><strong>Author:</strong> Bennie Haelen</div>
      <div><strong>Date:</strong> 7-5-2025</div>
    </div>
    <!-- Purpose -->
    <div style="font-weight:600; margin-bottom:0.75rem;">
      Purpose: This notebook demonstrates how to prepare data for a RAG solution
    </div>
    <!-- Outline -->
    <div style="margin-top:0;">
      <h3 style="margin:0 0 0.25rem;">Table of Contents</h3>
      <ol style="padding-left:1.25rem; margin:0; color:#333;">
        <li>Fetch Wikipedia articles and load them into a DataFrame</li>
        <li>Extract/clean the text content-split it into manageable chunks</li>
        <li>Calculate the embeddings</li>
        <li>Store the embeddings in a Delta file</li>
      </ol>
    </div>
  </div>
</div>


#Notebook Initialization

##Load our Common Libraries

In [0]:
 %pip install -qq -U llama-index pydantic wikipedia-api requests beautifulsoup4
 %pip install transformers[torch]
 
dbutils.library.restartPython()

##Load our Common Functions

In [0]:
%run ./9-00-Common-Code

# Fetch Wikipedia articles and load them into a DataFrame

To start, you need to fetch Wikipedia articles and load them into a DataFrame.

Steps:
1. Define a list of Wikipedia article titles to fetch
2. Create a function to fetch Wikipedia content
3. Use Spark to create a DataFrame with the articles
4. Ensure that each article is represented as a separate record in the DataFrame

In [0]:
## Import required libraries
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.schema import Document
from llama_index.core.utils import set_global_tokenizer
from transformers import AutoTokenizer
from typing import Iterator
from pyspark.sql.functions import col, udf, length, pandas_udf, explode
from pyspark.sql.types import StructType, StructField, StringType, TimestampType, LongType
import os
import pandas as pd 
import io
import requests
import re
from datetime import datetime

## Ingest the Wikipedia articles

###Initialize the Wikipedia topics

In [0]:
#This determines the maximum number of records per Arrow 
# batch during data conversion between Spark and Python.
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10)

# Set our table name
table_name = f"{CATALOG_NAME}.{SCHEMA_NAME}.lab_wikipedia_raw_text"

# Define Wikipedia articles to fetch 
WIKIPEDIA_TOPICS = [
    "Neural_network_(machine_learning)",
    "Supervised_learning",
    "Natural_language_processing",
    "Symbolic_artificial_intelligence",
    "Machine_learning",
    "Deep_learning",
    "Neural_network",
    "Neural_network_(machine_learning)",
    "Decision_tree_learning",
    "Support_vector_machine",
    "Random_forest",
    "Gradient_boosting",
    "K-means_clustering",
    "Principal_component_analysis",
    "Linear_regression",
    "Logistic_regression",
    "Naive_Bayes_classifier",
    "K-nearest_neighbors_algorithm",
    "Ensemble_learning",
    "Cross-validation_(statistics)",
    "Overfitting",
    "Backpropagation",
    "Gradient_descent",
    "Stochastic_gradient_descent",
    "Feature_selection",
    "Dimensionality_reduction",
    "Clustering",
    "Classification",
    "Regression_analysis",
    "Convolutional_neural_network",
    "Recurrent_neural_network",
    "Long_short-term_memory",
    "Generative_adversarial_network",
    "Autoencoder",
    "Multilayer_perceptron",
    "Perceptron",
    "Activation_function",
    "Batch_normalization",
    "Residual_neural_network",
    "Attention_(machine_learning)",
    "Self-attention",
    "Variational_autoencoder",
    "Generative_artificial_intelligence",
    "DALL-E",
    "Midjourney",
    "Stable_Diffusion",
    "Diffusion_model",
    "Optical_character_recognition",
    "Object_detection",
    "Facial_recognition_system"
]

## Define our ingestion function

In [0]:
def fetch_wikipedia_article(title):
    """
    Fetch Wikipedia article content using the Wikipedia API.
    
    Parameters:
        title (str): The title of the Wikipedia article (e.g., "Albert_Einstein").
    
    Returns:
        dict or None: A dictionary containing the title, full content, article URL, and fetch time,
                      or None if the article couldn't be fetched.
    """
    max_retries = 3
    base_delay = 2
    
    # Add User-Agent header to be respectful to Wikipedia
    headers = {
        'User-Agent': 'Educational Research Bot 1.0 (contact@example.com)'
    }
    
    for attempt in range(max_retries):
        try:
            url = "https://en.wikipedia.org/w/api.php"
            params = {
                'action': 'query',
                'format': 'json',
                'titles': title,  # Use title as-is (the debug function works this way)
                'prop': 'extracts',
                'exintro': False,
                'explaintext': True,
                'exsectionformat': 'plain'
            }
            
            # Make the request with headers and timeout (same as debug function)
            response = requests.get(url, params=params, headers=headers, timeout=15)
            
            if response.status_code == 200:
                # Parse JSON response (same logic as debug function)
                data = response.json()
                pages = data['query']['pages']
                page_id = list(pages.keys())[0]
                
                if page_id == '-1':
                    # Article not found
                    print(f"Article '{title}' not found")
                    return None
                
                page_data = pages[page_id]
                if 'extract' not in page_data:
                    print(f"No extract found for '{title}'")
                    if attempt < max_retries - 1:
                        time.sleep(base_delay)
                        continue
                    return None
                    
                content = page_data['extract'].strip()
                if not content:
                    print(f"Empty content for '{title}'")
                    if attempt < max_retries - 1:
                        time.sleep(base_delay)
                        continue
                    return None
                
                # Return structured result (same as your original function)
                return {
                    'title': title,
                    'content': content,
                    'url': f"https://en.wikipedia.org/wiki/{title}",
                    'fetch_time': datetime.now(),
                    'content_length': len(content)
                }
            
            elif response.status_code == 429:
                # Rate limited - exponential backoff
                wait_time = base_delay * (2 ** attempt) + random.uniform(0, 1)
                print(f"Rate limited (429) for {title}, waiting {wait_time:.1f}s before retry {attempt + 1}/{max_retries}")
                time.sleep(wait_time)
                continue
            
            else:
                print(f"HTTP {response.status_code} error for {title}")
                if attempt < max_retries - 1:
                    time.sleep(base_delay)
                    continue
                return None
                
        except requests.exceptions.Timeout:
            print(f"Timeout for {title} (attempt {attempt + 1}/{max_retries})")
            if attempt < max_retries - 1:
                time.sleep(base_delay * (attempt + 1))
                continue
            return None
            
        except Exception as e:
            print(f"Error for {title}: {type(e).__name__}: {e}")
            if attempt < max_retries - 1:
                time.sleep(base_delay)
                continue
            return None
    
    # All retries failed
    print(f"Failed to fetch {title} after {max_retries} attempts")
    return None


## Ingest the Articles

In [0]:
## Fetch Wikipedia articles
print("Fetching Wikipedia articles...")
articles_data = []

for topic in WIKIPEDIA_TOPICS:
    print(f"Fetching: {topic}")
    article_data = fetch_wikipedia_article(topic)
    if article_data and article_data['content']:
        articles_data.append(article_data)
        print(f"✓ Successfully fetched {topic} ({len(article_data['content'])} characters)")
    else:
        print(f"✗ Failed to fetch {topic}")

print(f"\nSuccessfully fetched {len(articles_data)} articles")

#Define a schema for the fetched articles

In [0]:
## Create DataFrame from fetched articles
schema = StructType([
    StructField("title", StringType(), True),
    StructField("content", StringType(), True),
    StructField("url", StringType(), True),
    StructField("fetch_time", TimestampType(), True),
    StructField("content_length", LongType(), True)
])

##Convert the articles to a Panda Dataframe

In [0]:
# Convert to pandas DataFrame first, then to Spark DataFrame
pandas_df = pd.DataFrame(articles_data)
pandas_df.columns = ['title', 'content', 'url', 'fetch_time', 'content_length']

df = spark.createDataFrame(pandas_df, schema)
df.display()

# Extract/clean the content & split it into manageable chunks

Steps:
1. Define a function to clean Wikipedia text
2. Define a function to split the text content into chunks
3. Apply the functions to create a new DataFrame with the text chunks

In [0]:
## Define a function to clean Wikipedia text
def clean_wikipedia_text(text):
    """
    Clean Wikipedia text by removing special formatting and references.
    
    Parameters:
        text (str): Raw Wikipedia text content.
        
    Returns:
        str: Cleaned text with references, section headers, and extra whitespace removed.
    """
    
    # If the input text is None or empty, return an empty string
    if not text:
        return ""
    
    # Remove citation-style references like [1], [2], etc.
    text = re.sub(r'\[\d+\]', '', text)
    
    # Replace multiple whitespace characters (spaces, tabs, newlines) with a single space
    text = re.sub(r'\s+', ' ', text)
    
    # Remove section headers enclosed in equal signs, e.g., "== Heading =="
    text = re.sub(r'={2,}.*?={2,}', '', text)
    
    # Trim leading and trailing whitespace
    text = text.strip()
    
    return text


##Define our Chunking Method

In [0]:
## Define a function to split the text content into chunks
@pandas_udf("array<string>")
def read_as_chunk(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    # Set llama2 as tokenizer
    set_global_tokenizer(
        AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
    )
    # Sentence splitter from llama_index to split on sentences
    splitter = SentenceSplitter(chunk_size=500, chunk_overlap=50)
    
    def process_batch(content_series):
        result = []
        for content in content_series:
            if content:
                # Clean the Wikipedia text
                cleaned_content = clean_wikipedia_text(content)
                
                # Create Document object for llama_index
                doc = Document(text=cleaned_content)
                
                # Split into chunks
                chunks = splitter.split_text(cleaned_content)
                
                # Filter out very short chunks
                valid_chunks = [chunk for chunk in chunks if len(chunk.strip()) > 100]
                result.append(valid_chunks)
            else:
                result.append([])
        
        return pd.Series(result)
    
    for batch in batch_iter:
        yield process_batch(batch)



In [0]:
## Apply the chunking function
df_chunks = (
    df.withColumn("chunks", read_as_chunk(col("content")))
    .select("title", "url", explode("chunks").alias("content"))
    .filter(length("content") > 100)  # Filter out very short chunks
)

df_chunks.display()
print(f"Created {df_chunks.count()} text chunks from Wikipedia articles")

# Caculate Embeddings and store in a Delta Table

##Define our get embedding UDF

In [0]:
from pyspark.sql.functions import udf, col
from pyspark.sql.types import ArrayType, FloatType

def get_embedding_udf(batch_size=10):
    """
    Returns a Spark UDF that takes a single text string and returns
    a fixed-length float array (the embedding) by calling a deployed model.
    """
    # Lazily import the MLflow Deployments API and get a client for Databricks
    import mlflow.deployments
    deploy_client = mlflow.deployments.get_deploy_client("databricks")
    
    def embed(text):
        """
        Inner function that:
        1) Handles empty or whitespace-only input by returning a zero-vector.
        2) Calls the deployed embedding endpoint with the text.
        3) Catches errors and falls back to returning a zero-vector.
        """
        try:
            # If text is None, empty, or only whitespace, return a default zero-vector
            if not text or not str(text).strip():
                return [0.0] * 1024
            
            # Call the embedding service via MLflow deployments
            response = deploy_client.predict(
                endpoint="databricks-bge-large-en",   # your deployed endpoint name
                inputs={"input": [str(text)]}         # wrap text in a list
            )
            
            # Extract and return the embedding for the first (and only) input
            return response.data[0]['embedding']
        
        except Exception as e:
            # Log or print the error, then return a zero-vector to keep things flowing
            print(f"Error embedding text: {e}")
            return [0.0] * 1024
    
    # Wrap the `embed` function as a Spark UDF that returns ArrayType(FloatType())
    return udf(embed, ArrayType(FloatType()))


# Invoke the UDF to calculate the embeddings

In [0]:
# 1) Print a log so you know when embedding computation starts
print("Computing embeddings for text chunks...")

# 2) Apply your embedding UDF to every row’s “content” and store result in a new column “embedding”
df_chunk_emd = df_chunks.withColumn(
    "embedding",
    get_embedding_udf()(col("content"))
)

# 3) In Databricks notebooks, `.display()` will render the DataFrame visually
df_chunk_emd.display()

# 4) Calling `.count()` actually runs the computation (including your UDF) 
#    and returns the total number of rows—then we log that number
print(f"Computed embeddings for {df_chunk_emd.count()} text chunks")


# Store the embeddings in a Delta File

Steps:
1. Create the Delta table schema
2. Save the DataFrame containing the computed embeddings as a Delta table

In [0]:
# Create the Delta table
embedding_table_name = f"{CATALOG_NAME}.{SCHEMA_NAME}.lab_wikipedia_text_embeddings"

drop_table_sql = f"""
DROP TABLE IF EXISTS {embedding_table_name}
"""
spark.sql(drop_table_sql)
print(f"Dropped Delta table: {embedding_table_name}")


# SQL command to create the table
create_table_sql = f"""
CREATE TABLE  {embedding_table_name} (
  id BIGINT GENERATED BY DEFAULT AS IDENTITY,
  title STRING,
  url STRING,
  content STRING,
  embedding ARRAY<FLOAT>
) TBLPROPERTIES (delta.enableChangeDataFeed = true)
"""

spark.sql(create_table_sql)
print(f"Created Delta table: {embedding_table_name}")

## Save the DataFrame as a Delta table
df_chunk_emd.write.mode("append").saveAsTable(embedding_table_name)
print(f"Saved {df_chunk_emd.count()} records to Delta table")

## Verify the data was saved correctly
verification_df = spark.sql(f"SELECT COUNT(*) as total_records FROM {embedding_table_name}")
verification_df.display()


##Query the Delta Table

In [0]:
%sql
SELECT * FROM book_ai_ml_lakehouse.rag.lab_wikipedia_text_embeddings