In [56]:
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine, text
import pandas as pd

load_dotenv()

DB_USER = os.getenv("DB_USER_NEON")
DB_PASSWORD = os.getenv("DB_PW_NEON")
DB_HOST = os.getenv("DB_NEON_HOST")
DB_NAME =os.getenv("DB_NEON_NAME")


DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}/{DB_NAME}?sslmode=require&channel_binding=require"

engine = create_engine(
    DATABASE_URL,
    pool_size=5,
    max_overflow=10,
    pool_pre_ping=True,
    pool_recycle=300,
)

with engine.connect() as conn:
    result = conn.execute(text("SELECT * FROM contracts"))
    rows = result.fetchall()

    for row in rows:
        print(row)

(1, 'Acme Corp', 'NDA', 24, 91, 'Passed', datetime.date(2024, 1, 15), 'Germany', 'Data Privacy Policy', 'EU')
(2, 'Zenith Solutions', 'Service Agreement', 36, 74, 'Pending', datetime.date(2023, 11, 20), 'India', 'Vendor Risk Policy', 'APAC')
(3, 'GlobalTech Ltd', 'Partnership', 48, 67, 'Failed', datetime.date(2023, 8, 10), 'United States', 'Financial Compliance Policy', 'North America')
(4, 'Bright Systems', 'Vendor Agreement', 18, 87, 'Passed', datetime.date(2024, 2, 5), 'Singapore', 'Vendor Risk Policy', 'APAC')
(5, 'Nova Innovations', 'Service Agreement', 12, 94, 'Passed', datetime.date(2024, 3, 12), 'United Kingdom', 'HR Compliance Policy', 'UK')
(6, 'Helix Enterprises', 'NDA', 36, 73, 'Pending', datetime.date(2023, 6, 18), 'France', 'Data Privacy Policy', 'EU')
(7, 'Pioneer Holdings', 'Partnership', 24, 86, 'Passed', datetime.date(2024, 4, 2), 'Canada', 'Financial Compliance Policy', 'North America')
(8, 'Apex Consulting', 'Service Agreement', 18, 76, 'Pending', datetime.date(2023

In [2]:
from pymilvus import connections, utility

# Connect to Milvus
connections.connect(
    alias="default",
    uri=os.getenv("MILVUS_URI"),
    token=os.getenv("MILVUS_API_KEY")
)

print("‚úÖ Connected to Milvus")

# Optional: Check existing collections
collections = utility.list_collections()
print("Existing collections:", collections)


‚úÖ Connected to Milvus
Existing collections: []


In [3]:
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),

    FieldSchema(name="contract_id", dtype=DataType.INT64),

    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),

    FieldSchema(name="contract_type", dtype=DataType.VARCHAR, max_length=100),
    FieldSchema(name="text_chunk", dtype=DataType.VARCHAR, max_length=5000),
]



In [4]:
schema = CollectionSchema(fields, description="Policy Clause Embeddings")

collection = Collection(
    name="legal_policy_vectors",
    schema=schema
)

index_params = {
    "metric_type": "COSINE",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}

collection.create_index(
    field_name="embedding",
    index_params=index_params
)

collection.load()

In [16]:
import os
from huggingface_hub import InferenceClient
from dotenv import load_dotenv

load_dotenv()
client = InferenceClient(
    provider="hf-inference",
    api_key=os.getenv("HF_TOKEN"),
)

result = client.sentence_similarity(
    "That is a happy person",  # source sentence
    [
        "That is a happy dog",
        "That is a very happy person",
        "Today is a sunny day"
    ],
    model="BAAI/bge-m3",
)

print(result)

[0.8589122891426086, 0.9666369557380676, 0.7509795427322388]


In [69]:
import os
from huggingface_hub import InferenceClient
import numpy as np

# Make sure HF_TOKEN is set
# os.environ["HF_TOKEN"] = "your_token_here"

em_client = InferenceClient(
    api_key=os.getenv("HF_TOKEN")
)

# Generate embedding
embedding = em_client.feature_extraction(
    "test sentence",
    model="BAAI/bge-m3"
)

# Convert to numpy array
embedding = np.array(embedding)

print("Embedding dimension:", len(embedding))

# Optional: Normalize (since you used normalize_embeddings=True earlier)
embedding = embedding / np.linalg.norm(embedding)

