In [5]:
#!/usr/bin/env python3
"""
Adobe Hackathon Round 1B - PDF Document Intelligence System
Context-aware extraction and ranking of document sections using LayoutLMv3 for heading detection
"""

import json
import os
import time
import re
import logging
from datetime import datetime
from typing import List, Dict, Tuple

import pdfplumber
import nltk
import torch
import torch.nn.functional as F
import numpy as np
from transformers import BertTokenizer, BertModel, LayoutLMv3Processor, LayoutLMv3ForTokenClassification
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.cluster import KMeans
from PIL import Image
import PyPDF2

logger = logging.getLogger(__name__)
logging.getLogger().setLevel(logging.INFO)

class DocumentAnalyzer:
    def __init__(self):
        """Initialize the document analyzer with BERT-Tiny and LayoutLMv3 models"""
        try:
            self.bert_tokenizer = BertTokenizer.from_pretrained('./pretrained_models_bert_tiny')
            self.bert_model = BertModel.from_pretrained('./pretrained_models_bert_tiny')
            self.bert_model.eval()
            logger.info("BERT-Tiny model loaded successfully")

            # try:
            self.layoutlm_processor = LayoutLMv3Processor.from_pretrained("./models/layoutlmv3", apply_ocr=False)
            self.layoutlm_model = LayoutLMv3ForTokenClassification.from_pretrained("./models/layoutlmv3", ignore_mismatched_sizes=True)
            # except Exception as e:
            #     logger.info(f"Using default LayoutLMv3 model due to: {str(e)}")
            #     self.layoutlm_processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)
            #     self.layoutlm_model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base", ignore_mismatched_sizes=True)
            self.layoutlm_model.eval()
            logger.info("LayoutLMv3 model and processor loaded successfully")

            try:
                nltk.data.find('tokenizers/punkt')
            except LookupError:
                logger.info("Downloading NLTK punkt data")
                nltk.download('punkt', quiet=True)

            try:
                nltk.data.find('taggers/averaged_perceptron_tagger_eng')
            except LookupError:
                logger.info("Downloading NLTK averaged_perceptron_tagger_eng data")
                nltk.download('averaged_perceptron_tagger_eng', quiet=True)

        except Exception as e:
            logger.error(f"Error initializing models: {str(e)}")
            raise

    def extract_domain_terms(self, persona: str, job_to_be_done: str) -> List[str]:
        """Extract up to 8 domain-specific terms"""
        combined_text = f"{persona} {job_to_be_done}".lower()
        stop_words = set(['a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'])
        inputs = self.bert_tokenizer(combined_text, padding=True, truncation=True, max_length=128, return_tensors='pt')
        with torch.no_grad():
            outputs = self.bert_model(**inputs)
            token_embeddings = outputs.last_hidden_state[0]
            tokens = self.bert_tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

        valid_tokens = [token for i, token in enumerate(tokens) if not token.startswith('##') and token not in stop_words and len(token) >= 3 and token not in ['[CLS]', '[SEP]']]
        valid_embeddings = [token_embeddings[i].numpy() for i, token in enumerate(tokens) if not token.startswith('##') and token not in stop_words and len(token) >= 3 and token not in ['[CLS]', '[SEP]']]

        if not valid_tokens:
            return []

        embeddings_array = np.array(valid_embeddings)
        n_clusters = min(8, len(valid_tokens))
        if n_clusters < 2:
            return valid_tokens[:8]

        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
        cluster_labels = kmeans.fit_predict(embeddings_array)

        domain_terms = []
        tfidf = TfidfVectorizer().fit([combined_text])
        vocab = tfidf.vocabulary_
        tfidf_scores = tfidf.transform([combined_text]).toarray()[0]

        for cluster_id in range(n_clusters):
            cluster_indices = [i for i, label in enumerate(cluster_labels) if label == cluster_id]
            if cluster_indices:
                cluster_tokens = [valid_tokens[i] for i in cluster_indices]
                token_scores = [(token, tfidf_scores[vocab.get(token, 0)]) for token in cluster_tokens]
                token_scores.sort(key=lambda x: x[1], reverse=True)
                domain_terms.append(token_scores[0][0])

        tagged_terms = nltk.pos_tag(domain_terms)
        domain_terms = [term for term, pos in tagged_terms if pos.startswith(('NN', 'VB', 'JJ')) and len(term) >= 4]
        domain_terms = list(dict.fromkeys(domain_terms))[:8]
        if not domain_terms:
            domain_terms = valid_tokens[:8]

        logger.info(f"Extracted domain terms: {domain_terms}")
        return domain_terms

    def extract_context_keywords(self, persona: str, job_to_be_done: str) -> List[str]:
        """Extract context keywords using domain terms"""
        combined_text = f"{persona} {job_to_be_done}".lower()
        stop_words = set(['a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'])
        words = re.findall(r'\b[a-zA-Z]{3,}\b', combined_text)
        domain_terms = self.extract_domain_terms(persona, job_to_be_done)
        keywords = [word for word in words if word not in stop_words and (any(term in word for term in domain_terms) or len(word) >= 5)]
        keywords = list(dict.fromkeys(keywords))[:10]
        if not keywords:
            keywords = [word for word in words if word not in stop_words and len(word) >= 4][:10]
        logger.info(f"Extracted context keywords: {keywords}")
        return keywords

    def is_bullet_point(self, text: str) -> bool:
        """Detect bullet points and list-like structures"""
        bullet_patterns = [
            r'^\s*[•▪▫‣⁃\u2022]\s+', r'^\s*[-*+]\s+', r'^\s*\d+[\.\)]\s+', r'^\s*[a-zA-Z][\.\)]\s+',
            r'^\s*(tip|note|warning|caution|important):?\s+'
        ]
        return any(re.match(pattern, text, re.IGNORECASE) for pattern in bullet_patterns)

    def is_valid_heading(self, text: str) -> bool:
        """Validate heading to reject fragments and bullet points"""
        text = text.strip()
        if len(text) < 3 or len(text) > 80 or len(text.split()) < 2 or len(text.split()) > 15:
            return False
        if re.match(r'.*[:,]$', text) or text.lower().startswith(('see ', 'refer ', 'check ', 'visit ')) or self.is_bullet_point(text):
            return False
        return True

    def validate_pdf(self, pdf_path: str) -> bool:
        """Validate PDF file integrity"""
        try:
            with open(pdf_path, 'rb') as f:
                PyPDF2.PdfReader(f, strict=False)
            return True
        except Exception as e:
            logger.error(f"PDF validation failed for {pdf_path}: {str(e)}")
            return False

    def extract_text_from_pdf(self, pdf_path: str, context_keywords: List[str]) -> List[Dict]:
        """Extract text and sections using only LayoutLMv3 for heading detection"""
        sections = []
        logger.info(f"Processing PDF: {pdf_path}")

        if not self.validate_pdf(pdf_path):
            logger.warning(f"Skipping {pdf_path} due to validation failure")
            return []

        try:
            with pdfplumber.open(pdf_path) as pdf:
                avg_font_size = sum(char['size'] for page in pdf.pages for char in page.chars) / max(1, sum(len(page.chars) for page in pdf.pages))
                for page_num, page in enumerate(pdf.pages, 1):
                    text_lines = page.extract_text_lines(return_chars=True) or []
                    if not text_lines:
                        logger.warning(f"No text extracted from page {page_num}")
                        continue

                    # Use pdfplumber's built-in image conversion to avoid pypdfium2
                    page_image = page.to_image(resolution=300).original.convert("RGB")
                    words = []
                    boxes = []
                    line_indices = []
                    for idx, line in enumerate(text_lines):
                        line_text = line['text'].strip()
                        if not line_text or len(line_text) < 3 or re.match(r'^\d+$', line_text):
                            continue
                        words.append(line_text)
                        x0, y0, x1, y1 = line['x0'], line['top'], line['x1'], line['bottom']
                        boxes.append([int(1000 * x0 / page.width), int(1000 * y0 / page.height), int(1000 * x1 / page.width), int(1000 * y1 / page.height)])
                        line_indices.append(idx)

                    if not words:
                        continue

                    encoding = self.layoutlm_processor(page_image, words, boxes=boxes, return_tensors="pt", truncation=True, max_length=512)
                    with torch.no_grad():
                        outputs = self.layoutlm_model(**encoding)
                        logits = outputs.logits
                        probabilities = F.softmax(logits, dim=-1)
                        predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
                        heading_probs = probabilities[:, :, 1].squeeze().tolist()[:len(words)]

                    id2label = {0: "body", 1: "heading"}
                    layoutlm_labels = []
                    for pred, prob in zip(predictions[:len(words)], heading_probs):
                        if id2label.get(pred) == "heading" and prob < 0.7:
                            layoutlm_labels.append("body")
                        else:
                            layoutlm_labels.append(id2label.get(pred, "body"))

                    current_section = None
                    section_text = []
                    for word_idx, (layoutlm_label, line_idx) in enumerate(zip(layoutlm_labels, line_indices)):
                        line = text_lines[line_idx]
                        line_text = line['text'].strip()
                        if not line_text or len(line_text) < 3 or re.match(r'^\d+$', line_text) or self.is_bullet_point(line_text):
                            continue

                        chars = line.get('chars', [])
                        font_size = max((char['size'] for char in chars), default=12) if chars else 12
                        line_text_normalized = line_text.title()

                        if layoutlm_label == "heading" and self.is_valid_heading(line_text):
                            if current_section and section_text:
                                combined_text = ' '.join(section_text).strip()
                                if len(combined_text.split()) >= 10:
                                    sections.append({
                                        'document': os.path.basename(pdf_path),
                                        'page': current_section['page'],
                                        'section_title': current_section['title'][:80],
                                        'text': combined_text,
                                        'level': current_section['level']
                                    })
                            level = 1 if font_size >= avg_font_size * 1.1 else 2
                            current_section = {'title': line_text_normalized, 'page': page_num, 'level': level}
                            section_text = []
                        else:
                            if current_section:
                                section_text.append(line_text)

                    if current_section and section_text:
                        combined_text = ' '.join(section_text).strip()
                        if len(combined_text.split()) >= 10:
                            sections.append({
                                'document': os.path.basename(pdf_path),
                                'page': current_section['page'],
                                'section_title': current_section['title'][:80],
                                'text': combined_text,
                                'level': current_section['level']
                            })

        except Exception as e:
            logger.error(f"Error processing {pdf_path}: {str(e)}")
            try:
                with pdfplumber.open(pdf_path) as pdf:
                    fallback_count = 0
                    for page_num, page in enumerate(pdf.pages, 1):
                        text = page.extract_text()
                        if text and len(text.strip()) > 50 and fallback_count < 5:
                            paragraphs = [p.strip() for p in text.split('\n\n') if len(p.strip()) > 20]
                            for para in paragraphs[:3]:
                                first_line = para.split('\n')[0].strip()
                                if len(first_line.split()) >= 2 and len(para.split()) >= 10 and self.is_valid_heading(first_line):
                                    chars = page.chars or []
                                    font_size = max((char['size'] for char in chars), default=12) if chars else 12
                                    if font_size >= 12:
                                        sections.append({
                                            'document': os.path.basename(pdf_path),
                                            'page': page_num,
                                            'section_title': first_line.title()[:80],
                                            'text': para,
                                            'level': 2
                                        })
                                        fallback_count += 1
            except Exception as fallback_error:
                logger.error(f"Fallback extraction failed for {pdf_path}: {str(fallback_error)}")

        merged_sections = []
        section_map = {}
        for section in sections:
            key = (section['document'], section['section_title'].lower(), section['page'])
            if key in section_map:
                section_map[key]['text'] += ' ' + section['text']
            else:
                section_map[key] = section
        merged_sections = list(section_map.values())

        logger.info(f"Extracted {len(merged_sections)} sections from {pdf_path}")
        return merged_sections[:15]

    def encode_text_bert(self, texts: List[str], batch_size: int = 16) -> np.ndarray:
        """Encode texts with mean pooling"""
        embeddings = []
        for i in range(0, len(texts), batch_size):
            batch_texts = [re.sub(r'\s+', ' ', t.strip())[:500] for t in texts[i:i + batch_size]]
            inputs = self.bert_tokenizer(batch_texts, padding=True, truncation=True, max_length=256, return_tensors='pt')
            with torch.no_grad():
                outputs = self.bert_model(**inputs)
                attention_mask = inputs['attention_mask']
                token_embeddings = outputs.last_hidden_state
                input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
                batch_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
                embeddings.extend(batch_embeddings.numpy())
        return np.array(embeddings)

    def compute_relevance_scores(self, sections: List[Dict], persona: str, job_to_be_done: str, context_keywords: List[str]) -> List[Tuple[Dict, float]]:
        """Compute relevance scores with enhanced keyword weighting"""
        if not sections:
            return []

        logger.info(f"Computing relevance scores for {len(sections)} sections")
        keywords_str = " ".join(context_keywords)
        query = f"As a {persona}, I need to {job_to_be_done}. Focus on: {keywords_str}"
        domain_terms = self.extract_domain_terms(persona, job_to_be_done)
        section_texts = [f"{section['section_title']} {section['text']}" for section in sections]
        all_texts = [query] + section_texts
        embeddings = self.encode_text_bert(all_texts, batch_size=16)
        query_embedding = embeddings[0:1]
        section_embeddings = embeddings[1:]

        similarities = cosine_similarity(query_embedding, section_embeddings)[0]
        adjusted_scores = []
        for i, (sim, section) in enumerate(zip(similarities, sections)):
            word_count = len(section['text'].split())
            length_factor = 0.5 if word_count < 20 else min(1.5, 1 + 0.3 * np.log(word_count / 20))
            title_quality = 1.5 if any(word in section['section_title'].lower() for word in ['form', 'fillable', 'onboarding', 'compliance', 'create', 'manage', 'signature']) else 1.0
            section_content = f"{section['section_title']} {section['text']}".lower()
            keyword_matches = sum(1 for keyword in context_keywords + domain_terms if keyword in section_content)
            keyword_boost = 3.0 if keyword_matches >= 3 else 2.0 if keyword_matches == 2 else 1.5 if keyword_matches == 1 else 1.0
            level_boost = 1.2 if section['level'] == 1 else 1.0 if section['level'] == 2 else 0.8
            domain_boost = 1 + sum(0.15 for term in domain_terms if term in section_content)
            adjusted_score = sim * length_factor * title_quality * keyword_boost * level_boost * domain_boost
            adjusted_scores.append(adjusted_score)

        scored_sections = list(zip(sections, adjusted_scores))
        scored_sections.sort(key=lambda x: x[1], reverse=True)
        return scored_sections[:max(7, len(scored_sections))]

    def extract_key_sentences(self, text: str, persona: str, job_to_be_done: str, context_keywords: List[str], top_k: int = 5) -> str:
        """Extract key sentences ensuring multi-sentence output"""
        sentences = [s.strip() for s in nltk.sent_tokenize(text) if len(s.split()) >= 15 and len(s) >= 50 and not s.lower().startswith(('note:', 'tip:', 'warning:', 'caution:')) and not s.endswith(('...', ','))]
        if len(sentences) < 2:
            paragraphs = re.split(r'\n{2,}', text)
            combined_text = ' '.join([p.strip() for p in paragraphs if len(p.strip().split()) >= 15][:3])
            sentences = [s.strip() for s in nltk.sent_tokenize(combined_text) if len(s.split()) >= 15 and len(s) >= 50 and not s.endswith(('...', ','))]
            if len(sentences) < 2:
                return combined_text[:400] if len(combined_text) > 50 else combined_text

        logger.info(f"Extracting {top_k} key sentences from {len(sentences)} sentences")
        keywords_str = " ".join(context_keywords)
        query = f"As a {persona}, I need to {job_to_be_done}. Key aspects: {keywords_str}"
        all_texts = [query] + sentences
        embeddings = self.encode_text_bert(all_texts, batch_size=16)
        query_embedding = embeddings[0:1]
        sentence_embeddings = embeddings[1:]
        similarities = cosine_similarity(query_embedding, sentence_embeddings)[0]

        adjusted_similarities = []
        for i, (sim, sentence) in enumerate(zip(similarities, sentences)):
            keyword_boost = 1 + sum(0.25 for keyword in context_keywords if keyword in sentence.lower())
            action_boost = 0.2 if any(word in sentence.lower() for word in ['create', 'manage', 'fill', 'sign', 'distribute', 'collect']) else 0.0
            adjusted_similarities.append(sim * keyword_boost * (1 + action_boost))

        sentence_scores = list(enumerate(adjusted_similarities))
        sentence_scores.sort(key=lambda x: x[1], reverse=True)
        top_indices = [idx for idx, _ in sentence_scores[:max(top_k, 2)]]
        top_indices.sort()
        selected_sentences = [sentences[i] for i in top_indices]
        return ' '.join(selected_sentences)

