<a href="https://colab.research.google.com/github/eswar3330/USEReady_ESWAR1/blob/main/USEReady_AIML_Assignment_1_Eswar_Reddy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# --- Phase 0 & 1: Project Setup & Document Text Extraction ---

print("--- Setting up environment and starting text extraction ---")

# Install necessary libraries
!sudo apt install -y tesseract-ocr
!pip install pytesseract
!pip install python-docx
!pip install pandas numpy

print("\nLibraries installed successfully.")

import pandas as pd
import numpy as np
import os
from PIL import Image
import pytesseract
import docx

# Define possible file extensions to try
POSSIBLE_EXTENSIONS = ['.docx', '.png', '.jpg', '.jpeg']

print("\nDefining text extraction functions...")

def extract_text_from_docx(docx_path):
    """Extracts text from a .docx file."""
    try:
        doc = docx.Document(docx_path)
        full_text = []
        for para in doc.paragraphs:
            full_text.append(para.text)
        return "\n".join(full_text)
    except Exception as e:
        print(f"Error extracting text from DOCX {os.path.basename(docx_path)}: {e}")
        return None

def extract_text_from_image(image_path):
    """Extracts text from an image using Tesseract OCR."""
    try:
        img = Image.open(image_path)
        text = pytesseract.image_to_string(img)
        return text
    except Exception as e:
        print(f"Error extracting text from IMAGE {os.path.basename(image_path)}: {e}")
        return None

print("Text extraction functions defined.")


# Batch Process Documents and Extract Text
print("\nBatch processing documents and extracting text...")

data_dir = 'data'
train_docs_dir = os.path.join(data_dir, 'train')
test_docs_dir = os.path.join(data_dir, 'test')

# Load ground truth CSVs to get file names
try:
    train_df_gt = pd.read_csv(os.path.join(data_dir, 'train.csv'))
    test_df_gt = pd.read_csv(os.path.join(data_dir, 'test.csv'))
    print("train.csv and test.csv loaded successfully.")
except FileNotFoundError:
    print(f"Error: train.csv or test.csv not found in '{data_dir}'. Please check file organization.")
    exit()

# Function to find the correct file path and extract text
def find_file_and_extract_text(base_file_name, directory):
    for ext in POSSIBLE_EXTENSIONS:
        full_file_name = base_file_name + ext
        file_path = os.path.join(directory, full_file_name)
        if os.path.exists(file_path):
            if ext == '.docx':
                return extract_text_from_docx(file_path)
            elif ext in ['.png', '.jpg', '.jpeg']:
                return extract_text_from_image(file_path)
    print(f"Warning: No matching document found for '{base_file_name}' in '{directory}' with any supported extension. Skipping.")
    return None

# Process training documents
extracted_train_data = []
print(f"\nProcessing documents in '{train_docs_dir}'...")
for index, row in train_df_gt.iterrows():
    base_file_name = row['File Name']
    extracted_text = find_file_and_extract_text(base_file_name, train_docs_dir)
    extracted_train_data.append({'File Name': base_file_name, 'extracted_text': extracted_text})

train_extracted_df = pd.DataFrame(extracted_train_data)
# Filter out rows where text extraction failed (extracted_text is None)
original_train_count = len(train_extracted_df)
train_extracted_df = train_extracted_df.dropna(subset=['extracted_text'])
skipped_train_count = original_train_count - len(train_extracted_df)
if skipped_train_count > 0:
    print(f"Skipped {skipped_train_count} training documents due to missing files or extraction errors.")

print(f"Successfully extracted text from {len(train_extracted_df)} training documents.")
print("\nfile names of extracted training text:")
print(train_extracted_df)


# Process test documents
extracted_test_data = []
print(f"\nProcessing documents in '{test_docs_dir}'...")
for index, row in test_df_gt.iterrows():
    base_file_name = row['File Name']
    extracted_text = find_file_and_extract_text(base_file_name, test_docs_dir)
    extracted_test_data.append({'File Name': base_file_name, 'extracted_text': extracted_text})

test_extracted_df = pd.DataFrame(extracted_test_data)
# Filter out rows where text extraction failed (extracted_text is None)
original_test_count = len(test_extracted_df)
test_extracted_df = test_extracted_df.dropna(subset=['extracted_text'])
skipped_test_count = original_test_count - len(test_extracted_df)
if skipped_test_count > 0:
    print(f"Skipped {skipped_test_count} test documents due to missing files or extraction errors.")

print(f"Successfully extracted text from {len(test_extracted_df)} test documents.")
print("\n file names of extracted test text:")
print(test_extracted_df)

print("\n--- Document Text Extraction Complete ---")

# Store the extracted DataFrames for future phases
%store train_extracted_df
%store test_extracted_df
%store train_df_gt
%store test_df_gt
print("\nExtracted text DataFrames and ground truth CSVs stored for next phase.")

--- Setting up environment and starting text extraction ---
Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
tesseract-ocr is already the newest version (4.1.1-2.1build1).
0 upgraded, 0 newly installed, 0 to remove and 35 not upgraded.

Libraries installed successfully.

Defining text extraction functions...
Text extraction functions defined.

Batch processing documents and extracting text...
train.csv and test.csv loaded successfully.

