In [None]:
import pandas as pd
import gradio as gr
import torch
import accelerate
from transformers import AutoTokenizer, AutoModelForCausalLM
import os 
from mistralai import Mistral
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from vectordb import VectorDB
import faiss
import weaviate
from weaviatedb import WeaviateDB

In [None]:
weaviate_client = weaviate.Client(url="http://localhost:8080")

In [None]:
from config import API_KEY

In [None]:
USE_MISTRAL = True

In [None]:
mistral_client = Mistral(api_key=API_KEY)


In [None]:
file_path = "document_metadata.csv"

df = pd.read_csv(file_path)

In [None]:
import json

USE_MISTRAL = True 


def query_to_filters(query: str, max_new_tokens: int = 150):
    """
    Convert English or Arabic query about contracts into a structured filter dict:
    { company, amount_min, amount_max, year_min, year_max, keywords }
    """

    prompt = f"""
Convert this query about documents into a JSON object with keys:
author, category, year_min, year_max, tags, keywords.
Use null if not specified. Respond ONLY with valid JSON.
- The 'keywords' field should always be in English (the language of the CSV),
even if the input query is in Arabic.
- If multiple tags are mentioned, return them as a list.
- If a date range is mentioned, fill year_min and year_max.
- If only a single year is mentioned, set both year_min and year_max to that year.
- If a field is not mentioned, set it to null.

English examples:
"Reports by John Smith" => {{"author":"John Smith","category":"Report","year_min":null,"year_max":null,"tags":null,"keywords":null}}
"Documents about financial performance in 2023" => {{"author":null,"category":null,"year_min":2023,"year_max":2023,"tags":null,"keywords":"financial performance"}}
"Policies by HR with tag onboarding between 2021 and 2022" => {{"author":"HR","category":"Policy","year_min":2021,"year_max":2022,"tags":["onboarding"],"keywords":null}}
"Documents tagged marketing and sales in 2023" => {{"author":null,"category":null,"year_min":2023,"year_max":2023,"tags":["marketing","sales"],"keywords":null}}

Arabic examples:
"تقارير من جون سميث" => {{"author":"John Smith","category":"Report","year_min":null,"year_max":null,"tags":null,"keywords":null}}
"مستندات عن الأداء المالي في ٢٠٢٣" => {{"author":null,"category":null,"year_min":2023,"year_max":2023,"tags":null,"keywords":"financial performance"}}
"سياسات من قسم الموارد البشرية مع علامات onboarding بين 2021 و 2022" => {{"author":"HR","category":"Policy","year_min":2021,"year_max":2022,"tags":["onboarding"],"keywords":null}}
"مستندات عن التسويق والمبيعات في 2023" => {{"author":null,"category":null,"year_min":2023,"year_max":2023,"tags":["marketing","sales"],"keywords":null}}

Query: "{query}"
JSON:
"""


    if USE_MISTRAL:
        response = mistral_client.chat.complete(
            model="ministral-3b-latest",  
            messages=[{"role": "user", "content": prompt}],
            temperature=0,
        )
        raw_output = response.choices[0].message.content

    try:
        start = raw_output.find("{")
        end = raw_output.rfind("}") + 1
        json_str = raw_output[start:end]
        filters = json.loads(json_str)
        
        for key in [ "document_id","title","author","created_date","last_modified","category","tags","content"]:
            if key not in filters:
                filters[key] = None

    except Exception as e:
        print("!!! Fallback triggered: returning empty filters")
        print("Raw model response:", raw_output)
        filters = {
            "author": None,
            "category": None,
            "year_min": None,
            "year_max": None,
            "tags": None,
            "keywords": None
        }

    return filters


In [None]:
import arabic_reshaper
from bidi.algorithm import get_display
import re

def normalize_arabic(text):
    if not isinstance(text, str):
        return ""
    text = text.replace("أ", "ا").replace("إ", "ا").replace("آ", "ا").replace("ى", "ي").strip()
    reshaped_text = arabic_reshaper.reshape(text)
    bidi_text = get_display(reshaped_text)
    return bidi_text


In [None]:
from sentence_transformers import SentenceTransformer

model = SentenceTransformer('msmarco-MiniLM-L6-cos-v5')
model.save('models/msmarco-MiniLM-L6-cos-v5')

In [None]:
db = WeaviateDB(client=weaviate_client, model=model)

In [None]:
# def compute_embeddings(df, embed_client):
#     embeddings = []
#     for i, row in df.iterrows():
#         text = f"{row['title']} {row['tags']} {row['content']}"
#         response = embed_client.embeddings.create(
#             model="mistral-embed",
#             inputs=[text]
#         )
#         embeddings.append(response.data[0].embedding)
#     df["embedding"] = embeddings
#     return df


In [None]:
def compute_embeddings(df, model):
    embeddings = []
    for i, row in df.iterrows():
        text = f"{row['title']} {row['tags']} {row['content']}"
        emb = model.encode(text, convert_to_numpy=True)
        emb_list = [float(x) for x in emb]
        embeddings.append(emb_list)
    
    df["embedding"] = embeddings
    return df


