In [1]:
from vertexai.generative_models import GenerativeModel
import vertexai

# Local DB clients
from neo4j import GraphDatabase
from qdrant_client import QdrantClient
from qdrant_client.http import models as q_models
from qdrant_client.http.models import Filter, FieldCondition, MatchValue
from concurrent.futures import ThreadPoolExecutor
from vertexai.preview.language_models import TextEmbeddingModel
from rank_bm25 import BM25Okapi

# ===================================================
# Global models (reuse for performance)
# ===================================================
EMBED_MODEL = TextEmbeddingModel.from_pretrained("gemini-embedding-001")
MODEL = GenerativeModel(model_name="gemini-2.0-flash-001")

import spacy, time, json
from datetime import datetime
from typing import List, Dict

# ================== Configuration ==================

PROJECT_ID = "bing-tan-sndbx-c"
LOCATION = "europe-west4"
vertexai.init(project=PROJECT_ID, location=LOCATION)

# ---------------- Local Qdrant -------------------
QDRANT_URL = "http://localhost:6333"
qdrant = QdrantClient(url=QDRANT_URL)

# Ensure collection exists
existing_collections = [c.name for c in qdrant.get_collections().collections]

# ---------------- NLP Model -----------------------
nlp = spacy.load("en_core_web_sm")  # or a bigger NER model if available

# ---------------- Local Neo4j ---------------------
NEO4J_URI = "bolt://localhost:7687"
NEO4J_USER = "neo4j"
NEO4J_PASSWORD = "testpassword"
neo4j_driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USER, NEO4J_PASSWORD))

# ================== Gemini Model ==================
# model = GenerativeModel(model_name="gemini-2.0-flash-001")

E0000 00:00:1761129815.567785 5989700 alts_credentials.cc:93] ALTS creds ignored. Not running on GCP and untrusted ALTS is not enabled.


In [None]:

# ===================================================
# Trace setup
# ===================================================
TRACE_LOGS = []
START_TIME = time.time()

def trace(event: str, start_time: float, data=None):
    """Store timing + trace entry."""
    duration_ms = (time.time() - start_time) * 1000
    TRACE_LOGS.append({
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "event": event,
        "duration_ms": round(duration_ms, 2),
        "data": str(data)[:400]
    })

# ===================================================
# Config constants
# ===================================================
MAX_CONTEXT_CHARS = 8000
TOP_K_CHUNKS = 10         # number of chunks returned by vector search
TOP_K_ARTICLES = 10       # how many article ids to consider from KG
MAX_ENTITIES = 10
MAX_RELATIONS = 20
COLLECTION_NAME = "news"

# ===================================================
# Helper functions
# ===================================================
def embed_text(text: str) -> list:
    t0 = time.time()
    embedding = EMBED_MODEL.get_embeddings([text])[0].values
    trace("Embedding", t0, {"text_len": len(text), "vector_dim": len(embedding)})
    return embedding

def extract_entities_and_relations(text: str):
    """Extract entities from text (used only for parsing query text here)."""
    t0 = time.time()
    doc = nlp(text)
    entities = [{"name": ent.text, "type": ent.label_} for ent in doc.ents]
    # lightweight relation extraction (not used heavily here)
    relations = []
    for sent in doc.sents:
        subj = [tok for tok in sent if tok.dep_ in ("nsubj", "nsubjpass")]
        obj = [tok for tok in sent if tok.dep_ in ("dobj", "pobj")]
        verb = [tok for tok in sent if tok.pos_ == "VERB"]
        if subj and obj and verb:
            relations.append({
                "subject": subj[0].text,
                "predicate": verb[0].lemma_.upper().replace(" ", "_"),
                "object": obj[0].text,
                "metadata": {"confidence": 0.8}
            })
    trace("Entity & relation extraction (query)", t0, {"entities": len(entities), "relations": len(relations)})
    return entities, relations

