In [None]:
# This is an example of performing Hybrid search using sqlite-vec and FTS

In [None]:
# !pip install --upgrade pip
# !pip install sqlite-vec
# !pip install pandas
# !pip install openai


In [1]:
import sqlite3
import sqlite_vec
from typing import List
import struct
import pandas as pd
from openai import AzureOpenAI, OpenAIError  
import openai 
import json
import numpy as np


In [2]:
top_k = 10
openai_embedding_api_base = "https://<redacted>.openai.azure.com/"
openai_embedding_api_key = "<redacted>"
openai_embedding_api_version =  "2024-02-15-preview"
openai_embedding_model = "text-embedding-ada-002"

In [3]:
# Function to serialize float32 list to binary format compatible with sqlite-vec  
def serialize_f32(vec):  
    return np.array(vec, dtype=np.float32).tobytes()  

def reciprocal_rank_fusion(fts_results, vec_results, k=60):  
    rank_dict = {}  
  
    # Process FTS results  
    for rank, (id,) in enumerate(fts_results):  
        if id not in rank_dict:  
            rank_dict[id] = 0  
        rank_dict[id] += 1 / (k + rank + 1)  
  
    # Process vector results  
    for rank, (rowid, distance) in enumerate(vec_results):  
        if rowid not in rank_dict:  
            rank_dict[rowid] = 0  
        rank_dict[rowid] += 1 / (k + rank + 1)  
  
    # Sort by RRF score  
    sorted_results = sorted(rank_dict.items(), key=lambda x: x[1], reverse=True)  
    return sorted_results 
  
def or_words(input_string):  
    # Split the input string into words  
    words = input_string.split()  
      
    # Join the words with ' OR ' in between  
    result = ' OR '.join(words)  
      
    return result

def lookup_row(id):
    row_lookup = cur.execute('''  
    SELECT content FROM mango_lookup WHERE id = ?
    ''', (id,)).fetchall()  
    content = ''
    for row in row_lookup:
        content= row[0]
        break
    return content

# Function to generate vectors for text  
def generate_embedding(text):  
    max_attempts = 6  
    max_backoff = 60  
    if text is None:  
        return None  
  
    client = AzureOpenAI(  
        api_version=openai_embedding_api_version,  
        azure_endpoint=openai_embedding_api_base,  
        api_key=openai_embedding_api_key  
    )  
  
    counter = 0  
    incremental_backoff = 1  # seconds to wait on throttling - this will be incremental backoff  
    while counter < max_attempts:  
        try:  
            response = client.embeddings.create(  
                input=text,  
                model=openai_embedding_model  
            )  
            return json.loads(response.model_dump_json())["data"][0]['embedding']  
        except OpenAIError as ex:  
            if str(ex.code) == "429":  
                print('OpenAI Throttling Error - Waiting to retry after', incremental_backoff, 'seconds...')  
                incremental_backoff = min(max_backoff, incremental_backoff * 1.5)  
                counter += 1  
                time.sleep(incremental_backoff)  
            elif str(ex.code) == "DeploymentNotFound":  
                print('Error: Deployment not found')  
                return 'Error: Deployment not found'  
            elif 'Error code: 40' in str(ex):  
                print('Error: ' + str(ex))  
                return 'Error:' + str(ex)  
            elif 'Connection error' in str(ex):  
                print('Error: Connection error')  
                return 'Error: Connection error'  
            elif str(ex.code) == "content_filter":  
                print('Content Filter Error', ex.code)  
                return "Error: Content could not be extracted due to Azure OpenAI content filter." + ex.code  
            else:  
                print('API Error:', ex)  
                print('API Error Code:', ex.code)  
                incremental_backoff = min(max_backoff, incremental_backoff * 1.5)  
                counter += 1  
                time.sleep(incremental_backoff)  
        except Exception as ex:  
            counter += 1  
            print('Error - Retry count:', counter, ex)  
            return None  

In [4]:
# Create an in memory sqlite db
db = sqlite3.connect(":memory:")
db.enable_load_extension(True)
sqlite_vec.load(db)
db.enable_load_extension(False)

sqlite_version, vec_version = db.execute(
    "select sqlite_version(), vec_version()"
).fetchone()
print(f"sqlite_version={sqlite_version}, vec_version={vec_version}")


