# Contract Metadata Extraction

The notebook demonstrates a complete **LLM based** pipeline that:

1. Extracts raw text from PDF contracts
2. Chunk the pdf if token size is greater than 25000. Since, I am using GPT-4o, it has a token limitation upto 30k.
3. Pass the prompt and chunk text to LLM model
4. LLM model generates the json output
    - If pdf size is less than max token (25000), entire pdf passed to LLM
    - If it exceeds 25,000 tokens, the text is divided into multiple chunks. The LLM produces separate JSON for each chunk, which are then de‑duplicated and merged into a single JSON payload.
5. OpenAI model **classify** the contract type along with metadata and **return strict JSON** metadata
6. Persists the result with guard‑rails for invalid output


## Imports

In [46]:
from __future__ import annotations
from pathlib import Path
import asyncio
from unidecode import unidecode
from itertools import chain
import PyPDF2
from textwrap import dedent
import logging, json, re, time, math, os, hashlib, pathlib
from pydantic import BaseModel, Field
from openai import AsyncOpenAI
from dotenv import load_dotenv
from typing import Dict, Type, Optional, List, Any, Iterable, Awaitable, Dict, Any
from importlib import import_module
from pkgutil import iter_modules
log = logging.getLogger(__name__)
load_dotenv()


OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
DEFAULT_MODEL = os.getenv("DEFAULT_MODEL")
TEMPERATURE = os.getenv("TEMPERATURE")
MAX_TOKENS_DEFAULT = int(os.getenv("TOKEN_LIMIT"))
TIKTOKEN = os.getenv("TIKTOKEN")

# Load pdf file from data directory

In [2]:
def load_pdf_text(path: Path) -> str:
    reader = PyPDF2.PdfReader(str(path))
    return "\n".join([unidecode(p.extract_text()) or "" for p in reader.pages])

# LLM calling function

- ChatMessage (Pydantic model): Typed container for each message in the prompt (role = "system", "user", or "assistant"; content = text).
- LLMClient:
    - __init__ : Stores the model name and instantiates AsyncOpenAI with OPENAI_API_KEY.
    - __aenter__ / __aexit__ : you use async with LLMClient() as client: so the underlying HTTP session is opened/closed cleanly.
    - chat_completion : Accepts a list of ChatMessage

In [3]:
class ChatMessage(BaseModel):
    role: str
    content: str

class LLMClient:
    def __init__(self, model: str = DEFAULT_MODEL):
        self.model = model
        self.client = AsyncOpenAI(api_key=OPENAI_API_KEY)

    async def __aenter__(self):
        return self

    async def __aexit__(self, exc_type, exc, tb):
        await self.client.close()

    async def chat_completion(self, messages: list[ChatMessage]) -> str:
        formatted_messages = [m.model_dump() for m in messages]
        response = await self.client.chat.completions.create(
            model=self.model,
            messages=formatted_messages,
            temperature=float(TEMPERATURE),
        )
        return response.choices[0].message.content

# Schema

### Schema Defination

In [4]:
class Party(BaseModel):
    role: str
    name: str

class Clause(BaseModel):
    name: str
    present: bool
    text: Optional[str] = None

class GenericContractSchema(BaseModel):
    """Fallback schema able to hold any contract."""
    contract_type: str
    parties: List[Party] = Field(default_factory=list)
    effective_date: Optional[str] = None
    termination_date: Optional[str] = None
    governing_law: Optional[str] = None
    renewal_terms: Optional[str] = None
    clauses: List[Clause] = Field(default_factory=list)
    custom_fields: Dict[str, Any] = Field(default_factory=dict)


### Employment Schema

In [5]:
class EmploymentSchema(GenericContractSchema):
    contract_type:str='employment'
    employee_name: Optional[str]=None
    employer_name: Optional[str]=None
    compensation: Optional[str]=None
    probation_period: Optional[str]=None

### NDA Schema