Processing documents in 'data/train'...
Skipped 1 training documents due to missing files or extraction errors.
Successfully extracted text from 9 training documents.

file names of extracted training text:
                                           File Name  \
0  6683127-House-Rental-Contract-GERALDINE-GALINA...   
1  6683129-House-Rental-Contract-Geraldine-Galina...   
2                        18325926-Rental-Agreement-1   
4                          36199312-Rental-Agreement   
5  44737744-Maddireddy-B

In [None]:
# --- Phase 2: Data Preprocessing & Ground Truth Alignment ---

print("\n--- Starting Data Preprocessing & Ground Truth Alignment ---")

# Recall stored DataFrames from previous phase
%store -r train_extracted_df
%store -r test_extracted_df
%store -r train_df_gt
%store -r test_df_gt

print("Recalled stored DataFrames.")

# Merge Extracted Text with Ground Truth Metadata
print("\nMerging extracted text with ground truth metadata...")

# Merge training data
# Using 'File Name' as the key for merging
train_data = pd.merge(train_extracted_df, train_df_gt, on='File Name', how='inner')
print(f"Merged training data shape: {train_data.shape}")
print("\nMerged Training Data Head:")
print(train_data.head())
print("\nMerged Training Data Info (check for missing values):")
train_data.info()

# Merge testing data
test_data = pd.merge(test_extracted_df, test_df_gt, on='File Name', how='inner')
print(f"\nMerged testing data shape: {test_data.shape}")
print("\nMerged Testing Data Head:")
print(test_data.head())
print("\nMerged Testing Data Info (check for missing values):")
test_data.info()


# Define Target Fields and Create Standardized Field List
print("\nDefining target fields and standardizing list...")

# These are the fields we need to extract, as per the problem statement
TARGET_FIELDS = [
    'Aggrement Value',
    'Aggrement Start Date',
    'Aggrement End Date',
    'Renewal Notice (Days)',
    'Party One',
    'Party Two'
]

print(f"\nTarget fields to extract: {TARGET_FIELDS}")

# Verify all target columns are present in the merged DataFrames
for field in TARGET_FIELDS:
    if field not in train_data.columns:
        print(f"Error: Target field '{field}' not found in train_data DataFrame. Check CSV columns.")
        exit()
    if field not in test_data.columns:
        print(f"Error: Target field '{field}' not found in test_data DataFrame. Check CSV columns.")
        exit()
print("All target fields confirmed present in merged DataFrames.")


# Handle Missing Ground Truth Values in Training Data
print("\nHandling missing ground truth values in training data...")
# For IE, if a ground truth value is missing in the CSV, it means that field is not present
# in that particular document for training purposes. We will leave them as NaN.
# During evaluation, we'll need to handle these correctly (e.g., if GT is NaN, it's not a matchable field).

print("Missing ground truth values in train_data before further processing:")
print(train_data[TARGET_FIELDS].isnull().sum())
print("\nMissing ground truth values in test_data before further processing:")
print(test_data[TARGET_FIELDS].isnull().sum())

# Insight: If 'Renewal Notice (Days)' has NaN, it might mean it's not applicable for some contracts.
# We need to decide how to handle this during training/evaluation.
# For exact match, an NaN in GT means we can't get a 'True' match for it.

print("\n--- Data Preprocessing & Ground Truth Alignment Complete ---")

# Store the merged DataFrames for the next phase
%store train_data
%store test_data
%store TARGET_FIELDS
print("Merged data and target fields stored for next phase.")


--- Starting Data Preprocessing & Ground Truth Alignment ---
Recalled stored DataFrames.

Merging extracted text with ground truth metadata...
Merged training data shape: (9, 8)

Merged Training Data Head:
                                           File Name  \
0  6683127-House-Rental-Contract-GERALDINE-GALINA...   
1  6683129-House-Rental-Contract-Geraldine-Galina...   
2                        18325926-Rental-Agreement-1   
3                          36199312-Rental-Agreement   
4  44737744-Maddireddy-Bhargava-Reddy-Rental-Agre...   

                                      extracted_text  Aggrement Value  \
0  House Rental Contract\nKNOWN ALL MEN BY THESE ...             6500   
1  \n\n\n\n\nHouse Rental Contract\nKNOWN ALL MEN...             6500   
2  \n\n\n\n\n\n\n\n\n\nRENTAL AGREEMENT\nThis dee...             4000   
3  RENEWAL OF RENTAL AGREEMENT\n\nThis AGREEMENT ...             3800   
4  RENTfft\tENT\nThis Rental Agreement is made an...             3000   

  Aggrement Start

In [None]:
# --- Phase 3: ML-based Information Extraction (Setup + Pseudo-Annotation) ---

print("\n--- Starting ML-based Information Extraction (Phase 3) ---")

# Install required libraries
print("\nInstalling required libraries...")
!pip install transformers datasets accelerate evaluate seqeval -q
!pip install fuzzywuzzy python-Levenshtein -q

print("Libraries installed successfully.")

# Load previously stored data
%store -r train_data
%store -r test_data

print("Train and test data loaded.")

# Fields we want to extract
TARGET_FIELDS = [
    'Aggrement Value',
    'Aggrement Start Date',
    'Aggrement End Date',
    'Renewal Notice (Days)',
    'Party One',
    'Party Two'
]
print(f"Target fields: {TARGET_FIELDS}")

