In [2]:
# !pip install transformers torch scikit-learn pandas IProgress

In [5]:
import torch
import numpy as np
import pandas as pd
import pdfplumber

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

from transformers import AutoTokenizer, DistilBertModel

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_name)
distilbert = DistilBertModel.from_pretrained(model_name)
distilbert.to(device)
distilbert.eval()



DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

# Convert PDF to text and extract sections

In [144]:
pdf_path = "RSM.pdf"

pages = []

with pdfplumber.open(pdf_path) as pdf:
    for page_num, page in enumerate(pdf.pages):
        text = page.extract_text()
        if text is None:
            text = ""
        pages.append(text)

full_text = "\n\n".join(pages)
table_of_contents = full_text[3250:20000].split('\n')

In [202]:
def isSection(line):
    elems = line.split(' ')
    if not line or len(elems) == 0:
        return False
    if elems[0][0].isnumeric():
        return True
    elif elems[0] == 'Appendix':
        return True
    return False

def getSectionMeta(section):
    multi_line_section_map = {
        '2.2.3': 17,
        '2.3.4.1': 22,
        '2.3.6': 22,
        '2.4.6': 35,
        '2.4.14.4': 54,
        'Appendix 4': 71,
    }
    
    removed_page = section[:-2]
    elems = removed_page.split(' ')[:-1]
    is_appendix = elems[0] == 'Appendix'
    section_number = " ".join(elems[:2]) if is_appendix else elems[0]
    title = " ".join(elems[3:]) if is_appendix else " ".join(elems[1:])
    page = multi_line_section_map[section_number]-1 if section_number in multi_line_section_map else int(section[-2:].strip(' '))
    return {
        "section_number": section_number,
        "title": title,
        "page": page,
    }


def getSections(section_metas, pages):
    full_sections = []
    n_sections = len(section_metas)
    n_pages = len(pages)

    for i, meta in enumerate(section_metas):
        start_page_idx = meta["page"] - 1
        if i + 1 < n_sections:
            next_meta = section_metas[i + 1]
            end_page_idx = next_meta["page"] - 1
        else:
            next_meta = None
            end_page_idx = n_pages - 1

        start_page_text = pages[start_page_idx]
        start_idx = start_page_text.find(meta['title'])

        if next_meta is not None:
            end_page_text = pages[end_page_idx]
            end_idx = end_page_text.find(next_meta['title'])

        else:
            # Last section goes to end of last page
            end_idx = len(pages[end_page_idx])

        if start_page_idx == end_page_idx:
            text_chunks = [pages[start_page_idx][start_idx:end_idx]]
        else:
            text_chunks = []
            text_chunks.append(pages[start_page_idx][start_idx:])
            for p in range(start_page_idx + 1, end_page_idx):
                text_chunks.append(pages[p])

            text_chunks.append(pages[end_page_idx][:end_idx])

        if len(text_chunks) == 1 and not text_chunks[0]:
            continue

        section_text = "\n\n".join(text_chunks)
            
        full_sections.append({
            **meta,
            "text": section_text,
        })

    return full_sections

def printSection(section_number):
    section = sections.loc[sections.section_number == section_number].to_dict('records')[0]
    print(f"{section['section_number']} {section['title']} (page {section['page']+1})\n\n{section['text']}")

section_metas = [getSectionMeta(s) for s in table_of_contents if isSection(s)]
sections = pd.DataFrame(getSections(section_metas, pages))

In [204]:
section_texts = sections['text']
questions = pd.Series(['What do I do if I spilled something?', 'I got something in eye, what do I do?','What personal protective equipment do I need?'])

# TF_IDF

In [205]:
all_texts = pd.concat(
    [section_texts, questions],
    ignore_index=True
)

tfidf = TfidfVectorizer()
tfidf_matrix = tfidf.fit_transform(all_texts)

n_sections = len(section_texts)
section_tfidf = tfidf_matrix[:n_sections]
question_tfidf = tfidf_matrix[n_sections:] 

def rank_sections_tfidf(q_idx):
    q_vec = question_tfidf[q_idx]
    sims = cosine_similarity(q_vec, section_tfidf)[0]
    ranked = np.argsort(-sims)
    return ranked

def top_n_sections_tfidf(q_idx, n_sections):
    ranks = rank_sections_tfidf(q_idx)
    return np.argsort(ranks)[:n_sections]

