## **AI Search with RefinedWeb Dataset and OLMo 2 Augmentation**

# Table of Contents
- [0. Setup](#0-setup)  
- [1. Data Loading](#1-data-loading)  
- [2. Data Exploration](#2-data-exploration)  
- [3. Data Preprocessing](#3-data-preprocessing)  
    - [3.1 Data Cleaning](#31-data-cleaning)   
    - [3.2 Feature Engineering](#32-feature-engineering)  
- [4. Brand Sentiment Analysis](#4-brand-sentiment-analysis)  
    - [4.1 Lexicon-Based](#41-lexicon-based)  
    - [4.2 Transformer-based](#42-transformer-based)
- [5. Brand-Specific Analysis](#5-brand-specific-analysis)  

# 0. Setup

In [1]:
from google.colab import drive
drive.mount('/content/drive')

NotImplementedError: Mounting drive is unsupported in this environment. Use PyDrive instead. See examples at https://colab.research.google.com/notebooks/io.ipynb#scrollTo=7taylj9wpsA2.

In [None]:
import os

folder_path = '/content/drive/My Drive/Digitas'

# Example: list all files in the folder
files = os.listdir(folder_path)
print(files)

In [None]:
!pip install -r "/content/drive/My Drive/Digitas/requirements.txt"

In [None]:
import pandas as pd
import re
import os
import time
import torch
import nltk
import spacy
import glob
import subprocess

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import size, col, udf, pandas_udf, PandasUDFType, arrays_zip, array_contains, substring, length, explode, first, avg, when, monotonically_increasing_id
from pyspark.sql.functions import to_date, dayofmonth, month, year
from pyspark.sql.types import DoubleType, IntegerType, StringType, FloatType, BooleanType, ArrayType, StructType, StructField
from pyspark.ml.feature import Tokenizer
from huggingface_hub import HfApi
# from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from collections import defaultdict
from emoji import demojize
from urllib.parse import urlparse
from nltk.tokenize import word_tokenize
from nltk.sentiment.vader import SentimentIntensityAnalyzer
from datasets import load_dataset

nltk.download('wordnet')
nltk.download('punkt')
nltk.download('vader_lexicon')

# 1. Data Loading


## 1.1 Generating Paths Files
This section of the code generates `refinedweb_paths.txt` with URLs to Parquet files for each dataset. If the files already exist, the code verifies them.

In [None]:
def generate_paths_file(dataset_id, output_file, directory_prefix=None):
    api = HfApi()

    # List all files in the dataset repository
    files = api.list_repo_files(repo_id=dataset_id, repo_type="dataset")

    # Filter for Parquet files
    parquet_urls = [
        f"https://huggingface.co/datasets/{dataset_id}/resolve/main/{f}"
        for f in files if f.endswith(".parquet") and (directory_prefix is None or f.startswith(directory_prefix))
    ]
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # If file exists, verify contents
    if os.path.exists(output_file):
        print(f"{output_file} already exists. Verifying contents...")
        with open(output_file, "r") as f:
            existing_urls = set(line.strip() for line in f if line.strip())
        if set(parquet_urls).issubset(existing_urls):
            print(f"{output_file} is valid with {len(existing_urls)} URLs.")
            return
        else:
            print(f"Updating {output_file} with new URLs...")
    # Save URLs to file
    with open(output_file, "w") as f:
        for url in parquet_urls:
            f.write(url + "\n")
    print(f"Saved {len(parquet_urls)} URLs to {output_file}")


# Generate paths for RefinedWeb
generate_paths_file("tiiuae/falcon-refinedweb", "data/refinedweb/refinedweb_paths.txt", directory_prefix="data/")

## 1.2 Download Dataset
The following code runs the `download_parquet.sh` script to download Parquet files for both datasets. The files will be saved to `data/refinedweb/`.

In [None]:
def download_dataset(dataset_name, paths_file):
    paths_file = os.path.abspath(paths_file)
    if not os.path.exists(paths_file):
        print(f"Error: {paths_file} does not exist. Skipping download for {dataset_name}.")
        return False
    if not os.path.getsize(paths_file) > 0:
        print(f"Error: {paths_file} is empty. Skipping download for {dataset_name}.")
        return False
    print(f"Downloading {dataset_name} dataset...")
    try:
        log_file = f"data/{dataset_name}/download.log"
        os.makedirs(os.path.dirname(log_file), exist_ok=True)
        with open(log_file, "w") as f:
            process = subprocess.Popen(
                ["bash", "download_parquet.sh", dataset_name, paths_file],
                stdout=f,
                stderr=f,
                text=True
            )
            process.wait(timeout=3600)  # 1 hour timeout
        with open(log_file, "r") as f:
            print(f.read())
        return process.returncode == 0
    except Exception as e:
        print(f"Error running download script for {dataset_name}: {e}")
        return False

# # Downloads RefinedWeb parquet files
# download_dataset("refinedweb", "data/refinedweb/refinedweb_paths.txt")

We will consequently verify the downloaded Parquet files to ensure they are accessible and readable:

In [None]:
def verify_parquet_files(directory):
    # Find all Parquet files in the directory
    parquet_files = glob.glob(f"{directory}/*.parquet")
    print(f"Number of Parquet files in {directory}: {len(parquet_files)}")

    if parquet_files:
        # Try reading the first Parquet file as a sample using Spark
        try:
            # Read the first Parquet file into a Spark DataFrame
            df_sample = spark.read.parquet(parquet_files[0])
            print(f"Number of rows per Parquet file: {df_sample.count()}")

            # Check for null values in all columns
            print("\nNull value counts per column:")
            from pyspark.sql.functions import col
            for column in df_sample.columns:
                null_count = df_sample.filter(col(column).isNull()).count()
                print(f"{column}: {null_count} nulls")

        except Exception as e:
            print(f"Error reading {parquet_files[0]}: {e}")
    else:
        print(f"No Parquet files found in {directory}")

verify_parquet_files("data/refinedweb")

## 1.3 Load and Filter Documents with Brand Mentions

PySpark is utilised to process the large-scale dataset. The Spark Session was initialised below:

In [None]:
spark = SparkSession.builder \
    .appName("AI Search Pipeline") \
    .master("local[*]") \
    .config("spark.driver.memory", "12g") \
    .config("spark.executor.memory", "6g") \
    .config("spark.sql.shuffle.partitions", "2") \
    .config("spark.memory.fraction", "0.8") \
    .config("spark.memory.storageFraction", "0.2") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

# For more detailed logging
# sc = spark.sparkContext
# sc.setLogLevel("INFO")

In [None]:
import duckdb
import os

def load_and_filter_data(input_dir, output_file):
    # Create output directory if it doesn't exist
    os.makedirs(os.path.dirname(output_file), exist_ok=True)

    # Connect to DuckDB
    con = duckdb.connect()

    # Create filtered table
    brands = ["HSBC", "Barclays", "Lloyds", "NatWest"]
    brand_conditions = ' AND '.join([f"LOWER(content) LIKE '%{brand.lower()}%'" for brand in brands])  # All brands must be mentioned
    query = f"""
        CREATE OR REPLACE TABLE filtered_brands AS
        SELECT *
        FROM '{os.path.join(input_dir, "*.parquet")}'
        WHERE {brand_conditions}
    """
    con.execute(query)

    # Export to Parquet
    con.execute(f"""
        COPY filtered_brands TO '{output_file}' (FORMAT PARQUET)
    """)

    print(f"Filtered + Exported to {output_file}")

    # Close the connection
    con.close()

#load_and_filter_data("data/refinedweb", "data/filtered_data/brands_articles.parquet")

In [None]:
df = spark.read.parquet("/content/drive/My Drive/Digitas/data/filtered_data/brands_articles.parquet")
df.show()

In [None]:
print(f"Number of documents mentioning all 4 brands together: {df.count()}")

# 2. Data Pre-processing

This segment cleans and transforms the raw dataset to make it suitable for sentiment analysis, removing noise and extracting useful features.

In [None]:
# Filtering for non-null/non-empty text
df = df.filter(col("content").isNotNull() & (col("content") != ""))

In [None]:
# Dropping irrelevant columns
df = df.drop("dump", "segment", "image_urls")

df.show()

# 2.1 Filter Valid URLs

In [None]:
import requests
from pyspark.sql.types import BooleanType
import time

# Function to check if a URL is valid
def is_url_valid(url):
    if not url:
        return False
    try:
        # Send a HEAD request to minimize data transfer
        response = requests.head(url, timeout=5, allow_redirects=True)
        # Consider 200 as valid; you can adjust to include other codes (e.g., 301, 302)
        return response.status_code == 200
    except requests.RequestException:
        # Handle connection errors, timeouts, etc.
        return False

is_url_valid_udf = udf(is_url_valid, BooleanType())

df = df.withColumn("is_url_valid", is_url_valid_udf(col("url")))
# df.select("url", "is_url_valid").show(truncate=False)

In [None]:
# Filter rows where URL is valid
df = df.filter(col("is_url_valid") == True)

# print(f"Number of documents with valid URLs mentioning all 4 brands: {df.count()}")
# df.select("url", "is_url_valid").show(truncate=False)

In [None]:
df = df.drop("is_url_valid")

In [None]:
# # Write the DataFrame as Parquet
# output_path = "data/filtered_data/valid_articles.parquet"

# df.write.mode("overwrite").parquet(output_path)

In [None]:
df = spark.read.parquet("/content/drive/My Drive/Digitas/data/filtered_data/valid_articles.parquet")

## 2.2 Deduplication by Date

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.types import StringType
import re
from datetime import datetime
from dateutil import parser

# URL date patterns
url_date_patterns = [
    r'/(\d{4})/([a-z]{3})/(\d{2})/',
    r'/(\d{4})/(\d{2})/(\d{2})/',
    r'/(\d{4})-(\d{2})-(\d{2})/',
    r'/(\d{4})\.(\d{2})\.(\d{2})/',
    r'/(\d{4})_(\d{2})_(\d{2})/',
    r'(\d{4})/(\d{2})(\d{2})/',
    r'/(\d{4})/(\d{2})/',
    r'/(\d{4})-(\d{2})/',
    r'(\d{4})[-_\.](\d{2})[-_\.](\d{2})',
    r'post[-_]?(\d{4})[-_](\d{2})[-_](\d{2})',
    r'(\d{8})',
]

month_abbrev_map = {
    "jan": "01", "feb": "02", "mar": "03", "apr": "04",
    "may": "05", "jun": "06", "jul": "07", "aug": "08",
    "sep": "09", "oct": "10", "nov": "11", "dec": "12"
}

# TEXT patterns
text_date_patterns = [
    r'(?:Published|Posted|Updated|Created|First published|Last updated)[:\s]*([A-Za-z]{3,9}[\s\-.,]?\d{1,2}(?:st|nd|rd|th)?[\s,]+(?:\d{4}))',
    r'(?:Published|Posted|Updated|Date|Created)[:\s]*([\d]{1,2}[\s\-/.][A-Za-z]{3,9}[\s\-/,]+[\d]{4})',
    r'(?:Published|Posted|Updated|Date)[:\s]*([\d]{4}[-/\.][\d]{1,2}[-/\.][\d]{1,2})',
    r'([A-Za-z]{3,9}\s\d{4})',
    r'(\d{4}/\d{2}/\d{2})',
    r'(\d{2}[-/\.]\d{2}[-/\.]\d{4})',
    r'(\d{4}[-/\.]\d{2})',
    r'/(\d{4})/(\d{1,2})/(\d{1,2})/',
    r'/(\d{4})/(\d{1,2})/'
]

# Combining extraction from 'url', revert to 'content' if date not found in url
@F.udf(StringType())
def extract_combined_date_udf(url, text):
    def try_url_date(url_str):
        if not url_str or not isinstance(url_str, str):
            return None
        if not re.search(r'https?://', url_str):
            return None
        for pattern in url_date_patterns:
            match = re.search(pattern, url_str, flags=re.IGNORECASE)
            if match:
                try:
                    parts = match.groups()
                    if len(parts) == 3:
                        year, month, day = parts
                        if month.isalpha():
                            month = month_abbrev_map.get(month.lower())
                            if not month:
                                continue
                    elif len(parts) == 2:
                        year, month = parts
                        day = "01"
                    elif len(parts) == 1:
                        val = parts[0]
                        if len(val) == 8:
                            year, month, day = val[:4], val[4:6], val[6:]
                        elif len(val) == 4:
                            year, month, day = val, "01", "01"
                        else:
                            continue
                    else:
                        continue
                    y, m, d = int(year), int(month), int(day)
                    if y < 1900 or y > datetime.now().year + 1:
                        continue
                    if m < 1 or m > 12:
                        continue
                    if d < 1 or d > 31:
                        continue
                    return f"{y:04d}-{m:02d}-{d:02d}"
                except:
                    continue
        return None

    def try_text_date(text_str):
        if not text_str or not isinstance(text_str, str):
            return None
        text_str = re.sub(r'\s+', ' ', text_str).strip()
        current_year = datetime.now().year
        probable_dates = []
        for pattern in text_date_patterns:
            matches = re.findall(pattern, text_str)
            for match in matches:
                raw = ' '.join(match) if isinstance(match, tuple) else match
                try:
                    parsed = parser.parse(raw, fuzzy=True)
                    year = parsed.year
                    if 1900 <= year <= current_year + 1:
                        probable_dates.append(parsed.date().isoformat())
                except:
                    continue
        return min(probable_dates) if probable_dates else None

    # Try extracting from URL first
    date_from_url = try_url_date(url)
    if date_from_url:
        return date_from_url

    # If failed, try from text
    return try_text_date(text)

df = df.withColumn("published_date", extract_combined_date_udf(F.col("url"), F.col("content")))

In [None]:
# Spark df
# df.select("content","url", "timestamp", "published_date").show(truncate=False)

# # Pandas df for readability
pdf = df.select("content", "url", "timestamp", "published_date").toPandas()
# Truncate content
pdf["content"] = pdf["content"].str[:50] + "..."

# Display full URL
with pd.option_context("display.max_colwidth", None):
    display(pdf.head(20))

In [None]:
df = df.dropDuplicates(["published_date"])

df.show(50)
print(f"Number of documents mentioning all 4 brands together after deduplication: {df.count()}")

# 2.3 Data Cleaning

In [None]:
# Tokenisation - clean text
def clean_text(text):
    if not isinstance(text, str):
        return ""

    text = re.sub(r'[\n\r]', ' ', text)     # removes newlines and carriage returns
    text = re.sub(r'[^\w\s]', '', text.lower())     # removes punctuation and lowercase
    text = re.sub(r'\s+', ' ', text).strip() # whitespace
    return text

clean_udf = udf(clean_text, StringType())
df = df.withColumn("clean_text", clean_udf(col("content")))

# 3. Extracting Topics and Implicit Rankings

In [None]:
# Convert Spark DataFrame to Pandas DataFrame
df = df.toPandas()

# Stop Spark session to free up resources
spark.stop()

In [None]:
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
# Define topics
topics = ["Sustainability", "Financial Resilience and Performance", "Technological Innovation", "Customer Service", "Regulatory Compliance", "Governance and Leadership", "Other"]

# Replace with actual brand names from your analysis
brands = ["HSBC", "Barclays", "Lloyds", "NatWest"]

In [None]:
from typing import List, Tuple

import torch
from transformers import pipeline

# Initialize pipeline with device_map
classifier = pipeline(
    "zero-shot-classification",
    model="facebook/bart-large-mnli",
    device_map="auto"  # Automatically offloads to CPU if needed
)

def categorize_topic(text: str) -> str:
    if not text or not isinstance(text, str):
        return "Other"
    text_preview = text[:1000]
    # Check GPU memory and offload if near limit
    if torch.cuda.memory_allocated() > 0.8 * torch.cuda.max_memory_allocated():
        classifier.model.to('cpu')
        torch.cuda.empty_cache()
    result = classifier(text_preview, topics, multi_label=False)
    print(f"Text: {text_preview[:200]}... -> Topic: {result['labels'][0]}, Scores: {result['scores']}")
    # Move back to GPU if available after processing
    if torch.cuda.is_available():
        classifier.model.to('cuda')
    return result['labels'][0]

def extract_ranking(text: str) -> List[Tuple[str, int, float]]:
    if not text or not isinstance(text, str):
        return []
    rankings = {brand: 0.0 for brand in brands}
    sentiment_pipeline = pipeline("sentiment-analysis", model="distilbert-base-uncased", device_map="auto")

    for brand in brands:
        sentences = [s for s in text.split('. ') if brand.lower() in s.lower()]
        if not sentences:
            continue
        try:
            # Offload if memory is critical
            if torch.cuda.memory_allocated() > 0.8 * torch.cuda.max_memory_allocated():
                sentiment_pipeline.model.to('cpu')
                torch.cuda.empty_cache()
            scores = [sentiment_pipeline(s[:512])[0]['score'] for s in sentences]
            avg_score = sum(scores) / len(scores) if scores else 0.0
            rankings[brand] = avg_score
        except Exception as e:
            print(f"Error processing sentiment for {brand}: {e}")
            rankings[brand] = 0.0
        finally:
            if torch.cuda.is_available():
                sentiment_pipeline.model.to('cuda')

    ranked_brands = sorted(rankings.items(), key=lambda x: x[1], reverse=True)
    return [(brand, idx + 1, float(score)) for idx, (brand, score) in enumerate(ranked_brands)]



In [None]:
# Apply topic categorization
df['topic'] = df['content'].apply(categorize_topic)
df.head(40)

In [None]:
# Apply ranking extraction
df['brand_rankings'] = df['content'].apply(extract_ranking)

In [None]:
df.head(40)

In [None]:
# Display results
df[['url', 'topic', 'brand_rankings']].head()

In [None]:
def aggregate_rankings(df):
    rankings_exploded = df.explode('brand_rankings')
    rankings_exploded = rankings_exploded[rankings_exploded['brand_rankings'].notnull()]

    rankings_exploded['brand_rankings'] = rankings_exploded['brand_rankings'].apply(
        lambda x: {'brand': x[0], 'rank': x[1], 'score': x[2]} if isinstance(x, (list, tuple)) and len(x) == 3 else None
    )

    rankings_exploded = rankings_exploded[rankings_exploded['brand_rankings'].notnull()]
    rankings_df = pd.json_normalize(rankings_exploded['brand_rankings'])
    rankings_df.columns = ['brand', 'rank', 'score']

    result = pd.concat([rankings_exploded['topic'].reset_index(drop=True), rankings_df.reset_index(drop=True)], axis=1)

    # Group by topic and brand, compute mean score
    temp = result.groupby(['topic', 'brand']).agg({'score': 'mean'}).reset_index()

    # Assign ranks based on mean score in descending order (highest score = 1)
    temp['rank'] = temp.groupby('topic')['score'].rank(ascending=False, method='min').astype(int)

    # Sort by topic and rank to ensure descending order
    temp = temp.sort_values(['topic', 'rank'])

    return temp[['topic', 'brand', 'rank', 'score']]

df_aggregated = aggregate_rankings(df)
print(df_aggregated)

# 4. Baseline LLM Brand Ranking

# 5. LLM Brand Ranking with RAG Context

RefinedWeb columns are explained below:

| Column Name | Description |
|-------------|-------------|
| `content` | The main textual content of the document record (e.g., the body of a document, article, or code snippet). This will be used as the primary field for training language model and analysis in our study. |
| `url` | The URL of the web page or resource from which the content was sourced. |
| `timestamp` | The date and time when the web page was crawled or the data was extracted from the source (e.g., Common Crawl). |
| `dump` | This refers to the specific Common Crawl (CC) dump from which the data was sourced. CC releases monthly dumps (e.g., CC-MAIN-2023-06), allowing users to trace the data back to its original crawl.|
| `segment` | Identifies the segment or subset of the Common Crawl dump from which the record originates.|
| `image_urls` | A list of URLs pointing to images found on the web page.|

## 3.1 Data Cleaning


In [None]:
# Tokenisation - clean text
def clean_text(text):
    if not isinstance(text, str):
        return ""

    text = re.sub(r'[\n\r]', ' ', text)     # removes newlines and carriage returns
    text = re.sub(r'[^\w\s]', '', text.lower())     # removes punctuation and lowercase
    text = re.sub(r'\d+', '', text)   # removes digits
    text = re.sub(r'\s+', ' ', text).strip() # whitespace
    text = demojize(text)  # convert emojis to text (e.g., 😊 → :smiling_face:)
    text = re.sub(r'[^\w\s:]', '', text.lower())  # preserve emoji tokens
    return text

clean_udf = udf(clean_text, StringType())
df = df.withColumn("clean_text", clean_udf(col("content")))


## 3.2 Feature Engineering

The aim of this part is to extract additional information and columns from the data to enable more detailed sentiment analysis, such as brand mentions and content types.

### Extraction of Brand Mentions: *brand_mention* and *mention_count*

Our analysis is focused on the following **4 brands**: HSBC, LLoyds, Barclays, and Revolut.

In [None]:
# UK banks
BRANDS = ["barclays", "lloyds", "hsbc", "monzo"]

In [None]:
# Banking-related context terms to confirm brand relevance
BANKING_CONTEXT = [
    "finance", "financial", "bank", "banking", "account", "savings", "current", "mortgage", "loan", "credit", "debit", "card",
    "app", "mobile", "online", "branch", "atm", "transfer", "fees", "overdraft", "service", "support"
]

# Negative context terms to exclude false positives
NEGATIVE_CONTEXT = {
    "revolut": ["revolution", "revolutionary", "national revolution"],
    "barclays": ["barclays center", "barclays arena"],
    "lloyds": ["lloyds of london"],
    "hsbc": [],
    "monzo": []
}

In [None]:
# def extract_brands_and_counts(text):
#     if not isinstance(text, str):
#         return [], []
#     text_lower = text.lower()
#     tokens = word_tokenize(text_lower)

#     brands_found = []
#     counts = []

#     for brand in BRANDS:
#         # Initialize count
#         brand_count = 0

#         # Check for brand in tokens with word boundaries
#         brand_pattern = r'\b' + re.escape(brand) + r'\b'
#         matches = re.findall(brand_pattern, text_lower)
#         brand_count += len(matches)

#         # Validate with banking context (at least one banking term nearby)
#         has_banking_context = False
#         for context in BANKING_CONTEXT:
#             if context in text_lower:
#                 has_banking_context = True
#                 break

#         # Check for negative context to exclude false positives
#         has_negative_context = False
#         for negative_term in NEGATIVE_CONTEXT.get(brand, []):
#             if negative_term in text_lower:
#                 has_negative_context = True
#                 break

#         # Only include brand if it has banking context and no negative context
#         if brand_count > 0 and has_banking_context and not has_negative_context:
#             brands_found.append(brand)
#             counts.append(brand_count)

#     return brands_found, counts

# @udf(ArrayType(StringType()))
# def extract_brands(text):
#     brands, _ = extract_brands_and_counts(text)
#     return brands

# @udf(ArrayType(IntegerType()))
# def extract_mention_counts(text):
#     _, counts = extract_brands_and_counts(text)
#     return counts

# df = df.withColumn("brand_name", extract_brands(col("clean_text")))
# df = df.withColumn("mention_count", extract_mention_counts(col("clean_text")))

# # Filter rows with at least one valid brand mention
# df = df.filter(col("brand_name").isNotNull() & (col("brand_name").getItem(0).isNotNull()))
# print(f"Number of rows with brand mentions: {df.count()}")
# df.show()

Save Spark Dataframe with brand mentions to Parquet files as a checkpoint:

In [None]:
# df.write.mode("overwrite").parquet("data/temp/olmo_brand_mentions")

In [None]:
brand_mentions_dir = "data/temp/olmo_brand_mentions"
df = spark.read.parquet(f"{brand_mentions_dir}/*.parquet")
print(f"Number of rows with brand mentions: {df.count()}")
df.show()

### Brand Mentions by Brand

In [None]:
# Viewing mentions for each brand
for brand in BRANDS:
    print(f"\n=== Documents mentioning '{brand}' ===")
    brand_df = df.filter(array_contains(col("brand_name"), brand))
    brand_df.select("clean_text", "brand_name", "mention_count").show()
    print(f"Total number of documents in RefinedWeb dataset mentioning '{brand}': {brand_df.count()}")

### Classification of Brand-related Content: *content_type*

Content types help tailor sentiment methods, i.e. VADER for user-generated, FinBERT for news).

In [None]:
def classify_content(url, clean_text):
    if not isinstance(url, str):
        url = ""
    if not isinstance(clean_text, str):
        clean_text = ""

    url = url.lower()
    clean_text = clean_text.lower()

    # Social media or blogs
    user_gen_domains = ["reddit", "twitter", "x.com", "facebook", "linkedin", "instagram", "tiktok", "pinterest", "forum", "discuss", "community", "medium", "wordpress", "blogger", "tumblr", "substack", "blog"]
    if any(domain in url for domain in user_gen_domains):
        return "user_generated"

    # News article: Reputable news sources or news-related keywords
    news_domains = ["bbc", "guardian", "telegraph", "ft.com", "reuters", "bloomberg", "cnn", "nytimes", "independent", "dailymail", "sky.com", "news", "times"]
    news_keywords = ["breaking news"]
    if any(domain in url for domain in news_domains) or any(keyword in clean_text for keyword in news_keywords):
        return "news_article"

    # Customer review: Review platforms or review-related keywords
    review_keywords = ["trustpilot", "feefo", "reviews", "yelp", "google.com/reviews"]
    if any(domain in url for domain in review_keywords) or any(keyword in clean_text for keyword in review_keywords):
        return "customer_review"

    # Regulatory document: Official or compliance-related sources or keywords
    regulatory_keywords = ["fca.org.uk", "bankofengland", "gov.uk"]
    if any(domain in url for domain in regulatory_keywords) or any(keyword in clean_text for keyword in regulatory_keywords):
        return "regulatory_document"

    # Advertising content: Promotional keywords
    advertising_keywords = ["ads", "campaign", "promo", "sponsor", "advert", "promotion", "ad"]
    if any(term in url for term in advertising_keywords):
        return "advertising_content"

    # Owned media: Brand or institutional domains or brand mentions
    owned_media_domains = ["gov.uk", "ac.uk", "co.uk", "barclays", "lloyds", "hsbc", "monzo"]
    if any(domain in url for domain in owned_media_domains):
        return "owned_media"

    # Forum post: Specific forum platforms or discussion keywords
    forum_keywords = ["moneysavingexpert", "thestudentroom", "forums"]
    if any(domain in url for domain in forum_keywords) or any(keyword in clean_text for keyword in forum_keywords):
        return "forum_post"

    # FAQ/Knowledge base: Support or informational keywords
    faq_keywords = ["faq", "how to", "guide", "tutorial"]
    if any(keyword in url for keyword in faq_keywords) or any(keyword in clean_text for keyword in faq_keywords):
        return "faq_knowledge_base"

    # Default: Other
    return "miscellaneous"

content_type_udf = udf(classify_content, StringType())
df = df.withColumn("content_type", content_type_udf(col("url"), col("clean_text")))

In [None]:
df.show()

### Summary of Final Columns

| Column Name    | Description                                                                 |
|----------------|-----------------------------------------------------------------------------|
| `text`         | The original textual content of the document record (e.g., the body of an article or code snippet), retained as the primary source text for analysis. |
| `clean_text`   | The processed version of the `text` column, where newlines, punctuation, digits, and excessive whitespace are removed, text is lowercased, and emojis are converted to text for consistency in analysis. |
| `url`          | The URL of the web page or resource from which the content was sourced, used for content type classification and brand context. |
| `brand_name`   | An array of organization names extracted from `clean_text`, representing brand mentions for targeted sentiment analysis. |
| `mention_count`| The number of brand mentions (size of the `brand_name` array) in each row, quantifying the frequency of brand references. |
| `content_type` | A categorized label (e.g., `user_generated`, `news_article`, `customer_review`, etc.) assigned based on the `url`, indicating the type of content for further analysis. |

In [None]:
df.printSchema()

In [None]:
num_rows = df.count()
num_cols = len(df.columns)
print(f"Number of rows: {num_rows}")
print(f"Number of columns: {num_cols}")

In [None]:
df.cache()

# 4. Brand Sentiment Analysis

In this section, sentiment analysis is performed on UK bank brand mentions using a hybrid approach combining lexicon-based (VADER) and a transformer-based model (FinBERT). The aim is to analyze the emotional tone (positive, neutral, negative) of the brand mentions.

## 4.1 Lexicon-Based (VADER)

The following code performs brand sentiment analysis using NLTK's VADER (Valence Aware Dictionary and sEntiment Reasoner), a lexicon-based tool specifically designed for detecting sentiment in user-generated texts. VADER is fast and handles slang, emojis, and short texts well, making it ideal for analysing sentiment in data sources such as social media and reviews.

VADER calculates 4 sentiment metrics for each text input:
- `vader_score` (compound score): A normalized weighted composite score ranging from -1 (negative) to +1 (positive). Derived from the sum of valence scores of individual words, adjusted for modifiers (e.g., "very good" amplifies positivity).
- `positive_score`, `neutral_score`, `negative_score`: Proportional metrics representing the text's positive, neutral, and negative sentiment (each ranges 0–1). The 3 scores sum to 1.

`sentiment_label` is assigned based on the compound `vader_score`.

In [None]:
# Initialise VADER
sid = SentimentIntensityAnalyzer()

# Calculates VADER sentiment
def vader_sentiment(text):
    if not isinstance(text, str) or text.strip() == "":
        return {"compound": 0.0, "positive": 0.0, "neutral": 0.0, "negative": 0.0}
    scores = sid.polarity_scores(text)
    return scores

# Schema for VADER output
vader_schema = StructType([
    StructField("compound", FloatType(), nullable=True),
    StructField("pos", FloatType(), nullable=True),
    StructField("neu", FloatType(), nullable=True),
    StructField("neg", FloatType(), nullable=True)
])

vader_udf = udf(vader_sentiment, vader_schema)
df = df.withColumn("vader_sentiment", vader_udf(col("clean_text")))

### Sentiment Scores and Label

In [None]:
df = df.withColumn("vader_score", col("vader_sentiment.compound"))
df = df.withColumn("positive_score", col("vader_sentiment.pos"))
df = df.withColumn("neutral_score", col("vader_sentiment.neu"))
df = df.withColumn("negative_score", col("vader_sentiment.neg"))

# Sentiment Label
df = df.withColumn("sentiment_label",
    when(col("vader_score") > 0.05, "Positive")
    .when(col("vader_score") < -0.05, "Negative")
    .otherwise("Neutral"))

df = df.drop("vader_sentiment")

print("\nVADER Sentiment Scores and Labels:")
df.select(
    "clean_text", "brand_name", "mention_count", "content_type", "vader_score", "positive_score",
    "neutral_score", "negative_score", "sentiment_label"
).show(7)

### Overall Sentiment Aggregation: *avg_vader_score*

In [None]:
# Associates sentiment with each brand
df_exploded = df.select(
    explode(col("brand_name")).alias("brand"),
    col("vader_score"),
    col("content_type")
)

# Sentiment by brand and content_type
sentiment_summary = df_exploded.groupBy("brand", "content_type").agg(
    avg("vader_score").alias("avg_vader_score")
).orderBy("brand", "content_type")

# Sentiment label
sentiment_summary = sentiment_summary.withColumn(
    "avg_sentiment_label",
    when(col("avg_vader_score") > 0.05, "Positive")
    .when(col("avg_vader_score") < -0.05, "Negative")
    .otherwise("Neutral")
)

print("VADER Sentiment Summary by Brand and Content Type:")
sentiment_summary.show(truncate=False)

## 4.2 Transformer-Based (FinBERT)
**FinBERT** model is implemented for brand sentiment analysis of UK financial services brands due to its:
- **Domain-specialisation**:  Explicitly trained on financial texts (10M+ finance docs), including financial news, analyst reports, earnings call transcripts, SEC/FCA filings, and other regulatory documents. It has good understanding of key financial concepts, such as, financial metrics, market movements, and regulatory language.
- **Sentiment granularity**: 3-class (positive/neutral/negative)
- **Numerical sensitivity**: Handles earnings and percentages well.

FinBERT understands context better than VADER, excelling in more complex texts such as news articles, regulatory documents, and reports.

The following outputs are computed:
- `finbert_label` – the sentiment class with the highest average probability across all chunks
- `finbert_score` – the sentiment polarity score, calculated as Positive - Negative probability.
- `finbert_confidence`: How confident FinBERT is about its prediction


In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline

# CPU / GPU checks
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cpu":
    print("Warning: Flash Attention requires a CUDA-capable GPU. Falling back to standard attention.")
else:
    print(f"Using device: {device}")

# Enables Flash Attention
torch.backends.cuda.enable_flash_sdp(True)

# FinBERT tokenizer and model
finbert_tokenizer = AutoTokenizer.from_pretrained("ProsusAI/finbert", use_fast=True)
finbert_model = AutoModelForSequenceClassification.from_pretrained("ProsusAI/finbert").to(device)

# Additional optimisation using xformers
try:
    from xformers.ops import memory_efficient_attention
    print("Using xformers for memory-efficient attention")
except ImportError:
    print("xformers not installed. Using PyTorch Flash Attention.")

# FinBERT pipeline
finbert_pipeline = pipeline(
    task="sentiment-analysis",
    model=finbert_model,
    tokenizer=finbert_tokenizer,
    device=0 if torch.cuda.is_available() else -1,
    torch_dtype=torch.float16,  # Keep for memory efficiency
    return_all_scores=True,
    truncation=True,
    padding=True,
    max_length=512,
    batch_size=32
)

The large document texts are then split into context-level chunks. That is, each chunk contains a brand mentions and captures ±2 sentences surrounding each mention.

The text is first split into individual sentences, with sentences containing brand mentions being flagged. Chunks are consequently formed around each brand mention with ±2 sentences being appended.

In [None]:
# Splits text into context-based chunks
def prepare_chunks(text, tokenizer, brands=None, max_tokens=510):
    if not isinstance(text, str) or not text.strip():
        return []

    # If no brands provided, return empty list to avoid processing
    if not brands:
        return []

    # Split text into sentences using NLTK
    sentences = nltk.sent_tokenize(text)
    if not sentences:
        return []

    # Normalize brands for case-insensitive matching
    brands = [brand.lower() for brand in brands]

    chunks = []

    # Identify sentences containing brand mentions
    brand_mention_indices = []
    for i, sentence in enumerate(sentences):
        if not sentence.strip():
            continue
        # Check if any brand is mentioned in the sentence (case-insensitive)
        if any(brand in sentence.lower() for brand in brands):
            brand_mention_indices.append(i)

    if not brand_mention_indices:
        return []

    # Create chunks around each brand mention
    for idx in brand_mention_indices:
        # Define context window: ±2 sentences (up to 5 sentences total)
        start_idx = max(0, idx - 2)
        end_idx = min(len(sentences), idx + 3)  # idx + 2 + 1 to include the mention sentence
        context_sentences = sentences[start_idx:end_idx]

        # Initialize chunk and token count
        current_chunk = []
        current_token_count = 0

        for sentence in context_sentences:
            if not sentence.strip():
                continue

            # Tokenize sentence to count tokens
            tokens = tokenizer.tokenize(sentence)
            token_count = len(tokens)

            # If a single sentence exceeds max_tokens, truncate it
            if token_count > max_tokens:
                truncated_tokens = tokens[:max_tokens]
                truncated_sentence = tokenizer.convert_tokens_to_string(truncated_tokens)
                chunks.append(truncated_sentence)
                continue

            # If adding sentence exceeds max_tokens, finalize current chunk
            if current_token_count + token_count > max_tokens:
                if current_chunk:
                    chunk_text = " ".join(current_chunk)
                    if chunk_text.strip():
                        chunks.append(chunk_text)
                current_chunk = [sentence]
                current_token_count = token_count
            else:
                # Add sentence to current chunk
                current_chunk.append(sentence)
                current_token_count += token_count

        # Append any remaining chunk
        if current_chunk:
            chunk_text = " ".join(current_chunk)
            if chunk_text.strip():
                chunks.append(chunk_text)

    return chunks

In [None]:
# Analyzes sentiment of text using FinBERT, processing chunks around brand mentions
def analyze_finbert(text):
    global BRANDS
    try:
        chunks = prepare_chunks(text, finbert_tokenizer, brands=BRANDS)
        if not chunks or all(not c.strip() for c in chunks):
            return "neutral", 0.0, 0.0, {"positive": 0.0, "neutral": 1.0, "negative": 0.0}

        results = finbert_pipeline(chunks)
        cumulative_scores = {"positive": 0.0, "neutral": 0.0, "negative": 0.0}
        confidences = []
        count = 0

        for r in results:
            if isinstance(r, list):
                for entry in r:
                    label = entry["label"].lower()
                    score = entry["score"]
                    cumulative_scores[label] += score
                confidences.append(max(entry["score"] for entry in r))
                count += 1

        if count == 0:
            return "neutral", 0.0, 0.0, {"positive": 0.0, "neutral": 1.0, "negative": 0.0}

        # Normalizes all the scores
        avg_scores = {k: v / count for k, v in cumulative_scores.items()}
        avg_confidence = sum(confidences) / count

        # Polarity score
        polarity = avg_scores["positive"] - avg_scores["negative"]

        # Final predicted label
        if abs(polarity) < 0.15:
            final_label = "neutral"
        else:
            final_label = "positive" if polarity > 0 else "negative"

        return (
            final_label,
            round(polarity, 4),
            round(avg_confidence, 4),
            {k: round(v, 4) for k, v in avg_scores.items()}
        )

    except torch.cuda.OutOfMemoryError:
        torch.cuda.empty_cache()
        return analyze_finbert_vader_style(text)
    except Exception as e:
        print(f"FinBERT error on text {text[:50]}...: {str(e)}")
        return "neutral", 0.0, 0.0, {"positive": 0.0, "neutral": 1.0, "negative": 0.0}

In [None]:
# Row index for joining
df = df.withColumn("row_id", monotonically_increasing_id())

# Convert to Pandas for transformer processing
pandas_df = df.select("row_id", "text").toPandas()

In [None]:
# Runs analysis
finbert_results = [analyze_finbert(text) for text in pandas_df["text"]]

# Sentiment label, score and confidence
pandas_df["finbert_label"] = [r[0] for r in finbert_results]
pandas_df["finbert_score"] = [r[1] for r in finbert_results]
pandas_df["finbert_confidence"] = [r[2] for r in finbert_results]

# Individual sentiment scores
pandas_df["finbert_dist_positive"] = [r[3]["positive"] for r in finbert_results]
pandas_df["finbert_dist_neutral"] = [r[3]["neutral"] for r in finbert_results]
pandas_df["finbert_dist_negative"] = [r[3]["negative"] for r in finbert_results]


In [None]:
# Saved to CSV
finbert_csv = "data/finbert_results.csv"
pandas_df[[
    "row_id", "finbert_label", "finbert_score", "finbert_confidence",
    "finbert_dist_positive", "finbert_dist_neutral", "finbert_dist_negative"
]].to_csv(finbert_csv, index=False)

# Transforms back to Spark
transformer_df = spark.read.csv(finbert_csv, header=True, inferSchema=True)
df = df.join(transformer_df, "row_id").drop("row_id")

In [None]:
print("\nFinBERT Sentiment Scores:")
df.select(
     "clean_text", "brand_name", "mention_count", "content_type",
    "finbert_label", "finbert_score", "finbert_confidence", "finbert_dist_positive", "finbert_dist_neutral", "finbert_dist_negative"
).show(7)

## 4.3 Hybrid Approach (Combining VADER and FinBERT labels)
The following section combines VADER and FinBERT predictions, weighted by `content_type`. VADER is up-weighted for `user_generated` and `customer_review`, and FinBERT for `news_article` and `regulatory_document`. This outputs `hybrid_sentiment`.

In [None]:
@udf(StringType())
def hybrid_sentiment(vader_score, finbert_score, content_type):
    if vader_score is None or finbert_score is None:
        return "Neutral"

    if content_type == "user_generated":
        combined_score = (vader_score * 0.6) + (finbert_score * 0.4) # adjust weight if necessary
    # Use only VADER when finbert_score is 0
    else:
        # Use only FinBERT for all other content types
        combined_score = finbert_score

    # Sentiment thresholds
    if combined_score > 0.05:
        return "Positive"
    elif combined_score < -0.05:
        return "Negative"
    return "Neutral"

df = df.withColumn("hybrid_sentiment_label", hybrid_sentiment(
    col("vader_score"),
    col("finbert_score"),
    col("content_type")
))

print("\nHybrid Sentiment Labels based on VADER and FinBERT results:")
df.select(
    "clean_text", "brand_name", "mention_count", "content_type", "hybrid_sentiment_label"
).show(10)

# 5. Brand-Specific Analysis

The objective of this section is to delve into sentiment insights for specific brands (e.g Lloyds, Barclays), exploring how sentiment varies by content type, with visualizations for clarity.

In [None]:
brand_sentiment_df = df.select(
    explode(arrays_zip(col("brand_name"), col("mention_count"))).alias("exploded"),
    col("content_type"),
    col("vader_score"),
    col("sentiment_label"),
    col("finbert_label"),
    col("finbert_score"),
    col("finbert_confidence"),
    col("hybrid_sentiment_label")
).select(
    col("exploded.brand_name").alias("brand"),
    col("exploded.mention_count").alias("mentions"),
    col("content_type"),
    col("vader_score"),
    col("sentiment_label"),
    col("finbert_label"),
    col("finbert_score"),
    col("finbert_confidence"),
    col("hybrid_sentiment_label")
)

### Sentiment by Brand and Content Type

This reveals which content types drive positive or negative sentiment which can guide brand reputation strategies.

In [None]:
brand_specific_df = brand_sentiment_df.filter(col("brand").isin(BRANDS))

# Sentiment summaries for each brand
for brand in BRANDS:
    brand_df = brand_specific_df.filter(col("brand") == brand)

    brand_summary = brand_df.groupBy(
        "brand", "content_type", "hybrid_sentiment_label"
    ).agg({"mentions": "sum"}).withColumnRenamed("sum(mentions)", "total_mentions")

    print(f"\nBrand Sentiment Summary for {brand.capitalize()}:")
    brand_summary.orderBy("content_type", "hybrid_sentiment_label").show(50, truncate=False)

# 6. Filtering Positive and Negative Brand Mentions

From here onwards, the analysis will be done on a brand-level. The analysis will be done on 1 brand at a time, with HSBC being the first one. Thus, HSBC mentions are filtered:

In [None]:
hsbc_df = df.filter(array_contains(col("brand_name"), "hsbc"))
# hsbc_df.show()

Extracting only positive and only negative brand mentions for HSBC and saving them to 2 Parquet files:

In [None]:
# positive_hsbc_df.write.mode("overwrite").parquet("data/filtered_brand_mentions/hsbc_positive_mentions")
# negative_hsbc_df.write.mode("overwrite").parquet("data/filtered_brand_mentions/hsbc_negative_mentions")

# 7. RAG Implementation

In [None]:
import logging
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer
import faiss
from langchain_core.retrievers import BaseRetriever
from langchain_core.documents import Document
from pydantic import Field
from datetime import datetime

In [None]:
# Lgging for tracking pipeline progress and errors
logging.basicConfig(level=logging.INFO)

In [None]:
!pip install llama-cpp-python --timeout 1000

Using RAG, the filtered data will be fed into the **OLMo 2 model**. The OLMo 2 Model class was defined using Hugging Face implementation.

OLMo 2 model was initialised with **quantization**. Quantization lowers the memory requirements of loading and using a model by storing the weights in a lower precision while trying to preserve as much accuracy as possible. Weights are traditionally stored in full-precision (fp32) floating point representations, but half-precision (fp16 or bf16) have become increasingly popular data types given the large size of models. The chosen OLMo-2-0425-1B-Instruct-GGUF model is already quantized and is suitable for local deployment.

The init function initializes the OLMo 2 model. The generate function includes the following arguments:
* prompt: Input prompt or question
* context: Optional list of context strings to include (this is where we inject sentiment)
* max_new_tokens: Maximum number of new tokens to generate
* do_sample: Whether to use sampling for generation
* top_k: Number of highest probability tokens to consider
* top_p: Cumulative probability cutoff for top-p sampling

This outputs a string of generated text responses.

In [None]:
import torch
from transformers import AutoTokenizer, pipeline
from llama_cpp import Llama  # for GGUF support

class OLMo2Model:
    def __init__(self, model_path: str = "allenai/OLMo-2-0425-1B-Instruct-GGUF"):

        # Load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("allenai/OLMo-2-0425-1B-Instruct")

        # Load the pre-quantized GGUF model using llama-cpp-python
        self.model = Llama(
            model_path=model_path,
            n_gpu_layers=0,  # set to a positive number if using GPU
            n_ctx=2048,      # context length - adjust based on model capabilities
            verbose=False
        )

        # Device handling is managed by llama-cpp-python
        self.device = "cpu"  # GGUF model defaults to CPU; GPU support depends on llama-cpp-python build

    def generate(self, prompt: str, context: list = None, max_new_tokens: int = 100, temperature: float = 0.7, top_k: int = 50, top_p: float = 0.95) -> str:

        # Combine context with prompt if provided
        if context:
            context_text = " ".join(context)
            full_prompt = f"Context: {context_text}\nQuestion: {prompt}"
        else:
            full_prompt = prompt

        # Generate response using the GGUF model
        output = self.model(
            full_prompt,
            max_tokens=max_new_tokens,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            stop=["<|endoftext|>"]  # Stop token based on OLMo 2 chat template
        )
        return output["choices"][0]["text"].strip()

# Initializing model
olmo2 = OLMo2Model()

## 7.1 Vector Store

A vector store is a database that stores text embeddings (numerical representations of text generated by a model like BERT). These embeddings enable efficient similarity searches to retrieve relevant documents based on semantic meaning rather than exact keyword matches.

The FAISS index is implemented here. First, the cleaned text is converted into a vector using Sentence Transformer (all-MiniLM-L6-v2). Secondly, we build the FAISS index for fast search by taking all of the dataset vectors. It will then find the most similar items to a new query via FAISS. Lastly, FAISS returns similar vectors and documents.

This function processes data, creates vector embeddings from the cleaned text, and creates a vector store. This includes the following arguments:

* df: Spark DataFrame containing text data (e.g., hsbc_df)
* sentiment_filter: Optional filter for sentiment (e.g., "Positive" or "Negative")
* index_path: File path to save the FAISS index
* metadata_path: File path to save the metadata CSV

It returns returns a tuple: (FAISS index, list of text data).

In [None]:
# Embedding model
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

def create_vector_store(df, sentiment_filter=None, index_path="data/vector_store/hsbc_index.faiss", metadata_path="data/vector_store/metadata.csv"):

    # Filter dataframe based on sentiment if specified
    if sentiment_filter:
        df_filtered = df.filter(col("hybrid_sentiment_label") == sentiment_filter)
    else:
        df_filtered = df

    # Extracts cleaned text and converts to vector embeddings
    texts = [row["clean_text"] for row in df_filtered.select("clean_text").collect()]
    embeddings = embedding_model.encode(texts, show_progress_bar=True)
    dimension = embeddings.shape[1]
    index = faiss.IndexFlatL2(dimension)
    index.add(embeddings)

    # Create directories and save the index and metadata
    os.makedirs("data/vector_store", exist_ok=True)
    faiss.write_index(index, index_path)
    metadata = pd.DataFrame({"text": texts, "sentiment": [sentiment_filter] * len(texts) if sentiment_filter else ["Mixed"] * len(texts)})
    metadata.to_csv(metadata_path, index=False)
    logging.info(f"Saved vector store to {index_path} and metadata to {metadata_path}")
    return index, texts

## 7.2 RAG Service and Query System

RAG combines a **retrieval** step (retrieves relevant documents from the vector store) with a **generation** step (using OLMo 2 to generate answers). This enhances the model's responses by grounding them in specific and retrieved context.

A custom **retriever** for brand sentiment data is defined below:

In [None]:
class BrandSentimentRetriever(BaseRetriever):
    index: any = Field(..., description="FAISS index for vector search")
    metadata_df: pd.DataFrame = Field(..., description="DataFrame containing metadata")
    embedding_model: any = Field(..., description="SentenceTransformer model")
    top_k: int = Field(default=5, description="Number of documents to retrieve")

    class Config:
        arbitrary_types_allowed = True # allowing non-serializable and non-standard types (like a FAISS index object)

    # Retrieves relevant documents based on query embedding similarity
    def _get_relevant_documents(self, query: str, *, run_manager=None) -> list[Document]: # query - User input string to search for; run_manager - optional LangChain run manager

        # Encodes query
        query_embedding = self.embedding_model.encode([query])

        # Searches the index
        distances, indices = self.index.search(query_embedding, self.top_k)

        # Returns documents with metadata
        documents = []
        for idx in indices[0]:
            row = self.metadata_df.iloc[idx]
            documents.append(
                Document(
                    page_content=row['text'],
                    metadata={
                        'sentiment': row.get('sentiment', 'N/A'),
                        'brand': 'HSBC'
                    }
                )
            )
        return documents # retrieved list of documents with content and metadata

OLMo 2 model is integrated with the retriever to generate text responses. The **generator** is defined here:

In [None]:
class SentimentRAG:
    def __init__(self, retriever):
        self.retriever = retriever
        self.olmo2 = OLMo2Model()  # instantiated OLMo 2 model

    # Formats retrieved documents into a single string for the prompt
    def format_docs(self, docs):
        return "\n".join([d.page_content for d in docs])

    # Generates a response using OLMo 2 with retrieved context
    def invoke(self, query):
        try:
            # Retrieves relevant documents
            docs = self.retriever.invoke(query)

            # Format prompt
            prompt = f"""<|system|>
            You are a sentiment analysis expert. Answer based only on your knowledge and the additional context provided.
            Provide a ranking of UK banks from best to worst.</s>
            <|user|>
            Context: {self.format_docs(docs)}
            Question: {query}</s>
            <|assistant|>"""

            # Model generates a response
            response = self.olmo2.generate(prompt)

            return response.strip()

        except Exception as e:
            logging.error(f"Error generating response: {str(e)}")
            return "Error generating response"

## 7.4 Main Pipeline

The following main pipeline executes end-to-end sentiment analysis pipeline with RAG experiments and returns positive_rag, negative_rag, and control_rag instances:

In [None]:
def run_sentiment_pipeline():
    try:
        logging.info("Starting pipeline with RAG")

        # 1. Loads HSBC data
        logging.info("Loading HSBC data...")
        hsbc_df = spark.read.parquet("data/filtered_brand_mentions/hsbc_*_mentions")

        # 2. Sets up vector stores for positive, negative, and mixed (control) cases
        logging.info("Creating vector stores...")
        positive_index, positive_texts = create_vector_store(hsbc_df, "Positive")
        negative_index, negative_texts = create_vector_store(hsbc_df, "Negative")
        # control_index, control_texts = create_vector_store(hsbc_df)

        # Creating metadata df for each case
        positive_metadata = pd.DataFrame({"text": positive_texts, "sentiment": ["Positive"] * len(positive_texts)})
        negative_metadata = pd.DataFrame({"text": negative_texts, "sentiment": ["Negative"] * len(negative_texts)})
        # control_metadata = pd.DataFrame({"text": control_texts, "sentiment": ["Mixed"] * len(control_texts)})

        # 3. Sets up RAG for each experimental case
        logging.info("Initializing RAG for experiments...")
        positive_retriever = BrandSentimentRetriever(
            index=positive_index, metadata_df=positive_metadata, embedding_model=embedding_model, top_k=5
        )
        negative_retriever = BrandSentimentRetriever(
            index=negative_index, metadata_df=negative_metadata, embedding_model=embedding_model, top_k=5
        )
        # control_retriever = BrandSentimentRetriever(
        #     index=control_index, metadata_df=control_metadata, embedding_model=embedding_model, top_k=5
        # )

        positive_rag = SentimentRAG(positive_retriever)
        negative_rag = SentimentRAG(negative_retriever)
        # control_rag = SentimentRAG(control_retriever)

        # 4. Conducts experiments with the specified prompt
        query = "What is the best bank in the UK? Provide a ranking from best to worst"
        logging.info(f"Testing with query: {query}")

        # Control case: No RAG context injected, direct generation
        print("\nControl Case (No Context):")
        control_response = olmo2.generate(query)
        print(control_response)

        # Positive case: RAG with positive HSBC mentions
        print("\nPositive Case (With Positive HSBC Context):")
        positive_response = positive_rag.invoke(query)
        print(positive_response)

        # Negative case: RAG with negative HSBC mentions
        print("\nNegative Case (With Negative HSBC Context):")
        negative_response = negative_rag.invoke(query)
        print(negative_response)

        # 5. Loop for further queries
        while True:
            user_query = input("\nEnter a new query (or 'quit'): ")
            if user_query.lower() == 'quit':
                break
            print("\nControl Case (No Context, Direct Generation):", olmo2.generate(user_query))
            print("Positive Case (With Positive HSBC Context):", positive_rag.invoke(user_query))
            print("Negative Case (With Negative HSBC Context):", negative_rag.invoke(user_query))
        return positive_rag, negative_rag, control_rag

    except Exception as e:
        logging.critical(f"Pipeline failed: {str(e)}")
        raise

if __name__ == "__main__":
    rags = run_sentiment_pipeline()
    spark.stop()

In [None]:
spark.stop()