In [1]:
import os

cache_dir ='/scratch/hakeem.at/Queryable-Shared-Reference-Repository/notebooks/pretrained_models'

os.environ['HF_HOME'] = cache_dir
os.environ['TRANSFORMERS_CACHE'] = cache_dir
os.environ['HUGGINGFACE_HUB_CACHE'] = cache_dir


In [2]:
import warnings
import logging
import os
import sys
from contextlib import contextmanager

warnings.filterwarnings("ignore")
os.environ['PYTHONWARNINGS'] = 'ignore'

logging.getLogger("docling").setLevel(logging.ERROR)
logging.getLogger("docling.backend").setLevel(logging.ERROR)
logging.getLogger("docling.datamodel").setLevel(logging.ERROR)
logging.getLogger("docling_parse").setLevel(logging.ERROR)
logging.getLogger("PIL").setLevel(logging.ERROR)
logging.getLogger("pdfplumber").setLevel(logging.ERROR)
logging.getLogger("pdfminer").setLevel(logging.ERROR)

for logger_name in logging.Logger.manager.loggerDict.keys():
    if 'docling' in logger_name.lower() or 'pdf' in logger_name.lower():
        logging.getLogger(logger_name).setLevel(logging.CRITICAL)

@contextmanager
def suppress_stdout_stderr():
    """Suppress all output to stdout and stderr."""
    null_file = open(os.devnull, 'w')
    old_stdout = sys.stdout
    old_stderr = sys.stderr
    sys.stdout = null_file
    sys.stderr = null_file
    try:
        yield
    finally:
        sys.stdout = old_stdout
        sys.stderr = old_stderr
        null_file.close()

from docling.document_converter import DocumentConverter, PdfFormatOption


In [3]:
with suppress_stdout_stderr():
    print("hi")

In [4]:
import os
import time
from collections import Counter
from pathlib import Path
import numpy as np
import pandas as pd
import json
import random
import re
from tqdm.auto import tqdm
from pydantic import BaseModel, Field

from langchain_community.document_loaders import PyPDFLoader, BSHTMLLoader, Docx2txtLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter, TokenTextSplitter
from langchain_experimental.text_splitter import SemanticChunker
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import FAISS

import torch
import torch.nn.functional as F
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, AutoModel
# from adapters import AutoAdapterModel
from multiprocessing import Pool, cpu_count

In [5]:
from docling.document_converter import DocumentConverter, PdfFormatOption
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.backend.docling_parse_v2_backend  import DoclingParseV2DocumentBackend
from docling_core.types.doc.document import DoclingDocument
from docling.chunking import HybridChunker
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.transforms.chunker.tokenizer.huggingface import HuggingFaceTokenizer
from docling_core.types.doc.labels import DocItemLabel
import multiprocessing as mp

In [6]:
import mistune

In [7]:
pipeline_options = PdfPipelineOptions()
pipeline_options.do_ocr = False

In [8]:
import torch
import gc

def clear_gpu(*items):
    for item in items:
        del item
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.synchronize()

In [9]:
seed = 42
random.seed(seed)
# np.random.seed(seed)
# torch.manual_seed(seed)

# if torch.cuda.is_available():
#     torch.cuda.manual_seed(seed)
#     torch.backends.cudnn.deterministic = True

In [51]:
converter = DocumentConverter(format_options={
    InputFormat.PDF: PdfFormatOption(
        pipeline_options = pipeline_options,
        backend = DoclingParseV2DocumentBackend
    )
})
chunker = HybridChunker(merge_peers=True, max_tokens=1024)

In [49]:
output_file = "docling_processed_text.jsonl"
loaded_docs = []
with open(output_file, "r") as f:
    for line in f:
        doc = json.loads(line)
        doc = DoclingDocument.model_validate(doc)
        loaded_docs.append(doc)

In [55]:
from collections import defaultdict

chunk_lookup = defaultdict(dict)

