In [26]:
import torch
import numpy as np

from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
from typing import List

from pprint import pprint

class BiEncoderRetriever:
    def __init__(self) -> None:
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder")
        self.question_encoder = DPRQuestionEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder").to(self.device)
        self.ctxt_encoder = DPRContextEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder").to(self.device)

    def encode_summaries(self, summaries: List[str]):
        input_dict = self.tokenizer(summaries, padding='max_length', max_length=128, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.ctxt_encoder(**input_dict)['pooler_output']

    def encode_question(self, question: str):
        input_dict = self.tokenizer(question, padding='max_length', max_length=32, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.question_encoder(**input_dict)['pooler_output']

    def retrieve_top_summaries(self, question: str, summaries: List[str], encoded_summaries: np.ndarray = None, topk: int = 5):
        encoded_question = self.encode_question(question)
        if encoded_summaries is None:
            encoded_summaries = self.encode_summaries(summaries)
        else:
            encoded_summaries = torch.from_numpy(encoded_summaries).to(self.device)

        scores = torch.mm(encoded_question, encoded_summaries.T)
        # print(encoded_question.shape)
        # print(encoded_summaries.T.shape)
        if topk >= len(summaries):
            return summaries
        top_k = torch.topk(scores, topk).indices.squeeze()
        
        print("all scores : ")
        for i,v in enumerate(scores[0]):
            print(int(v.item()),end=" ")
        print()
        
        return [(summaries[i],scores[0][i].item()) for i in top_k]

In [27]:
rt = BiEncoderRetriever()

In [28]:
question = "What does Sarah do for a living"
personalist = ['Sarah is 24 years old.', 
               'Sarah currently lives in Canada.', 
               "Sarah is a swim coach at Sarah's local pool.", 
               'Sarah is studying to be a computer programmer.', 
               'Sarah is also a graduate student.', 
               'Sarah is now looking for a new job.', 
               "Sarah's mother is very traditional while Sarah prefers to be more free spirited.", 
               "Sarah's family and Sarah are from India.", 
               "Sarah's favorite music genre is death metal.", 
               'Sarah is a famous twitch streamer.', 
               'Sarah likes watching war documentaries.', 
               "Sarah's favorite food is mexican food."]

In [29]:
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=5
            ))

all scores : 
72 72 73 74 71 73 73 71 64 69 66 69 
[('Sarah is studying to be a computer programmer.', 74.15234375),
 ('Sarah is now looking for a new job.', 73.7214584350586),
 ("Sarah's mother is very traditional while Sarah prefers to be more free "
  'spirited.',
  73.61085510253906),
 ("Sarah is a swim coach at Sarah's local pool.", 73.37893676757812),
 ('Sarah currently lives in Canada.', 72.65416717529297)]


In [30]:
question = "That's a bummer. Hopefully once you move out you can at least find something active that you enjoy doing with friends. How is the job search going?"
personalist = ["Speaker 1 is professional basketball player.",
"Speaker 2 is high school student.",
"Speaker 2 is waiting to get a job.",
"Speaker 2 would like to become an engineer.",
"Speaker 2 used to play basketball.",
"Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]

In [31]:
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
62 63 65 59 60 60 
[('Speaker 2 is waiting to get a job.', 65.63642120361328),
 ('Speaker 2 is high school student.', 63.490013122558594),
 ('Speaker 1 is professional basketball player.', 62.49553680419922)]


In [32]:
question = "I should be able to do better. I just need to work harder at it. I missed 3 free throws, so my coach is pushing me hard at practice. Do you play basketball or any sports?"
personalist = ["Speaker 1 is professional basketball player.",
"Speaker 2 is high school student.",
"Speaker 2 is waiting to get a job.",
"Speaker 2 would like to become an engineer.",
"Speaker 2 used to play basketball.",
"Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]

In [33]:
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
56 56 56 53 56 62 
[("Speaker 2's dad is strict, preventing the speaker 2 from doing basketball "
  'as a profession.',
  62.10819625854492),
 ('Speaker 2 is high school student.', 56.86595916748047),
 ('Speaker 2 used to play basketball.', 56.648990631103516)]


In [34]:
question = "This is meant as advice and a little funny. Buy one that is bright-colored so it won't blend in your surroundings and you lose it easily."
personalist = ["Speaker 1 works as waiter.",
"Speaker 1 regrets career choices.",
"Speaker 2’s roommates hate Speaker 2’s parakeet.",
"Speaker 2’s favorite color is orange."]

In [35]:
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=2
            ))

all scores : 
53 55 57 62 
[('Speaker 2’s favorite color is orange.', 62.0982551574707),
 ('Speaker 2’s roommates hate Speaker 2’s parakeet.', 57.907135009765625)]


내가 직접 만든 예시