def bm25_entity_ranking(query: str, entities: List[Dict], max_entities=10):
    """Rank entities using BM25Okapi against the query."""
    if not entities:
        return []
    entity_tokens = [e["name"].lower().split() for e in entities]
    bm25 = BM25Okapi(entity_tokens)
    query_tokens = query.lower().split()
    scores = bm25.get_scores(query_tokens)
    scored_entities = list(zip(scores, entities))
    scored_entities.sort(reverse=True, key=lambda x: x[0])
    top_entities = [e for s, e in scored_entities if s > 0][:max_entities]
    return top_entities

# ===================================================
# Formatting helpers (chunks + relations)
# ===================================================
def format_chunk_payload(payload: dict, query: str = None, relations: List[dict] = None) -> str:
    """
    Create a compact human-readable block for one chunk.
    payload is expected to contain: content, title, author, date, publication, entities
    relations: list of relation dicts relevant to the parent article
    """
    entities = payload.get("entities", []) or []
    if query:
        entities = bm25_entity_ranking(query, entities, MAX_ENTITIES)
    else:
        entities = entities[:MAX_ENTITIES]
    entity_names = ", ".join([e["name"] for e in entities])

    # filter relations for this article (if provided)
    relations = relations or []
    relations_filtered = [
        r for r in relations
        if r.get("subject") in {e["name"] for e in entities} or r.get("object") in {e["name"] for e in entities}
    ][:MAX_RELATIONS]
    relations_str = "; ".join([f"{r['subject']} {r['predicate']} {r['object']}" for r in relations_filtered])

    title = payload.get("title", "")
    date = payload.get("date", "")
    author = payload.get("author", "")
    publication = payload.get("publication", "")

    # Keep chunk content reasonably short: truncate if necessary
    content = payload.get("content", "")
    if len(content) > 2000:
        content = content[:2000] + "\n[TRUNCATED]"

    block = (
        f"{title} {date}\n"
        f"{author} | {publication}\n"
        f"{content}\n"
    )
    if entity_names:
        block += f"Entities: {entity_names}\n"
    if relations_str:
        block += f"Relations: {relations_str}\n"
    return block

# ===================================================
# KG helpers
# ===================================================
def kg_query_articles_by_entities(entity_names: List[str]) -> List[str]:
    """
    Given entity names, return list of article_ids (distinct) that mention them.
    """
    if not entity_names:
        return []
    def _tx(tx):
        q = """
            MATCH (e:Entity)<-[:MENTIONS]-(a:Article)
            WHERE e.name IN $names
            RETURN DISTINCT a.id AS article_id
            LIMIT $limit
        """
        result = tx.run(q, names=entity_names, limit=TOP_K_ARTICLES)
        return [rec["article_id"] for rec in result]

    try:
        with neo4j_driver.session() as session:
            return session.execute_read(_tx)
    except Exception as e:
        # don't fail hard; return empty
        print(f"⚠️ Neo4j article query failed: {e}")
        return []

def kg_fetch_relations_for_articles(article_ids: List[str]) -> Dict[str, List[dict]]:
    """
    Fetch relation triples from Neo4j for the provided article ids.
    Returns dict: { article_id: [ {subject, predicate, object, metadata}, ... ] }
    """
    if not article_ids:
        return {}

    def _tx(tx):
        q = """
            MATCH (a:Article)-[:MENTIONS]->(s:Entity)-[r]->(o:Entity)
            WHERE a.id IN $article_ids
            RETURN a.id AS article_id, s.name AS subject, type(r) AS predicate, o.name AS object, r AS metadata
        """
        res = tx.run(q, article_ids=article_ids)
        out = {}
        for rec in res:
            aid = rec["article_id"]
            out.setdefault(aid, []).append({
                "subject": rec["subject"],
                "predicate": rec["predicate"],
                "object": rec["object"],
                "metadata": {}  # strip raw metadata if needed; keep empty or map properties
            })
        return out

    try:
        with neo4j_driver.session() as session:
            return session.execute_read(_tx)
    except Exception as e:
        print(f"⚠️ Neo4j relations query failed: {e}")
        return {}

