In [16]:
from sqlalchemy import create_engine, text
import os
from dotenv import load_dotenv

load_dotenv()

DB_USER = os.getenv("DB_USER")
DB_PASSWORD = os.getenv("DB_PASSWORD") # encodes special characters
DB_HOST = os.getenv("DB_HOST")
DB_PORT = os.getenv("DB_PORT")
DB_NAME = os.getenv("DB_NAME")

DATABASE_URL = f"postgresql://{DB_USER}:{DB_PASSWORD}@{DB_HOST}:{DB_PORT}/{DB_NAME}"

engine = create_engine(DATABASE_URL)

with engine.connect() as conn:
    result = conn.execute(text("SELECT * FROM contracts"))
    rows = result.fetchall()

    for row in rows:
        print(row)

(1, 'Acme Corp', 'NDA', 24, 91, 'Passed', datetime.date(2024, 1, 15), 'Germany', 'Data Privacy Policy', 'EU')
(2, 'Zenith Solutions', 'Service Agreement', 36, 74, 'Pending', datetime.date(2023, 11, 20), 'India', 'Vendor Risk Policy', 'APAC')
(3, 'GlobalTech Ltd', 'Partnership', 48, 67, 'Failed', datetime.date(2023, 8, 10), 'United States', 'Financial Compliance Policy', 'North America')
(4, 'Bright Systems', 'Vendor Agreement', 18, 87, 'Passed', datetime.date(2024, 2, 5), 'Singapore', 'Vendor Risk Policy', 'APAC')
(5, 'Nova Innovations', 'Service Agreement', 12, 94, 'Passed', datetime.date(2024, 3, 12), 'United Kingdom', 'HR Compliance Policy', 'UK')
(6, 'Helix Enterprises', 'NDA', 36, 73, 'Pending', datetime.date(2023, 6, 18), 'France', 'Data Privacy Policy', 'EU')
(7, 'Pioneer Holdings', 'Partnership', 24, 86, 'Passed', datetime.date(2024, 4, 2), 'Canada', 'Financial Compliance Policy', 'North America')
(8, 'Apex Consulting', 'Service Agreement', 18, 76, 'Pending', datetime.date(2023

In [None]:
from pymilvus import connections, utility

# Connect to Milvus
connections.connect(
    alias="default",
    host="localhost",
    port="19530"
)

print("âœ… Connected to Milvus")

# Optional: Check existing collections
collections = utility.list_collections()
print("Existing collections:", collections)


âœ… Connected to Milvus
Existing collections: []


In [13]:
from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

fields = [
    FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),

    FieldSchema(name="contract_id", dtype=DataType.INT64),

    FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1024),

    FieldSchema(name="contract_type", dtype=DataType.VARCHAR, max_length=100),
    FieldSchema(name="text_chunk", dtype=DataType.VARCHAR, max_length=5000),
]



In [14]:
schema = CollectionSchema(fields, description="Policy Clause Embeddings")

collection = Collection(
    name="legal_policy_vectors",
    schema=schema
)

index_params = {
    "metric_type": "COSINE",
    "index_type": "IVF_FLAT",
    "params": {"nlist": 128}
}

collection.create_index(
    field_name="embedding",
    index_params=index_params
)

collection.load()

In [4]:
from sentence_transformers import SentenceTransformer

# Load BGE-M3
model = SentenceTransformer("BAAI/bge-m3")

# Confirm dimension
test_embedding = model.encode("test sentence", normalize_embeddings=True)
print("Embedding dimension:", len(test_embedding))

  from .autonotebook import tqdm as notebook_tqdm
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
Loading weights: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 391/391 [00:00<00:00, 732.37it/s, Materializing param=pooler.dense.weight]                               


Embedding dimension: 1024


In [11]:
def chunk_text(text, chunk_size=800, overlap=150):
    """
    Chunk text with overlap for better retrieval context.
    """
    chunks = []
    start = 0
    text_length = len(text)

    while start < text_length:
        end = start + chunk_size
        chunk = text[start:end]
        chunks.append(chunk.strip())
        start += chunk_size - overlap

    return chunks