def load_input_config(input_dir: str) -> Tuple[str, str]:
    """Load input configuration"""
    config_path = os.path.join(input_dir, "input.json")
    if not os.path.exists(config_path):
        raise FileNotFoundError(f"input.json not found at {config_path}")
    with open(config_path, 'r', encoding='utf-8') as f:
        config = json.load(f)
    persona = config.get('persona')
    job_to_be_done = config.get('job_to_be_done')
    if isinstance(persona, dict):
        persona = persona.get('role') or persona.get('name')
    if isinstance(job_to_be_done, dict):
        job_to_be_done = job_to_be_done.get('task') or job_to_be_done.get('description')
    if not persona or not job_to_be_done:
        raise ValueError("Both 'persona' and 'job_to_be_done' must be specified in input.json")
    return str(persona), str(job_to_be_done)

def find_pdf_files(pdfs_dir: str) -> List[str]:
    """Find and validate PDF files"""
    if not os.path.exists(pdfs_dir):
        raise FileNotFoundError(f"PDFs directory not found: {pdfs_dir}")
    pdf_files = [f for f in os.listdir(pdfs_dir) if f.lower().endswith('.pdf')]
    if not pdf_files:
        raise FileNotFoundError(f"No PDF files found in {pdfs_dir}")
    return [os.path.join(pdfs_dir, f) for f in pdf_files]

