In [2]:
from transformers import DPRQuestionEncoder, DPRContextEncoder, DPRQuestionEncoderTokenizer, DPRContextEncoderTokenizer, AutoModelForSeq2SeqLM, AutoTokenizer
from rank_bm25 import BM25Okapi
from nltk.tokenize import word_tokenize
import torch

# Base class for retrieval
class Retriever:
    def retrieve_passages(self, book_text, claim):
        raise NotImplementedError("Subclasses must implement this method.")


In [11]:
# DPR-based Retriever
class DPRRetriever(Retriever):
    def __init__(self):
        self.question_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        self.question_encoder_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
        self.context_encoder = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
        self.context_encoder_tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")

    def retrieve_passages(self, book_text, claim, top_k=5):
        question_input = self.question_encoder_tokenizer(claim, return_tensors='pt')
        context_input = self.context_encoder_tokenizer(book_text, truncation=True, padding=True, return_tensors='pt', max_length=512)
        question_emb = self.question_encoder(**question_input).pooler_output
        context_emb = self.context_encoder(**context_input).pooler_output
        cos = torch.nn.CosineSimilarity(dim=1)
        similarities = cos(question_emb, context_emb)
        top_k_indices = similarities.topk(k=top_k).indices
        return [book_text[idx] for idx in top_k_indices]


In [12]:
# BM25-based Retriever
class BM25Retriever(Retriever):
    def __init__(self):
        self.tokenizer = word_tokenize

    def retrieve_passages(self, book_text, claim, top_k=5):
        tokenized_corpus = [self.tokenizer(doc) for doc in book_text]
        bm25 = BM25Okapi(tokenized_corpus)
        tokenized_claim = self.tokenizer(claim)
        doc_scores = bm25.get_scores(tokenized_claim)
        top_k_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)[:top_k]
        return [book_text[idx] for idx in top_k_indices]


In [13]:
# Entailment checking function using T5
class EntailmentChecker:
    def __init__(self):
        self.model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
        self.tokenizer = AutoTokenizer.from_pretrained("t5-small")

    def check_entailment(self, claim, passages):
        results = []
        for passage in passages:
            input_text = f"claim: {claim} context: {passage}"
            input_ids = self.tokenizer(input_text, return_tensors='pt').input_ids
            outputs = self.model.generate(input_ids)
            result = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            results.append(result)
        return results

In [16]:
# Example usage
book_text = ["", ""]  # This should be a list of passages
claim = "."

# Choose the retriever
retriever = BM25Retriever() 

# Retrieve passages
retrieved_passages = retriever.retrieve_passages(book_text, claim)

# Check entailment
entailment_checker = EntailmentChecker()
entailment_results = entailment_checker.check_entailment(claim, retrieved_passages)

print(entailment_results)

['Another piece of text from the book', 'a claim from the book']