In [6]:
class NDASchema(GenericContractSchema):
    contract_type:str='nda'
    confidentiality_period: Optional[str]=None
    non_compete: Optional[bool]=None


### Service Schema

In [7]:
class MSASchema(GenericContractSchema):
    contract_type:str='service agreement'
    payment_terms: Optional[str]=None
    indemnification: Optional[bool]=None
    limitation_of_liability: Optional[str]=None


# Utils

In [8]:
async def gather_with_concurrency(concurrency:int,*aws:Iterable[Awaitable[Any]])->list[Any]:
    sem=asyncio.Semaphore(concurrency)
    async def wrap(coro):
        async with sem:
            return await coro
    return await asyncio.gather(*[wrap(a) for a in aws])

## Pdf Chunking

In [9]:
def _get_encoder(model: str | None = None):
    """
    Return a tiktoken encoder if available; otherwise fall back to a 4‑chars≈1‑token
    approximation.
    """
    try:
        import tiktoken
        if model:
            return tiktoken.encoding_for_model(model)
        return tiktoken.get_encoding(TIKTOKEN)
    except Exception:
        class _Approx:
            def encode(self, text: str) -> list[int]:
                return [0] * math.ceil(len(text) / 4)
        return _Approx()

In [10]:
def token_len(text: str, model: str | None = None) -> int:
    try:
        return len(_get_encoder(model).encode(text))
    except Exception:
        return math.ceil(len(text) / 4)

In [11]:
def split_by_tokens(
    text: str,
    *,
    max_tokens: int = MAX_TOKENS_DEFAULT,
    model: str | None = None,
) -> List[str]:
    """
    Split *text* into chunks whose approximate token length ≤ *max_tokens*.
    """
    enc = _get_encoder(model)
    try:
        toks = enc.encode(text)
        chunk_lists = [toks[i:i+max_tokens] for i in range(0, len(toks), max_tokens)]
        if hasattr(enc, "decode"):
            return [enc.decode(chunk) for chunk in chunk_lists]
        raise AttributeError
    except AttributeError:
        char_step = max_tokens * 4
        return [text[i:i+char_step] for i in range(0, len(text), char_step)]

## Merge LLM Output and remove duplicates

In [12]:
def _dedup_list(lst: List[Any]) -> List[Any]:
    seen = set()
    out = []
    for item in lst:
        key = json.dumps(item, sort_keys=True, default=str)
        if key not in seen:
            seen.add(key)
            out.append(item)
    return out

