In [1]:
import torch
from unixcoder import UniXcoder

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UniXcoder("microsoft/unixcoder-base")
model.to(device)
model.eval()
print(device)

cuda


In [2]:
from datasets import load_dataset

conala = load_dataset('neulab/conala')

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.


In [3]:
conala['train'][0]

{'question_id': 41067960,
 'intent': 'How to convert a list of multiple integers into a single integer?',
 'rewritten_intent': "Concatenate elements of a list 'x' of multiple integers to a single integer",
 'snippet': 'sum(d * 10 ** i for i, d in enumerate(x[::-1]))'}

In [4]:
# Evaluation Funtions

def get_hit_rate(k, topk_list, n):
    hits = sum([1 if topk < k else 0 for topk in topk_list])
    return hits/n * 100

def mean_reciprocal_rank(ranks):
    reciprocal_ranks = 1.0 / ranks.float()
    mrr = torch.mean(reciprocal_ranks)
    return mrr.item() * 100

In [14]:
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

class TokenDataset(Dataset):
    def __init__(self, intent_tokens_ids, snippet_tokens_ids):
        super().__init__()
        self.intent_tokens_ids = intent_tokens_ids
        self.snippet_tokens_ids = snippet_tokens_ids
        self.len = len(intent_tokens_ids)
    
    def __len__(self):
        return self.len

    def __getitem__(self, index):
        return {
            'intent': torch.Tensor(self.intent_tokens_ids[index]),
            'snippet': torch.Tensor(self.snippet_tokens_ids[index]),
        }

# Trainset

In [11]:
raw_train_intent_texts = [data['rewritten_intent'] for data in conala['train']]
train_intent_texts = []
train_snippet_texts = []

for idx, text in enumerate(raw_train_intent_texts):
    if text:
        train_intent_texts.append(text)
        train_snippet_texts.append(conala['train'][idx]['snippet'])

In [12]:
train_intent_tokens_ids = model.tokenize(train_intent_texts)
train_snippet_tokens_ids = model.tokenize(train_snippet_texts)

In [13]:
token_dataset = TokenDataset(train_intent_tokens_ids, train_snippet_tokens_ids)
token_dataloader = DataLoader(token_dataset, batch_size=1, shuffle=False)

In [15]:
train_intent_embeddings = []
train_snippet_embeddings = []

for batch in tqdm(token_dataloader):
    inputs = batch['intent'].to(device).long()
    outputs = batch['snippet'].to(device).long()
    _, intent_embedding = model(inputs)
    _, snippet_embedding = model(outputs)

    train_intent_embeddings.append(intent_embedding.detach().cpu())
    train_snippet_embeddings.append(snippet_embedding.detach().cpu())

100%|██████████| 2300/2300 [00:39<00:00, 57.87it/s]


In [16]:
train_norm_intent_embeddings = [torch.nn.functional.normalize(intent, p=2, dim=1) for intent in train_intent_embeddings]
train_norm_snippet_embeddings = [torch.nn.functional.normalize(snippet, p=2, dim=1) for snippet in train_snippet_embeddings]

In [17]:
train_snippet_emb_concat = torch.concat(train_norm_snippet_embeddings)

In [18]:
train_topk_list = []

for idx, emb in enumerate(tqdm(train_norm_intent_embeddings)):
    similarity = torch.matmul(train_snippet_emb_concat, emb.T).squeeze()
    argsorted = torch.argsort(similarity, descending=True)
    topk = (argsorted == idx).nonzero(as_tuple=True)[0].item()
    train_topk_list.append(topk)

100%|██████████| 2300/2300 [00:00<00:00, 3307.29it/s]


In [20]:
n = len(train_topk_list)
print(f'Recall@10:\t{get_hit_rate(10, train_topk_list, n):0.2f}%')
print(f'Recall@50:\t{get_hit_rate(50, train_topk_list, n):0.2f}%')
print(f'Recall@100:\t{get_hit_rate(100, train_topk_list, n):0.2f}%')
print(f'Recall@200:\t{get_hit_rate(200, train_topk_list, n):0.2f}%')
print(f'MRR: {mean_reciprocal_rank(torch.Tensor(train_topk_list)+1):0.2f}%')

Recall@10:	88.91%
Recall@50:	95.87%
Recall@100:	97.39%
Recall@200:	98.65%
MRR: 72.47%