In [206]:
top_n_sections_tfidf(0,10)

array([55, 61, 62, 63, 47, 58, 57, 56, 59, 50])

# DistilBERT embedding baseline

In [207]:
@torch.no_grad()
def encode_texts(texts, max_length=256):
    all_embs = []
    batch_size = 8
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        tokens = tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(device)

        outputs = distilbert(**tokens)
        hidden = outputs.last_hidden_state

        # Compute mean embedding
        mask = tokens["attention_mask"].unsqueeze(-1)
        emb = (hidden * mask).sum(dim=1) / mask.sum(dim=1)
        all_embs.append(emb.cpu().numpy())

    return np.vstack(all_embs)


In [208]:
section_texts = sections["text"].tolist()
question_texts = questions.tolist()

section_embeddings = encode_texts(section_texts)
question_embeddings = encode_texts(question_texts)

section_embeddings.shape, question_embeddings.shape


((124, 768), (3, 768))

In [209]:
def rank_sections_embedding(q_idx):
    q_vec = question_embeddings[q_idx:q_idx+1]        # shape (1, H)
    sims = cosine_similarity(q_vec, section_embeddings)[0]  # (n_sections,)
    ranked = np.argsort(-sims)                        # indices of sections_df, best first
    return ranked

def top_n_sections_embedding(q_idx, n_sections):
    ranks = rank_sections_embedding(q_idx)
    return np.argsort(ranks)[:n_sections]

In [210]:
top_n_sections_embedding(0,10)

array([ 68,  87,  49,  45,  78,  91,  88, 120,  59, 102])

# Attention-Based Method

In [232]:
@torch.no_grad()
def attention_scores_for_question(question_text, section_texts, max_length=256, batch_size=8):
    num_sections = len(section_texts)
    num_layers = distilbert.config.num_hidden_layers
    num_heads = distilbert.config.n_heads 
    
    scores = np.zeros((num_layers, num_heads, num_sections), dtype=np.float32)
    
    for start in range(0, num_sections, batch_size):
        end = min(start + batch_size, num_sections)
        num_sections_in_batch = end - start
        batch_sections = section_texts[start:end]
        batch_question = [question_text] * num_sections_in_batch

        tokens = tokenizer(
            batch_question,
            batch_sections,
            padding=True,
            truncation=True,
            max_length=max_length,
            return_tensors="pt"
        ).to(device)

        outputs = distilbert(**tokens, output_attentions=True)
        attentions = outputs.attentions
        input_ids = tokens["input_ids"]

        for i in range(num_sections_in_batch):
            section_idx = start + i
            
            ids = input_ids[i]
            sep_id = tokenizer.sep_token_id
            sep_positions = (ids == sep_id).nonzero(as_tuple=False).flatten().tolist()

            # should be of the form [CLS] question [SEP] section [SEP]
            first_sep, second_sep = sep_positions[0], sep_positions[1]
            q_idx = list(range(1, first_sep))
            s_idx = list(range(first_sep + 1, second_sep))

            for layer_idx, batch_layer_attentions in enumerate(attentions):
                layer_attentions = batch_layer_attentions[i]
                for head_idx in range(num_heads):
                    head_attention = layer_attentions[head_idx]
                    q2s_attention = head_attention[np.ix_(q_idx, s_idx)]
                    scores[layer_idx, head_idx, section_idx] = q2s_attention.mean().item()

    return scores

In [237]:
def rank_sections_attention(q_idx):
    q_text = questions.tolist()[q_idx]
    sec_texts = sections["text"].tolist()

    scores = attention_scores_for_question(
        q_text,
        sec_texts,
        max_length=256,
        batch_size=8
    )

    ranked = np.argsort(-scores) # (layers, heads, ranks)
    return ranked

def top_n_sections_attention(q_idx, n_sections):
    ranks = rank_sections_attention(q_idx)
    return np.argsort(ranks)[:,:,:n_sections]

In [239]:
x = top_n_sections_attention(1,5)

print(x.shape)
print(x[0][0])

(6, 12, 5)
[86 42 61 73 17]


# Decoding & Testing