# NER tag setup
NER_TAG_PREFIXES = {
    'Aggrement Value': 'AGREEMENT_VALUE',
    'Aggrement Start Date': 'AGREEMENT_START_DATE',
    'Aggrement End Date': 'AGREEMENT_END_DATE',
    'Renewal Notice (Days)': 'RENEWAL_NOTICE_DAYS',
    'Party One': 'PARTY_ONE',
    'Party Two': 'PARTY_TWO'
}

all_ner_tags = ['O']
for prefix in NER_TAG_PREFIXES.values():
    all_ner_tags.append(f"B-{prefix}")
    all_ner_tags.append(f"I-{prefix}")

tag_to_id = {tag: i for i, tag in enumerate(all_ner_tags)}
id_to_tag = {i: tag for tag, i in tag_to_id.items()}

print(f"NER tag list prepared: {all_ner_tags}")

# Imports for processing
from datetime import datetime
import re
from fuzzywuzzy import fuzz, process
import unicodedata

# Helper to add ordinal suffix to a number
def ordinal(n):
    if 10 <= n % 100 <= 20:
        return str(n) + 'th'
    else:
        return str(n) + {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th')

# Clean and expand ground truth values for matching
def normalize_gt_for_matching(field_name, value):
    if pd.isna(value) or str(value).strip() == '':
        return []

    value_str = str(value).strip()
    possible_matches = []

    base_normalized = re.sub(r'\s+', ' ', value_str).strip().lower()
    possible_matches.append(base_normalized)

    cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', unicodedata.normalize('NFKD', value_str).encode('ascii', 'ignore').decode('utf-8')).strip()
    cleaned = re.sub(r'\s+', ' ', cleaned).lower()
    possible_matches.append(cleaned)
    possible_matches.append(cleaned.replace(' ', ''))

    if 'Date' in field_name:
        try:
            dt_formats = [
                '%d.%m.%Y', '%d/%m/%Y', '%Y-%m-%d', '%m/%d/%Y',
                '%d-%b-%Y', '%B %d, %Y', '%d %B %Y', '%b %d, %Y',
                '%d %B, %Y', '%d %b, %Y', '%Y/%m/%d', '%Y.%m.%d'
            ]
            dt_obj = None
            for fmt in dt_formats:
                try:
                    dt_obj = datetime.strptime(value_str, fmt)
                    break
                except ValueError:
                    continue

            if dt_obj:
                output_formats = [
                    '%d.%m.%Y', '%d/%m/%Y', '%Y-%m-%d', '%m/%d/%Y',
                    '%B %d, %Y', '%d %B %Y', '%b %d, %Y', '%d %b %Y',
                    '%Y/%m/%d', '%Y.%m.%d'
                ]
                for fmt in output_formats:
                    try:
                        formatted = dt_obj.strftime(fmt)
                        possible_matches.append(re.sub(r'\s+', ' ', formatted).strip().lower())
                    except ValueError:
                        pass

                day_with_suffix = ordinal(dt_obj.day)
                month_full = dt_obj.strftime('%B')
                month_abbr = dt_obj.strftime('%b')
                year = dt_obj.year

                possible_matches.extend([
                    f"{day_with_suffix} day of {month_full} {year}".lower(),
                    f"{day_with_suffix} day of {month_abbr} {year}".lower(),
                    f"{day_with_suffix} of {month_full} {year}".lower(),
                    f"{month_full} {dt_obj.day}, {year}".lower(),
                    f"{month_abbr} {dt_obj.day}, {year}".lower(),
                    f"{month_full} {day_with_suffix}, {year}".lower()
                ])

                possible_matches = [re.sub(r'\s+', ' ', s).strip() for s in possible_matches]

        except ValueError:
            pass

    elif 'Value' in field_name or 'Days' in field_name:
        cleaned_value = value_str.lower().replace('usd', '').replace('rs', '').replace('inr', '').replace('p', '').strip()
        cleaned_value = re.sub(r'[^0-9.,]', '', cleaned_value)

        try:
            num_for_float = re.sub(r'[^0-9.]', '', cleaned_value.replace(',', ''))
            float_val = float(num_for_float)
            possible_matches.append(str(float_val).lower())
            if float_val == int(float_val):
                possible_matches.append(str(int(float_val)).lower())
        except ValueError:
            pass

        if 'Value' in field_name:
            possible_matches.extend([
                'rs. ' + cleaned_value,
                'rs ' + cleaned_value,
                'p ' + cleaned_value,
                'inr ' + cleaned_value,
                '₹ ' + cleaned_value
            ])
        if 'Days' in field_name:
            possible_matches.extend([
                cleaned_value + ' days',
                cleaned_value + ' day'
            ])

    elif 'Party' in field_name:
        cleaned_name = re.sub(r'[^a-zA-Z0-9\s]', '', unicodedata.normalize('NFKD', value_str).encode('ascii', 'ignore').decode('utf-8')).strip()
        cleaned_name = re.sub(r'\s+', ' ', cleaned_name).strip().lower()
        possible_matches.append(cleaned_name)

        if "and/or" in value_str.lower():
            parties = value_str.lower().split("and/or")
            for p in parties:
                cleaned_p = re.sub(r'[^a-zA-Z0-9\s]', '', unicodedata.normalize('NFKD', p).encode('ascii', 'ignore').decode('utf-8')).strip()
                cleaned_p = re.sub(r'\s+', ' ', cleaned_p).strip()
                if cleaned_p:
                    possible_matches.append(cleaned_p)

    return list(set([s for s in possible_matches if s]))


def pseudo_annotate_document_refined(document_text, ground_truth_row, ner_tag_prefixes, tag_to_id, fuzzy_match_threshold=85):
    annotations = []
    text_original = document_text if document_text else ""
    text_lower_cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', unicodedata.normalize('NFKD', text_original).encode('ascii', 'ignore').decode('utf-8')).strip()
    text_lower_cleaned = re.sub(r'\s+', ' ', text_lower_cleaned).lower()

    temp_annotations = []

    for field, tag_prefix in ner_tag_prefixes.items():
        gt_value = ground_truth_row.get(field)
        if pd.isna(gt_value) or str(gt_value).strip() == '':
            continue

        original_gt_str = str(gt_value).strip()
        search_strings = normalize_gt_for_matching(field, original_gt_str)

        best_match_idx = -1
        best_match_len = -1
        matched_text = ""

        for search_str in sorted(search_strings, key=len, reverse=True):
            if not search_str: continue

            start_idx = text_original.lower().find(search_str)

            if start_idx != -1:
                actual_text = text_original[start_idx : start_idx + len(search_str)]
                best_match_idx = start_idx
                best_match_len = len(search_str)
                matched_text = actual_text
                break

        if best_match_idx != -1:
            end_idx = best_match_idx + best_match_len
            temp_annotations.append({
                'start': best_match_idx,
                'end': end_idx,
                'label': f"B-{tag_prefix}",
                'value': original_gt_str,
                'found_text': matched_text
            })
        else:
            print(f"  INFO: '{original_gt_str}' not found for '{field}' in file '{ground_truth_row['File Name']}'")

    annotations.extend(sorted(temp_annotations, key=lambda x: x['start']))
    return annotations


# Run pseudo-annotation on training data
train_annotated_data = []
for index, row in train_data.iterrows():
    annotations = pseudo_annotate_document_refined(row['extracted_text'], row, NER_TAG_PREFIXES, tag_to_id)
    train_annotated_data.append({
        'id': str(index),
        'document_text': row['extracted_text'],
        'annotations': annotations,
        'original_row': row.to_dict()
    })

print(f"Annotated {len(train_annotated_data)} documents.")

# Show an example annotation
if train_annotated_data:
    first_doc = train_annotated_data[0]
    snippet = first_doc['document_text'][:500].replace('\n', ' ')
    print(f"\nFile: {first_doc['original_row']['File Name']}")
    print(f"Text Preview: {snippet}...")
    print("Annotations:")
    if first_doc['annotations']:
        for ann in first_doc['annotations']:
            context_start = max(0, ann['start'] - 20)
            context_end = min(len(first_doc['document_text']), ann['end'] + 20)
            context = first_doc['document_text'][context_start:context_end].replace('\n', ' ')
            print(f"  Label: {ann['label']}, Value: '{ann['value']}' (Found: '{ann['found_text']}', Context: '...{context}...')")
    else:
        print("  No annotations found.")
else:
    print("No training data available.")

print("\n--- Pseudo-Annotation Complete. Ready for tokenization ---")

# Save for next step
%store train_annotated_data
%store all_ner_tags
%store tag_to_id
%store id_to_tag



--- Starting ML-based Information Extraction (Phase 3) ---

Installing required libraries...
Libraries installed successfully.
Train and test data loaded.
Target fields: ['Aggrement Value', 'Aggrement Start Date', 'Aggrement End Date', 'Renewal Notice (Days)', 'Party One', 'Party Two']
NER tag list prepared: ['O', 'B-AGREEMENT_VALUE', 'I-AGREEMENT_VALUE', 'B-AGREEMENT_START_DATE', 'I-AGREEMENT_START_DATE', 'B-AGREEMENT_END_DATE', 'I-AGREEMENT_END_DATE', 'B-RENEWAL_NOTICE_DAYS', 'I-RENEWAL_NOTICE_DAYS', 'B-PARTY_ONE', 'I-PARTY_ONE', 'B-PARTY_TWO', 'I-PARTY_TWO']
  INFO: '6500' not found for 'Aggrement Value' in file '6683127-House-Rental-Contract-GERALDINE-GALINATO-v2-Page-1'
  INFO: '20.05.2008' not found for 'Aggrement End Date' in file '6683127-House-Rental-Contract-GERALDINE-GALINATO-v2-Page-1'
  INFO: '15.0' not found for 'Renewal Notice (Days)' in file '6683127-House-Rental-Contract-GERALDINE-GALINATO-v2-Page-1'
  INFO: 'Antonio Levy S. Ingles, Jr. and/or Mary Rose C. Ingles' not

In [None]:
# --- Phase 3: ML-based Information Extraction (Tokenization & Model Fine-tuning) ---

print("\n--- Starting Tokenization and Model Fine-tuning ---")

# Load all necessary data from previous step
%store -r train_annotated_data
%store -r all_ner_tags
%store -r tag_to_id
%store -r id_to_tag
%store -r test_data

print("Loaded annotated training data and tag mappings.")

# Step 3.3: Prepare dataset for Hugging Face Transformers
print("\nPreparing dataset for Hugging Face Transformers...")

from datasets import Dataset
from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification
)
import torch
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