# Testset

In [21]:
raw_test_intent_texts = [data['rewritten_intent'] for data in conala['test']]
test_intent_texts = []
test_snippet_texts = []

for idx, text in enumerate(raw_test_intent_texts):
    if text:
        test_intent_texts.append(text)
        test_snippet_texts.append(conala['test'][idx]['snippet'])

In [22]:
test_intent_tokens_ids = model.tokenize(test_intent_texts)
test_snippet_tokens_ids = model.tokenize(test_snippet_texts)

In [25]:
token_dataset = TokenDataset(test_intent_tokens_ids, test_snippet_tokens_ids)
token_dataloader = DataLoader(token_dataset, batch_size=1, shuffle=False)

In [26]:
test_intent_embeddings = []
test_snippet_embeddings = []

for batch in tqdm(token_dataloader):
    inputs = batch['intent'].to(device).long()
    outputs = batch['snippet'].to(device).long()
    _, intent_embedding = model(inputs)
    _, snippet_embedding = model(outputs)

    test_intent_embeddings.append(intent_embedding.detach().cpu())
    test_snippet_embeddings.append(snippet_embedding.detach().cpu())

100%|██████████| 477/477 [00:09<00:00, 52.93it/s]


In [27]:
test_norm_intent_embeddings = [torch.nn.functional.normalize(intent, p=2, dim=1) for intent in test_intent_embeddings]
test_norm_snippet_embeddings = [torch.nn.functional.normalize(snippet, p=2, dim=1) for snippet in test_snippet_embeddings]

In [28]:
test_snippet_emb_concat = torch.concat(test_norm_snippet_embeddings)

In [30]:
test_topk_list = []

for idx, emb in enumerate(tqdm(test_norm_intent_embeddings)):
    similarity = torch.matmul(test_snippet_emb_concat, emb.T).squeeze()
    argsorted = torch.argsort(similarity, descending=True)
    topk = (argsorted == idx).nonzero(as_tuple=True)[0].item()
    test_topk_list.append(topk)

100%|██████████| 477/477 [00:00<00:00, 8526.53it/s]


In [31]:
n = len(test_topk_list)
print(f'Recall@10:\t{get_hit_rate(10, test_topk_list, n):0.2f}%')
print(f'Recall@50:\t{get_hit_rate(50, test_topk_list, n):0.2f}%')
print(f'Recall@100:\t{get_hit_rate(100, test_topk_list, n):0.2f}%')
print(f'Recall@200:\t{get_hit_rate(200, test_topk_list, n):0.2f}%')
print(f'MRR: {mean_reciprocal_rank(torch.Tensor(test_topk_list)+1):0.2f}%')

Recall@10:	95.81%
Recall@50:	99.79%
Recall@100:	100.00%
Recall@200:	100.00%
MRR: 79.63%


# Trainset + Testset

In [32]:
total_norm_intent_embeddings = train_norm_intent_embeddings + test_norm_intent_embeddings
total_norm_snippet_embeddings = train_norm_snippet_embeddings + test_norm_snippet_embeddings

In [33]:
total_snippet_emb_concat = torch.concat(total_norm_snippet_embeddings)

In [35]:
total_topk_list = []

for idx, emb in enumerate(tqdm(total_norm_intent_embeddings)):
    similarity = torch.matmul(total_snippet_emb_concat, emb.T).squeeze()
    argsorted = torch.argsort(similarity, descending=True)
    topk = (argsorted == idx).nonzero(as_tuple=True)[0].item()
    total_topk_list.append(topk)

100%|██████████| 2777/2777 [00:00<00:00, 3192.27it/s]


In [36]:
n = len(total_topk_list)
print(f'Recall@10:\t{get_hit_rate(10, total_topk_list, n):0.2f}%')
print(f'Recall@50:\t{get_hit_rate(50, total_topk_list, n):0.2f}%')
print(f'Recall@100:\t{get_hit_rate(100, total_topk_list, n):0.2f}%')
print(f'Recall@200:\t{get_hit_rate(200, total_topk_list, n):0.2f}%')
print(f'MRR: {mean_reciprocal_rank(torch.Tensor(total_topk_list)+1):0.2f}%')

Recall@10:	87.94%
Recall@50:	95.14%
Recall@100:	97.23%
Recall@200:	98.38%
MRR: 71.33%
