# Hybrid Search with Weaviate

**Imports**

In [None]:
import pandas as pd
from pandarallel import pandarallel
import os
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer

import weaviate
from weaviate.classes.config import Property, DataType
import weaviate.classes as wvc
import weaviate.classes.config as wc

pd.options.mode.chained_assignment = None
pd.options.display.max_rows = 500
pd.options.display.max_seq_items = 500

pandarallel.initialize(progress_bar=True)

# Suppress Hugginface warning about tokenizers.
os.environ["TOKENIZERS_PARALLELISM"] = "false"

**Constants**

In [None]:
load_dotenv()

PREP_OUTPUT_KRP = os.getenv("PREP_OUTPUT_KRP")
PREP_OUTPUT_RRB = os.getenv("PREP_OUTPUT_RRB")
PREP_OUTPUT_GSZH = os.getenv("PREP_OUTPUT_GSZH")

DATA_OUTPUT_FULL = os.getenv("DATA_OUTPUT_FULL")
DATA_OUTPUT_CHUNKS = os.getenv("DATA_OUTPUT_CHUNKS")
DATA_EMBEDDINGS = os.getenv("DATA_EMBEDDINGS")

# Load data

In [None]:
df = pd.read_parquet(DATA_EMBEDDINGS)
df["date"] = pd.to_datetime(df["date"]).dt.tz_localize("UTC")
df.drop(columns=["word_count"], inplace=True)
df.drop(columns=["year"], inplace=True)
df["year"] = df["date"].dt.year

In [None]:
display(df.info(memory_usage="deep"))
df.sample(10).T

# Weaviate

In [None]:
client = weaviate.connect_to_embedded()
# Use this code line if Weaviate is already running, e.g. from the Streamlit app.
# client = weaviate.connect_to_local(port=8079, grpc_port=50050)

In [None]:
# Get the meta endpoint description of weaviate.
display(client.get_meta())

# Ping Weaviate’s live and ready state.
print(client.is_live())
print(client.is_ready())

In [None]:
client.collections.create(
    "stazh",
    vectorizer_config=wc.Configure.Vectorizer.none(),
    inverted_index_config=wvc.config.Configure.inverted_index(
        bm25_b=0.75,
        bm25_k1=1.2,
        # stopwords_additions=None,
        # stopwords_preset=None,
        # stopwords_removals=None,
    ),
    properties=[
        Property(name="identifier", data_type=DataType.TEXT),
        Property(name="date", data_type=DataType.DATE),
        Property(name="year", data_type=DataType.INT),
        Property(name="title", data_type=DataType.TEXT),
        Property(name="link", data_type=DataType.TEXT),
        Property(name="stazh_ident", data_type=DataType.TEXT),
        Property(name="series", data_type=DataType.TEXT),
        Property(name="chunk_text", data_type=DataType.TEXT),
        Property(name="ref", data_type=DataType.TEXT),
    ],
)

In [None]:
# List all collections.
for v in client.collections.list_all().values():
    print(v.name)

# Get detailed information about all collections.
schema = client.collections.list_all(simple=False)
print(schema)

In [None]:
# # Delete collection
# client.collections.delete("stazh")

In [None]:
# Ingest data
collection = client.collections.get("stazh")

with collection.batch.dynamic() as batch:
    for idx, data in enumerate(df.to_dict(orient="records")):
        properties = {
            "identifier": data["identifier"],
            "date": data["date"],
            "year": data["year"],
            "title": data["title"],
            "link": data["link"],
            "stazh_ident": data["stazh_ident"],
            "series": data["series"],
            "chunk_text": data["chunk_text"],
            "ref": data["ref"],
        }
        batch.add_object(properties=properties, vector=data["embeddings"].tolist())

In [None]:
# List all items in the collection.
collection = client.collections.get("stazh")
for item in collection.iterator():
    print(item)
    break

In [None]:
# Get total count of all items in the collection.
collection = client.collections.get("stazh")
response = collection.aggregate.over_all(total_count=True)

print(response.total_count)

# Lexical search

In [None]:
collection = client.collections.get("stazh")
response = collection.query.bm25(
    query="Steuerreform",
    # query_properties=["title"], # Define which fields to search over.
    offset=0,
    limit=100,
    auto_limit=4,
    return_metadata=wvc.query.MetadataQuery(score=True, distance=True, certainty=True),
    # filters=wvc.query.Filter.by_property("year").equal(2012),
    #  filters=wvc.query.Filter.by_property("year").less_than(2012),
    #  auto_limit=True,
)

seen = []
final_results = []

for item in response.objects:
    if item.properties["identifier"] in seen:
        continue
    final_results.append(item.properties["title"])
    final_results.append(item.properties["series"])
    seen.append(item.properties["identifier"])
for elem in final_results:
    print(elem)

# Vector search

In [None]:
model_path = "jinaai/jina-embeddings-v2-base-de"
model = SentenceTransformer(
    model_path,
    trust_remote_code=True,
    device="mps",
)
model.max_seq_length = 512


def embed_query(query):
    return model.encode(query, convert_to_tensor=False, normalize_embeddings=True)

In [None]:
query = "Steuerreform"
query_embedding = embed_query(query)

collection = client.collections.get("stazh")
response = collection.query.near_vector(
    near_vector=list(query_embedding),
    target_vector="text",
    limit=20,
    auto_limit=3,
    return_metadata=wvc.query.MetadataQuery(distance=True),
)

seen = []
final_results = []

for item in response.objects:
    if item.properties["identifier"] in seen:
        continue
    final_results.append(item.properties["title"])
    final_results.append(item.properties["series"])
    seen.append(item.properties["identifier"])
for elem in final_results:
    print(elem)

# Hybrid search

In [None]:
query = "Steuerreform"
query_embedding = embed_query(query)

collection = client.collections.get("stazh")
response = collection.query.hybrid(
    query=query,
    vector=list(query_embedding),
    limit=5,
    auto_limit=2,
    alpha=0.7,
    fusion_type=wvc.query.HybridFusion.RELATIVE_SCORE,
    filters=wvc.query.Filter.by_property("year").greater_or_equal(1803)
    & wvc.query.Filter.by_property("year").less_or_equal(1995),
)

seen = []
final_results = []

for item in response.objects:
    if item.properties["identifier"] in seen:
        continue
    final_results.append(item.properties["title"])
    final_results.append(item.properties["series"])
    seen.append(item.properties["identifier"])
    
for elem in final_results:
    print(elem)

# Search by document

In [None]:
ident = "StAZH ABl 1987 (S. 1079)"

collection = client.collections.get("stazh")
response = collection.query.fetch_objects(
    filters=wvc.query.Filter.by_property("stazh_ident").equal(year)
)

uuid = response.objects[0].uuid

response = collection.query.near_object(near_object=uuid)

for item in response.objects:
    print(
        item.properties["title"],
        item.properties["year"],
    )