In [13]:
def _dedup_parties(parties: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    seen = set()
    out = []
    for party in parties:
        key = party.get("name", "").strip().lower()
        if key and key not in seen:
            seen.add(key)
            out.append(party)
    return out

In [14]:
def _dedup_clauses(clauses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    seen = set()
    out = []
    for clause in clauses:
        key = (clause.get("name", "").strip().lower(), bool(clause.get("present")))
        if key not in seen:
            seen.add(key)
            out.append(clause)
    return out

In [15]:
def _postprocess_contract(data: Dict[str, Any]) -> Dict[str, Any]:
    """Remove duplicates inside a single contract JSON."""
    if not isinstance(data, dict):
        return data
    if "parties" in data and isinstance(data["parties"], list):
        data["parties"] = _dedup_parties(data["parties"])
    if "clauses" in data and isinstance(data["clauses"], list):
        data["clauses"] = _dedup_clauses(data["clauses"])
    return data

In [16]:
def merge_json_objects(objs: List[Dict[str, Any]]) -> Dict[str, Any]:
    if not objs:
        return {}
    result: Dict[str, Any] = {}
    for obj in objs:
        for k, v in obj.items():
            if v in (None, [], {}, "", 0):
                continue
            if k not in result or result[k] in (None, [], {}, "", 0):
                result[k] = v
            else:
                if isinstance(v, list) and isinstance(result[k], list):
                    result[k] = _dedup_list(result[k] + v)
                elif isinstance(v, dict) and isinstance(result[k], dict):
                    merged = result[k]
                    merged.update({kk: vv for kk, vv in v.items() if vv not in (None, [], {}, "", 0)})
                    result[k] = merged
    return _postprocess_contract(result)

## Clear json

In [17]:
_JSON_FENCE_RE = re.compile(r"```[a-zA-Z0-9_]*|```")

def _strip_markdown_fence(text:str)->str:
    return _JSON_FENCE_RE.sub("", text).strip()

def _remove_trailing_commas(s:str)->str:
    # commas before } or ]
    return re.sub(r",\s*([}\]])", r"\\1", s)

def _add_missing_commas(s:str)->str:
    # "next_key" -> "value"
    return re.sub(r'("\s*)("[a-zA-Z0-9_]+"\s*:)', r'", \2', s)

CLEAN_STEPS = (_strip_markdown_fence, _remove_trailing_commas, _add_missing_commas)

def fix_and_load(blob:str)->dict|None:
    """Try json.loads; if it fails run and retry."""
    try:
        return json.loads(blob)
    except Exception as e:
        cleaned = blob
        for fn in CLEAN_STEPS:
            cleaned = fn(cleaned)
        try:
            return json.loads(cleaned)
        except Exception as e2:
            log.debug("JSON repair failed: %s", e2)
            return None

# SCHEMA MAP

In [18]:
SCHEMA_MAP: Dict[str, Type[GenericContractSchema]] = {
    EmploymentSchema().contract_type.lower(): EmploymentSchema,
    NDASchema().contract_type.lower(): NDASchema,
    MSASchema().contract_type.lower(): MSASchema,
}

In [19]:
def _schema_menu_json() -> str:
    menu = {}
    for name, schema_cls in SCHEMA_MAP.items():
        menu[name] = schema_cls.schema()
    menu["generic"] = GenericContractSchema.schema()
    return json.dumps(menu, indent=2)

# Metadata Extraction

## Prompt

In [20]:
PROMPT = dedent("""You are an expert contract analyst.

**Goal**  
Return a single **valid JSON object** that follows the selected schema and
contains the best structured metadata you can extract.

### 1. Decide `contract_type`
Give the document a short label (one or two words).

### 2. Pick schema
Select the schema from the menu whose name is closest to that label.
If nothing fits, pick `generic`.

### 3. Extract
Fill every key. Use null where data is absent.  
Unknown but important extra fields → `custom_fields` (key/value).  
Detect common clauses (indemnification, confidentiality, non‑compete, etc.)
and list them in `clauses`.

### Output rules
- **Return only JSON** (no markdown, no explanations).  
- Must parse with `json.loads` on first try.

### Schema menu
{schema_menu}

### Contract
{contract_text}
### End
""")

In [21]:
def _extract_json(blob: str) -> str | None:
    """
    Extract the *largest* JSON object embedded in `blob`.

    Using the first "{" and the last "}" is a robust way to keep the whole
    object even when the model spills log‑probs or other text before / after.
    """
    start = blob.find("{")
    end = blob.rfind("}")
    if start == -1 or end == -1 or end <= start:
        return None
    return blob[start : end + 1]

In [22]:
async def _extract_single(
    pdf: Path, text: str, client: LLMClient
) -> GenericContractSchema:
    """
    Call the LLM once for a text *chunk* and return a parsed schema instance.

    Includes automatic JSON‑repair fallback.
    """
    prompt = PROMPT.format(schema_menu=_schema_menu_json(), contract_text=text)
    messages = [ChatMessage(role="system", content=prompt)]

    start_ts = time.perf_counter()
    raw = await client.chat_completion(messages)
    latency = time.perf_counter() - start_ts

    json_blob = _extract_json(raw) or raw
    data = fix_and_load(json_blob)
    if data is None:
        log.error(
            "Unrecoverable JSON for %s (first 120 chars): %s",
            pdf.name,
            json_blob[:120],
        )
        return GenericContractSchema(contract_type="unknown", parties=[])

    schema_cls: Type[GenericContractSchema] = SCHEMA_MAP.get(
        data.get("contract_type", "").lower(), GenericContractSchema
    )
    extraction = schema_cls(**data)
    log.info("Extracted chunk for %s in %.1fs", pdf.name, latency)
    return extraction

In [23]:
async def extract_metadata(
    pdf: Path,
    client: LLMClient,
    *,
    save: bool = True,
    output_dir: Optional[Path] = None,
    max_tokens: int = MAX_TOKENS_DEFAULT,
) -> GenericContractSchema:
    """
    High‑level orchestration for a *single PDF*.

    1. Load & (if necessary) chunk the PDF text.
    2. Run `_extract_single` on each chunk (concurrently limited by caller).
    3. Merge the partial JSONs → deduplicate lists.
    4. Persist result to disk (optional).
    """
    text = load_pdf_text(pdf)
    model_name = client.model

    if token_len(text, model_name) <= max_tokens:
        final_extraction = await _extract_single(pdf, text, client)
    else:
        chunks = split_by_tokens(text, max_tokens=max_tokens, model=model_name)
        extractions: List[GenericContractSchema] = []
        for idx, chunk in enumerate(chunks):
            log.info("Processing chunk %d/%d for %s", idx + 1, len(chunks), pdf.name)
            ext = await _extract_single(pdf, chunk, client)
            extractions.append(ext)

        merged_json = merge_json_objects(
            [e.model_dump(mode="json", exclude_none=True) for e in extractions]
        )
        schema_cls: Type[GenericContractSchema] = SCHEMA_MAP.get(
            merged_json.get("contract_type", "").lower(), GenericContractSchema
        )
        final_extraction = schema_cls(**merged_json)

    if save:
        import hashlib

        if output_dir:
            output_dir.mkdir(parents=True, exist_ok=True)
            short_hash = hashlib.md5(str(pdf).encode()).hexdigest()[:6]
            json_name = f"{pdf.stem}_{short_hash}.json"
            out_path = output_dir / json_name
        else:
            out_path = pdf.with_suffix(".json")

        out_path.write_text(
            final_extraction.model_dump_json(indent=2, exclude_none=True),
            encoding="utf-8",
        )
        log.info("Wrote %s", out_path)

    return final_extraction

# Run Main

In [24]:
async def _proc(pdf:Path,client:LLMClient,*,save:bool,od:Optional[Path]):
    meta = await extract_metadata(pdf,client,save=save,output_dir=od)
    if not save:
        print(json.dumps(meta.model_dump(mode='json',exclude_none=True),indent=2))

In [63]:
async def main(dir_path:Path,concurrency:int,save:bool,output_dir:Optional[Path]):
    dir_path = Path(dir_path)
    pdfs = [f for f in dir_path.rglob('*') if (
            f.is_file()
            and f.suffix.lower() == '.pdf' 
            and f.stat().st_size > 0
        )
    ]
    if not pdfs:
        raise SystemExit('No PDFs found')
    async with LLMClient() as client:
        await gather_with_concurrency(concurrency,*(_proc(p,client,save=save,od=output_dir) for p in pdfs))

In [64]:
import nest_asyncio
nest_asyncio.apply()
try:
    asyncio.run(main("../data/full_contract_pdf",2,True,Path("output")))
except KeyboardInterrupt:
    print('Cancelled')

/var/folders/_n/rk12tlkj4_zd449jbv9n6p0m0000gn/T/ipykernel_42917/932756393.py:4: PydanticDeprecatedSince20: The `schema` method is deprecated; use `model_json_schema` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  menu[name] = schema_cls.schema()
/var/folders/_n/rk12tlkj4_zd449jbv9n6p0m0000gn/T/ipykernel_42917/932756393.py:5: PydanticDeprecatedSince20: The `schema` method is deprecated; use `model_json_schema` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  menu["generic"] = GenericContractSchema.schema()


## The output you can find inside the notebook/output folder 