In [None]:
import utils # utils.py
import os
import json
import torch
from config import ROOT, RAW_DIR, FORMMATED_DIR, INDEX_DIR
from pyserini.search.faiss import FaissSearcher, DprQueryEncoder
from transformers import DPRReader, DPRReaderTokenizer
from pyserini.search.lucene import LuceneSearcher
from cnc_highlighting.encode import BertForHighlightPrediction

K=10

# dense searcher
query_encoder = DprQueryEncoder("facebook/dpr-question_encoder-multiset-base")
dense_searcher = FaissSearcher(f"{INDEX_DIR}/dpr-ctx_encoder-multiset-base", query_encoder)

class DenseDocumentRetriever: 
    def __init__(self, searcher, docs_dir=FORMMATED_DIR, k=K):
        self.searcher = searcher
        self.docs_dir = docs_dir
        self.k = k
    
    def get_document_content(self, docid):
        ''' return the paragraph content given the docid from raw jsonl files '''
        file_name = docid.split('_')[0] + '_' + docid.split('_')[1] + '_' + docid.split('_')[2] + '.jsonl'
        with open(os.path.join(self.docs_dir, file_name), "r") as open_file:
            for line in open_file:
                data = json.loads(line)
                if data["id"] == docid:
                    return data["contents"]
        print("Paragraph not found.")
        return None

    def search_documents(self, query):
        ''' return the top k documents given the query '''
        hits = self.searcher.search(query, k=self.k)
        return hits

    def extract_titles_and_texts(self, hits):
        ''' Extract and return titles and texts from the top k hits '''
        titles = [hits[i].docid for i in range(len(hits))]
        texts = [self.get_document_content(hits[i].docid) for i in range(len(hits))]
        return titles, texts
    
    def retrieve_and_process_documents(self, query):
        ''' Retrieve the top k documents and prepare their data reader processing '''
        hits = self.search_documents(query)
        titles, texts = self.extract_titles_and_texts(hits)
        return titles, texts
  

class DprHighlighter:
    ''' 
    The DprHighlighter serves as the baseline to compare with generator 
    https://huggingface.co/facebook/dpr-reader-multiset-base
    '''
    def __init__(self, model_name: str = 'facebook/dpr-reader-multiset-base', tokenizer_name: str = 'facebook/dpr-reader-multiset-base', device: str = 'cpu'):
        self.device = device
        self.model = DPRReader.from_pretrained(model_name)
        self.model.to(self.device)
        self.tokenizer = DPRReaderTokenizer.from_pretrained(tokenizer_name)

    @staticmethod
    def find_max_idx(logits, dim=-1):
        probs = torch.softmax(logits, dim=dim)
        return torch.argmax(probs)

    @staticmethod
    def extract_answer_span(tokenizer, token_ids, start_position, end_position):
        answer_tokens = token_ids[start_position : end_position + 1]
        answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
        return answer
    
    def highlighting_outputs(self, target, ref_titles, references):
        ''' 
        target: the target paragraph that should be highlighted
        ref_titles: the IDs of the reference paragraphs
        references: the retrieved paragrpah, which are the reference for our highlighting work
        BERT input: [CLS] <texts> [SEP] <questions> 
        TODO: 
            - 限制start_logits & end_logits在最後面? (作為好多個reference的最終highlight)
            - titles should be concatenated with texts
            - handle paragraph that is too long
        '''
        
        targets = [target] * len(ref_titles) # our target paragraph that should be highlighted
        
        encoded_inputs = self.tokenizer(
            questions=references,       # retrieved documents are the reference for our highlighting work
            # titles=ref_titles,        # TODO: titles should be concatenated with texts
            texts=targets,
            padding=True if len(targets) > 1 else False,
            return_tensors="pt", 
            truncation=True             # TODO: handle paragraph that is too long
        )
        
        outputs = self.model(**encoded_inputs)

        return encoded_inputs, outputs
    

    def visualize_highlight_span(self, encoded_inputs, ref_titles, relevance_logits, start_logits, end_logits):
        num_ref = start_logits.shape[0]

        # Sort the relevance logits in descending order
        relevance_probs = torch.softmax(relevance_logits, dim=-1)
        sorted_indices = torch.argsort(relevance_probs, descending=True)

        for i in sorted_indices:

            start_idx = self.find_max_idx(start_logits[i])
            end_idx = self.find_max_idx(end_logits[i])
            highlighted_span = self.extract_answer_span(
                self.tokenizer,
                encoded_inputs['input_ids'][i],
                start_idx,
                end_idx
            )
            
            print(f"{relevance_probs[i]:.4f} reference {ref_titles[i]}:")
            print(f"start_idx: {start_idx}, end_idx: {end_idx}, span: {highlighted_span}")

