In [None]:
import sqlite3
import numpy as np
from openai import OpenAI
import os
from tqdm.notebook import tqdm
import time
import requests
from requests.exceptions import Timeout

# Import the API key and org ID from config.py
try:
    from config import OPENAI_API_KEY
except ImportError:
    raise ImportError("Please create a config.py file with your OPENAI_API_KEY and OPENAI_ORG_ID")

print("Setting up OpenAI client...")
client = OpenAI(api_key=OPENAI_API_KEY)

print("Initializing database...")
conn = sqlite3.connect('p2025_db.sqlite')
cursor = conn.cursor()

print("Creating table if not exists...")
cursor.execute('''
CREATE TABLE IF NOT EXISTS document_chunks
(id INTEGER PRIMARY KEY, content TEXT, embedding BLOB, shape TEXT)
''')

def encode_text(text, max_retries=10, backoff_factor=2, timeout=30):
    print(f"Starting to encode text of length {len(text)}")
    for attempt in range(max_retries):
        try:
            print(f"Attempt {attempt + 1} to encode text")
            response = client.embeddings.create(
                model="text-embedding-ada-002",
                input=[text],
                timeout=timeout
            )
            embedding = np.array(response.data[0].embedding)
            print(f"Successfully encoded text")
            return embedding, embedding.shape
        except Timeout:
            wait_time = backoff_factor * (2 ** attempt)
            print(f"Request timed out. Retrying in {wait_time} seconds...")
            time.sleep(wait_time)
        except Exception as e:
            wait_time = backoff_factor * (2 ** attempt)
            print(f"Error occurred: {e}. Retrying in {wait_time} seconds...")
            time.sleep(wait_time)
    print("Failed to encode text after all attempts")
    raise Exception("Failed to encode text after all attempts")

def add_chunk(content, embedding, shape):
    cursor.execute('INSERT INTO document_chunks (content, embedding, shape) VALUES (?, ?, ?)',
                   (content, ','.join(map(str, embedding)), str(shape)))
    conn.commit()

def read_and_chunk_file(file_path, chunk_size=3500, overlap=500):
    print(f"Reading file: {file_path}")
    chunks = []
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read()
    
    print(f"Chunking file (chunk size: {chunk_size}, overlap: {overlap})")
    start = 0
    with tqdm(total=len(content), desc="Chunking progress") as pbar:
        while start < len(content):
            end = start + chunk_size
            chunk = content[start:end]
            
            if end < len(content):
                sentence_end = chunk.rfind('.')
                paragraph_end = chunk.rfind('\n')
                if sentence_end > 0:
                    end = start + sentence_end + 1
                elif paragraph_end > 0:
                    end = start + paragraph_end + 1
            
            chunks.append(content[start:end])
            start = end - overlap
            pbar.update(end - start)
    
    print(f"Created {len(chunks)} chunks")
    return chunks

print("Reading and chunking file...")
chunks = read_and_chunk_file('p2025.txt')

print("Processing chunks and adding to database...")
for chunk in tqdm(chunks, desc="Processing chunks"):
    print(f"Encoding chunk (length: {len(chunk)})")
    embedding, shape = encode_text(chunk)
    print(f"Adding chunk to database (embedding shape: {shape})")
    add_chunk(chunk, embedding, shape)

print(f"Added {len(chunks)} chunks to the database.")

def retrieve_chunks(query, top_k=5):
    print(f"Retrieving chunks for query: '{query}'")
    query_embedding, query_shape = encode_text(query)
    print(f"Query embedding shape: {query_shape}")
    
    cursor.execute('SELECT id, embedding, shape FROM document_chunks')
    results = cursor.fetchall()
    
    print(f"Comparing query to {len(results)} stored chunks")
    similarities = []
    for id, emb, shape in tqdm(results, desc="Comparing embeddings"):
        emb_array = np.array([float(x) for x in emb.split(',')]).reshape(eval(shape))
        
        if emb_array.shape != query_shape:
            print(f"Warning: Embedding shape mismatch. Query: {query_shape}, Stored: {emb_array.shape}")
            continue
        
        similarity = np.dot(query_embedding, emb_array) / (np.linalg.norm(query_embedding) * np.linalg.norm(emb_array))
        similarities.append((id, similarity))
    
    if not similarities:
        print("No valid embeddings found for comparison.")
        return []
    
    top_ids = sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
    
    placeholders = ','.join('?' for _ in top_ids)
    cursor.execute(f'SELECT content FROM document_chunks WHERE id IN ({placeholders})', 
                   [id for id, _ in top_ids])
    return cursor.fetchall()

print("Testing retrieval...")
test_query = "What does Project 2025 say about BLM's move west?"
relevant_chunks = retrieve_chunks(test_query)

print("\nRelevant chunks for the query:")
for i, chunk in enumerate(relevant_chunks, 1):
    print(f"Chunk {i}:")
    print(chunk[0][:200] + "...")  # Print first 200 characters of each chunk
    print()

print("Closing database connection...")
conn.close()
print("Done!")