In [None]:
# MODEL
import torch
from datasets import load_dataset
from transformers import AutoModel, AutoTokenizer 

codesearchnet = load_dataset('code_search_net')

# Define the model repo
model_name_query = 'bert-base-uncased'
model_name_code = "huggingface/CodeBERTa-small-v1" 

tokenizer_query = AutoTokenizer.from_pretrained(model_name_query, add_prefix_space=True)
tokenizer_code = AutoTokenizer.from_pretrained(model_name_code, add_prefix_space=True)

model_query = AutoModel.from_pretrained(model_name_query)
model_code = AutoModel.from_pretrained(model_name_code)

In [None]:
# SAVE ENCODINGS TO DISC
import os
import torch
from tqdm.auto import tqdm

#os.mkdir('./data')
#os.mkdir('./data/train')
#os.mkdir('./data/train/query')
#os.mkdir('./data/train/code')
#os.mkdir('./data/validation')
#os.mkdir('./data/validation/query')
#os.mkdir('./data/validation/code')
#os.mkdir('./data/test')
#os.mkdir('./data/test/query')
#os.mkdir('./data/test/code')

max_length = 512
hidden_size = 768


device = 'cuda'

if torch.cuda.is_available():
    model_query = model_query.to(device)
    model_code = model_code.to(device)

with torch.no_grad():
    for split in ['train', 'validation', 'test']:
        print("Filtering python in {}".format(split))
        data = codesearchnet[split].filter(lambda x: x['language'] == 'python')
        print("Filtering DONE")
        
        num_samples = len(data['func_code_tokens'])
        progress_bar = tqdm(range(num_samples))
        
        filename = './data/{}/average_pooling_outputs_v2'.format(split)
        
        samples = torch.FloatTensor(torch.FloatStorage.from_file(filename, shared=True, size=num_samples * 2 * hidden_size)).reshape(num_samples, 2, hidden_size)
        
        for i, (query, code) in enumerate(zip(data['func_documentation_tokens'], data['func_code_tokens'])):
            progress_bar.update(1)

            query_tok = tokenizer_query(query, is_split_into_words = True, padding='max_length', return_tensors='pt')
            code_tok = tokenizer_code(code, is_split_into_words = True, padding='max_length', return_tensors='pt')

            if len(query_tok['input_ids'][0])  > max_length or len(code_tok['input_ids'][0])  > max_length:
                continue

            query_length = query_tok['attention_mask'].sum() - 2 # We don't want to include the first and last token
            code_length = code_tok['attention_mask'].sum() - 2 # We don't want to include the first and last token
            
            if torch.cuda.is_available():
                query_tok = query_tok.to(device)
                code_tok = code_tok.to(device)
            
            query_out = model_query(**query_tok)
            code_out = model_code(**code_tok)
            
            query_pool = query_out.last_hidden_state[:,1:query_length+1,:].mean(dim=1)
            code_pool = code_out.last_hidden_state[:,1:code_length+1,:].mean(dim=1)
            
            samples[i] = torch.cat((query_pool, code_pool), 0)
        