min_chunk_length = 100
skip_headings = ['reference','references', 'bibliography', 'works cited', 'citations']
for doc_idx, loaded_doc in enumerate(tqdm(loaded_docs, total = len(loaded_docs))):
    chunks = list(chunker.chunk(dl_doc = loaded_doc))
    for chunk_idx, chunk in enumerate(chunks):
        if not len(chunk.text.strip()) > min_chunk_length:
            continue
        metadata = chunk.meta
        skip = False
        if hasattr(metadata, "headings") and metadata.headings:
            for heading in metadata.headings:
                if heading.lower() in skip_headings:
                    skip=True
                    break
        if skip:
            continue
        chunked_data = {
            'chunk': chunk.text.strip(),
            'doc_idx': doc_idx,
            'chunk_idx':chunk_idx,
            'filename':loaded_doc.origin.filename
        }
        if chunked_data["chunk"]:
            text_key = chunked_data["chunk"][:200].strip()
            chunk_lookup[chunked_data['filename']][text_key] = (chunked_data['chunk_idx'], chunked_data['chunk'])

  0%|          | 0/249 [00:00<?, ?it/s]

In [25]:
max_dp = 900
df = pd.read_json("prompt_thresholding_responses.jsonl")
df = df[(df['question_type']!='borderline') & (df['prompt_type']=="explicit_idk")]
df = df[:900]
df.head()

Unnamed: 0,question,question_type,context,source,ground_truth,prompt_type,raw_response
1500,What is the DOI for the Supporting Information...,answerable,The Supporting Information is available free o...,Holzlechner et al. - 2017 - In Situ Characteri...,The Supporting Information is available free o...,explicit_idk,The DOI for the Supporting Information is 10.1...
1502,What software was used to process the MALDI MS...,unanswerable,The Supporting Information is available free o...,Holzlechner et al. - 2017 - In Situ Characteri...,,explicit_idk,I don't know.
1503,What method was used to measure local variance...,answerable,Table /1 the derivative of the sum of local co...,Cachier and Pennec - 2000 - 3D non-rigid regis...,cal variance using the difference of mean rela...,explicit_idk,The method used to measure local variance usin...
1505,What specific programming language was used to...,unanswerable,Table /1 the derivative of the sum of local co...,Cachier and Pennec - 2000 - 3D non-rigid regis...,,explicit_idk,I don't know.
1506,What cell lines were obtained from American Ty...,answerable,For details about the chemicals used in the st...,Zhang et al. - 2023 - Single-cell lipidomics e...,"The commercial human cell lines, including pan...",explicit_idk,


In [56]:
matched = 0
unmatched_examples = []

for idx, row in tqdm(df.iterrows(), total=len(df)):
    filename = row['source']
    context = row['context'][:200].strip()
    
    if filename in chunk_lookup and context in chunk_lookup[filename]:
        matched += 1
    else:
        unmatched_examples.append({
            'idx': idx,
            'filename': filename,
            'context_len': len(context),
            'context_preview': context[:100]
        })

print(f"Exact match rate: {matched}/{len(df)} ({matched*100/len(df):.1f}%)")

if unmatched_examples:
    print(f"\nFirst unmatched example:")
    ex = unmatched_examples[0]
    print(f"Filename: {ex['filename']}")
    print(f"Context preview: {ex['context_preview']}")
    
    if ex['filename'] in chunk_lookup:
        print(f"Filename found, {len(chunk_lookup[ex['filename']])} chunks available")
        for key in list(chunk_lookup[ex['filename']].keys())[:3]:
            print(f"Sample chunk start: {key[:100]}")

  0%|          | 0/900 [00:00<?, ?it/s]

Exact match rate: 900/900 (100.0%)


In [61]:
from transformers import AutoTokenizer
from collections import defaultdict
import numpy as np

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B", cache_dir=cache_dir)

In [64]:
def expand_top(filename, current_idx, target_tokens):
    current_tokens, current_text = chunk_data[filename][current_idx]
    
    if current_tokens >= target_tokens:
        return current_text, current_tokens
    
    sorted_idx = chunk_indices[filename]
    current_pos = sorted_idx.index(current_idx)
    
    indices_before = list(reversed(sorted_idx[:current_pos]))
    
    expanded_chunks = []
    accumulated_tokens = 0
    needed_tokens = target_tokens - current_tokens
    
    for idx in indices_before:
        if accumulated_tokens >= needed_tokens:
            break
        num_tokens, text = chunk_data[filename][idx]
        expanded_chunks.append((idx, text, num_tokens))
        accumulated_tokens += num_tokens
    
    if accumulated_tokens < needed_tokens and expanded_chunks:
        while accumulated_tokens < needed_tokens:
            for idx, text, num_tokens in list(expanded_chunks):
                if accumulated_tokens >= needed_tokens:
                    break
                expanded_chunks.append((idx, text, num_tokens))
                accumulated_tokens += num_tokens
    
    expanded_chunks.sort(key=lambda x: x[0])
    
    padding_text = "\n\n".join([text for _, text, _ in expanded_chunks])
    final_text = padding_text + "\n\n" + current_text if padding_text else current_text
    
    return final_text, accumulated_tokens + current_tokens