# Choose your base model
MODEL_CHECKPOINT = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)

# Convert annotated documents to tokens and aligned NER tags
def get_tokens_and_labels_for_hf(annotated_doc, tag_to_id_map):
    text = annotated_doc['document_text']
    annotations = annotated_doc['annotations']
    if not text or not annotations:
        return None

    tokens, labels = [], []
    annotations.sort(key=lambda x: x['start'])
    current_char_idx = 0

    for ann in annotations:
        before_text = text[current_char_idx : ann['start']]
        for word in before_text.split():
            tokens.append(word)
            labels.append(tag_to_id_map['O'])

        entity_text = text[ann['start'] : ann['end']]
        entity_words = entity_text.split()

        if not entity_words:
            current_char_idx = ann['end']
            continue

        tokens.append(entity_words[0])
        labels.append(tag_to_id_map[ann['label']])
        for word in entity_words[1:]:
            labels.append(tag_to_id_map[ann['label'].replace('B-', 'I-')])
            tokens.append(word)

        current_char_idx = ann['end']

    after_text = text[current_char_idx:]
    for word in after_text.split():
        tokens.append(word)
        labels.append(tag_to_id_map['O'])

    return {"tokens": tokens, "ner_tags": labels} if tokens else None

# Process all annotated examples
processed_train_examples = []
for doc in train_annotated_data:
    processed = get_tokens_and_labels_for_hf(doc, tag_to_id)
    if processed:
        processed_train_examples.append(processed)

