In [1]:
import os

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

In [3]:
import sys
sys.path.append('../')
from utils import load_data
from metrics import accuracy_at_k
from transformers import AutoModel, AutoTokenizer



In [4]:
import pandas as pd
import numpy as np
import scipy
from scipy.spatial.distance import cosine
import torch
from tqdm import tqdm

In [5]:
np.random.seed(42)

In [6]:
checkpoint = "Salesforce/codet5p-110m-embedding"
device = "cuda"

tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
model = AutoModel.from_pretrained(checkpoint, trust_remote_code=True).to(device)

In [7]:
model.eval()
print()




In [8]:
# settings :`cross_file_first`, `cross_file_random`, or `in_file`
settings = 'cross_file_first'
data = load_data('train', 'r', 'python', settings)

Loading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:04<00:00,  4.93s/it]


In [9]:
n_samples = 1_000

In [12]:
raw_samples = np.random.choice(data['easy'], n_samples)
# raw_samples = data['easy']

In [20]:
def print_metrics(preds, gts, ks=[1, 3, 5]):
    for k in ks:
        print(f'accuracy@{k}: {accuracy_at_k(preds, gts, k=k)}')

In [13]:
def get_embedding(src, tokenizer, model, device=device):
    tokens = tokenizer.encode(src, truncation=True, max_length=512, return_tensors="pt")
    embedd = model(tokens.to(device))[0].detach().cpu()
    return embedd

In [14]:
# samples = []
# for i, raw_sample in tqdm(enumerate(raw_samples), total=len(raw_samples)):
#     sample = {}
#     sample['code'] = tokenizer.encode(raw_sample['code'], return_tensors="pt")
#     sample['context'] = [tokenizer.encode(c, return_tensors="pt") for c in raw_sample['context']]
#     sample['target'] = raw_sample['golden_snippet_index']
#     sample['next_line'] = tokenizer.encode(raw_sample['next_line'], return_tensors="pt")
#     samples.append(sample)

In [15]:
# preds, gts = [], []

# for sample in tqdm(samples):
#     with torch.inference_mode():
#         nl_embedding = model(sample['next_line'].to(device))[0].detach().cpu()
#         context_embedding = torch.stack([model(ct[:10].to(device))[0].detach().cpu() for ct in sample['context']])
#         dist = nl_embedding @ context_embedding.T
#         preds.append(dist.argmax().item())
#         gts.append(sample['target'])

In [16]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = get_embedding(sample['next_line'], tokenizer, model, device)
        
        code_embedding = get_embedding(sample['code'], tokenizer, model, device)
        
        context_embedding = torch.stack([get_embedding(c, tokenizer, model, device) for c in sample['context']])
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:24<00:00,  6.90it/s]


In [49]:
for k in [1, 3, 5]:
    v = np.mean([accuracy_at_k([np.random.permutation(p) for p in asymm_preds], gts, k=k) for _ in range(100)])
    print(f'accuracy@{k}: {v}')

accuracy@1: 0.15948
accuracy@3: 0.47827
accuracy@5: 0.8007700000000001


In [50]:
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

symm
accuracy@1: 0.754
accuracy@3: 0.909
accuracy@5: 0.969
assym
accuracy@1: 0.142
accuracy@3: 0.43
accuracy@5: 0.797


In [23]:
for sample in raw_samples[:10]:
    print(sample['next_line'])
    print('-' * 50)

    with ctx.cd(".php-build"):
--------------------------------------------------
    plugin = ApiritifPytestPlugin(config)
--------------------------------------------------
            _reparse_binary_expression([Value(3, BQScalarType.INTEGER), '+'])
--------------------------------------------------
        [reservation.ListReservationsRequest],
--------------------------------------------------
        grid, grid_name, x, y, z, t, grid_t_idx, grid_x_idx, grid_z_idx = read_file_3d(f, var, x_name='x', y_name='y', z_name='z')
--------------------------------------------------
    url(r'^home/$', cache_page(cache_time)(HomeListAPIView.as_view()), name='home'),
--------------------------------------------------
        self.add_ramp(WordRamp, 'sentence')
--------------------------------------------------
                    self._DataFilePaths.append(ROOT_DIR + '/' + fits_dir)