Embedding dimension: 1024


In [19]:
def chunk_text(text, chunk_size=800, overlap=150):
    """
    Chunk text with overlap for better retrieval context.
    """
    chunks = []
    start = 0
    text_length = len(text)

    while start < text_length:
        end = start + chunk_size
        chunk = text[start:end]
        chunks.append(chunk.strip())
        start += chunk_size - overlap

    return chunks

In [20]:
import os

from sentence_transformers import SentenceTransformer
import numpy as np

# Load model once (not inside function)
model = SentenceTransformer("BAAI/bge-m3")

print("‚úÖ Model loaded")
def insert_document(file_path, contract_id, contract_type):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()

    # Step 1: Chunk the text
    chunks = chunk_text(text)

    # Step 2: Generate embeddings (normalized for cosine similarity)
    embeddings = model.encode(
        chunks,
        normalize_embeddings=True
    )

    # Step 3: Prepare data in EXACT schema order
    data = [
        [contract_id] * len(chunks),        # contract_id
        embeddings.tolist(),                # embedding
        [contract_type] * len(chunks),      # contract_type
        chunks                              # text_chunk
    ]

    # Step 4: Insert into Milvus
    collection.insert(data)
    collection.flush()

    print(f"‚úÖ Inserted {len(chunks)} chunks for Contract ID {contract_id}")

Loading weights: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 391/391 [00:00<00:00, 439.45it/s, Materializing param=pooler.dense.weight]                               


‚úÖ Model loaded


In [7]:
import torch

def check_device():
    if torch.cuda.is_available():
        print("‚úÖ GPU Available")
        print("GPU Name:", torch.cuda.get_device_name(0))
        return "cuda"
    else:
        print("‚ö† GPU Not Available, using CPU")
        return "cpu"

device = check_device()

‚ö† GPU Not Available, using CPU


In [8]:
collection_name = "legal_policy_vectors"

if utility.has_collection(collection_name):
    utility.drop_collection(collection_name)
    print("‚úÖ Old collection dropped successfully.")
else:
    print("No existing collection found.")

‚úÖ Old collection dropped successfully.


In [33]:
for i in [23,30]:
    insert_document("C:\\Users\\mbalasubramanian\\Documents\\rag_bot_assignment\\rag_data\\partnership\\partnership_v3.txt",i,"Partnership")

‚úÖ Inserted 11 chunks for Contract ID 23
‚úÖ Inserted 11 chunks for Contract ID 30


In [34]:
SYSTEM_PROMPT_EXTRACT = """
You are a legal data extraction assistant.

Your task is to extract structured filters from a user query
and return ONLY valid JSON.

Available database columns:
- vendor_name (string)
- contract_type (must be one of: NDA, Service Agreement, Vendor Agreement, Partnership, General)
- compliance_score (integer)
- audit_status (Passed, Failed, Pending)
- jurisdiction (string)
- region (string)
- duration_months (integer)
- contract_date (date)

Rules:

1. Ignore capitalization differences.
2. Only return JSON.
3. Include only fields explicitly mentioned in the query.
4. Do NOT add extra fields.
5. If nothing relevant is found, return {}.
6. Always strictly use the column names provided above.

Numeric Filtering Rules:

‚Ä¢ If query says:
  - "above X", "greater than X", "more than X"
    ‚Üí use: "compliance_score_min": X

  - "below X", "less than X"
    ‚Üí use: "compliance_score_max": X

  - "between X and Y"
    ‚Üí use: "compliance_score_between": [X, Y]

‚Ä¢ For duration:
  - "longer than X months"
    ‚Üí use: "duration_min": X

  - "shorter than X months"
    ‚Üí use: "duration_max": X

‚Ä¢ For relative dates:
  - "last X months"
    ‚Üí use: "last_n_months": X

‚Ä¢ Never output natural language.
‚Ä¢ Never explain anything.
‚Ä¢ Output must be valid JSON only.

Examples:

Query: Show failed vendor agreements in APAC with compliance score above 70
Output:
{
  "contract_type": "Vendor Agreement",
  "audit_status": "Failed",
  "region": "APAC",
  "compliance_score_min": 70
}

Query: Contracts between 60 and 80 score from last 3 months
Output:
{
  "compliance_score_between": [60, 80],
  "last_n_months": 3
}
"""

