In [1]:
import streamlit as st
from neo4j import GraphDatabase, basic_auth
from llm import embeddings
import pickle
from tqdm import tqdm

In [2]:
URI = st.secrets["NEO4J_URI"]
USERNAME = st.secrets["NEO4J_USERNAME"]
PASSWORD = st.secrets["NEO4J_PASSWORD"]

driver = GraphDatabase.driver(URI, auth=basic_auth(USERNAME, PASSWORD))
driver.verify_connectivity()
driver.verify_authentication()

True

In [3]:
session = driver.session()

In [4]:
records = session.execute_read(
    lambda tx: [dict(r) for r in tx.run("""
MATCH (n:Tweet) 
RETURN n.id AS id, n.favorites AS favorites, n.text AS text
""")]
)
texts = [str(result["text"]) for result in records]
text_embeddings = embeddings.embed_documents(texts)
for result, embedding in zip(records, text_embeddings):
    result["text_embedding"] = embedding
with open("tweets.pkl", "wb") as f:
    pickle.dump(records, f)

In [5]:
with open("tweets.pkl", "rb") as f:
    records = pickle.load(f)
batch_size = 100
batches = [records[i:i + batch_size] for i in range(0, len(records), batch_size)]
for batch in tqdm(batches):
    session.execute_write(
        lambda tx: tx.run("""
UNWIND $batch AS row
MATCH (n:Tweet {id: row.id})
SET n.text_embedding = row.text_embedding
""", batch=batch)
    )

100%|██████████| 25/25 [00:56<00:00,  2.24s/it]
