In [22]:
import json

from abc import ABC, abstractmethod

import numpy as np
import weaviate

from sklearn.metrics.pairwise import distance_metrics, pairwise_distances

from text2sql.engine.embeddings import SentenceTransformerEmbedder

In [14]:
sentence_transformer_embedder = SentenceTransformerEmbedder(
    model_path="sentence-transformers/LaBSE"
)



In [38]:
# test with text from aeneid (public domain) 
# https://classics.mit.edu/Virgil/aeneid.1.i.html
import os

with open("aeneid_sample.txt") as f:
    texts = f.read().split("\n")
texts = [t.strip().lstrip() for t in texts if t]

if not os.path.exists("aeneid_sample_embeddings.npy"):
    embeddings = sentence_transformer_embedder.embed(texts, verbose=True)
    np.save("aeneid_sample_embeddings.npy", embeddings)
else:
    embeddings = np.load("aeneid_sample_embeddings.npy")
assert len(embeddings) == len(texts)

In [16]:
class BaseRetriever(ABC):

    @abstractmethod
    def query():
        pass

In [39]:
class LocalRetriever(BaseRetriever):
    
    def __init__(self, embeddings: list[list[float]] | np.ndarray, data: list[dict], distance_metric: str = "cosine"):
        if len(embeddings) != len(data):
            raise ValueError("The number of embeddings must equal the number of data!")
        if distance_metric not in distance_metrics():
            raise ValueError(f"Unknown distance metric '{distance_metric}', must be one of {list(distance_metrics().keys())}")
        self.distance_metric = distance_metric
        self.embeddings = np.array(embeddings)
        self.data = data

    def query(self, query_vector: list[float] | np.ndarray, top_k: int = 10) -> list[dict]:
        query_vector = np.array(query_vector).reshape(1, -1)
        distances = pairwise_distances(query_vector, self.embeddings, metric=self.distance_metric)[0]
        indices = np.argsort(distances)
        results = [{"distance": float(distances[i]), "data": self.data[i]} for i in indices[:top_k]]
        return results


In [40]:
data = [{"line": line + 1, "text": text} for line, text in enumerate(texts)]
aeneid_retriever = LocalRetriever(embeddings=embeddings, data=data)

In [45]:
query_text = "Before his eyes his goddess mother stood:"

query_vector = sentence_transformer_embedder.embed(query_text)
responses = aeneid_retriever.query(query_vector, top_k=5)

In [46]:
for d in responses:
    print(json.dumps(d, indent=2))

{
  "distance": 0.0,
  "data": {
    "line": 434,
    "text": "Before his eyes his goddess mother stood:"
  }
}
{
  "distance": 0.38323378562927246,
  "data": {
    "line": 826,
    "text": "His mother goddess, with her hands divine,"
  }
}
{
  "distance": 0.48103129863739014,
  "data": {
    "line": 487,
    "text": "Of her unhappy lord: the specter stares,"
  }
}
{
  "distance": 0.49619847536087036,
  "data": {
    "line": 967,
    "text": "He walks Iulus in his mother's sight,"
  }
}
{
  "distance": 0.4964240789413452,
  "data": {
    "line": 919,
    "text": "Her mother Leda's present, when she came"
  }
}