In [65]:
import json
import os
from groq import Groq

client = Groq(api_key=os.getenv("GROQ_API_KEY"))

def extract_filters_from_query(user_query):

    

    completion = client.chat.completions.create(
        model="openai/gpt-oss-20b",
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT_EXTRACT},
            {"role": "user", "content": user_query}
        ],
        temperature=0,
        max_completion_tokens=300,
        top_p=1,
        stream=False
    )

    content = completion.choices[0].message.content.strip()

    if content.startswith("```"):
        content = content.replace("```json", "").replace("```", "").strip()

    try:
        return json.loads(content)
    except json.JSONDecodeError:
        print("‚ö† LLM did not return valid JSON:")
        print(content)
        return {}

In [37]:
from sqlalchemy import text
from datetime import datetime, timedelta

def get_contracts_dynamic(filters):
    base_query = "SELECT contract_id FROM contracts WHERE 1=1"
    params = {}

    # üîπ Text Filters
    if filters.get("vendor_name"):
        base_query += " AND vendor_name ILIKE :vendor_name"
        params["vendor_name"] = f"%{filters['vendor_name']}%"

    if filters.get("contract_type"):
        base_query += " AND contract_type ILIKE :contract_type"
        params["contract_type"] = f"%{filters['contract_type']}%"

    if filters.get("audit_status"):
        base_query += " AND audit_status ILIKE :audit_status"
        params["audit_status"] = f"%{filters['audit_status']}%"

    if filters.get("region"):
        base_query += " AND region ILIKE :region"
        params["region"] = f"%{filters['region']}%"

    if filters.get("jurisdiction"):
        base_query += " AND jurisdiction ILIKE :jurisdiction"
        params["jurisdiction"] = f"%{filters['jurisdiction']}%"

    if filters.get("policy_name"):
        base_query += " AND policy_name ILIKE :policy_name"
        params["policy_name"] = f"%{filters['policy_name']}%"

    # üîπ Compliance Score Filters
    if filters.get("compliance_score_min"):
        base_query += " AND compliance_score >= :compliance_score_min"
        params["compliance_score_min"] = filters["compliance_score_min"]

    if filters.get("compliance_score_max"):
        base_query += " AND compliance_score <= :compliance_score_max"
        params["compliance_score_max"] = filters["compliance_score_max"]

    # üîπ Range filter (between)
    if filters.get("compliance_score_between"):
        base_query += " AND compliance_score BETWEEN :score_min AND :score_max"
        params["score_min"] = filters["compliance_score_between"][0]
        params["score_max"] = filters["compliance_score_between"][1]

    # üîπ Duration Filters
    if filters.get("duration_min"):
        base_query += " AND duration_months >= :duration_min"
        params["duration_min"] = filters["duration_min"]

    if filters.get("duration_max"):
        base_query += " AND duration_months <= :duration_max"
        params["duration_max"] = filters["duration_max"]

    # üîπ Relative Date Filtering (e.g., last 3 months)
    if filters.get("last_n_months"):
        months = filters["last_n_months"]
        date_threshold = datetime.today() - timedelta(days=30 * months)

        base_query += " AND contract_date >= :date_threshold"
        params["date_threshold"] = date_threshold

    with engine.connect() as conn:
        result = conn.execute(text(base_query), params)
        rows = result.fetchall()

    return rows

In [66]:
query = "Show failed contracts of Pacific Trade Co in APAC"

filters = extract_filters_from_query(query)

print(filters)

{'vendor_name': 'Pacific Trade Co', 'audit_status': 'Failed', 'region': 'APAC'}


In [41]:
def process_user_query(user_query):
    # Step 1: Extract filters using LLM
    filters = extract_filters_from_query(user_query)
    print("üîé Extracted Filters:", filters)

    # Step 2: Fetch contract IDs from Postgres
    contract_ids = get_contracts_dynamic(filters)
    print("üìÑ Matching Contracts:", contract_ids)

    return contract_ids

query = "Show contracts in India"
print(process_user_query(query))

üîé Extracted Filters: {'jurisdiction': 'India'}
üìÑ Matching Contracts: [(2,), (10,), (12,), (22,), (30,)]
[(2,), (10,), (12,), (22,), (30,)]


In [42]:
from sqlalchemy import text