sqlite_version=3.40.1, vec_version=v0.1.1


In [5]:
test_vec = generate_embedding('The quick brown fox')
dims = len(test_vec)
print ('Dims in Vector Embeddings:', dims)

Dims in Vector Embeddings: 1536


In [6]:
cur = db.cursor()
cur.execute('CREATE VIRTUAL TABLE mango_fts USING fts5(id UNINDEXED, content, tokenize="porter unicode61");')

# sqlite-vec always adds an ID field
cur.execute('''CREATE VIRTUAL TABLE mango_vec USING vec0(embedding float[''' + str(dims) + '])''')  

# Create a content lookup table with an index on the ID  
cur.execute('CREATE TABLE mango_lookup (id INTEGER PRIMARY KEY, content TEXT);')  

<sqlite3.Cursor at 0x776b84f1aec0>

In [7]:
# Insert some sample data into mango_fts  
fts_data = [  
    (1, 'The quick brown fox jumps over the lazy dog.'),  
    (2, 'Artificial intelligence is transforming the world.'),  
    (3, 'Climate change is a pressing global issue.'),  
    (4, 'The stock market fluctuates based on various factors.'),  
    (5, 'Remote work has become more prevalent during the pandemic.'),  
    (6, 'Electric vehicles are becoming more popular.'),  
    (7, 'Quantum computing has the potential to revolutionize technology.'),  
    (8, 'Healthcare innovation is critical for societal well-being.'),  
    (9, 'Space exploration expands our understanding of the universe.'),  
    (10, 'Cybersecurity threats are evolving and becoming more sophisticated.')  
] 
  
cur.executemany('''  
INSERT INTO mango_fts (id, content) VALUES (?, ?)  
''', fts_data)  


cur.executemany('''  
  INSERT INTO mango_lookup (id, content) VALUES (?, ?)  
''', fts_data)  
  

# Generate embeddings for the content and insert into mango_vec  
for row in fts_data:  
    id, content = row  
    embedding = generate_embedding(content)
    cur.execute('''  
    INSERT INTO mango_vec (rowid, embedding) VALUES (?, ?)  
    ''', (id, serialize_f32(embedding)))  


# Commit changes  
db.commit()  

In [10]:
# Full-text search query  
fts_search_query = "AI"  
# fts_search_query = "technology innovation"  
# fts_search_query = "electricity cars"  
# fts_search_query = "medical"  

fts_results = cur.execute('''  
  SELECT id FROM mango_fts WHERE mango_fts MATCH ? ORDER BY rank limit 5  
''', (or_words(fts_search_query),)).fetchall()  
  
# Vector search query  
query_embedding = generate_embedding(fts_search_query)  
vec_results = cur.execute('''  
  SELECT rowid, distance FROM mango_vec WHERE embedding MATCH ? and K = ?  
  ORDER BY distance  
''', [serialize_f32(query_embedding), top_k]).fetchall()  
  
# Combine results using RRF  
combined_results = reciprocal_rank_fusion(fts_results, vec_results)  
  
# Print combined results  
for id, score in combined_results:  
    print(f'ID: {id}, Content: {lookup_row(id)}, RRF Score: {score}')  

ID: 8, Content: Healthcare innovation is critical for societal well-being., RRF Score: 0.01639344262295082
ID: 2, Content: Artificial intelligence is transforming the world., RRF Score: 0.016129032258064516
ID: 9, Content: Space exploration expands our understanding of the universe., RRF Score: 0.015873015873015872
ID: 1, Content: The quick brown fox jumps over the lazy dog., RRF Score: 0.015625
ID: 5, Content: Remote work has become more prevalent during the pandemic., RRF Score: 0.015384615384615385
ID: 10, Content: Cybersecurity threats are evolving and becoming more sophisticated., RRF Score: 0.015151515151515152
ID: 7, Content: Quantum computing has the potential to revolutionize technology., RRF Score: 0.014925373134328358
ID: 6, Content: Electric vehicles are becoming more popular., RRF Score: 0.014705882352941176
ID: 3, Content: Climate change is a pressing global issue., RRF Score: 0.014492753623188406
ID: 4, Content: The stock market fluctuates based on various factors., RRF 

In [None]:
# Close the connection  
db.close()  