def expand_bottom(filename, current_idx, target_tokens):
    current_tokens, current_text = chunk_data[filename][current_idx]
    
    if current_tokens >= target_tokens:
        return current_text, current_tokens
    
    sorted_idx = chunk_indices[filename]
    current_pos = sorted_idx.index(current_idx)
    
    indices_after = sorted_idx[current_pos + 1:]
    
    expanded_chunks = []
    accumulated_tokens = 0
    needed_tokens = target_tokens - current_tokens
    
    for idx in indices_after:
        if accumulated_tokens >= needed_tokens:
            break
        num_tokens, text = chunk_data[filename][idx]
        expanded_chunks.append((idx, text, num_tokens))
        accumulated_tokens += num_tokens
    
    if accumulated_tokens < needed_tokens and expanded_chunks:
        while accumulated_tokens < needed_tokens:
            for idx, text, num_tokens in list(expanded_chunks):
                if accumulated_tokens >= needed_tokens:
                    break
                expanded_chunks.append((idx, text, num_tokens))
                accumulated_tokens += num_tokens
    
    padding_text = "\n\n".join([text for _, text, _ in expanded_chunks])
    final_text = current_text + "\n\n" + padding_text if padding_text else current_text
    
    return final_text, accumulated_tokens + current_tokens


def expand_middle(filename, current_idx, target_tokens):
    current_tokens, current_text = chunk_data[filename][current_idx]
    
    if current_tokens >= target_tokens:
        return current_text, current_tokens
    
    sorted_idx = chunk_indices[filename]
    current_pos = sorted_idx.index(current_idx)
    
    indices_before = list(reversed(sorted_idx[:current_pos]))
    indices_after = sorted_idx[current_pos + 1:]
    
    needed_tokens = target_tokens - current_tokens
    needed_per_side = needed_tokens // 2
    
    before_chunks = []
    after_chunks = []
    before_tokens = 0
    after_tokens = 0
    
    for idx in indices_before:
        if before_tokens >= needed_per_side:
            break
        num_tokens, text = chunk_data[filename][idx]
        before_chunks.append((idx, text, num_tokens))
        before_tokens += num_tokens
    
    for idx in indices_after:
        if after_tokens >= needed_per_side:
            break
        num_tokens, text = chunk_data[filename][idx]
        after_chunks.append((idx, text, num_tokens))
        after_tokens += num_tokens
    
    total = before_tokens + after_tokens
    if total < needed_tokens:
        remaining = needed_tokens - total
        
        for idx in indices_before[len(before_chunks):]:
            if before_tokens + after_tokens >= needed_tokens:
                break
            num_tokens, text = chunk_data[filename][idx]
            before_chunks.append((idx, text, num_tokens))
            before_tokens += num_tokens
        
        for idx in indices_after[len(after_chunks):]:
            if before_tokens + after_tokens >= needed_tokens:
                break
            num_tokens, text = chunk_data[filename][idx]
            after_chunks.append((idx, text, num_tokens))
            after_tokens += num_tokens
    
    total = before_tokens + after_tokens
    if total < needed_tokens:
        all_chunks = before_chunks + after_chunks
        if all_chunks:
            while before_tokens + after_tokens < needed_tokens:
                for idx, text, num_tokens in list(all_chunks):
                    if before_tokens + after_tokens >= needed_tokens:
                        break
                    if idx < current_idx:
                        before_chunks.append((idx, text, num_tokens))
                        before_tokens += num_tokens
                    else:
                        after_chunks.append((idx, text, num_tokens))
                        after_tokens += num_tokens
    
    before_chunks.sort(key=lambda x: x[0])
    
    parts = []
    if before_chunks:
        parts.append("\n\n".join([text for _, text, _ in before_chunks]))
    parts.append(current_text)
    if after_chunks:
        parts.append("\n\n".join([text for _, text, _ in after_chunks]))
    
    final_text = "\n\n".join(parts)
    return final_text, before_tokens + current_tokens + after_tokens



In [None]:
MAX_CONTEXT = 32768

