In [52]:
import json
import uuid
from abc import ABC, abstractmethod

import numpy as np
import tqdm
import weaviate
import weaviate.classes as wvc

from sklearn.metrics.pairwise import distance_metrics, pairwise_distances
from weaviate.classes.config import Property, DataType
from weaviate.classes.query import MetadataQuery
from weaviate.util import generate_uuid5

from text2sql.engine.embeddings import SentenceTransformerEmbedder

In [2]:
sentence_transformer_embedder = SentenceTransformerEmbedder(
    model_path="sentence-transformers/LaBSE"
)



In [3]:
# test with text from aeneid (public domain) 
# https://classics.mit.edu/Virgil/aeneid.1.i.html
import os

with open("aeneid_sample.txt") as f:
    texts = f.read().split("\n")
texts = [t.strip().lstrip() for t in texts if t]

if not os.path.exists("aeneid_sample_embeddings.npy"):
    embeddings = sentence_transformer_embedder.embed(texts, verbose=True)
    np.save("aeneid_sample_embeddings.npy", embeddings)
else:
    embeddings = np.load("aeneid_sample_embeddings.npy")
assert len(embeddings) == len(texts)
data = [{"line": line + 1, "text": text} for line, text in enumerate(texts)]

In [4]:
class BaseRetriever(ABC):

    @abstractmethod
    def query():
        pass

In [35]:
class LocalRetriever(BaseRetriever):
    
    def __init__(self, embeddings: list[list[float]] | np.ndarray, data: list[dict], distance_metric: str = "cosine"):
        if len(embeddings) != len(data):
            raise ValueError("The number of embeddings must equal the number of data!")
        if distance_metric not in distance_metrics():
            raise ValueError(f"Unknown distance metric '{distance_metric}', must be one of {list(distance_metrics().keys())}")
        self.distance_metric = distance_metric
        self.embeddings = np.array(embeddings)
        self.data = data

    def query(self, query_vector: list[float] | np.ndarray, top_k: int = 10) -> list[dict]:
        query_vector = np.array(query_vector).reshape(1, -1)
        distances = pairwise_distances(query_vector, self.embeddings, metric=self.distance_metric)[0]
        indices = np.argsort(distances)
        results = [{"id": int(i), "distance": float(distances[i]), "data": self.data[i]} for i in indices[:top_k]]
        return results


In [36]:
aeneid_local_retriever = LocalRetriever(embeddings=embeddings, data=data)

In [37]:
query_text = "Before his eyes his goddess mother stood:"

query_vector = sentence_transformer_embedder.embed(query_text)


In [38]:
local_responses = aeneid_local_retriever.query(query_vector, top_k=5)
for d in local_responses:
    print(json.dumps(d, indent=2))


{
  "id": 434,
  "distance": 0.0,
  "data": {
    "line": 435,
    "text": "Before his eyes his goddess mother stood:"
  }
}
{
  "id": 826,
  "distance": 0.38323378562927246,
  "data": {
    "line": 827,
    "text": "His mother goddess, with her hands divine,"
  }
}
{
  "id": 487,
  "distance": 0.47407424449920654,
  "data": {
    "line": 488,
    "text": "Of her unhappy lord: the spectre stares,"
  }
}
{
  "id": 919,
  "distance": 0.49701905250549316,
  "data": {
    "line": 920,
    "text": "Her mother Leda\u2019s present, when she came"
  }
}
{
  "id": 967,
  "distance": 0.5039515495300293,
  "data": {
    "line": 968,
    "text": "He walks Iulus in his mother\u2019s sight,"
  }
}


In [21]:
def weaviate_properties_from_dict(data_sample: dict) -> list[Property]:
    """get properties from a data sample"""
    properties = []
    for key, value in data_sample.items():
        if isinstance(value, str):
            prop_dtype = DataType.TEXT
        elif isinstance(value, uuid.UUID):
            prop_dtype = DataType.UUID
        elif isinstance(value, int):
            prop_dtype = DataType.INT
        elif isinstance(value, float):
            prop_dtype = DataType.NUMBER
        elif isinstance(value, bool):
            prop_dtype = DataType.BOOLEAN
        elif isinstance(value, dict):
            prop_dtype = DataType.OBJECT
        elif isinstance(value, list):
            if isinstance(value[0], str):
                prop_dtype = DataType.TEXT_ARRAY
            elif isinstance(value[0], str):
                prop_dtype = DataType.INT_ARRAY
            elif isinstance(value[0], float):
                prop_dtype = DataType.NUMBER_ARRAY
            elif isinstance(value[0], bool):
                prop_dtype = DataType.BOOL_ARRAY
            elif isinstance(value[0], dict):
                prop_dtype = DataType.OBJECT_ARRAY
            elif isinstance(value[0], uuid.UUID):
                prop_dtype = DataType.UUID
        else:
            raise ValueError(f"Unknown type for {key=} and {value=}")
        properties.append(Property(name=key, data_type=prop_dtype))
    return properties