In [36]:
question = "Do you have any hobbies or interests outside of school and work?"
personalist = ["Speaker 1 is professional basketball player.",
"Speaker 2 is high school student.",
"Speaker 2 is waiting to get a job.",
"Speaker 2 would like to become an engineer.",
"Speaker 2 used to play basketball.",
"Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
59 62 56 62 62 63 
[("Speaker 2's dad is strict, preventing the speaker 2 from doing basketball "
  'as a profession.',
  63.28627014160156),
 ('Speaker 2 would like to become an engineer.', 62.90799331665039),
 ('Speaker 2 is high school student.', 62.81760025024414)]


In [37]:
question = "Does Speaker 1 have any hobbies or interests outside of school and work?"
personalist = ["Speaker 1 is professional basketball player.",
"Speaker 2 is high school student.",
"Speaker 2 is waiting to get a job.",
"Speaker 2 would like to become an engineer.",
"Speaker 2 used to play basketball.",
"Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
74 74 67 68 71 72 
[('Speaker 1 is professional basketball player.', 74.68264770507812),
 ('Speaker 2 is high school student.', 74.30900573730469),
 ("Speaker 2's dad is strict, preventing the speaker 2 from doing basketball "
  'as a profession.',
  72.72579193115234)]


In [38]:
question = "When you were a kid, did you have any hobbies or interests outside of school and work?"
personalist = ["Speaker 1 is professional basketball player.",
"Speaker 2 is high school student.",
"Speaker 2 is waiting to get a job.",
"Speaker 2 would like to become an engineer.",
"Speaker 2 used to play basketball.",
"Speaker 2's dad is strict, preventing the speaker 2 from doing basketball as a profession."]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
56 60 54 59 61 60 
[('Speaker 2 used to play basketball.', 61.53269577026367),
 ('Speaker 2 is high school student.', 60.758575439453125),
 ("Speaker 2's dad is strict, preventing the speaker 2 from doing basketball "
  'as a profession.',
  60.383460998535156)]


페르소나 수를 늘려보자

In [39]:
question = "Do you enjoy outdoor adventures or prefer staying indoors?"
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
62 60 53 62 55 58 55 54 61 62 50 54 
[('Speaker 1 has a fear of deep water.', 62.31932830810547),
 ('Speaker 2 loves mystery novels and has a collection.', 62.11561965942383),
 ('Speaker 1 is passionate rock climber.', 62.04978561401367)]


In [40]:
question = "Your perspective on things is unique. Does your profession involve capturing moments or emotions?"
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
59 57 59 52 65 58 60 54 61 55 50 49 
[('Speaker 1 is a professional photographer.', 65.19570922851562),
 ('Speaker 2 is studying marine biology.', 61.44655990600586),
 ('Speaker 2 is a professional dancer specializing in contemporary dance.',
  60.75514602661133)]


In [41]:
question = "You have such a fluid movement. Is there a specific art form you're involved in?"
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
62 61 63 59 60 59 67 60 62 61 56 56 
[('Speaker 2 is a professional dancer specializing in contemporary dance.',
  67.8241195678711),
 ('Speaker 1 is a trained classical pianist.', 63.435401916503906),
 ('Speaker 1 is passionate rock climber.', 62.821388244628906)]


In [42]:
question = "You seem to have a diverse cultural understanding. Have you lived in multiple places?"
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
57 60 57 55 58 60 58 66 57 57 54 50 
[('Speaker 2 has lived in three different countries.', 66.7225112915039),
 ('Speaker 1 volunteered in a wildlife sanctuary last summer.',
  60.29518127441406),
 ('Speaker 1 grew up on a farm in the countryside.', 60.26553726196289)]


질문 말고 평문으로 해보자

In [43]:
question = "I've always admired people who challenge themselves physically, especially in nature."
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
65 56 58 63 58 55 59 53 60 58 52 54 
[('Speaker 1 is passionate rock climber.', 65.91217041015625),
 ('Speaker 1 has a fear of deep water.', 63.267154693603516),
 ('Speaker 2 is studying marine biology.', 60.49422836303711)]


In [44]:
question = "Capturing moments can be so powerful. I've always believed that some people can see the world in a different light."
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
57 57 57 51 60 57 58 54 58 57 53 50 
[('Speaker 1 is a professional photographer.', 60.550079345703125),
 ('Speaker 2 is studying marine biology.', 58.92174530029297),
 ('Speaker 2 is a professional dancer specializing in contemporary dance.',
  58.4135856628418)]


In [45]:
question = "There's something about the way you move; it reminds me of an art performance."
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
57 56 59 52 56 56 63 56 58 54 53 51 
[('Speaker 2 is a professional dancer specializing in contemporary dance.',
  63.07608413696289),
 ('Speaker 1 is a trained classical pianist.', 59.283756256103516),
 ('Speaker 2 is studying marine biology.', 58.70583724975586)]


In [46]:
question = "Your worldview seems so expansive, like you've absorbed bits from various cultures."
personalist = ["Speaker 1 is passionate rock climber.",
"Speaker 1 grew up on a farm in the countryside.",
"Speaker 1 is a trained classical pianist.",
"Speaker 1 has a fear of deep water.",
"Speaker 1 is a professional photographer.",
"Speaker 1 volunteered in a wildlife sanctuary last summer.",
"Speaker 2 is a professional dancer specializing in contemporary dance.",
"Speaker 2 has lived in three different countries.",
"Speaker 2 is studying marine biology.",
"Speaker 2 loves mystery novels and has a collection.",
"Speaker 2 has a twin sibling.",
"Speaker 2 is allergic to peanuts.",              
]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

all scores : 
53 55 56 48 52 52 53 58 56 57 55 50 
[('Speaker 2 has lived in three different countries.', 58.41655349731445),
 ('Speaker 2 loves mystery novels and has a collection.', 57.47710037231445),
 ('Speaker 1 is a trained classical pianist.', 56.0594367980957)]