def get_contract_rows_by_ids(contract_id_rows):

    if not contract_id_rows:
        return []

    # If already list of ints ‚Üí use directly
    if isinstance(contract_id_rows[0], int):
        contract_ids = contract_id_rows
    else:
        # If list of tuples ‚Üí extract
        contract_ids = [row[0] for row in contract_id_rows]

    query = text("""
        SELECT *
        FROM contracts
        WHERE contract_id = ANY(:ids)
    """)

    with engine.connect() as conn:
        result = conn.execute(query, {"ids": contract_ids})
        rows = result.fetchall()

    return rows

In [43]:
query = "Show contracts in India"

rows = get_contract_rows_by_ids(process_user_query(query))
for row in rows:
    print(row)

üîé Extracted Filters: {'jurisdiction': 'India'}
üìÑ Matching Contracts: [(2,), (10,), (12,), (22,), (30,)]
(2, 'Zenith Solutions', 'Service Agreement', 36, 74, 'Pending', datetime.date(2023, 11, 20), 'India', 'Vendor Risk Policy', 'APAC')
(10, 'Orion Supplies', 'Vendor Agreement', 24, 81, 'Passed', datetime.date(2024, 1, 30), 'India', 'Vendor Risk Policy', 'APAC')
(12, 'IndoLogix Pvt Ltd', 'Service Agreement', 48, 71, 'Pending', datetime.date(2023, 10, 3), 'India', 'Vendor Risk Policy', 'APAC')
(22, 'AsiaTrade Logistics', 'Service Agreement', 12, 90, 'Passed', datetime.date(2024, 5, 5), 'India', 'Vendor Risk Policy', 'APAC')
(30, 'Pacific Alliance Group', 'Partnership', 36, 69, 'Failed', datetime.date(2023, 6, 22), 'India', 'Vendor Risk Policy', 'APAC')


In [None]:
import numpy as np

def milvus_search(query_text, matching_contracts, top_k=3):

    # 1Ô∏è‚É£ Extract contract IDs safely
    contract_ids = [row[0] for row in matching_contracts]

    if not contract_ids:
        print("No matching contracts found")
        return []

    # 2Ô∏è‚É£ Create embedding
    embedding = em_client.feature_extraction(
        query_text,
        model="BAAI/bge-m3"
    )

    embedding = np.array(embedding)

    # 3Ô∏è‚É£ Normalize safely (important for COSINE)
    norm = np.linalg.norm(embedding)
    if norm == 0:
        print("Zero vector embedding")
        return []

    query_embedding = (embedding / norm).tolist()

    # Milvus expects list of vectors
    query_embedding = [query_embedding]

    # 4Ô∏è‚É£ Search params
    search_params = {
        "metric_type": "COSINE",
        "params": {"nprobe": 10}
    }

    # 5Ô∏è‚É£ Handle filter expression correctly
    # If contract_id is INT:
    filter_expr = f"contract_id in {contract_ids}"

    # If contract_id is STRING use:
    # formatted_ids = [f'"{cid}"' for cid in contract_ids]
    # filter_expr = f"contract_id in [{','.join(formatted_ids)}]"

    # 6Ô∏è‚É£ Perform search
    results = collection.search(
        data=query_embedding,
        anns_field="embedding",
        param=search_params,
        limit=top_k,
        expr=filter_expr,
        output_fields=["contract_id", "contract_type", "text_chunk"]
    )

    # 7Ô∏è‚É£ Extract results
    retrieved_chunks = []

    for hits in results:
        for hit in hits:
            retrieved_chunks.append({
                "contract_id": hit.entity.get("contract_id"),
                "contract_type": hit.entity.get("contract_type"),
                "text_chunk": hit.entity.get("text_chunk"),
                "score": hit.score
            })

    return retrieved_chunks

In [45]:
from pymilvus import Collection

collection = Collection("legal_policy_vectors")
collection.load()

In [50]:

query = "Show contracts in EU on Intellectual property for Helix Enterprises"

filters = extract_filters_from_query(query)

matching_contracts = get_contracts_dynamic(filters)

vector_results = milvus_search(query, matching_contracts, top_k=5)

for r in vector_results:
    print("Score:", r["score"])
    print("Contract ID:", r["contract_id"])
    print("Text:", r["text_chunk"])
    print("-" * 50)