# ===================================================
# Context builder (chunk-aware)
# ===================================================
def fetch_and_build_context(combined_article_ids: List[str], query: str = None, post_date="2018-01-01") -> str:
    """
    Build context from: 1) top vector chunk hits (from query_system) and 2) extra chunks from KG article ids.
    This function expects that vector search already returned chunk-level results; but because we call vector search inside query_system,
    we re-query Qdrant here to fetch chunk points for KG articles (if necessary).
    """
    if not combined_article_ids:
        return ""

    # --- Fetch chunk points for the combined_article_ids from Qdrant ---
    # We'll retrieve up to TOP_K_CHUNKS chunks per article (but safe cap applied)
    filter_conditions = [FieldCondition(key="article_id", match=MatchValue(value=a_id))
                         for a_id in combined_article_ids]

    # scroll to gather chunk points for KG-backed articles (limit overall to avoid huge fetch)
    try:
        article_chunk_records, _ = qdrant.scroll(
            collection_name=COLLECTION_NAME,
            scroll_filter=Filter(should=filter_conditions),
            with_payload=True,
            limit=TOP_K_CHUNKS * len(combined_article_ids)
        )
    except Exception as e:
        print(f"⚠️ Qdrant scroll for article chunks failed: {e}")
        article_chunk_records = []

    # Build dict of chunks by article_id
    chunks_by_article = {}
    for rec in article_chunk_records:
        try:
            aid = rec.payload.get("article_id")
            chunks_by_article.setdefault(aid, []).append(rec)
        except Exception:
            continue

    # --- Fetch relations for these articles from KG ---
    relations_by_article = kg_fetch_relations_for_articles(combined_article_ids)

    # --- Assemble context parts ---
    context_parts = []
    total_len = 0

    # We iterate articles in the given order so that priority (vector hits earlier in pipeline) can be preserved by caller.
    # For each article, prefer the chunks from chunks_by_article[aid] ordered as returned by Qdrant (assumed relevance)
    for aid in combined_article_ids:
        # skip articles with date < post_date if chunk has date metadata
        article_chunks = chunks_by_article.get(aid, [])
        if not article_chunks:
            continue

        # Each chunk record is a qdrant point; payload contains chunk fields
        # sort chunks by chunk_id if present to preserve order
        try:
            article_chunks.sort(key=lambda r: r.payload.get("chunk_id", 0))
        except Exception:
            pass

        # append top chunks for that article until context limit is reached
        for rec in article_chunks:
            payload = rec.payload
            # optional date filtering
            if "date" in payload and payload["date"] < post_date:
                continue
            rels = relations_by_article.get(aid, [])
            chunk_text = format_chunk_payload(payload, query=query, relations=rels)
            if total_len + len(chunk_text) > MAX_CONTEXT_CHARS:
                # stop entirely if adding this would exceed context budget
                break
            context_parts.append(chunk_text)
            total_len += len(chunk_text)
        if total_len >= MAX_CONTEXT_CHARS:
            break

    return "\n\n".join(context_parts)