In [248]:
def createResults(top_n = 5):
    num_questions = len(questions)
    all_dfs = []
    for q in range(num_questions):
        attention_inds = top_n_sections_attention(q,top_n)
        embedding_inds = top_n_sections_embedding(q,top_n)
        tfidf_inds = top_n_sections_tfidf(q,top_n)
        
        embedding_sections = sections.iloc[embedding_inds]['section_number'].reset_index(drop=True)
        tfidf_sections = sections.iloc[tfidf_inds]['section_number'].reset_index(drop=True)
        results = {'question_ind':q,'rank':range(top_n),'tfidf': tfidf_sections, 'embedding':embedding_sections}
        for l,layer in enumerate(attention_inds):
            for h,head_inds in enumerate(layer):
                results[f"attention_{l}_{h}"] = sections.iloc[head_inds]['section_number'].reset_index(drop=True)
        all_dfs.append(pd.DataFrame(results))
    return pd.concat(all_dfs)

In [249]:
results = createResults()
results

Unnamed: 0,question_ind,rank,tfidf,embedding,attention_0_0,attention_0_1,attention_0_2,attention_0_3,attention_0_4,attention_0_5,...,attention_5_2,attention_5_3,attention_5_4,attention_5_5,attention_5_6,attention_5_7,attention_5_8,attention_5_9,attention_5_10,attention_5_11
0,0,0,2.4.2,2.4.6.3,2.4.14.4,2.4.21.4,2.4.4.2,Appendix 5,Appendix 7,Appendix 11,...,2.4.6,2.4.2.2,2.4.11.1,2.4.6.3,2.4.8,2.4.3,2.4.13.1,2.4.11.1,2.4.12.1,2.4.13
1,0,1,2.4.4.2,2.4.14.3,2.3.13.2,2.3.13.2,2.4.14,2.4.21.4,2.3.13.2,2.3.17,...,2.3.15,2.3.13.1,2.3.13.1,2.3.15,2.3.13.2,2.3.13,2.3.17,2.3.16,2.3.17,2.3.15
2,0,2,2.4.4.3,2.3.19,2.4.4.2,2.4.4.4,2.4.20,2.3.22,2.4.4.3,2.4.6,...,2.3.19,2.3.17.1,2.4.4.1,2.4,2.4.2.1,2.3.17.1,2.4.5,2.4.2,2.3.22,2.4.2.1
3,0,3,2.4.4.4,2.3.16,2.4.10,2.4.10,Appendix 2,Appendix 10,2.4.10,2.4.21.2,...,2.4.21.2,2.4.18,2.4.21.7,2.4.21,2.4.19,2.4.17.2,2.4.14.5,2.4.21.4,2.4.21.6,2.4.21.3
4,0,4,2.3.17.1,2.4.12.1,2.2.2,2.2.2,2.2.2,2.3.9,2.2.2,2.3.1,...,1.2.1,1.1.2,2.1,1.2,1.3,1.1.2,1.3.1,1.1.3,1.2,1.1.3
0,1,0,2.3.12.2,Appendix 2,2.4.14.2,2.4.18,2.4.21.1,2.5,2.4.21,Appendix 7,...,2.4.6.4,2.4.4.1,2.4.7,2.4.10,2.4.6.2,2.3.20,2.4.14.1,2.4.12.2,2.4.12.1,2.4.6.3
1,1,1,2.4.17.1,2.4.7,2.3.13.2,2.3.13.2,2.3.13.1,2.4.21.5,2.3.13.2,2.3.16,...,2.3.15,2.3.13.1,2.3.13.1,2.3.15,2.3.13.2,2.3.12.2,2.3.15,2.3.16,2.3.14,2.3.13.1
2,1,2,2.4.14.5,2.4.12.1,2.4.4.2,2.4.4.4,2.4.4.1,2.3.12.2,2.4.4.2,2.4.4.4,...,2.4.2,2.3.19,2.4.4.3,2.4.1,2.4.2.2,2.3.17,2.4.6,2.4.2.2,2.4.3,2.4.1
3,1,3,2.4.2,2.4.1,2.4.10,2.4.10,2.4.2.2,2.4.21.4,2.4.10,2.4.14.1,...,2.4.21.3,2.4.19,2.4.21.1,2.4.21.3,2.4.21,2.4.17.2,2.4.13.1,2.4.21.5,2.4.21.5,2.4.21.4
4,1,4,2.4.17,2.4.14.2,2.2.2,2.2.2,2.3.9,2.3.20,2.2.2,2.3.1,...,1.3,1.1.1.1,2.1,1.3,1.3.2,1.1.1.1,1.2,1.1.3,1.2.2,1.1.3