In [15]:
import os

def insert_document(file_path, contract_id, contract_type):
    with open(file_path, "r", encoding="utf-8") as f:
        text = f.read()

    # Step 1: Chunk the text
    chunks = chunk_text(text)

    # Step 2: Generate embeddings (normalized for cosine similarity)
    embeddings = model.encode(
        chunks,
        normalize_embeddings=True
    )

    # Step 3: Prepare data in EXACT schema order
    data = [
        [contract_id] * len(chunks),        # contract_id
        embeddings.tolist(),                # embedding
        [contract_type] * len(chunks),      # contract_type
        chunks                              # text_chunk
    ]

    # Step 4: Insert into Milvus
    collection.insert(data)
    collection.flush()

    print(f"âœ… Inserted {len(chunks)} chunks for Contract ID {contract_id}")

In [7]:
import torch

def check_device():
    if torch.cuda.is_available():
        print("âœ… GPU Available")
        print("GPU Name:", torch.cuda.get_device_name(0))
        return "cuda"
    else:
        print("âš  GPU Not Available, using CPU")
        return "cpu"

device = check_device()

âš  GPU Not Available, using CPU


In [8]:
collection_name = "legal_policy_vectors"

if utility.has_collection(collection_name):
    utility.drop_collection(collection_name)
    print("âœ… Old collection dropped successfully.")
else:
    print("No existing collection found.")

âœ… Old collection dropped successfully.


In [28]:
for i in [23,30]:
    insert_document("D:\\rag_assignment\\rag_data\\rag_data\\partnership\\partnership_v3.txt",i,"Partnership")

âœ… Inserted 11 chunks for Contract ID 23
âœ… Inserted 11 chunks for Contract ID 30


In [None]:
SYSTEM_PROMPT_EXTRACT = """
You are a legal data extraction assistant.

Your task is to extract structured filters from a user query
and return ONLY valid JSON.

Available database columns:
- vendor_name (string)
- contract_type (must be one of: NDA, Service Agreement, Vendor Agreement, Partnership, General)
- compliance_score (integer)
- audit_status (Passed, Failed, Pending)
- jurisdiction (string)
- region (string)
- policy_name (string)
- duration_months (integer)
- contract_date (date)

Rules:

1. Ignore capitalization differences.
2. Only return JSON.
3. Include only fields explicitly mentioned in the query.
4. Do NOT add extra fields.
5. If nothing relevant is found, return {}.
6. Always strictly use the column names provided above.

Numeric Filtering Rules:

â€¢ If query says:
  - "above X", "greater than X", "more than X"
    â†’ use: "compliance_score_min": X

  - "below X", "less than X"
    â†’ use: "compliance_score_max": X

  - "between X and Y"
    â†’ use: "compliance_score_between": [X, Y]

â€¢ For duration:
  - "longer than X months"
    â†’ use: "duration_min": X

  - "shorter than X months"
    â†’ use: "duration_max": X

â€¢ For relative dates:
  - "last X months"
    â†’ use: "last_n_months": X

â€¢ Never output natural language.
â€¢ Never explain anything.
â€¢ Output must be valid JSON only.

Examples:

Query: Show failed vendor agreements in APAC with compliance score above 70
Output:
{
  "contract_type": "Vendor Agreement",
  "audit_status": "Failed",
  "region": "APAC",
  "compliance_score_min": 70
}

Query: Contracts between 60 and 80 score from last 3 months
Output:
{
  "compliance_score_between": [60, 80],
  "last_n_months": 3
}
"""

In [9]:
import json
from ollama import chat

def extract_filters_from_query(user_query):
    response = chat(
        model='gpt-oss:20b-cloud',
        messages=[
            {'role': 'system', 'content': SYSTEM_PROMPT_EXTRACT},
            {'role': 'user', 'content': user_query}
        ],
        options={
            "temperature": 0
        }
    )

    content = response.message.content.strip()

    # Try parsing JSON safely
    try:
        return json.loads(content)
    except json.JSONDecodeError:
        print("âš  LLM did not return valid JSON:")
        print(content)
        return {}