--------------------------------------------------
    day_req = ReqRecord.objects.filter(uri='/lab/get_proxy/'

In [66]:
import re
def remove_docstrings(code):
    docstring_pattern = r'(""".*?"""|\'\'\'.*?\'\'\')'
    cleaned_code = re.sub(docstring_pattern, '', code, flags=re.DOTALL)
    return cleaned_code
    
def replace_class_and_function_names(code):
    """
    Replace all class and function names in the given code snippet with 'cls' and 'func' respectively.
    """
    # Regular expression to find class and function definitions
    class_pattern = r"\bclass\s+(\w+)"
    function_pattern = r"\bdef\s+(\w+)"

    # Replace class names with 'cls'
    replaced_code = re.sub(class_pattern, "class cls", code)

    # Replace function names with 'func', excluding special methods like __init__
    replaced_code = re.sub(function_pattern, lambda m: "def func" if not m.group(1).startswith('__') else m.group(0), replaced_code)

    return replaced_code

In [64]:
print(raw_samples[1]['context'][0])

class http(object):
    log = log.getChild('http')

    @staticmethod
    def target(*args, **kwargs):
        return HTTPTarget(*args, **kwargs)

    @staticmethod
    def request(method, address, session=None,
                params=None, headers=None, cookies=None, data=None, json=None, files=None,
                encrypted_cert=None, allow_redirects=True, timeout=30):
        """

        :param method: str
        :param address: str
        :return: response
        :rtype: HTTPResponse
        """
        http.log.info("Request: %s %s", method, address)
        msg = "Request: params=%r, headers=%r, cookies=%r, data=%r, json=%r, files=%r, allow_redirects=%r, timeout=%r"
        http.log.debug(msg, params, headers, cookies, data, json, files, allow_redirects, timeout)

        if headers is None:
            headers = {}
        if "User-Agent" not in headers:
            headers["User-Agent"] = "Apiritif"

        if session is None:
            session = requests.Session()

   

In [65]:
print(remove_docstrings(raw_samples[1]['context'][0]))

class http(object):
    log = log.getChild('http')

    @staticmethod
    def target(*args, **kwargs):
        return HTTPTarget(*args, **kwargs)

    @staticmethod
    def request(method, address, session=None,
                params=None, headers=None, cookies=None, data=None, json=None, files=None,
                encrypted_cert=None, allow_redirects=True, timeout=30):
        
        http.log.info("Request: %s %s", method, address)
        msg = "Request: params=%r, headers=%r, cookies=%r, data=%r, json=%r, files=%r, allow_redirects=%r, timeout=%r"
        http.log.debug(msg, params, headers, cookies, data, json, files, allow_redirects, timeout)

        if headers is None:
            headers = {}
        if "User-Agent" not in headers:
            headers["User-Agent"] = "Apiritif"

        if session is None:
            session = requests.Session()

        if encrypted_cert is not None:
            certificate_file_path, passphrase = encrypted_cert
            adapter = SSLAd

In [67]:
print(replace_class_and_function_names(raw_samples[1]['context'][0]))

class cls(object):
    log = log.getChild('http')

    @staticmethod
    def func(*args, **kwargs):
        return HTTPTarget(*args, **kwargs)

    @staticmethod
    def func(method, address, session=None,
                params=None, headers=None, cookies=None, data=None, json=None, files=None,
                encrypted_cert=None, allow_redirects=True, timeout=30):
        """

        :param method: str
        :param address: str
        :return: response
        :rtype: HTTPResponse
        """
        http.log.info("Request: %s %s", method, address)
        msg = "Request: params=%r, headers=%r, cookies=%r, data=%r, json=%r, files=%r, allow_redirects=%r, timeout=%r"
        http.log.debug(msg, params, headers, cookies, data, json, files, allow_redirects, timeout)

        if headers is None:
            headers = {}
        if "User-Agent" not in headers:
            headers["User-Agent"] = "Apiritif"

        if session is None:
            session = requests.Session()

        i

In [68]:
# Wihout docstrings

In [71]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = get_embedding(sample['next_line'], tokenizer, model, device)
        
        code_embedding = get_embedding(remove_docstrings(sample['code']), tokenizer, model, device)
        
        context_embedding = torch.stack([get_embedding(remove_docstrings(c), tokenizer, model, device) for c in sample['context']])
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:17<00:00,  7.25it/s]


In [72]:
print('without docstring')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without docstring
symm
accuracy@1: 0.744
accuracy@3: 0.91
accuracy@5: 0.971
assym
accuracy@1: 0.132
accuracy@3: 0.431
accuracy@5: 0.793


In [73]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = get_embedding(sample['next_line'], tokenizer, model, device)
        
        code_embedding = get_embedding(replace_class_and_function_names(sample['code']), tokenizer, model, device)
        
        context_embedding = torch.stack([get_embedding(replace_class_and_function_names(c), tokenizer, model, device) for c in sample['context']])
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [02:27<00:00,  6.77it/s]


In [74]:
print('without names')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without names
symm
accuracy@1: 0.598
accuracy@3: 0.839
accuracy@5: 0.951
assym
accuracy@1: 0.136
accuracy@3: 0.43
accuracy@5: 0.781


In [34]:
# Without docstring, names

In [35]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = get_embedding(sample['next_line'], tokenizer, model, device)
        
        code_embedding = get_embedding(replace_class_and_function_names(sample['code']), tokenizer, model, device)
        
        context_embedding = torch.stack([get_embedding(replace_class_and_function_names(c), tokenizer, model, device) for c in sample['context']])
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [10:05<00:00,  8.26it/s]


In [36]:
print('without docstring, names')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

accuracy@1: 0.5416
accuracy@3: 0.828
accuracy@5: 0.9478
