In [2]:
import os

In [3]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

In [1]:
import sys
sys.path.append('../')
from utils import load_data, replace_class_and_function_names, remove_docstrings
from metrics import accuracy_at_k
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer



In [4]:
import pandas as pd
import numpy as np
import scipy
from scipy.spatial.distance import cosine
import torch
from tqdm import tqdm

In [5]:
np.random.seed(42)

In [6]:
def print_metrics(preds, gts, ks=[1, 3, 5]):
    for k in ks:
        print(f'accuracy@{k}: {accuracy_at_k(preds, gts, k=k)}')

In [7]:
device = "cuda"
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
model.to(device)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [8]:
model.eval()
print()




In [9]:
# settings :`cross_file_first`, `cross_file_random`, or `in_file`
settings = 'cross_file_first'
data = load_data('train', 'r', 'python', settings)

Loading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.31s/it]


In [10]:
n_samples = 1_000

In [11]:
raw_samples = np.random.choice(data['easy'], n_samples)
# raw_samples = data['easy']

In [12]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(sample['code'], convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode(sample['context'], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:51<00:00, 19.35it/s]


In [15]:
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

symm
accuracy@1: 0.78
accuracy@3: 0.922
accuracy@5: 0.976
assym
accuracy@1: 0.188
accuracy@3: 0.511
accuracy@5: 0.828


In [16]:
# Wihout docstrings

In [18]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(remove_docstrings(sample['code']), convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode([remove_docstrings(c) for c in  sample['context']], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:50<00:00, 19.97it/s]


In [19]:
print('without docstrings')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without docstrings
symm
accuracy@1: 0.773
accuracy@3: 0.922
accuracy@5: 0.98
assym
accuracy@1: 0.17
accuracy@3: 0.489
accuracy@5: 0.83


In [30]:
# Without names

In [33]:
for sample in raw_samples[:3]:
    for cntx in sample['context'][:2]:
        print(replace_class_and_function_names(cntx))
        print('-' * 50)


def func():
def func(x, partitioning_dims: int):
def func(
    x: T,
    activation_partitioning_dims: Optional[int],
    logical_axis_names: Tuple[str, ...],
) -> T:
T = TypeVar('T')
--------------------------------------------------
class cls(nn.Module, param_remapping.ParameterRemappable):
class cls:
  def func(self):
  def __call__(self, inputs: Array, *args, **kwargs) -> Array:
  def func(self, start_idx: int, end_idx: Optional[int],
                            inputs: Array, *args, **kwargs) -> Array:
  def __call__(self, inputs: Array, *args, **kwargs) -> Array:
  def func(self, start_idx: int, end_idx: Optional[int],
                            inputs: Array, *args, **kwargs) -> Array:
--------------------------------------------------
class cls(ClientMixin):
    """
    The representation of a job on a remote hoplite server
    """

    def __init__(self, address, port=5000, name="", uuid="", api_key="", config={}):
        """
        :param address: IP address or hostname of

In [20]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(replace_class_and_function_names(sample['code']), convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode([replace_class_and_function_names(c) for c in  sample['context']], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:54<00:00, 18.45it/s]


In [21]:
print('without names')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without names
symm
accuracy@1: 0.561
accuracy@3: 0.842
accuracy@5: 0.964
assym
accuracy@1: 0.167
accuracy@3: 0.487
accuracy@5: 0.818


In [22]:
# Without docstring, renaming

In [23]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(remove_docstrings(replace_class_and_function_names(sample['code'])), convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode([remove_docstrings(replace_class_and_function_names(c)) for c in  sample['context']], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:50<00:00, 19.65it/s]


In [24]:
print('without names and docstrings')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without names and docstrings
symm
accuracy@1: 0.512
accuracy@3: 0.824
accuracy@5: 0.957
assym
accuracy@1: 0.157
accuracy@3: 0.477
accuracy@5: 0.82


In [25]:
# Only names

In [58]:
import re

def extract_code_elements(code):
    """
    Extract class names, function names, class fields, and function arguments from the given code snippet.
    """
    # Regular expressions for class and function definitions, class fields, and function arguments
    class_pattern = r"\bclass\s+(\w+)"
    function_pattern = r"\bdef\s+(\w+)"
    class_field_pattern = r"\bself\.(\w+)"
    function_arg_pattern = r"\bdef\s+\w+\(([^)]*)\)"
    docstring_pattern = r'""".*?"""|\'\'\'.*?\'\'\''

    # Extract class and function names
    class_names = re.findall(class_pattern, code)
    function_names = re.findall(function_pattern, code)

    # Extract class fields and function arguments
    class_fields = re.findall(class_field_pattern, code)
    function_args = re.findall(function_arg_pattern, code)
    
    # Extract docstrings 
    docstrings = re.findall(docstring_pattern, code, re.DOTALL)

    # Process function arguments to split them into individual arguments
    processed_function_args = []
    for args in function_args:
        args = args.replace(' ', '').split(',')
        # Remove 'self' from arguments
        args = [arg for arg in args if arg != 'self' and arg]
        args = [arg.split('=')[0].strip() for arg in args]
        processed_function_args.extend(args)

    # Create a dictionary with unique names
    unique_names = {
        "class_names": list(set(class_names)),
        "function_names": list(set(function_names) - set(['__init__', '__str__', '__len__'])),
        "class_fields": list(set(class_fields)),
        "function_args": list(set(processed_function_args)),
        "docstrings": docstrings,
    }

    return unique_names

In [36]:
d = extract_code_elements(sample['context'][0])
d

{'class_names': ['BufferedStream'],
 'function_names': ['_bytes_remaining',
  'stream_end_position',
  'read',
  'stream_exhausted'],
 'class_fields': [],
 'function_args': ['start', 'stream', 'size']}

In [60]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)

        # raw code
        code_embedding = model.encode(sample['code'], convert_to_tensor=True, convert_to_numpy=False)

        # only keywords
        new_context = []
        for c in sample['context']:
            code_meta = extract_code_elements(c)
            s = ""
            s += " ".join(code_meta['class_names']) + " "
            s += " ".join(code_meta['function_names']) + " "
            s += " ".join(code_meta['docstrings']) + " "
            
            new_context.append(s)
        context_embedding = model.encode(new_context, batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:48<00:00, 20.44it/s]


In [61]:
print('context - only keywords and dosctrings')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

context - only keywords and dosctrings
symm
accuracy@1: 0.723
accuracy@3: 0.901
accuracy@5: 0.97
assym
accuracy@1: 0.169
accuracy@3: 0.51
accuracy@5: 0.835
