In [13]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="4"
import torch
import json
from ast import literal_eval # 걍 eval쓰면 스트링을 배열로 만들 수 있다.

In [14]:
import numpy as np

from transformers import AutoTokenizer, DPRQuestionEncoder, DPRContextEncoder
from typing import List

from pprint import pprint

class BiEncoderRetriever:
    def __init__(self) -> None:
        self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
        self.tokenizer = AutoTokenizer.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder")
        self.question_encoder = DPRQuestionEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-question-encoder").to(self.device)
        self.ctxt_encoder = DPRContextEncoder.from_pretrained("sivasankalpp/dpr-multidoc2dial-structure-ctx-encoder").to(self.device)

    def encode_summaries(self, summaries: List[str]):
        input_dict = self.tokenizer(summaries, padding='max_length', max_length=128, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.ctxt_encoder(**input_dict)['pooler_output']

    def encode_question(self, question: str):
        input_dict = self.tokenizer(question, padding='max_length', max_length=32, truncation=True, return_tensors="pt").to(self.device)
        del input_dict["token_type_ids"]
        return self.question_encoder(**input_dict)['pooler_output']

    def retrieve_top_summaries(self, question: str, summaries: List[str], encoded_summaries: np.ndarray = None, topk: int = 5):
        encoded_question = self.encode_question(question)
        if encoded_summaries is None:
            encoded_summaries = self.encode_summaries(summaries)
        else:
            encoded_summaries = torch.from_numpy(encoded_summaries).to(self.device)

        scores = torch.mm(encoded_question, encoded_summaries.T)
        # print(encoded_question.shape)
        # print(encoded_summaries.T.shape)
        if topk >= len(summaries):
            return summaries
        top_k = torch.topk(scores, topk).indices.squeeze()
        
        #print("all scores : ")
        #for i,v in enumerate(scores[0]):
            #print(int(v.item()),end=" ")
        #print()
        
        return [summaries[i] for i in top_k]

In [15]:
rt = BiEncoderRetriever()

In [21]:
with open("./output.jsonl","w") as fout:
    print("start!")

with open("./persona_dialog.jsonl", "r") as dialogs:
    for num,dialog in enumerate(dialogs):
        if num == 500:
            break;
        else:
            print(f"now {num} dialog")
        dic = literal_eval(dialog)
        dialog_list = dic["current"]
        persona_list = dic["persona_list"]
        
        
        with open("./output.jsonl","a") as fout:
            resultlist = []
            for i,v in enumerate(dialog_list):
                
                question = "".join(v.split()[2:])
                personalist = persona_list
                resultlist.append(rt.retrieve_top_summaries(question, personalist, None, topk=3))
            epl_list = []
            for d,r in zip(dialog_list,resultlist):
                epl_list.append({"utterance":d, "3persona":r})
            #newdic = {"current" : dialog_list, "persona_list": persona_list, "extracted_persona_list":epl_list}
            newdic = {"persona_list": persona_list, "extracted_persona_list":epl_list}
            fout.write(json.dumps(newdic, ensure_ascii=False) + "\n")

start!
now 0 dialog
now 1 dialog
now 2 dialog
now 3 dialog
now 4 dialog
now 5 dialog
now 6 dialog
now 7 dialog
now 8 dialog
now 9 dialog
now 10 dialog
now 11 dialog
now 12 dialog
now 13 dialog
now 14 dialog
now 15 dialog
now 16 dialog
now 17 dialog
now 18 dialog
now 19 dialog
now 20 dialog
now 21 dialog
now 22 dialog
now 23 dialog
now 24 dialog
now 25 dialog
now 26 dialog
now 27 dialog
now 28 dialog
now 29 dialog
now 30 dialog
now 31 dialog
now 32 dialog
now 33 dialog
now 34 dialog
now 35 dialog
now 36 dialog
now 37 dialog
now 38 dialog
now 39 dialog
now 40 dialog
now 41 dialog
now 42 dialog
now 43 dialog
now 44 dialog
now 45 dialog
now 46 dialog
now 47 dialog


In [8]:
question = "What is Sarah's favorite animal?"
personalist = ['Sarah is 24 years old.', 
               'Sarah currently lives in Canada.', 
               "Sarah is a swim coach at Sarah's local pool.", 
               'Sarah is studying to be a computer programmer.', 
               'Sarah is also a graduate student.', 
               'Sarah is now looking for a new job.', 
               "Sarah's mother is very traditional while Sarah prefers to be more free spirited.", 
               "Sarah's family and Sarah are from India.", 
               "Sarah's favorite music genre is death metal.", 
               'Sarah is a famous twitch streamer.', 
               'Sarah likes watching war documentaries.', 
               "Sarah's favorite food is mexican food.",
               "What is Sarah's favorite animal?"
              ]
pprint(rt.retrieve_top_summaries( # here is using encodings.?!@
                question, personalist, None, topk=3
            ))

["What is Sarah's favorite animal?",
 "Sarah's favorite food is mexican food.",
 "Sarah's mother is very traditional while Sarah prefers to be more free "
 'spirited.']
