In [34]:
import base64
import requests
from io import BytesIO
from PIL import Image
import os
import torch
from visual_bge.modeling import Visualized_BGE
from pymilvus import (
    utility,
    CollectionSchema, DataType, FieldSchema, model,
    connections, Collection, AnnSearchRequest, RRFRanker,
)
from tqdm import tqdm

In [54]:
ENDPOINT = os.getenv('ZILLIS_ENDPOINT')
TOKEN = os.getenv('ZILLIS_TOKEN')
connections.connect(uri=ENDPOINT, token=TOKEN)

COLLECTION_NAME = "odprt_index"

In [118]:
AUTO_ID = FieldSchema(
    name="auto_id",
    dtype=DataType.INT64,
    is_primary=True,
    auto_id=True
)

DOC_ID = FieldSchema(
    name="doc_id",
    dtype=DataType.VARCHAR,
    max_length=500
)

DOC_SOURCE = FieldSchema(
    name="doc_source",
    dtype=DataType.VARCHAR,
    max_length=1000,
    default_value="NA"
)

### TEXT FEATURES

TEXT = FieldSchema(
    name="text",
    dtype=DataType.VARCHAR,
    max_length=50000,
    default_value=""
)

TEXT_DENSE_EMBEDDING = FieldSchema(
    name="text_dense_embedding",
    dtype=DataType.FLOAT_VECTOR,
    dim=1024
)

TEXT_SPARSE_EMBEDDING = FieldSchema(
    name="text_sparse_embedding",
    dtype=DataType.SPARSE_FLOAT_VECTOR
)

### IMAGE FEATURES

DESCRIPTION = FieldSchema(
    name="description",
    dtype=DataType.VARCHAR,
    max_length=5000,
    default_value=""
)

DESCRIPTION_EMBEDDING = FieldSchema(
    name="description_embedding",
    dtype=DataType.FLOAT_VECTOR,
    dim=768
)

IMAGE_EMBEDDING = FieldSchema(
    name="image_embedding",
    dtype=DataType.FLOAT_VECTOR,
    dim=768 # Image embedding dim
)

### DEFINING THE SCHEMA

SCHEMA = CollectionSchema(
    fields=[AUTO_ID, DOC_ID, DOC_SOURCE, TEXT, TEXT_DENSE_EMBEDDING, TEXT_SPARSE_EMBEDDING, DESCRIPTION, DESCRIPTION_EMBEDDING, IMAGE_EMBEDDING],
    description="Schema for indexing documents and images",
    enable_dynamic_field=True
)

In [119]:
def create_collection(collection_name, schema):
    # Check if the collection exists
    if utility.has_collection(collection_name):
        print(f"Collection '{collection_name}' already exists")
        return Collection(name=collection_name) 
    else:
        # Create the collection
        return Collection(name=collection_name, schema=schema, using='default', shards_num=2)

def drop_collection(collection_name):
    # Check if the collection exists
    if utility.has_collection(collection_name):
        collection = Collection(name=collection_name)
        # Release the collection
        collection.release()
        # Drop the collection if it exists
        utility.drop_collection(collection_name)
        print(f"Collection '{collection_name}' has been dropped")
    else:
        print(f"Collection '{collection_name}' does not exist")

drop_collection(COLLECTION_NAME)
collection = create_collection(collection_name=COLLECTION_NAME, schema = SCHEMA)

Collection 'odprt_index' has been dropped


In [9]:
# Initialise Hyperbolic API Details
api_key = os.getenv("HYPERBOLIC_API_KEY")
api = "https://api.hyperbolic.xyz/v1/chat/completions"
headers = {
    "Content-Type": "application/json",
    "Authorization": f"Bearer {api_key}",
}

In [12]:
# Embedding Model
embedding_model = Visualized_BGE(model_name_bge="BAAI/bge-base-en-v1.5", model_weight="Visualized_base_en_v1.5.pth")
embedding_model.eval()

Visualized_BGE(
  (bge_encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSdpaSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): BertOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=T

In [15]:
# Load Image
image_path = "image.png"
img = Image.open(image_path)

# Encode Images for Payload
def encode_image(img):
    buffered = BytesIO()
    img.save(buffered, format="PNG")
    encoded_string = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return encoded_string

base64_img = encode_image(img)

In [25]:
# API payload for generating image summary
payload = {
    "messages": [
        {
            "role": "user",
            "content": [
                {"type": "text", "text": "Generate me a summary of this image?"},
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{base64_img}"},
                },
            ],
        }
    ],
    "model": "Qwen/Qwen2-VL-7B-Instruct",
    "max_tokens": 300, # Max length for BGE-Embeddings
    "temperature": 0.7,
    "top_p": 0.9,
}

# Get response from VLM
response = requests.post(api, headers=headers, json=payload)

# Extract summary from response
summary = response.json().get("choices", [{}])[0].get("message", {}).get("content", "No summary available")

In [26]:
len(summary)

1406