# Build Hugging Face dataset
train_hf_dataset = Dataset.from_list(processed_train_examples)
print(f"Processed {len(processed_train_examples)} training examples.")

# Show sample
if train_hf_dataset:
    print("\nSample tokenized example:")
    print(f"Tokens: {train_hf_dataset[0]['tokens']}")
    print(f"Labels: {[id_to_tag[i] for i in train_hf_dataset[0]['ner_tags']]}")
else:
    print("No examples to show.")

# Align labels with tokens (handles subword splitting)
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        is_split_into_words=True,
    )

    labels = []
    for i, example_labels in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []

        for word_id in word_ids:
            if word_id is None:
                label_ids.append(-100)
            elif word_id != previous_word_idx:
                label_ids.append(example_labels[word_id])
            else:
                current_tag = id_to_tag[example_labels[word_id]]
                if current_tag.startswith("B-"):
                    label_ids.append(tag_to_id[current_tag.replace("B-", "I-")])
                else:
                    label_ids.append(example_labels[word_id])
            previous_word_idx = word_id

        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Tokenize the dataset
tokenized_train_dataset = train_hf_dataset.map(tokenize_and_align_labels, batched=True)
tokenized_train_dataset = tokenized_train_dataset.remove_columns(['tokens', 'ner_tags'])

print("\nDataset tokenized and aligned.")
print(f"Input IDs: {tokenized_train_dataset[0]['input_ids']}")
print(f"Labels: {[id_to_tag[i] if i != -100 else 'IGNORE' for i in tokenized_train_dataset[0]['labels']]}")

# Step 3.4: Load model and prepare for fine-tuning
print("\nLoading pre-trained model and setting up training...")

model = AutoModelForTokenClassification.from_pretrained(
    MODEL_CHECKPOINT,
    num_labels=len(all_ner_tags),
    id2label=id_to_tag,
    label2id=tag_to_id
)

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    num_train_epochs=10,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    save_strategy="epoch",
    load_best_model_at_end=False,
    push_to_hub=False,
    disable_tqdm=False,
    fp16=torch.cuda.is_available(),
)

# Trainer handles training loop
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train_dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

# Start training
print("\nFine-tuning model (this might take some time)...")
trainer.train()

print("\n--- Fine-tuning Complete ---")

# Step 3.5: Save the model and tokenizer
print("\nSaving the fine-tuned model and tokenizer...")
output_model_dir = "./fine_tuned_ner_model"
model.save_pretrained(output_model_dir)
tokenizer.save_pretrained(output_model_dir)
print(f"Model saved to '{output_model_dir}'")

print("\n--- Phase 3 (Tokenization & Model Fine-tuning) Complete ---")

# Store path for next use
%store output_model_dir


In [None]:
# --- Diagnostic: Inspecting Pseudo-Annotated Training Data ---

print("\n--- Inspecting Pseudo-Annotated Training Data ---")

# Load necessary variables from previous steps
%store -r train_annotated_data
%store -r id_to_tag
%store -r tag_to_id

# Re-define tag prefixes in case not loaded
NER_TAG_PREFIXES = {
    'Aggrement Value': 'AGREEMENT_VALUE',
    'Aggrement Start Date': 'AGREEMENT_START_DATE',
    'Aggrement End Date': 'AGREEMENT_END_DATE',
    'Renewal Notice (Days)': 'RENEWAL_NOTICE_DAYS',
    'Party One': 'PARTY_ONE',
    'Party Two': 'PARTY_TWO'
}