def main():
    """Main function to process documents and generate output"""
    start_time = time.time()
    try:
        input_dir = "./input"
        pdfs_dir = os.path.join(input_dir, "PDFs")
        output_dir = "./output"
        os.makedirs(output_dir, exist_ok=True)

        logger.info("Loading input configuration")
        persona, job_to_be_done = load_input_config(input_dir)
        logger.info(f"Persona: {persona}")
        logger.info(f"Job to be done: {job_to_be_done}")

        pdf_paths = find_pdf_files(pdfs_dir)
        pdf_files = [os.path.basename(path) for path in pdf_paths]
        logger.info(f"Processing {len(pdf_files)} documents: {pdf_files}")

        logger.info("Initializing document analyzer")
        analyzer = DocumentAnalyzer()

        context_keywords = analyzer.extract_context_keywords(persona, job_to_be_done)
        all_sections = []
        for pdf_path in pdf_paths:
            sections = analyzer.extract_text_from_pdf(pdf_path, context_keywords)
            all_sections.extend(sections)

        if not all_sections:
            raise ValueError("No sections extracted from any document")

        logger.info(f"Total sections extracted: {len(all_sections)}")
        scored_sections = analyzer.compute_relevance_scores(all_sections, persona, job_to_be_done, context_keywords)
        top_sections = scored_sections

        output_data = {
            "metadata": {
                "input_documents": pdf_files,
                "persona": persona,
                "job_to_be_done": job_to_be_done,
                "context_keywords": context_keywords,
                "processing_timestamp": datetime.now().isoformat(),
                "total_sections_processed": len(all_sections)
            },
            "extracted_sections": [
                {
                    "document": section['document'],
                    "section_title": section['section_title'],
                    "importance_rank": rank,
                    "page_number": section['page'],
                    "heading_level": section['level']
                } for rank, (section, _) in enumerate(top_sections, 1)
            ],
            "sub_section_analysis": [
                {
                    "document": section['document'],
                    "section_title": section['section_title'],
                    "refined_text": analyzer.extract_key_sentences(
                        section['text'], persona, job_to_be_done, context_keywords
                    ),
                    "page_number": section['page']
                } for section, _ in top_sections
            ]
        }

        output_path = os.path.join(output_dir, "output.json")
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(output_data, f, indent=2, ensure_ascii=False)

        processing_time = time.time() - start_time
        logger.info(f"Processing completed in {processing_time:.2f} seconds")
        logger.info(f"Output saved to {output_path}")
        for rank, (section, score) in enumerate(top_sections, 1):
            logger.info(f"{rank}. {section['document']} - {section['section_title']} (Level: {section['level']}, Score: {score:.3f})")

        return output_data

    except Exception as e:
        logger.error(f"Processing failed: {str(e)}")
        raise

if __name__ == "__main__":
    main()

Some weights of LayoutLMv3ForTokenClassification were not initialized from the model checkpoint at ./models/layoutlmv3 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
No text extracted from page 10
No text extracted from page 12