In [17]:
from sqlalchemy import text
from datetime import datetime, timedelta

def get_contracts_dynamic(filters):
    base_query = "SELECT contract_id FROM contracts WHERE 1=1"
    params = {}

    # ðŸ”¹ Text Filters
    if filters.get("vendor_name"):
        base_query += " AND vendor_name ILIKE :vendor_name"
        params["vendor_name"] = f"%{filters['vendor_name']}%"

    if filters.get("contract_type"):
        base_query += " AND contract_type ILIKE :contract_type"
        params["contract_type"] = f"%{filters['contract_type']}%"

    if filters.get("audit_status"):
        base_query += " AND audit_status ILIKE :audit_status"
        params["audit_status"] = f"%{filters['audit_status']}%"

    if filters.get("region"):
        base_query += " AND region ILIKE :region"
        params["region"] = f"%{filters['region']}%"

    if filters.get("jurisdiction"):
        base_query += " AND jurisdiction ILIKE :jurisdiction"
        params["jurisdiction"] = f"%{filters['jurisdiction']}%"

    if filters.get("policy_name"):
        base_query += " AND policy_name ILIKE :policy_name"
        params["policy_name"] = f"%{filters['policy_name']}%"

    # ðŸ”¹ Compliance Score Filters
    if filters.get("compliance_score_min"):
        base_query += " AND compliance_score >= :compliance_score_min"
        params["compliance_score_min"] = filters["compliance_score_min"]

    if filters.get("compliance_score_max"):
        base_query += " AND compliance_score <= :compliance_score_max"
        params["compliance_score_max"] = filters["compliance_score_max"]

    # ðŸ”¹ Range filter (between)
    if filters.get("compliance_score_between"):
        base_query += " AND compliance_score BETWEEN :score_min AND :score_max"
        params["score_min"] = filters["compliance_score_between"][0]
        params["score_max"] = filters["compliance_score_between"][1]

    # ðŸ”¹ Duration Filters
    if filters.get("duration_min"):
        base_query += " AND duration_months >= :duration_min"
        params["duration_min"] = filters["duration_min"]

    if filters.get("duration_max"):
        base_query += " AND duration_months <= :duration_max"
        params["duration_max"] = filters["duration_max"]

    # ðŸ”¹ Relative Date Filtering (e.g., last 3 months)
    if filters.get("last_n_months"):
        months = filters["last_n_months"]
        date_threshold = datetime.today() - timedelta(days=30 * months)

        base_query += " AND contract_date >= :date_threshold"
        params["date_threshold"] = date_threshold

    with engine.connect() as conn:
        result = conn.execute(text(base_query), params)
        rows = result.fetchall()

    return rows

In [12]:
query = "Show failed contracts of Pacific Trade Co in APAC"

filters = extract_filters_from_query(query)

print(filters)

{'vendor_name': 'Pacific Trade Co', 'audit_status': 'Failed', 'region': 'APAC'}


In [19]:
def process_user_query(user_query):
    # Step 1: Extract filters using LLM
    filters = extract_filters_from_query(user_query)
    print("ðŸ”Ž Extracted Filters:", filters)

    # Step 2: Fetch contract IDs from Postgres
    contract_ids = get_contracts_dynamic(filters)
    print("ðŸ“„ Matching Contracts:", contract_ids)

    return contract_ids

query = "Show contract on 2023-11-8"
print(process_user_query(query))

ðŸ”Ž Extracted Filters: {'contract_date': '2023-11-8'}
ðŸ“„ Matching Contracts: [(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (15,), (16,), (17,), (18,), (19,), (20,), (21,), (22,), (23,), (24,), (25,), (26,), (27,), (28,), (29,), (30,)]
[(1,), (2,), (3,), (4,), (5,), (6,), (7,), (8,), (9,), (10,), (11,), (12,), (13,), (14,), (15,), (16,), (17,), (18,), (19,), (20,), (21,), (22,), (23,), (24,), (25,), (26,), (27,), (28,), (29,), (30,)]