import pandas as pd

if 'train_annotated_data' not in locals():
    print("Error: Data not found. Please run Phase 3 Part 1 to generate annotations.")
else:
    print(f"Documents available: {len(train_annotated_data)}")

    total_annotations = sum(len(doc['annotations']) for doc in train_annotated_data)
    print(f"Total annotations generated: {total_annotations}")

    # Summarize annotation counts per document
    print("\n--- Annotations per Document (Summary) ---")
    summary = []
    for doc in train_annotated_data:
        file_name = doc['original_row']['File Name']
        annotation_count = len(doc['annotations'])

        field_count = {prefix: 0 for prefix in set(NER_TAG_PREFIXES.values())}
        for ann in doc['annotations']:
            if ann['label'].startswith('B-'):
                prefix = ann['label'][2:]
                if prefix in field_count:
                    field_count[prefix] += 1

        summary.append({
            'File Name': file_name,
            'Total Annotations': annotation_count,
            'Field Counts': field_count
        })

    print(pd.DataFrame(summary).to_string())

    # Show detailed annotations for first 3 documents with data
    print("\n--- Detailed Annotations (First 3 Non-Empty Documents) ---")
    shown = 0
    for doc in train_annotated_data:
        if shown >= 3:
            break
        if doc['annotations']:
            print(f"\nDocument: {doc['original_row']['File Name']} (ID: {doc['id']})")
            snippet = doc['document_text'][:200].replace('\n', ' ')
            print(f"Text Snippet: {snippet}...")
            print("Annotations:")
            for ann in doc['annotations']:
                s = max(0, ann['start'] - 15)
                e = min(len(doc['document_text']), ann['end'] + 15)
                context = doc['document_text'][s:e].replace('\n', ' ')
                print(f"  - Label: {ann['label']}, Value: '{ann['value']}' (Found: '{ann['found_text']}', Context: '...{context}...')")
            shown += 1

    if shown == 0:
        print("No documents with annotations available to show.")

print("\n--- Inspection Complete ---")



--- Inspecting Pseudo-Annotated Training Data ---
Documents available: 9
Total annotations generated: 20

--- Annotations per Document (Summary) ---
                                                    File Name  Total Annotations                                                                                                                          Field Counts