In [None]:
df = db.compute_embeddings(df)

In [None]:
db.insert_documents(df)

In [None]:
res = db.query("financial report 2023", top_k=5)
print(res)

In [None]:
# def semantic_search(query, embed_client, df, index, top_k=5):
#     response = embed_client.embeddings.create(
#         model="mistral-embed",
#         inputs=[query]
#     )
#     query_vec = np.array(response.data[0].embedding).astype("float32").reshape(1, -1)

#     distances, indices = index.search(query_vec, top_k)

#     sims = 1 / (1 + distances[0])

#     results = df.iloc[indices[0]].copy()
#     results["similarity"] = sims
    
#     return results[["title", "author", "category", "tags", "similarity"]]


In [None]:
# def semantic_search_faiss(query, vdb, top_k=5):
#     query_vec = model.encode([query], convert_to_numpy=True).astype("float32")
#     results = vdb.search(query_vec, top_k=top_k)
#     return results[["title", "author", "category", "tags", "similarity"]]

In [None]:
def search_csv(query):
    filters = query_to_filters(query)
    print("Structured query:", filters)

    results = df.copy()

    if filters.get("author"):
        results = results[results["author"].str.contains(filters["author"], case=False, na=False)]
    if filters.get("category"):
        results = results[results["category"].str.contains(filters["category"], case=False, na=False)]
    if filters.get("year_min"):
        results = results[pd.to_datetime(results["created_date"]).dt.year >= filters["year_min"]]
    if filters.get("year_max"):
        results = results[pd.to_datetime(results["created_date"]).dt.year <= filters["year_max"]]
    if filters.get("tags"):
        tag_pattern = "|".join(filters["tags"])
        results = results[results["tags"].str.contains(tag_pattern, case=False, na=False)]
    if filters.get("keywords"):
        keyword_text = filters["keywords"].strip()
        if keyword_text:
            words = re.split(r"\s+", keyword_text)
            words_normalized = [normalize_arabic(w) for w in words]

            results["content_normalized"] = results["content"].apply(normalize_arabic)

            pattern = "|".join(re.escape(w) for w in words_normalized)
            results = results[results["content_normalized"].str.contains(pattern, case=False, regex=True)]


    return results


In [None]:
# def semantic_search_gradio(query):
#     return semantic_search_faiss(query, vdb=vdb, top_k=5)

In [None]:
query_text = df.iloc[0]['title']  # or first document content
result = db.vector_search(query_text, top_k=5, metric="cosine")
print(result)


In [None]:
def search_with_filters(user_query, metric="cosine", top_k=5):
    filters = query_to_filters(user_query)
    where_filter = {"operator": "And", "operands": []}
    
    if filters["author"]:
        where_filter["operands"].append({
            "path": ["author"],
            "operator": "Equal",
            "valueString": filters["author"]
        })
    if filters["category"]:
        where_filter["operands"].append({
            "path": ["category"],
            "operator": "Equal",
            "valueString": filters["category"]
        })
    if filters["year_min"]:
        where_filter["operands"].append({
            "path": ["created_date"],
            "operator": "GreaterThanEqual",
            "valueInt": filters["year_min"]
        })
    if filters["year_max"]:
        where_filter["operands"].append({
            "path": ["created_date"],
            "operator": "LessThanEqual",
            "valueInt": filters["year_max"]
        })
    if filters["tags"]:
        for tag in filters["tags"]:
            where_filter["operands"].append({
                "path": ["tags"],
                "operator": "Equal",
                "valueString": tag
            })
    
    if len(where_filter["operands"]) == 0:
        where_filter = None

    results = db.query(user_query, top_k=top_k, metric=metric, where_filter=where_filter)
    
    return results


In [None]:
weaviate_client.query.get("Document", ["document_id", "title"]).with_limit(5).do()


In [None]:
def gradio_search(query, metric, top_k):
    result = db.query(query, metric=metric, top_k=top_k)
    
    if not result:
        return []

    rows = []
    for d in result:
        rows.append([
            d.get("document_id"),
            d.get("title"),
            d.get("author"),
            d.get("created_date"),
            d.get("category"),
            d.get("tags"),
            d.get("content")
        ])
    return rows


In [None]:
with gr.Blocks() as iface:
    gr.Markdown("## Weaviate Document Search")
    
    with gr.Row():
        query_input = gr.Textbox(label="Enter your query (English or Arabic)", lines=2, placeholder="Search documents...")
        metric_input = gr.Dropdown(["cosine", "dot", "euclidean"], value="cosine", label="Vector similarity metric")
        top_k_input = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Number of results")
    
    search_button = gr.Button("Search")
    results_table = gr.DataFrame(headers=["ID", "Title", "Author", "Date", "Category", "Tags", "Content"])
    
    search_button.click(
        gradio_search,
        inputs=[query_input, metric_input, top_k_input],
        outputs=[results_table]
    )

iface.launch()