# ===================================================
# Main query pipeline (vector search over chunks + KG)
# ===================================================
def query_system(query: str):
    total_start = time.time()

    # --- Extract query entities ---
    t0 = time.time()
    query_entities, _ = extract_entities_and_relations(query)
    trace("Extract query entities", t0, {"entities": query_entities})

    # --- Parallel KG traversal (article ids) + embedding ---
    def kg_query():
        names = [ent["name"] for ent in query_entities]
        return kg_query_articles_by_entities(names)

    t0 = time.time()
    with ThreadPoolExecutor() as executor:
        kg_future = executor.submit(kg_query)
        emb_future = executor.submit(embed_text, query)
        candidate_article_ids = kg_future.result()
        query_embedding = emb_future.result()
    trace("Neo4j KG traversal + Embedding", t0, {"articles_found": len(candidate_article_ids)})

    # --- Vector DB search (chunk-level) ---
    t0 = time.time()
    try:
        vector_results = qdrant.search(
            collection_name=COLLECTION_NAME,
            query_vector=query_embedding,
            limit=TOP_K_CHUNKS,
            with_payload=True
        )
    except Exception as e:
        print(f"⚠️ Qdrant vector search failed: {e}")
        vector_results = []
    trace("Qdrant vector search (chunks)", t0, {"hits": len(vector_results)})

    # Collect article_ids from chunk hits and preserve ordering (vector hits prioritized)
    vector_article_ids_ordered = []
    vector_chunk_records = []
    for pt in vector_results:
        aid = pt.payload.get("article_id")
        if aid:
            if aid not in vector_article_ids_ordered:
                vector_article_ids_ordered.append(aid)
        vector_chunk_records.append(pt)

    # --- Combine article ids (prioritize vector hits, then KG candidates) ---
    combined_article_ids = vector_article_ids_ordered + [
        cid for cid in candidate_article_ids if cid not in vector_article_ids_ordered
    ][:TOP_K_ARTICLES]
    trace("Combine KG + vector article ids", t0, {"combined": len(combined_article_ids)})

    # --- Build context: prefer vector-returned chunks first, then KG-based chunks ---
    t0 = time.time()
    # To honor preference for vector results, we will pre-insert vector chunk records into a temporary store
    # and then fetch additional chunks for KG-only articles inside fetch_and_build_context.
    # We'll create a small temporary context using the vector_chunk_records first.
    context_parts = []
    total_len = 0

    # Fetch relations for combined articles (will be used for both vector chunks & KG chunks)
    relations_by_article = kg_fetch_relations_for_articles(combined_article_ids)

    # Add vector chunk records to context first (they are most relevant)
    for rec in vector_chunk_records:
        aid = rec.payload.get("article_id")
        # optional date filter
        # if "date" in rec.payload and rec.payload["date"] < "2018-01-01":
        #     continue

        chunk_block = format_chunk_payload(rec.payload, query=query, relations=relations_by_article.get(aid, []))
        if total_len + len(chunk_block) > MAX_CONTEXT_CHARS:
            break
        context_parts.append(chunk_block)
        total_len += len(chunk_block)

    # If we still have room, fetch more chunks for KG-derived articles not already covered by vector hits
    remaining_article_ids = [aid for aid in combined_article_ids if aid not in {r.payload.get("article_id") for r in vector_chunk_records}]
    if remaining_article_ids and total_len < MAX_CONTEXT_CHARS:
        # fetch additional chunks via fetch_and_build_context but only for remaining_article_ids
        extra_context = fetch_and_build_context(remaining_article_ids, query=query)
        if extra_context:
            # ensure we don't exceed MAX_CONTEXT_CHARS
            if total_len + len(extra_context) > MAX_CONTEXT_CHARS:
                extra_context = extra_context[:(MAX_CONTEXT_CHARS - total_len)] + "\n[TRUNCATED]"
            context_parts.append(extra_context)
            total_len += len(extra_context)

    context_text = "\n\n".join(context_parts)
    trace("Build context text", t0, {"context_len": len(context_text)})

    # --- Generate final response ---
    t0 = time.time()
    # Structured prompt: separate "Relevant snippets" and "Entity/relations summary"
    entity_summary_lines = []
    for aid in combined_article_ids:
        rels = relations_by_article.get(aid, [])
        if not rels:
            continue
        # compact summary per article (max few relations)
        rels_short = rels[:5]
        s = f"Article {aid}: " + "; ".join([f"{r['subject']} -[{r['predicate']}]-> {r['object']}" for r in rels_short])
        entity_summary_lines.append(s)
    entity_summary = "\n".join(entity_summary_lines)

    full_query = (
        "You are given relevant document snippets and a small KG-derived relation summary. "
        "Answer the user's question using only the provided information. If the answer is not present, say you don't know.\n\n"
        f"User query: {query}\n\n"
        f"KG relation summary:\n{entity_summary}\n\n"
        f"Relevant snippets:\n{context_text}\n\n"
        "Base every statement on direct evidence from the snippets or KG facts."
        "If evidence is ambiguous or missing, reply:"
        "I don’t know based on the provided information."
    )

    try:
        response = MODEL.generate_content(full_query)
        llm_text = response.text
    except Exception as e:
        print(f"⚠️ LLM generation failed: {e}")
        llm_text = "Error generating response."

    trace("Gemini LLM generation", t0, {"response_len": len(llm_text)})

    trace("Total query runtime", total_start)
    return llm_text, full_query