CONTEXT_PERCENTAGES = [0.10, 0.25, 0.50, 0.75, 0.95]
TARGET_LENGTHS = {p: int(MAX_CONTEXT * p) for p in CONTEXT_PERCENTAGES}
print("Target lengths:", TARGET_LENGTHS)
# {0.1: 3276, 0.25: 8192, 0.5: 16384, 0.75: 24576, 0.95: 31129}

chunk_data = defaultdict(dict)    
chunk_indices = defaultdict(list) 

min_chunk_length = 100
skip_headings = ['reference', 'references', 'bibliography', 'works cited', 'citations']

for doc_idx, loaded_doc in enumerate(tqdm(loaded_docs, total=len(loaded_docs))):
    filename = loaded_doc.origin.filename
    chunks = list(chunker.chunk(dl_doc=loaded_doc))
    
    for chunk_idx, chunk in enumerate(chunks):
        text = chunk.text.strip()
        if len(text) <= min_chunk_length:
            continue
            
        metadata = chunk.meta
        skip = False
        if hasattr(metadata, "headings") and metadata.headings:
            for heading in metadata.headings:
                if heading.lower() in skip_headings:
                    skip = True
                    break
        if skip:
            continue
        
        num_tokens = len(tokenizer.encode(text, add_special_tokens=False))
        chunk_data[filename][chunk_idx] = (num_tokens, text)
    
    chunk_indices[filename] = sorted(chunk_data[filename].keys())

df_to_chunk = {}

for idx, row in tqdm(df.iterrows(), total=len(df)):
    filename = row['source']
    context_key = row['context'][:200].strip()
    
    for chunk_idx, (num_tokens, text) in chunk_data[filename].items():
        if text[:200].strip() == context_key:
            df_to_chunk[idx] = (filename, chunk_idx)
            break

print(f"Mapped {len(df_to_chunk)}/{len(df)} rows to chunk indices")

expansion_functions = {
    'top': expand_top,
    'bottom': expand_bottom,
    'middle': expand_middle
}

df['expansion_type'] = np.tile(['top', 'bottom', 'middle'], len(df) // 3 + 1)[:len(df)]
np.random.shuffle(df['expansion_type'].values) 

expanded_dataset = []

for idx, row in tqdm(df.iterrows(), total=len(df)):
    if idx not in df_to_chunk:
        continue
    
    filename, chunk_idx = df_to_chunk[idx]
    expansion_type = row['expansion_type']
    expand_fn = expansion_functions[expansion_type]
    
    for pct, target_len in TARGET_LENGTHS.items():
        expanded_text, actual_tokens = expand_fn(filename, chunk_idx, target_len)
        
        expanded_dataset.append({
            'original_idx': idx,
            'source': filename,
            'question_type': row['question_type'],
            'query': row["question"],
            'original_context': row['context'],
            'expanded_context': expanded_text,
            'expansion_type': expansion_type,
            'target_pct': pct,
            'actual_pct': actual_tokens/MAX_CONTEXT,
            'target_tokens': target_len,
            'actual_tokens': actual_tokens,
            'ground_truth': row['ground_truth'],
            'prompt_type': row['prompt_type'],
            'raw_response': row['raw_response'],
        })

expanded_df = pd.DataFrame(expanded_dataset)
print(f"Created {len(expanded_df)} expanded samples")
print(f"Distribution:\n{expanded_df.groupby(['expansion_type', 'target_pct']).size().unstack()}")

expanded_df.to_json("context_expansion_dataset.jsonl", orient='records', lines=True)

In [66]:
sample = expanded_df[expanded_df['target_pct'] == 0.5].iloc[0]
actual = len(tokenizer.encode(sample['expanded_context'], add_special_tokens=False))
print(f"Target: {sample['target_tokens']}, Recorded: {sample['actual_tokens']}, Verified: {actual}")

print(expanded_df.groupby('expansion_type')['actual_tokens'].describe())

Target: 16384, Recorded: 17118, Verified: 17126
                 count          mean           std     min      25%      50%  \
expansion_type                                                                 
bottom          1500.0  16422.755333  10524.443741    41.0  8227.25  16560.5   
middle          1495.0  17152.258863  10169.104983  3279.0  8346.50  16705.0   
top             1495.0  16316.797324  10581.560186    20.0  8229.50  16600.0   

                    75%      max  
expansion_type                    
bottom          25034.0  32237.0  
middle          25144.0  32293.0  
top             25056.0  32627.0  
