In [16]:
import os 
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
from neo4j import GraphDatabase
import re
from typing import List
import time
import polars as pl
from util import chunk_text

def remove_html_tags(text):
    clean = re.sub(r'<.*?>', '', text)
    return clean

URI = "bolt://localhost:7687"
AUTH = ("neo4j", "fairusecases")

load_dotenv()

    
class Custom_Embeddings:
    def __init__(self):
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2', token = os.environ["HUGGING_FACE"])
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        return [self.model.encode(t)[0] for t in texts]
            
    def embed_query(self, query: str) -> List[float]:
        return self.model.encode([query])[0]

In [17]:
def get_all_parts(tx):
    result = tx.run("""
        MATCH (c:Case)-[:HAS_OPINION]-(o:Opinion)-[:OF]-(f)
        RETURN c.WestLawCaseName AS WestLawCaseName, f.Document AS Document, Labels(f) AS Label
        """
    )

    return result.to_df()

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session(database="neo4j") as session:
            df = session.execute_read(get_all_parts)

In [18]:
df = pl.from_pandas(df)
embedder = Custom_Embeddings()

df = df.with_columns(
    pl.col("Label").map_elements(lambda x: x[0])
).with_columns(
    pl.col("Document").map_elements(lambda x: chunk_text(remove_html_tags(x))).alias("Chunks")
).explode("Chunks")

In [None]:
df = df.with_columns(
    pl.col("Chunks").map_elements(embedder.embed_query).alias("Embeddings")
).with_columns( pl.col("Chunks").map_elements(len).alias("CLen") ).filter(pl.col("CLen") > 25)

In [5]:
df = df.filter( pl.col("Chunks").is_not_null() )

In [6]:
def set_mini_embeddings(tx, ob):
    result = tx.run("""
        MATCH (c:Case {WestLawCaseName: $WestLawCaseName})-[:HAS_OPINION]-(o:Opinion)-[:OF]-(f {Document: $Document})
        MERGE (ch:Chunk {Document: $Chunks, MiniEmbedding : $MiniEmbedding})
        MERGE (f)-[:FROM]-(ch)
        """, WestLawCaseName = ob["WestLawCaseName"], Label = ob["Label"], Document = ob["Document"], Chunks = ob["Chunks"], MiniEmbedding = ob["Embeddings"]
    )

    return result

obs = df.iter_rows(named=True)

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session(database="neo4j") as session:
            for ob in obs:
                session.execute_write(set_mini_embeddings, ob)

In [75]:
def get_unembedded_gemini(tx):
    result = tx.run("""
        MATCH (n:Chunk)
        WHERE n.GeminiEmbeddings IS NULL
        RETURN n.Document AS Chunks
        """
    )

    return result.to_df()

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session(database="neo4j") as session:
            df = session.execute_read(get_unembedded_gemini)

In [None]:
df = pl.from_pandas(df)
df.describe

In [78]:
from google import genai
import time

client = genai.Client(api_key=os.environ["GEMINI_API"])

def embed_document(text, client = client):

    result = client.models.embed_content(
            model="text-embedding-004",
            contents=text)
    
    time.sleep(.5)

    return list(result.embeddings[0])[0][1]

In [79]:
def set_gemini_embeddings(tx, ob, embedding):
    result = tx.run("""
        MATCH (ch:Chunk {Document: $Chunks})
        SET ch.GeminiEmbeddings = $embedding
        """, Chunks = ob["Chunks"], embedding = embedding
    )

    return result

In [80]:
obs = df.iter_rows(named=True)

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session(database="neo4j") as session:
            for ob in obs:
                embedding = embed_document(ob["Chunks"])
                session.execute_write(set_gemini_embeddings, ob, embedding)

In [None]:
def create_mini_vector_index(tx):
    tx.run("""
        CREATE VECTOR INDEX MiniEmbeddingIndex IF NOT EXISTS
        FOR (c:Chunk)
        ON c.MiniEmbedding
        OPTIONS { indexConfig: {
        `vector.dimensions`: 384,
        `vector.similarity_function`: 'cosine'
        }}
        """
    )

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session(database="neo4j") as session:
            session.execute_write(create_mini_vector_index)

In [None]:
def create_gemini_vector_index(tx):
    tx.run("""
        CREATE VECTOR INDEX GeminiEmbeddingIndex IF NOT EXISTS
        FOR (c:Chunk)
        ON c.GeminiEmbeddings
        OPTIONS { indexConfig: {
        `vector.dimensions`: 768,
        `vector.similarity_function`: 'cosine'
        }}
        """
    )

with GraphDatabase.driver(URI, auth=AUTH) as driver:
    with driver.session(database="neo4j") as session:
            session.execute_write(create_gemini_vector_index)