In [48]:
class WeaviateRetriever(BaseRetriever):
    
    def __init__(self, host: str, port: int, grpc_port: int, collection_name: str):
        self.collection_name = collection_name
        self.host = host
        self.port = port
        self.grpc_port = grpc_port
        self.client = self._get_weaviate_client(host, port, grpc_port)

    def _get_weaviate_client(self, host: str, port: int, grpc_port: int):
        """get weaviate client"""
        client: weaviate.Client = weaviate.connect_to_local(
            host=host,
            port=port,
            grpc_port=grpc_port,
        )
        if not client.is_ready():
            raise Exception("weaviate client not ready")
        return client

    def _create_weaviate_collection(self, properties: list[Property]):
        self.client.collections.create(
            self.collection_name,
            vectorizer_config=wvc.config.Configure.Vectorizer.none(),
            properties=properties,
        )
        return True

    def populate_collection(self, embeddings: list[list[float]] | np.ndarray, data: list[dict], delete_existing: bool = False, verbose: bool = True):
        """add data to the weaviate collection"""
        if len(embeddings) != len(data):
            raise ValueError("The number of embeddings must equal the number of data!")
        if delete_existing:
            self.client.collections.delete(self.collection_name)
        if not self.client.collections.exists(self.collection_name):
            properties = weaviate_properties_from_dict(data[0])
            self._create_weaviate_collection(properties)
        collection = self.client.collections.get(self.collection_name)
        with collection.batch.dynamic() as batch:
            for i in tqdm.trange(len(embeddings)):
                embedding = list(embeddings[i])
                datum = data[i]
                batch.add_object(
                    uuid=generate_uuid5(datum),
                    properties=datum,
                    vector=embedding
                )
            if len(collection.batch.failed_objects) > 0:
                raise Exception(f"Failed to import {len(collection.batch.failed_objects)} objects")
        return self.get_collection_info()


    def get_collection_info(self):
        if not self.client.collections.exists(self.collection_name):
            raise ValueError(f"Collection '{self.collection_name}' does not exist! please do populate_collection() first!")
        collection = self.client.collections.get(self.collection_name)
        properties = collection.config.get().to_dict()
        count = collection.aggregate.over_all(total_count=True).total_count
        return {
            "collection_name": self.collection_name,
            "properties": properties,
            "count": count,
        }


    def query(self, query_vector: list[float] | np.ndarray, top_k: int = 10) -> list[dict]:
        collection = self.client.collections.get(self.collection_name)
        response = collection.query.near_vector(
            near_vector=query_vector,
            limit=top_k,
            return_metadata=MetadataQuery(distance=True),
        )
        results = [
            {
                "id": str(obj.uuid),
                "distance": float(obj.metadata.distance),
                "data": dict(obj.properties),
            }
            for obj in response.objects
        ]
        return results


In [49]:
weaviate_host = "localhost"
weaviate_port = 8081
weaviate_gpu_port = 50051

aeneid_weaviate_retriever = WeaviateRetriever(
    host=weaviate_host,
    port=weaviate_port,
    grpc_port=weaviate_gpu_port,
    collection_name="AeneidLabse"
)

In [50]:
info = aeneid_weaviate_retriever.populate_collection(embeddings, data, delete_existing=True, verbose=True)

100%|██████████| 1066/1066 [00:00<00:00, 4509.84it/s]


In [53]:
weaviate_responses = aeneid_weaviate_retriever.query(query_vector, top_k=5)
for d in weaviate_responses:
    print(json.dumps(d, indent=2))


{
  "id": "1ecd05b6-22f0-5a8a-bd06-d3906706d060",
  "distance": -2.384185791015625e-07,
  "data": {
    "text": "Before his eyes his goddess mother stood:",
    "line": 435
  }
}
{
  "id": "138b2937-9715-5573-beca-f2e6e6b36d34",
  "distance": 0.38323378562927246,
  "data": {
    "text": "His mother goddess, with her hands divine,",
    "line": 827
  }
}
{
  "id": "630e9061-18e4-5c36-a93b-5258f9541043",
  "distance": 0.47407418489456177,
  "data": {
    "text": "Of her unhappy lord: the spectre stares,",
    "line": 488
  }
}
{
  "id": "c68e098b-e587-52b7-8302-1dfcec7d2bc4",
  "distance": 0.4970189332962036,
  "data": {
    "text": "Her mother Leda\u2019s present, when she came",
    "line": 920
  }
}
{
  "id": "9639d4a1-576e-58b9-89dc-c30f7a94a4d1",
  "distance": 0.503951370716095,
  "data": {
    "text": "He walks Iulus in his mother\u2019s sight,",
    "line": 968
  }
}