def print_hits(hits, display_top_n=10):
    for i in range(display_top_n):
        print(f'{i+1:2} {hits[i].docid:7} {hits[i].score:.5f}')
        print(utils.retrieve_paragraph_from_docid(hits[i].docid))
    print()

Some weights of the model checkpoint at facebook/dpr-question_encoder-multiset-base were not used when initializing DPRQuestionEncoder: ['question_encoder.bert_model.pooler.dense.bias', 'question_encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRQuestionEncoder from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRQuestionEncoder from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
target_paragraph = utils.retrieve_paragraph_from_docid("20221025_10-Q_789019_part1_item2_para334")
print(target_paragraph)

On January 18, 2022, we entered into a definitive agreement to acquire Activision Blizzard, Inc. Activision Blizzard for 95.00 per share in an all-cash transaction valued at 68.7 billion, inclusive of Activision Blizzard s net cash. The acquisition has been approved by Activision Blizzard s shareholders, and we expect it to close in fiscal year 2023, subject to the satisfaction of certain regulatory approvals and other customary closing conditions.


In [None]:

dense_retriever = DenseDocumentRetriever(dense_searcher)

dense_hits = dense_retriever.search_documents(target_paragraph)

# titles & texts format for DPR highlighter
dense_titles, dense_texts = dense_retriever.retrieve_and_process_documents(target_paragraph)

print("Retrieval results from dense retriever:")
print_hits(dense_hits)
print()

Retrieval results from dense retriever:
 1 20221025_10-Q_789019_part1_item2_para334 93.43729
On January 18, 2022, we entered into a definitive agreement to acquire Activision Blizzard, Inc. Activision Blizzard for 95.00 per share in an all-cash transaction valued at 68.7 billion, inclusive of Activision Blizzard s net cash. The acquisition has been approved by Activision Blizzard s shareholders, and we expect it to close in fiscal year 2023, subject to the satisfaction of certain regulatory approvals and other customary closing conditions.
 2 20220426_10-Q_789019_part1_item2_para492 93.06297
On January 18, 2022, we entered into a definitive agreement to acquire Activision Blizzard, Inc. Activision Blizzard for 95.00 per share in an all-cash transaction valued at 68.7 billion, inclusive of Activision Blizzard s net cash. We expect this acquisition to close in fiscal year 2023, subject to approval by Activision Blizzard s shareholders, the satisfaction of certain regulatory approvals, an

In [None]:
dpr_highlighter = DprHighlighter(model_name="facebook/dpr-reader-multiset-base", tokenizer_name="facebook/dpr-reader-multiset-base")

dense_encoded_inputs, dense_outputs = dpr_highlighter.highlighting_outputs(target_paragraph, dense_titles, dense_texts)

print("Highlighting results from dense retriever: (desc)")
dpr_highlighter.visualize_highlight_span(dense_encoded_inputs, dense_titles, dense_outputs.relevance_logits, dense_outputs.start_logits, dense_outputs.end_logits)
# dpr_highlighter.visualize_highlight_span(dense_encoded_inputs, dense_outputs.start_logits, dense_outputs.end_logits)
print()


Some weights of the model checkpoint at facebook/dpr-reader-multiset-base were not used when initializing DPRReader: ['span_predictor.encoder.bert_model.pooler.dense.bias', 'span_predictor.encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRReader from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRReader from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRReaderTokenizer'.


Highlighting results from dense retriever: (desc)
0.7695 reference 20220125_10-Q_789019_part1_item2_para475:
start_idx: 110, end_idx: 110, span: blizzard
0.0764 reference 20221025_10-Q_789019_part1_item1_para158:
start_idx: 128, end_idx: 132, span: january 18, 2022
0.0677 reference 20220426_10-Q_789019_part1_item1_para169:
start_idx: 125, end_idx: 129, span: january 18, 2022
0.0245 reference 20220426_10-Q_789019_part1_item2_para492:
start_idx: 115, end_idx: 110, span: 
0.0245 reference 20220125_10-Q_789019_part1_item2_para474:
start_idx: 115, end_idx: 110, span: 
0.0228 reference 20221025_10-Q_789019_part1_item2_para334:
start_idx: 118, end_idx: 144, span: 95. 00 per share in an all - cash transaction valued at 68. 7 billion, inclusive of activision blizzard s net cash.
0.0059 reference 20221025_10-Q_789019_part2_item1a_para29:
start_idx: 347, end_idx: 347, span: blizzard
0.0029 reference 20220426_10-Q_789019_part2_item1a_para28:
start_idx: 151, end_idx: 331, span: advance our business

In [None]:
dense_texts

['On January 18, 2022, we entered into a definitive agreement to acquire Activision Blizzard, Inc. Activision Blizzard for 95.00 per share in an all-cash transaction valued at 68.7 billion, inclusive of Activision Blizzard s net cash. The acquisition has been approved by Activision Blizzard s shareholders, and we expect it to close in fiscal year 2023, subject to the satisfaction of certain regulatory approvals and other customary closing conditions.',
 'On January 18, 2022, we entered into a definitive agreement to acquire Activision Blizzard, Inc. Activision Blizzard for 95.00 per share in an all-cash transaction valued at 68.7 billion, inclusive of Activision Blizzard s net cash. We expect this acquisition to close in fiscal year 2023, subject to approval by Activision Blizzard s shareholders, the satisfaction of certain regulatory approvals, and other customary closing conditions.',
 'On January 18, 2022, we entered into a definitive agreement to acquire Activision Blizzard, Inc. A

In [None]:
dense_encoded_inputs['input_ids'].shape

torch.Size([10, 422])

In [None]:
from transformers import DPRReader, DPRReaderTokenizer

tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-multiset-base")
model = DPRReader.from_pretrained("facebook/dpr-reader-multiset-base")

target_paragraph = utils.retrieve_paragraph_from_docid("20221025_10-Q_789019_part1_item2_para334")
reference = utils.retrieve_paragraph_from_docid("20220426_10-Q_789019_part2_item1a_para28")

encoded_inputs = tokenizer(
    questions=[reference],
    texts=[target_paragraph],
    return_tensors="pt",
    )

print(encoded_inputs['input_ids'].shape)

outputs = model(**encoded_inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
relevance_logits = outputs.relevance_logits

start_idx = torch.argmax(torch.softmax(start_logits, dim=-1))
end_idx = torch.argmax(torch.softmax(end_logits, dim=-1))

# visualize the highlighted span
answer_tokens = encoded_inputs['input_ids'][0][start_idx : end_idx + 1]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

print(f"reference: {reference}")
print(f"start_idx: {start_idx}, end_idx: {end_idx}")
print(f"span: {answer}")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRReaderTokenizer'.
Some weights of the model checkpoint at facebook/dpr-reader-multiset-base were not used when initializing DPRReader: ['span_predictor.encoder.bert_model.pooler.dense.bias', 'span_predictor.encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRReader from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRReader from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([1, 407])
reference: Acquisitions, joint ventures, and strategic alliances may have an adverse effect on our business. We expect to continue making acquisitions and entering into joint ventures and strategic alliances as part of our long-term business strategy. For example, in October 2018 we completed our acquisition of GitHub, Inc. GitHub for 7.5 billion, in March 2021 we completed our acquisition of ZeniMax Media Inc. for 8.1 billion, and in March 2022 we completed our acquisition of Nuance Communications, Inc. for 18.8 billion. In January 2022 we announced a definitive agreement to acquire Activision Blizzard, Inc. for 68.7 billion. These acquisitions and other transactions and arrangements involve significant challenges and risks, including that they do not advance our business strategy, that we get an unsatisfactory return on our investment, that we have difficulty integrating and retaining new employees, business systems, and technology, that they distract management 

In [None]:
print(f"tokens length of target_paragraph: {len(tokenizer.tokenize(target_paragraph))}")
print(f"tokents length of reference: {len(tokenizer.tokenize(reference))}")

tokens length of target_paragraph: 91
tokents length of reference: 313


In [None]:
start_logits.shape

torch.Size([1, 407])

In [None]:
from transformers import DPRReader, DPRReaderTokenizer

tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-multiset-base")
model = DPRReader.from_pretrained("facebook/dpr-reader-multiset-base")

query = "What company does micfrosoft acquire?"
target_paragraph = utils.retrieve_paragraph_from_docid("20221025_10-Q_789019_part1_item2_para334")

encoded_inputs = tokenizer(
    questions=[query],
    texts=[target_paragraph],
    return_tensors="pt",
    )

print(encoded_inputs['input_ids'].shape)

outputs = model(**encoded_inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
relevance_logits = outputs.relevance_logits

start_idx = torch.argmax(torch.softmax(start_logits, dim=-1))
end_idx = torch.argmax(torch.softmax(end_logits, dim=-1))

# visualize the highlighted span
answer_tokens = encoded_inputs['input_ids'][0][start_idx : end_idx + 1]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

print(f"reference: {reference}")
print(f"start_idx: {start_idx}, end_idx: {end_idx}")
print(f"span: {answer}")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRReaderTokenizer'.
Some weights of the model checkpoint at facebook/dpr-reader-multiset-base were not used when initializing DPRReader: ['span_predictor.encoder.bert_model.pooler.dense.bias', 'span_predictor.encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRReader from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRReader from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([1, 103])
reference: Acquisitions, joint ventures, and strategic alliances may have an adverse effect on our business. We expect to continue making acquisitions and entering into joint ventures and strategic alliances as part of our long-term business strategy. For example, in October 2018 we completed our acquisition of GitHub, Inc. GitHub for 7.5 billion, in March 2021 we completed our acquisition of ZeniMax Media Inc. for 8.1 billion, and in March 2022 we completed our acquisition of Nuance Communications, Inc. for 18.8 billion. In January 2022 we announced a definitive agreement to acquire Activision Blizzard, Inc. for 68.7 billion. These acquisitions and other transactions and arrangements involve significant challenges and risks, including that they do not advance our business strategy, that we get an unsatisfactory return on our investment, that we have difficulty integrating and retaining new employees, business systems, and technology, that they distract management 

In [None]:
print(query)
print(f"tokents length of query: {len(tokenizer.tokenize(query))}")
print(f"tokents length of target_paragraph: {len(tokenizer.tokenize(target_paragraph))}")

What company does micfrosoft acquire?
tokents length of query: 9
tokents length of target_paragraph: 91


In [None]:
target_paragraph

'On January 18, 2022, we entered into a definitive agreement to acquire Activision Blizzard, Inc. Activision Blizzard for 95.00 per share in an all-cash transaction valued at 68.7 billion, inclusive of Activision Blizzard s net cash. The acquisition has been approved by Activision Blizzard s shareholders, and we expect it to close in fiscal year 2023, subject to the satisfaction of certain regulatory approvals and other customary closing conditions.'

In [None]:
from transformers import DPRReader, DPRReaderTokenizer

tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-multiset-base")
model = DPRReader.from_pretrained("facebook/dpr-reader-multiset-base")

query = "What company does micfrosoft acquire?"
# reference = utils.retrieve_paragraph_from_docid("20220426_10-Q_789019_part2_item1a_para28")

encoded_inputs = tokenizer(
    questions=["What is love ?"],
    texts=["'What Is Love' is a song recorded by the artist Haddaway"],
    return_tensors="pt",
    )

print(encoded_inputs['input_ids'].shape)

outputs = model(**encoded_inputs)
start_logits = outputs.start_logits
end_logits = outputs.end_logits
relevance_logits = outputs.relevance_logits

start_idx = torch.argmax(torch.softmax(start_logits, dim=-1))
end_idx = torch.argmax(torch.softmax(end_logits, dim=-1))

# visualize the highlighted span
answer_tokens = encoded_inputs['input_ids'][0][start_idx : end_idx + 1]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)

print(f"reference: {reference}")
print(f"start_idx: {start_idx}, end_idx: {end_idx}")
print(f"span: {answer}")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DPRQuestionEncoderTokenizer'. 
The class this function is called from is 'DPRReaderTokenizer'.
Some weights of the model checkpoint at facebook/dpr-reader-multiset-base were not used when initializing DPRReader: ['span_predictor.encoder.bert_model.pooler.dense.bias', 'span_predictor.encoder.bert_model.pooler.dense.weight']
- This IS expected if you are initializing DPRReader from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DPRReader from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([1, 22])
reference: Acquisitions, joint ventures, and strategic alliances may have an adverse effect on our business. We expect to continue making acquisitions and entering into joint ventures and strategic alliances as part of our long-term business strategy. For example, in October 2018 we completed our acquisition of GitHub, Inc. GitHub for 7.5 billion, in March 2021 we completed our acquisition of ZeniMax Media Inc. for 8.1 billion, and in March 2022 we completed our acquisition of Nuance Communications, Inc. for 18.8 billion. In January 2022 we announced a definitive agreement to acquire Activision Blizzard, Inc. for 68.7 billion. These acquisitions and other transactions and arrangements involve significant challenges and risks, including that they do not advance our business strategy, that we get an unsatisfactory return on our investment, that we have difficulty integrating and retaining new employees, business systems, and technology, that they distract management f