In [13]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"
import torch
import json
from ast import literal_eval # 걍 eval쓰면 스트링을 배열로 만들 수 있다.

In [14]:
import numpy as np

from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
from typing import List

from pprint import pprint

class BiEncoderRetriever:
    def __init__(self) -> None:
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder")
        self.question_encoder = DPRQuestionEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder").to(self.device)
        self.ctxt_encoder = DPRContextEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder").to(self.device)

    def encode_summaries(self, summaries: List[str]):
        input_dict = self.tokenizer(summaries, padding='max_length', max_length=128, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.ctxt_encoder(**input_dict)['pooler_output']

    def encode_question(self, question: str):
        input_dict = self.tokenizer(question, padding='max_length', max_length=32, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.question_encoder(**input_dict)['pooler_output']

    def retrieve_top_summaries(self, question: str, summaries: List[str], encoded_summaries: np.ndarray = None, topk: int = 5):
        encoded_question = self.encode_question(question)
        if encoded_summaries is None:
            encoded_summaries = self.encode_summaries(summaries)
        else:
            encoded_summaries = torch.from_numpy(encoded_summaries).to(self.device)

        scores = torch.mm(encoded_question, encoded_summaries.T)
        # print(encoded_question.shape)
        # print(encoded_summaries.T.shape)
        if topk >= len(summaries):
            return summaries
        top_k = torch.topk(scores, topk).indices.squeeze()
        
        #print("all scores : ")
        #for i,v in enumerate(scores[0]):
            #print(int(v.item()),end=" ")
        #print()
        
        return [summaries[i] for i in top_k]

In [15]:
rt = BiEncoderRetriever()

In [16]:
with open("./output.jsonl","w") as fout:
    print("start!")

with open("./persona_dialog.jsonl", "r") as dialogs:
    for num,dialog in enumerate(dialogs):
        if num == 500:
            break;
        else:
            print(f"now {num} dialog")
        dic = literal_eval(dialog)
        dialog_list = dic["current"]
        persona_list = dic["persona_list"]
        
        
        with open("./output.jsonl","a") as fout:
            resultlist = []
            for i,v in enumerate(dialog_list):
                
                question = v
                personalist = persona_list
                resultlist.append(rt.retrieve_top_summaries(question, personalist, None, topk=3))
            epl_list = []
            for d,r in zip(dialog_list,resultlist):
                epl_list.append({"utterance":d, "3persona":r})
            #newdic = {"current" : dialog_list, "persona_list": persona_list, "extracted_persona_list":epl_list}
            newdic = {"persona_list": persona_list, "extracted_persona_list":epl_list}
            fout.write(json.dumps(newdic, ensure_ascii=False) + "\n")

start!
now 0 dialog
['speaker 1: how is the preparation for your exam going?', "speaker 2: it's not.. really.  i've been procrastinating a bit.  any tips on how to buckle down?", 'speaker 1: procrastinating is a hard habit to break. just stop doing it.', "speaker 2: haha easier said than done!  kind of like monitoring your blood sugar, how's that going for you lately?  do you feel well?", "speaker 1: you are right my friend. it's still all over the place.", 'speaker 2: how does that work?  is it based on what you eat or is hereditary?  would exercising more help at all?', 'speaker 1: they told me that i got it by eating wrong in the first place by eating bad most of my life.', "speaker 2: oh, yeah, that's how my mom is.  she's right at the point where if it gets any higher she would be technically diabetic and have to be put on medicine.  do you take any medicine for it?", 'speaker 1: i use to but through time i was able to control it with diet. i think i got off kilter at christmas wi

In [8]:
question = "What is Sarah's favorite animal?"
personalist = ['Sarah is 24 years old.', 
               'Sarah currently lives in Canada.', 
               "Sarah is a swim coach at Sarah's local pool.", 
               'Sarah is studying to be a computer programmer.', 
               'Sarah is also a graduate student.', 
               'Sarah is now looking for a new job.', 
               "Sarah's mother is very traditional while Sarah prefers to be more free spirited.", 
               "Sarah's family and Sarah are from India.", 
               "Sarah's favorite music genre is death metal.", 
               'Sarah is a famous twitch streamer.', 
               'Sarah likes watching war documentaries.', 
               "Sarah's favorite food is mexican food.",
               "What is Sarah's favorite animal?"
              ]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

["What is Sarah's favorite animal?",
 "Sarah's favorite food is mexican food.",
 "Sarah's mother is very traditional while Sarah prefers to be more free "
 'spirited.']