0  6683127-House-Rental-Contract-GERALDINE-GALINATO-v2-Page-1                  1  {'AGREEMENT_VALUE': 0, 'AGREEMENT_END_DATE': 0, 'PARTY_ONE': 0, 'RENEWAL_NOTICE_DAYS': 0, 'AGREEMENT_START_DATE': 1, 'PARTY_TWO': 0}
1         6683129-House-Rental-Contract-Geraldine-Galinato-v2                  3  {'AGREEMENT_VALUE': 0, 'AGREEMENT_END_DATE': 0, 'PARTY_ONE': 0, 'RENEWAL_NOTICE_DAYS': 1, 'AGREEMENT_START_DATE': 1, 'PARTY_TWO': 1}
2                                 18325926-Rental-Agreement-1                  2  {'AGREEMENT_VALUE': 0, 'AGREEMENT_END_DATE': 0, 'PARTY_ONE': 1, 'RENEWAL_NOTICE_DAYS': 0, 'AGREEMENT_START_DATE': 0, 'PARTY

In [None]:
# --- Phase 4: Metadata Extraction using Gemini API ---

print("\n--- Starting Metadata Extraction using Gemini API ---")

# 4.1 Install and Configure Google Generative AI SDK
print("\n4.1 Installing Google Generative AI SDK...")
!pip install -q google-generativeai

import google.generativeai as genai
import os
import json
import re
import pandas as pd

# Load required data from previous phases
%store -r test_data
%store -r TARGET_FIELDS
print("Recalled test data and target fields.")

# Set your Gemini API key here (replace with your actual key)
API_KEY = "AIzaSyB6OAtSRJ0aGeKmLGVo4BTMhpPRsmCXFs4"
genai.configure(api_key=API_KEY)

# Initialize the Gemini model
GEMINI_MODEL_NAME = "models/gemini-2.5-pro"
model_gemini = genai.GenerativeModel(GEMINI_MODEL_NAME)
print(f"Gemini model '{GEMINI_MODEL_NAME}' initialized.")

# 4.2 Define function to call Gemini API and parse output
print("\n4.2 Defining Gemini API call and parsing function...")

def extract_metadata_with_gemini(document_text, model, target_fields):
    """
    Extracts metadata from document text using Gemini.
    Returns a dictionary with all target fields (None if not found).
    """
    if not document_text or len(document_text.strip()) < 50:
        return {field: None for field in target_fields}

    fields_str = ", ".join([f'"{field}"' for field in target_fields])
    prompt = f"""
    You are an expert metadata extractor. Extract the following fields from the document below:
    {fields_str}

    Return the result as a JSON object. Use null if a field is not found.

    Document:
    ---
    {document_text}
    ---

    JSON Output:
    """

    try:
        config = genai.GenerationConfig(temperature=0.0, top_p=1.0)
        response = model.generate_content(prompt, generation_config=config)
        response_text = response.text.strip()

        # Extract JSON string from response (if wrapped in markdown)
        match = re.search(r"```json\n(.*)\n```", response_text, re.DOTALL)
        json_str = match.group(1) if match else response_text

        extracted_data = json.loads(json_str)

        # Normalize and match Gemini keys to our expected field names
        final_output = {}
        for field in target_fields:
            key_plain = field.lower().replace(' ', '').replace('(', '').replace(')', '').replace('.', '')
            key_underscore = field.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')

            val = extracted_data.get(field)
            if val is None:
                # Try case-insensitive fallback matching
                for k in extracted_data:
                    k_lower = k.lower()
                    if k_lower in [field.lower(), key_plain, key_underscore]:
                        val = extracted_data[k]
                        break

            final_output[field] = str(val).strip() if val and str(val).strip().lower() not in ['null', ''] else None

        return final_output

    except json.JSONDecodeError as e:
        print(f"JSON parse error (document starts with): '{document_text[:100]}...': {e}")
        return {field: None for field in target_fields}

    except Exception as e:
        print(f"Gemini error (document starts with): '{document_text[:100]}...': {e}")
        return {field: None for field in target_fields}

# 4.3 Run metadata extraction for all test documents
print("\n4.3 Generating predictions for test documents using Gemini API...")

gemini_predicted_metadata_list = []
for index, row in test_data.iterrows():
    file_name = row['File Name']
    extracted_text = row['extracted_text']

    print(f"  Processing '{file_name}'...")
    prediction = extract_metadata_with_gemini(extracted_text, model_gemini, TARGET_FIELDS)
    prediction['File Name'] = file_name
    gemini_predicted_metadata_list.append(prediction)

# Convert predictions to DataFrame
gemini_predictions_df = pd.DataFrame(gemini_predicted_metadata_list)
print("\nGemini Predictions Preview:")
print(gemini_predictions_df.head())

print("\n--- Gemini API Metadata Extraction Complete ---")

# Save predictions for evaluation phase
%store gemini_predictions_df



--- Starting Metadata Extraction using Gemini API ---

4.1 Installing Google Generative AI SDK...
Recalled test data and target fields.
Gemini model 'models/gemini-2.5-pro' initialized.

4.2 Defining Gemini API call and parsing function...

4.3 Generating predictions for test documents using Gemini API...
  Processing '24158401-Rental-Agreement'...
  Processing '95980236-Rental-Agreement'...
  Processing '156155545-Rental-Agreement-Kns-Home'...
  Processing '228094620-Rental-Agreement'...

Gemini Predictions Preview:
  Aggrement Value Aggrement Start Date Aggrement End Date  \
0           12000           2008-04-01         2009-03-31   
1            9000           2010-04-01         2011-02-28   
2           12000           2012-12-15         2013-11-14   
3         15000.0         July 7, 2013       June 6, 2014   

  Renewal Notice (Days)           Party One  \
0                    60       Sri Hanumaiah   
1                  None   Mrs. S.Sakunthala   
2                    30      

In [None]:
# --- Phase 5: Evaluation & Final Conclusion ---

print("\n--- Starting Evaluation & Conclusion ---")

# 5.1 Load Required Data
print("\n5.1 Recalling test data, ground truth, and predictions...")
%store -r test_data
%store -r gemini_predictions_df
%store -r TARGET_FIELDS

# Load ground truth from CSV
import pandas as pd
import os

try:
    test_df_gt = pd.read_csv(os.path.join("data", "test.csv"))
    print("Ground truth file 'test.csv' loaded successfully.")
except FileNotFoundError:
    print("Error: 'test.csv' not found in the 'data/' directory.")
    exit()

print("All data loaded for evaluation.")

# 5.2 Normalize Values for Matching
print("\n5.2 Normalizing values for comparison...")

from datetime import datetime
import re

def normalize_value_for_comparison(field_name, value):
    if value is None:
        return None

    value_str = str(value).strip()

    # Handle dates
    if "Date" in field_name:
        date_formats = [
            '%Y-%m-%d', '%d.%m.%Y', '%d/%m/%Y', '%m/%d/%Y',
            '%B %d, %Y', '%d %B %Y', '%b %d, %Y', '%d %b %Y',
            '%Y/%m/%d', '%Y.%m.%d'
        ]
        # Try direct parsing
        for fmt in date_formats:
            try:
                return datetime.strptime(value_str, fmt).strftime('%Y-%m-%d')
            except ValueError:
                continue
        # Try after removing ordinal suffixes
        cleaned = re.sub(r'(\d+)(st|nd|rd|th)', r'\1', value_str)
        for fmt in date_formats:
            try:
                return datetime.strptime(cleaned, fmt).strftime('%Y-%m-%d')
            except ValueError:
                continue
        return re.sub(r'\s+', ' ', value_str).lower()

    # Handle numbers
    elif "Value" in field_name or "Days" in field_name:
        cleaned = re.sub(r'[^0-9.]', '', value_str.lower().replace(',', ''))
        try:
            return float(cleaned)
        except ValueError:
            return None

    # Handle party names
    elif "Party" in field_name:
        cleaned = re.sub(r'[^a-zA-Z0-9\s]', '', value_str)
        return re.sub(r'\s+', ' ', cleaned).strip().lower()

    # Default normalization
    return re.sub(r'\s+', ' ', value_str).strip().lower()

print("Normalization function ready.")

# 5.3 Evaluate Gemini Predictions
print("\n5.3 Evaluating Gemini predictions...")

# Prepare structure to store results
recall_scores = {}
evaluation_results = {
    field: {"true_matches": 0, "false_matches": 0, "not_extracted_in_gt": 0}
    for field in TARGET_FIELDS
}

# Merge predictions with ground truth
evaluation_df = pd.merge(
    test_df_gt, gemini_predictions_df,
    on="File Name", how="inner", suffixes=("_gt", "_pred")
)

print(f"Documents evaluated: {len(evaluation_df)}")

# Perform per-field evaluation
for _, row in evaluation_df.iterrows():
    for field in TARGET_FIELDS:
        gt_value = row.get(f"{field}_gt")
        pred_value = row.get(f"{field}_pred")

        norm_gt = normalize_value_for_comparison(field, gt_value)
        norm_pred = normalize_value_for_comparison(field, pred_value)

        if norm_gt is not None:
            if norm_pred is not None and norm_gt == norm_pred:
                evaluation_results[field]["true_matches"] += 1
            else:
                evaluation_results[field]["false_matches"] += 1
        else:
            evaluation_results[field]["not_extracted_in_gt"] += 1

# Display per-field recall scores
print("\nPer-Field Recall Scores:")
for field, scores in evaluation_results.items():
    true = scores["true_matches"]
    false = scores["false_matches"]
    total = true + false

    if total > 0:
        recall = true / total
        print(f"- {field}: {recall:.2f} (True: {true}, False: {false}, GT Total: {total})")
    else:
        print(f"- {field}: N/A (No ground truth available)")

print("\n--- Evaluation Complete ---")

# 5.4 Generate Final Conclusion Text
print("\n--- Preparing Final Conclusion for README.md ---")

final_conclusion_text = f"""
# AI/ML System for Metadata Extraction from Documents

## 1. Problem Statement
Develop an AI/ML system to extract the following metadata fields from documents:

- Aggrement Value
- Aggrement Start Date
- Aggrement End Date
- Renewal Notice (Days)
- Party One
- Party Two

Manual rules were not allowed for the extraction system.

## 2. Approach Overview

### Phase 1: Text Extraction
- `.docx`: Extracted via `python-docx`
- Images (`.png`, `.jpg`): OCR via `pytesseract`
- Robust handling for missing files

### Phase 2: Dataset Alignment
- Extracted text aligned with `train.csv` and `test.csv` to build structured datasets

### Phase 3: BERT Fine-tuning (NER)
- Used `bert-base-uncased` for token classification
- Failed to learn effectively due to very limited data
- Model defaulted to predicting all 'O' (outside) tags

### Phase 4: Gemini API (LLM-based Extraction)
- Sent full document text to Gemini API with prompt for JSON extraction
- Output parsed and evaluated for accuracy
- Bypassed need for large labeled training sets

## 3. Evaluation Results

Per-field recall (True / (True + False)):

"""

# Append field-wise scores
for field, results in evaluation_results.items():
    true = results['true_matches']
    false = results['false_matches']
    total = true + false

    if total > 0:
        recall = true / total
        final_conclusion_text += f"- {field}: {recall:.2f} (True: {true}, False: {false})\n"
    else:
        final_conclusion_text += f"- {field}: N/A (No GT available)\n"

final_conclusion_text += """
## 4. Instructions to Run

1. Clone repo and open the notebook in Google Colab
2. Upload `train.csv` and `test.csv` into the `data/` folder
3. Add document files to `data/train/` and `data/test/`
4. Replace the placeholder with your Gemini API key
5. Run cells sequentially

## 5. Future Enhancements

- Human-annotated training set for better fine-tuning
- Advanced weak labeling for pseudo-annotations
- Combine Gemini API with classical NLP models
- Confidence scores per field
- Optional FastAPI wrapper for RESTful use

"""

print(final_conclusion_text)
print("\n--- Phase 5 Complete ---")



--- Starting Evaluation & Conclusion ---

5.1 Recalling test data, ground truth, and predictions...
Ground truth file 'test.csv' loaded successfully.
All data loaded for evaluation.

5.2 Normalizing values for comparison...
Normalization function ready.

5.3 Evaluating Gemini predictions...
Documents evaluated: 4

Per-Field Recall Scores:
- Aggrement Value: 1.00 (True: 4, False: 0, GT Total: 4)
- Aggrement Start Date: 1.00 (True: 4, False: 0, GT Total: 4)
- Aggrement End Date: 0.75 (True: 3, False: 1, GT Total: 4)
- Renewal Notice (Days): 0.50 (True: 2, False: 2, GT Total: 4)
- Party One: 0.25 (True: 1, False: 3, GT Total: 4)
- Party Two: 0.00 (True: 0, False: 4, GT Total: 4)

--- Evaluation Complete ---

--- Preparing Final Conclusion for README.md ---

# AI/ML System for Metadata Extraction from Documents

## 1. Problem Statement
Develop an AI/ML system to extract the following metadata fields from documents:

- Aggrement Value
- Aggrement Start Date
- Aggrement End Date
- Renewal N