# ===================================================
# Example usage
# ===================================================
if __name__ == "__main__":
    # Example query
    answer, full_query = query_system(
        "Why did Donald Trump accuse China?"
    )

    print("\n--- FINAL ANSWER ---\n")
    print(answer)
    print("\n--- FULL PROMPT SENT TO LLM ---\n")
    print(full_query)

    total_duration_ms = (time.time() - START_TIME) * 1000
    print(f"\n=== TRACE SUMMARY (Total: {round(total_duration_ms, 2)} ms) ===\n")
    for log in TRACE_LOGS:
        print(f"[{log['timestamp']}] {log['event']}: {log['duration_ms']} ms | {log['data']}")

    with open("trace_log.json", "w") as f:
        json.dump(TRACE_LOGS, f, indent=2)


  vector_results = qdrant.search(



--- FINAL ANSWER ---

Donald Trump threatened to impose tariffs on billions of dollars worth of Chinese imports to punish Beijing for intellectual property abuses (Exclusive: China shuns U.S. request for talks on airline website dispute over Taiwan / Matthew Miller, Michael Martina, David Shepardson / Reuters).


--- FULL PROMPT SENT TO LLM ---

You are given relevant document snippets and a small KG-derived relation summary. Answer the user's question using only the provided information. If the answer is not present, say you don't know.

User query: Why did Donald Trump accuse China?

Relevant snippets:
China is dismissing unfavorable media reports as fake because that's what Trump does 2017-03-02 00:00:00
Tim Hume | Vice News
China is dismissing unfavorable media reports as fake news because that’s what Trump does  China is dismissing unfavorable media reports as fake news because that’s what Trump does  In his short political career, Donald Trump has made a habit of dismissing unfa

In [None]:
"""
# Qdrant on Cloud Run
gcloud run deploy qdrant-service \
  --image=qdrant/qdrant:latest \
  --region=europe-west4 \
  --memory=2Gi \
  --allow-unauthenticated

# Neo4j on Compute Engine (VM)
gcloud compute instances create-with-container neo4j-vm \
  --container-image=neo4j:latest \
  --machine-type=e2-medium \
  --boot-disk-size=20GB \
  --tags=http-server,https-server

"""

In [11]:
print(qdrant.get_collections())


collections=[CollectionDescription(name='news')]


In [3]:
# Example: Find most similar articles to a sample text
query_text = "China & tarifs"
query_emb = embed_text(query_text)

results = qdrant.search(
    collection_name="news",
    query_vector=query_emb,
    limit=5,
    with_payload=True
)

print(results[0].payload.keys())
print("Top 5 similar articles:")
for res in results:
    print(f"Title: {res.payload.get('title')}")
    print(f"Article ID: {res.payload.get('article_id')}")
    print("---")


dict_keys(['article_id', 'chunk_id', 'title', 'content', 'entities', 'author', 'date', 'section', 'publication'])
Top 5 similar articles:
Title: Exclusive: China shuns U.S. request for talks on airline website dispute over Taiwan
Article ID: 43
---
Title: Factbox: Investments by automakers in the U.S. and China since Trump came to power
Article ID: 24
---
Title: Exclusive: China shuns U.S. request for talks on airline website dispute over Taiwan
Article ID: 43
---
Title: Exclusive: China shuns U.S. request for talks on airline website dispute over Taiwan
Article ID: 43
---
Title: Exclusive: China shuns U.S. request for talks on airline website dispute over Taiwan
Article ID: 43
---


  results = qdrant.search(


In [4]:
# Fetch all articles
with neo4j_driver.session() as session:
    articles = session.run("MATCH (a:Article) RETURN a.id AS id, a.title AS title")
    print("Articles in KG:")
    for a in articles:
        print(a["id"], "-", a["title"])

# Fetch all entities
with neo4j_driver.session() as session:
    entities = session.run("MATCH (e:Entity) RETURN e.name AS name, e.type AS type")
    print("\nEntities in KG:")
    for e in entities:
        print(e["name"], "-", e["type"])

# Fetch all relationships
with neo4j_driver.session() as session:
    rels = session.run("""
        MATCH (s:Entity)-[r]->(o:Entity)
        RETURN s.name AS subject, type(r) AS predicate, o.name AS object, r
    """)
    print("\nRelationships in KG:")
    for r in rels:
        print(f"{r['subject']} -[{r['predicate']}]-> {r['object']}")


Articles in KG:
0 - We should take concerns about the health of liberal democracy seriously
1 - Colts GM Ryan Grigson says Andrew Luck's contract makes it difficult to build the team
2 - Trump denies report he ordered Mueller fired
3 - France's Sarkozy reveals his 'Passions' but insists no come-back on cards
4 - Paris Hilton: Woman In Black For Uncle Monty's Funeral
5 - ECB's Coeure: If we decide to cut rates, we'd have to consider tiering
6 - Venezuela detains six military, police officials: family members, activists
7 - You Can Trick Your Brain Into Being More Focused
8 - How to watch the Google I/O keynote live
9 - China is dismissing unfavorable media reports as fake because that's what Trump does
10 - “Elizabeth Warren called me!” is turning into a Twitter meme
11 - Hudson's Bay's chairman's buyout bid pits retail versus real estate
12 - Joakim Noah's Victoria Secret Model GF Lais Ribeiro Rocks Thong Bikini In Malibu
13 - Jermaine Jackson Rips Quincy Jones For Scrubbing Michael's 

In [106]:
with neo4j_driver.session() as session:
    result = session.run("""
        MATCH (a:Article)-[:MENTIONS]->(e:Entity)
        RETURN a.title AS title, COUNT(e) AS entity_count
    """)
    print("Entity mentions per article:")
    for r in result:
        print(f"{r['title']}: {r['entity_count']} entities")


Entity mentions per article:
We should take concerns about the health of liberal democracy seriously: 22 entities
Colts GM Ryan Grigson says Andrew Luck's contract makes it difficult to build the team: 25 entities
Trump denies report he ordered Mueller fired: 15 entities
France's Sarkozy reveals his 'Passions' but insists no come-back on cards: 31 entities
Paris Hilton: Woman In Black For Uncle Monty's Funeral: 7 entities
ECB's Coeure: If we decide to cut rates, we'd have to consider tiering: 13 entities
Venezuela detains six military, police officials: family members, activists: 45 entities
You Can Trick Your Brain Into Being More Focused: 45 entities
How to watch the Google I/O keynote live: 18 entities
China is dismissing unfavorable media reports as fake because that's what Trump does: 39 entities
“Elizabeth Warren called me!” is turning into a Twitter meme: 63 entities
Hudson's Bay's chairman's buyout bid pits retail versus real estate: 61 entities
Joakim Noah's Victoria Secret Mo

In [None]:
# RESET VECTOR DB COLLECTION

# collection_name = "news"
# embedding_dim = 3072  # correct for Vertex AI gemini embeddings

# # Delete existing collection if it exists
# existing_collections = [c.name for c in qdrant.get_collections().collections]
# if collection_name in existing_collections:
#     qdrant.delete_collection(collection_name=collection_name)

# # Recreate collection with correct embedding dimension
# qdrant.recreate_collection(
#     collection_name=collection_name,
#     vectors_config=q_models.VectorParams(
#         size=embedding_dim,
#         distance=q_models.Distance.COSINE
#     )
# )


  qdrant.recreate_collection(


True