In [107]:
# Helper function to generate embeddings
def generate_embeddings(image_path, text_summary):
    """Generates embeddings for both image and optional text summary."""
    description_embedding = []
    image_embedding = []

    with torch.no_grad():
        if text_summary and text_summary != "No summary available":
            description_embedding = embedding_model.encode(text=text_summary).tolist()[0]
        
        # Generate embedding for image
        image_embedding = embedding_model.encode(image=image_path).tolist()[0]

    return description_embedding, image_embedding

# Generate embeddings
description_embedding, image_embedding = generate_embeddings(image_path, summary)

In [120]:
data = [
    {
        "doc_id": "doc_001",
        "doc_source": "uploaded_image",
        "text": "",
        "text_dense_embedding": [0.0] * 1024,  # Placeholder for text dense embedding
        "text_sparse_embedding": [],  # Empty sparse vector
        "description": summary if summary != "No summary available" else "",
        "description_embedding": description_embedding,
        "image_embedding": image_embedding[:768] if len(image_embedding) > 768 else image_embedding  # Ensure correct 768-dim image embedding
    }
]

In [121]:
from tqdm import tqdm

def batch_ingestion(collection, data):
    batch_size = 100
    total_elements = len(data)  # Ensure batching considers the number of records
    total_batches = (total_elements + batch_size - 1) // batch_size

    # Using tqdm to create a progress bar
    for start in tqdm(range(0, total_elements, batch_size), 
                      total=total_batches,
                      desc="Ingesting batches"):
        end = min(start + batch_size, total_elements)
        batch = data[start:end]  # Slice batch correctly
        collection.insert(batch)  # Insert batch into collection


In [122]:
batch_ingestion(collection=collection, data = data)


Ingesting batches: 100%|██████████| 1/1 [00:00<00:00,  3.66it/s]


In [126]:
def create_all_indexes(collection: Collection) -> None:
    # dense embeddings index
    collection.create_index(
        field_name="text_dense_embedding",
        index_params={
            "metric_type": "COSINE",
            "index_type": "HNSW",
            "params": {
                "M": 5,
                "efConstruction": 512
            }
        },
        index_name="dense_embeddings_index"
    )
    
    print("Dense embeddings index created")

    # sparse embeddings index
    collection.create_index(
        field_name="text_sparse_embedding",
        index_params={
            "metric_type": "IP",
            "index_type": "SPARSE_INVERTED_INDEX",
            "params": {
                "drop_ratio_build": 0.2
            }
        },
        index_name="sparse_embeddings_index"
    )

    print("Sparse embeddings index created")

    # description embeddings index
    collection.create_index(
        field_name="description_embedding",
        index_params={
            "metric_type": "COSINE",
            "index_type": "HNSW"
        },
        index_name="description_embedding_index"
    )
    
    print("description_embedding index created")

    # sparse embeddings index
    collection.create_index(
        field_name="image_embedding",
        index_params={
            "metric_type": "COSINE",
            "index_type": "HNSW",
        },
        index_name="image_embedding_index"
    )
    
    print("image_embedding index created")
    # load
    collection.load()
    print("Collection loaded")

In [127]:
create_all_indexes(collection)

Dense embeddings index created
Sparse embeddings index created
description_embedding index created
image_embedding index created
Collection loaded


In [128]:
collection

<Collection>:
-------------
<name>: odprt_index
<description>: Schema for indexing documents and images
<schema>: {'auto_id': True, 'description': 'Schema for indexing documents and images', 'fields': [{'name': 'auto_id', 'description': '', 'type': <DataType.INT64: 5>, 'is_primary': True, 'auto_id': True}, {'name': 'doc_id', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 500}}, {'name': 'doc_source', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 1000}, 'default_value': string_data: "NA"
}, {'name': 'text', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_length': 50000}, 'default_value': string_data: ""
}, {'name': 'text_dense_embedding', 'description': '', 'type': <DataType.FLOAT_VECTOR: 101>, 'params': {'dim': 1024}}, {'name': 'text_sparse_embedding', 'description': '', 'type': <DataType.SPARSE_FLOAT_VECTOR: 104>}, {'name': 'description', 'description': '', 'type': <DataType.VARCHAR: 21>, 'params': {'max_len

In [133]:
def hybrid_search(query: str) -> str:
    query_embedding = embedding_model.encode(text=query).tolist()[0]
    
    search_results = collection.hybrid_search(
            reqs=[
                AnnSearchRequest(
                    data=[query_embedding],  # content vector embedding
                    anns_field='description_embedding',  # content vector field
                    param={"metric_type": "COSINE", "params": {"M": 64, "efConstruction": 512}}, 
                    limit=3
                ),
                AnnSearchRequest(
                    data=[query_embedding],  # keyword vector embedding
                    anns_field='image_embedding',  # keyword vector field
                    param={"metric_type": "COSINE", "params": {"M": 64, "efConstruction": 512}}, 
                    limit=3
                )
            ],
            output_fields=['doc_id', 'text', 'doc_source'],
            # using RRFRanker here for reranking
            rerank=RRFRanker(),
            limit=3
            )
    
    hits = search_results[0]
    
    context = []
    for res in hits:
        text = res.text
        source = res.doc_source
        context.append(f"Source: {source} \n Context: {text}")
    
    return "\n\n".join(context)

In [134]:
hybrid_search("Model Architecture")

'Source: uploaded_image \n Context: '