In [1]:
import torch
from unixcoder import UniXcoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniXcoder("microsoft/unixcoder-base").to(device).eval()

In [2]:
from datasets import load_dataset

dataset = load_dataset('KonradSzafer/stackoverflow_python_preprocessed')

# answer list 구성
answer_texts = [data['answer'] for data in dataset['train']]
answer_token_ids = model.tokenize(answer_texts)

In [5]:
from torch.utils.data import Dataset, DataLoader

class TokenDataset(Dataset):
    def __init__(self, tokens_ids):
        super().__init__()
        self.tokens_ids = tokens_ids
        self.len = len(tokens_ids)
    
    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return torch.Tensor(self.tokens_ids[index])
    
token_dataset = TokenDataset(answer_token_ids)
token_dataloader = DataLoader(token_dataset, batch_size=1, shuffle=False)

In [13]:
from tqdm import tqdm

token_embeddings = []

for batch in tqdm(token_dataloader):
    inputs = batch.to(device).long()
    _, token_embedding = model(inputs)

    token_embeddings.append(token_embedding.detach().cpu())

100%|██████████| 3296/3296 [00:39<00:00, 82.75it/s] 


In [14]:
# 정규화 한 임베딩 텐서 구성
norm_answer_embeddings = [torch.nn.functional.normalize(answer, p=2, dim=-1) for answer in token_embeddings]
answer_embeddings_concat = torch.stack(norm_answer_embeddings).detach().cpu()
answer_embeddings_concat = torch.squeeze(answer_embeddings_concat)
answer_embeddings_concat.shape

torch.Size([3296, 768])

In [15]:
# 중복된 query를 방지하기 위해 unique한 query list 및 query 별 answer list 구성
# example
# unique_questions = [query_1, query_2, query_3, ...]
# all_answers = [[answer_1_for_query_1, answer_2_for_query_2, ...], [answer_1_for_query_2, answer_2_for_query_2, ...], ...]

prev_question = ''
prev_answers = []
unique_questions = []
unique_answers = []
all_answers = []

for data in dataset['train']:
    title = data['title']
    question = data['question']
    answer = data['answer']
    if prev_question != question:
        if prev_answers:
            all_answers.append(prev_answers)
            prev_answers = []
        prev_question = question
        unique_questions.append(title + '\n' + question)
        unique_answers.append(answer)
        prev_answers.append(answer)
    else:
        prev_answers.append(answer)
all_answers.append(prev_answers)

In [46]:
def get_relevant_documents(query, topk=5):
    question_tokens_ids = torch.Tensor(model.tokenize([query])).to(device).long()
    _, question_embedding = model(question_tokens_ids)
    # if len(question_tokens_ids[0]) <= 5000:
    #     question_embedding = model(question_tokens_ids.to(device))[0]
    # else:
    #     question_embedding = cpu_model(question_tokens_ids)[0]
    
    question_tokens_ids.detach().cpu()
    question_embedding.detach().cpu()
    norm_question_embedding = torch.nn.functional.normalize(question_embedding[0], p=2, dim=-1)
    similarity = torch.matmul(answer_embeddings_concat.to('cpu'), norm_question_embedding.T.to('cpu'))
    argsorted = list(torch.argsort(similarity, descending=True)[:topk])
    return [answer_texts[idx] for idx in argsorted]

In [49]:
# GT answer 중 하나라도 retrieve가 되면 true로 판별

is_retrieved = []
for query, answers in tqdm(zip(unique_questions, all_answers)):
    retrived_documents = get_relevant_documents(query)
    retrieved = False
    for answer in answers:
        if not retrieved:
            for doc in retrived_documents:
                if answer == doc:
                    retrieved = True
                    break
    is_retrieved.append(retrieved)

962it [00:13, 71.92it/s]


In [51]:
true_ratio = sum([1 for res in is_retrieved if res ])/len(is_retrieved)
print(f'True Ratio:{true_ratio*100:0.2f}%')

True Ratio:59.88%