rows = get_contract_rows_by_ids(process_user_query(query))
for row in rows:
    print(row)

Score: 0.4746566712856293
Contract ID: 6
Text: of:

Unauthorized access to test systems

Exposure of client data

Security vulnerabilities discovered during testing

Cybersecurity incidents affecting confidential environments

The Receiving Party shall notify the Disclosing Party within seventy-two (72) hours of discovery and cooperate in remediation, investigation, and regulatory reporting obligations.

5. Intellectual Property Protection

All software, documentation, test data, and related materials remain the sole property of the Disclosing Party unless otherwise agreed in writing.

Test artifacts, reports, automation frameworks, and documentation created under this Agreement shall be used exclusively for the project scope and shall not be reused for competing commercial advantage without written authorization.

6. Remedies and Ind
--------------------------------------------------
Score: 0.4102223813533783
Contract ID: 6
Text: used exclusively for the project scope and shall not be

In [51]:
def build_context(contract_rows, vector_results):
    context = ""

    # Structured metadata section
    context += "STRUCTURED CONTRACT DATA:\n"
    for row in contract_rows:
        context += f"""
Contract ID: {row[0]}
Vendor: {row[1]}
Contract Type: {row[2]}
Duration (months): {row[3]}
Compliance Score: {row[4]}
Audit Status: {row[5]}
Contract Date: {row[6]}
Jurisdiction: {row[7]}
Policy Name: {row[8]}
Region: {row[9]}
-------------------------------------
"""

    # Vector retrieved clauses
    context += "\nRELEVANT CONTRACT CLAUSES:\n"

    for r in vector_results:
        context += f"""
[Contract ID: {r['contract_id']} | Score: {round(r['score'],3)}]
{r['text_chunk']}
-------------------------------------
"""

    return context

In [64]:
from groq import Groq
import os
from dotenv import load_dotenv

load_dotenv()

# Initialize client once (important for performance)
client = Groq(api_key=os.getenv("GROQ_API_KEY"))

def generate_answer(user_query, context, stream=False):
    
    system_prompt = """
You are Comp-Check Bot, a professional legal compliance assistant.

IMPORTANT RULES:
- Answer strictly based on the provided contract data and clauses.
- Do NOT hallucinate.
- If information is missing, clearly state: "Not found in the available contract records."
- Provide structured legal-style response.

Structure your answer as:
1. Executive Summary
2. Relevant Clauses
3. Risk Assessment
4. Missing Information (if any)
5. Final Compliance Status
"""

    messages = [
        {"role": "system", "content": system_prompt},
        {
            "role": "user",
            "content": f"""
User Query:
{user_query}

Context:
{context}
"""
        }
    ]

    completion = client.chat.completions.create(
        model="openai/gpt-oss-20b",
        messages=messages,
        temperature=0.2,  # LOWER for compliance
        max_completion_tokens=2048,
        top_p=1,
        stream=stream
    )

    # If streaming enabled
    if stream:
        full_response = ""
        for chunk in completion:
            delta = chunk.choices[0].delta.content
            if delta:
                full_response += delta
                print(delta, end="", flush=True)
        return full_response

    # Non-streaming version
    else:
        return completion.choices[0].message.content

In [67]:
query = "What are the service agreement terms with Asiatrade Logistics?"

#1. Show summary about contract with transcontinental corp
#2.Which contracts in the EU region are currently in ‚ÄúPending‚Äù status and what are their risk scores?
#3.Show me the Data Privacy related contracts in France and summarize their compliance requirements.
#4. List all high-risk vendor agreements that failed compliance audit and explain possible reasons.
#5.
# Step 1: Extract filters
filters = extract_filters_from_query(query)

# Step 2: Structured filtering
matching_contracts = get_contracts_dynamic(filters)

# Step 3: Vector search
vector_results = milvus_search(query, matching_contracts, top_k=5)

# Step 4: Get full rows
contract_rows = get_contract_rows_by_ids(
    [row[0] for row in matching_contracts]
)

# Step 5: Build context
context = build_context(contract_rows, vector_results)

# Step 6: Generate final answer
final_response = generate_answer(query, context)

print("\nüí° FINAL RESPONSE:\n")
print(final_response)

AttributeError: 'Groq' object has no attribute 'feature_extraction'