In [1]:
import torch
import torch.nn.functional as F
from transformers import RobertaTokenizer, RobertaModel
from collections import defaultdict
import json

In [2]:
tokenizer = RobertaTokenizer.from_pretrained("microsoft/codebert-base")
model = RobertaModel.from_pretrained("microsoft/codebert-base")

In [None]:
cluster = defaultdict(list)
with open("cluster.json", "r") as f:
    data = json.load(f)
    for k, value in data.items():
        cluster[int(k)] = [v[0] for v in value]

In [12]:
MAX_TEST_SIZE = 10

In [None]:
cluster_0 = cluster[0]
test_datapoints = cluster_0[:MAX_TEST_SIZE]

query_embeddings_list = []
code_embeddings_list = []
for td in test_datapoints:
    query_embeddings_list.append(
        tokenizer(
            td["query"],
            truncation=True,  # Truncate sequences longer than max_length
            max_length=32,
            padding="max_length",  # Pad sequences shorter than max_length
            return_tensors="pt",  # Return PyTorch tensors (optional, can be 'tf' or None)
        )
    )
    code_embeddings_list.append(
        tokenizer(
            td["code"],
            truncation=True,  # Truncate sequences longer than max_length
            max_length=128,
            padding="max_length",  # Pad sequences shorter than max_length
            return_tensors="pt",  # Return PyTorch tensors (optional, can be 'tf' or None)
        )
    )

In [None]:
print(len(query_embeddings_list[0]["input_ids"][0]))
print(len(code_embeddings_list[0]["input_ids"][0]))

32
128


In [None]:
query_hidden_states = []
code_hidden_states = []
for i in range(MAX_TEST_SIZE):
    with torch.no_grad():
        query_hidden_states.append(model(**query_embeddings_list[i]).last_hidden_state)
        code_hidden_states.append(model(**code_embeddings_list[i]).last_hidden_state)

In [None]:
query_embed_dims = []
code_embed_dims = []
for qhs in query_hidden_states:
    batch_size, query_len, embed_dim = qhs.size()
    query_embed_dims.appen((batch_size, query_len, embed_dim))
for chs in code_embed_dims:
    batch_size, query_len, embed_dim = qhs.size()
    query_embed_dims.appen((batch_size, query_len, embed_dim))

torch.Size([1, 32, 768])
