In [1]:
from transformers import AutoModel, AutoTokenizer

checkpoint = "Salesforce/codet5p-110m-embedding"
device = 'cuda'

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)

# gpu oom을 방지하기 위한 cpu 모델 추가
cpu_model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to('cpu')

In [2]:
from datasets import load_dataset

dataset = load_dataset('KonradSzafer/stackoverflow_python_preprocessed')

# answer list 구성
answer_texts = [data['answer'] for data in dataset['train']]

In [3]:
from torch.utils.data import Dataset, DataLoader


class TokenDataset(Dataset):
    def __init__(self, answer_texts, tokenizer):
        super().__init__()
        answer_tokens_ids = [tokenizer.encode(text, return_tensors="pt")[0] for text in answer_texts]
        self.answer_tokens_ids = answer_tokens_ids
        self.len = len(answer_texts)
    
    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return self.answer_tokens_ids[index]


# 임베딩을 구하기 위한 dataloader 클래스 구성
token_dataset = TokenDataset(answer_texts, tokenizer)
token_dataloader = DataLoader(token_dataset, batch_size=1, shuffle=False)

Token indices sequence length is longer than the specified maximum sequence length for this model (678 > 512). Running this sequence through the model will result in indexing errors


In [4]:
from tqdm import tqdm
import torch
import pickle

answer_embeddings = []

for batch in tqdm(token_dataloader):
    if len(batch[0]) <= 8000:
        with torch.no_grad():
            embedding = model(batch.to(device))[0].detach().cpu()
    else:
        with torch.no_grad():
            embedding = cpu_model(batch)[0]
    answer_embeddings.append(embedding)


with open('answer_embeddings.pkl', 'wb') as f:
    pickle.dump(answer_embeddings, f)

100%|██████████| 3296/3296 [02:15<00:00, 24.36it/s] 


In [4]:
from tqdm import tqdm
import torch
import pickle

with open('answer_embeddings.pkl', 'rb') as f:
    answer_embeddings = pickle.load(f)

In [5]:
# 정규화 한 임베딩 텐서 구성
norm_answer_embeddings = [torch.nn.functional.normalize(answer, p=2, dim=-1) for answer in answer_embeddings]
answer_embeddings_concat = torch.stack(norm_answer_embeddings).detach().cpu()
answer_embeddings_concat.shape

torch.Size([3296, 256])

In [6]:
# 중복된 query를 방지하기 위해 unique한 query list 및 query 별 answer list 구성
# example
# unique_questions = [query_1, query_2, query_3, ...]
# all_answers = [[answer_1_for_query_1, answer_2_for_query_2, ...], [answer_1_for_query_2, answer_2_for_query_2, ...], ...]

prev_question = ''
prev_answers = []
unique_questions = []
unique_answers = []
all_answers = []

for data in dataset['train']:
    title = data['title']
    question = data['question']
    answer = data['answer']
    if prev_question != question:
        if prev_answers:
            all_answers.append(prev_answers)
            prev_answers = []
        prev_question = question
        unique_questions.append(title + '\n' + question)
        unique_answers.append(answer)
        prev_answers.append(answer)
    else:
        prev_answers.append(answer)
all_answers.append(prev_answers)

In [13]:
def get_relevant_documents(query, topk=5):
    question_tokens_ids = tokenizer.encode(query, return_tensors="pt")
    if len(question_tokens_ids[0]) <= 5000:
        question_embedding = model(question_tokens_ids.to(device))[0]
    else:
        question_embedding = cpu_model(question_tokens_ids)[0]
    
    question_tokens_ids.detach().cpu()
    question_embedding.detach().cpu()
    norm_question_embedding = torch.nn.functional.normalize(question_embedding, p=2, dim=-1)
    similarity = torch.matmul(answer_embeddings_concat.to('cpu'), norm_question_embedding.T.to('cpu'))
    argsorted = list(torch.argsort(similarity, descending=True)[:topk])
    return [answer_texts[idx] for idx in argsorted]

In [14]:
# GT answer 중 하나라도 retrieve가 되면 true로 판별

is_retrieved = []
for query, answers in tqdm(zip(unique_questions, all_answers)):
    retrived_documents = get_relevant_documents(query)
    retrieved = False
    for answer in answers:
        if not retrieved:
            for doc in retrived_documents:
                if answer == doc:
                    retrieved = True
                    break
    is_retrieved.append(retrieved)

  similarity = torch.matmul(answer_embeddings_concat.to('cpu'), norm_question_embedding.T.to('cpu'))
962it [02:18,  6.97it/s]


In [17]:
true_ratio = sum([1 for res in is_retrieved if res ])/len(is_retrieved)
print(f'True Ratio:{true_ratio*100:0.2f}%')

True Ratio:76.30%
