In [13]:
import torch
import numpy as np

from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
from typing import List

class BiEncoderRetriever:
    def __init__(self) -> None:
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder")
        self.question_encoder = DPRQuestionEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder").to(self.device)
        self.ctxt_encoder = DPRContextEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder").to(self.device)

    def encode_summaries(self, summaries: List[str]):
        input_dict = self.tokenizer(summaries, padding='max_length', max_length=128, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.ctxt_encoder(**input_dict)['pooler_output']

    def encode_question(self, question: str):
        input_dict = self.tokenizer(question, padding='max_length', max_length=32, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.question_encoder(**input_dict)['pooler_output']

    def retrieve_top_summaries(self, question: str, summaries: List[str], encoded_summaries: np.ndarray = None, topk: int = 5):
        encoded_question = self.encode_question(question)
        if encoded_summaries is None:
            encoded_summaries = self.encode_summaries(summaries)
        else:
            encoded_summaries = torch.from_numpy(encoded_summaries).to(self.device)

        scores = torch.mm(encoded_question, encoded_summaries.T)
        if topk >= len(summaries):
            return summaries
        top_k = torch.topk(scores, topk).indices.squeeze()
        return [summaries[i] for i in top_k]

In [14]:
rt = BiEncoderRetriever()

In [15]:
question = "What does Sarah do for a living"
personalist = ['Sarah is 24 years old.', 'Sarah currently lives in Canada.', "Sarah is a swim coach at Sarah's local pool.", 'Sarah is studying to be a computer programmer.', 'Sarah is also a graduate student.', 'Sarah is now looking for a new job.', "Sarah's mother is very traditional while Sarah prefers to be more free spirited.", "Sarah's family and Sarah are from India.", "Sarah's favorite music genre is death metal.", 'Sarah is a famous twitch streamer.', 'Sarah likes watching war documentaries.', "Sarah's favorite food is mexican food."]

In [28]:
print(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=5
            ))

['Speaker 2 is high school student.', 'Speaker 2 is waiting to get a job.', 'Speaker 2 would like to become an engineer.', 'Speaker 2 used to play basketball.', "Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]


In [36]:
question = "That's a bummer. Hopefully once you move out you can at least find something active that you enjoy doing with friends. How is the job search going?"
personalist = ["Speaker 1 is professional basketball player.",
"Speaker 2 is high school student.",
"Speaker 2 is waiting to get a job.",
"Speaker 2 would like to become an engineer.",
"Speaker 2 used to play basketball.",
"Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]

In [37]:
print(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

['Speaker 2 is waiting to get a job.', 'Speaker 2 is high school student.', 'Speaker 1 is professional basketball player.']


In [38]:
question = "I should be able to do better. I just need to work harder at it. I missed 3 free throws, so my coach is pushing me hard at practice. Do you play basketball or any sports?"
personalist = ["Speaker 1 is professional basketball player.",
"Speaker 2 is high school student.",
"Speaker 2 is waiting to get a job.",
"Speaker 2 would like to become an engineer.",
"Speaker 2 used to play basketball.",
"Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]

In [39]:
print(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

["Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession.", 'Speaker 2 is high school student.', 'Speaker 2 used to play basketball.']


In [40]:
question = "This is meant as advice and a little funny. Buy one that is bright-colored so it won't blend in your surroundings and you lose it easily."
personalist = ["Speaker 1 works as waiter.",
"Speaker 1 regrets career choices.",
"Speaker 2’s roommates hate Speaker 2’s parakeet.",
"Speaker 2’s favorite color is orange."]

In [41]:
print(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=2
            ))

['Speaker 2’s favorite color is orange.', 'Speaker 2’s roommates hate Speaker 2’s parakeet.']
