In [1]:
import os

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'

In [3]:
import sys
sys.path.append('../')
from utils import load_data, replace_class_and_function_names, remove_docstrings
from metrics import accuracy_at_k
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer



In [4]:
import pandas as pd
import numpy as np
import scipy
from scipy.spatial.distance import cosine
import torch
from tqdm import tqdm

In [5]:
np.random.seed(42)

In [6]:
def print_metrics(preds, gts, ks=[1, 3, 5]):
    for k in ks:
        print(f'accuracy@{k}: {accuracy_at_k(preds, gts, k=k)}')

In [7]:
device = "cuda"
model = SentenceTransformer('sentence-transformers/all-MiniLM-L12-v2')
model.to(device)

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [8]:
model.eval()
print()




In [9]:
# settings :`cross_file_first`, `cross_file_random`, or `in_file`
settings = 'cross_file_first'
data = load_data('train', 'r', 'python', settings)

Loading data: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:05<00:00,  5.43s/it]


In [11]:
# raw_samples = np.random.choice(data['easy'], n_samples)
raw_samples = data['easy']

In [12]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(sample['code'], convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode(sample['context'], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98038/98038 [1:26:46<00:00, 18.83it/s]


In [13]:
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

symm
accuracy@1: 0.77697
accuracy@3: 0.92849
accuracy@5: 0.97564
assym
accuracy@1: 0.13946
accuracy@3: 0.45182
accuracy@5: 0.78131


In [14]:
# Wihout docstrings

In [15]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(remove_docstrings(sample['code']), convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode([remove_docstrings(c) for c in  sample['context']], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98038/98038 [1:24:02<00:00, 19.44it/s]


In [16]:
print('without docstrings')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without docstrings
symm
accuracy@1: 0.77258
accuracy@3: 0.92677
accuracy@5: 0.9756
assym
accuracy@1: 0.13559
accuracy@3: 0.44889
accuracy@5: 0.78066


In [17]:
# Without names

In [18]:
for sample in raw_samples[:3]:
    for cntx in sample['context'][:2]:
        print(replace_class_and_function_names(cntx))
        print('-' * 50)


def func(config=ProdConfig):
    app = Flask(__name__)
    app.config.from_object(config)
    app.config.from_envvar('DOORMAN_SETTINGS', silent=True)

    register_blueprints(app)
    register_errorhandlers(app)
    register_loggers(app)
    register_extensions(app)
    register_auth_method(app)
    register_filters(app)

    return app
--------------------------------------------------
class cls(object):
class cls(CRUDMixin, db.Model):
class cls(object):
    def func(cls, **kwargs):
    def func(self, commit=True, **kwargs):
    def func(self, commit=True):
    def func(self, commit=True):
    def func(cls, record_id):
def func(tablename, nullable=False, pk_name='id', **kwargs):
--------------------------------------------------
def func(config=ProdConfig):
    app = Flask(__name__)
    app.config.from_object(config)
    app.config.from_envvar('DOORMAN_SETTINGS', silent=True)

    register_blueprints(app)
    register_errorhandlers(app)
    register_loggers(app)
    register_extension

In [19]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(replace_class_and_function_names(sample['code']), convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode([replace_class_and_function_names(c) for c in  sample['context']], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98038/98038 [1:29:30<00:00, 18.25it/s]


In [20]:
print('without names')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without names
symm
accuracy@1: 0.5793
accuracy@3: 0.84291
accuracy@5: 0.94924
assym
accuracy@1: 0.14272
accuracy@3: 0.45467
accuracy@5: 0.78229


In [21]:
# Without docstring, renaming

In [22]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)
        
        code_embedding = model.encode(remove_docstrings(replace_class_and_function_names(sample['code'])), convert_to_tensor=True, convert_to_numpy=False)
        
        context_embedding = model.encode([remove_docstrings(replace_class_and_function_names(c)) for c in  sample['context']], batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98038/98038 [1:25:27<00:00, 19.12it/s]


In [23]:
print('without names and docstrings')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

without names and docstrings
symm
accuracy@1: 0.5295
accuracy@3: 0.81043
accuracy@5: 0.93949
assym
accuracy@1: 0.14221
accuracy@3: 0.45543
accuracy@5: 0.78327


In [24]:
# Only names

In [25]:
import re

def extract_code_elements(code):
    """
    Extract class names, function names, class fields, and function arguments from the given code snippet.
    """
    # Regular expressions for class and function definitions, class fields, and function arguments
    class_pattern = r"\bclass\s+(\w+)"
    function_pattern = r"\bdef\s+(\w+)"
    class_field_pattern = r"\bself\.(\w+)"
    function_arg_pattern = r"\bdef\s+\w+\(([^)]*)\)"
    docstring_pattern = r'""".*?"""|\'\'\'.*?\'\'\''

    # Extract class and function names
    class_names = re.findall(class_pattern, code)
    function_names = re.findall(function_pattern, code)

    # Extract class fields and function arguments
    class_fields = re.findall(class_field_pattern, code)
    function_args = re.findall(function_arg_pattern, code)
    
    # Extract docstrings 
    docstrings = re.findall(docstring_pattern, code, re.DOTALL)

    # Process function arguments to split them into individual arguments
    processed_function_args = []
    for args in function_args:
        args = args.replace(' ', '').split(',')
        # Remove 'self' from arguments
        args = [arg for arg in args if arg != 'self' and arg]
        args = [arg.split('=')[0].strip() for arg in args]
        processed_function_args.extend(args)

    # Create a dictionary with unique names
    unique_names = {
        "class_names": list(set(class_names)),
        "function_names": list(set(function_names) - set(['__init__', '__str__', '__len__'])),
        "class_fields": list(set(class_fields)),
        "function_args": list(set(processed_function_args)),
        "docstrings": docstrings,
    }

    return unique_names

In [26]:
d = extract_code_elements(sample['context'][0])
d

{'class_names': ['JobSpec'],
 'function_names': ['add_workspec_list',
  'all_files_triggered_to_stage_out',
  'is_final_status',
  'set_all_input_ready',
  'get_pilot_type',
  'all_events_done',
  'get_groups_of_output_files',
  'get_files_to_delete',
  'has_attribute',
  'get_job_status_from_attributes',
  'get_output_file_attributes',
  'get_job_params',
  'update_group_status_in_files',
  'add_out_file',
  'trigger_preparation',
  'set_input_file_paths',
  'set_attributes',
  'get_status',
  'set_start_time',
  'get_input_file_attributes',
  'is_pilot_closed',
  'add_in_file',
  'not_suppress_heartbeat',
  'get_output_file_specs',
  'set_end_time',
  'add_file',
  'set_groups_to_files',
  'get_input_file_specs',
  'get_workspec_list',
  'convert_job_json',
  'trigger_stage_out',
  'set_pilot_closed',
  'trigger_propagation',
  'to_event_data',
  'get_logfile_info',
  'add_event',
  'get_groups_of_input_files',
  'set_pilot_error',
  'set_one_attribute',
  'get_job_attributes_for_pan

In [27]:
symm_preds, asymm_preds, gts = [], [], []

for sample in tqdm(raw_samples):
    with torch.inference_mode():
        nl_embedding = model.encode(sample['next_line'], convert_to_tensor=True, convert_to_numpy=False)

        # raw code
        code_embedding = model.encode(sample['code'], convert_to_tensor=True, convert_to_numpy=False)

        # only keywords
        new_context = []
        for c in sample['context']:
            code_meta = extract_code_elements(c)
            s = ""
            s += " ".join(code_meta['class_names']) + " "
            s += " ".join(code_meta['function_names']) + " "
            s += " ".join(code_meta['docstrings']) + " "
            
            new_context.append(s)
        context_embedding = model.encode(new_context, batch_size=16, convert_to_tensor=True, convert_to_numpy=False)
        
        symm_dist = nl_embedding @ context_embedding.T
        asymm_dist = code_embedding @ context_embedding.T
        
        symm_preds.append(symm_dist.argsort(descending=True).detach().cpu().numpy())
        asymm_preds.append(asymm_dist.argsort(descending=True).detach().cpu().numpy())
        
        gts.append(sample['golden_snippet_index'])

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 98038/98038 [1:19:45<00:00, 20.49it/s]


In [34]:
print(c)

def synchronize(func):
    def wrapper(*args, **kwargs):
    def __init__(self):
    def get_elapsed_time(self):
    def get_elapsed_time_in_sec(self, precise=False):
    def reset(self):
    def __init__(self):
    def __getitem__(self, item):
    def __setitem__(self, item, value):
    def __contains__(self, item):
    def acquire(self):
    def release(self):
    def iteritems(self):
    def __init__(cls, *args,**kwargs):
    def __call__(cls, *args, **kwargs):
    def __init__(cls, *args,**kwargs):
    def __call__(cls, *args, **kwargs):
def enable_memory_profiling():
def setup_logger(name=None):
def make_logger(tmp_log, token=None, method_name=None, hook=None):
def dump_error_message(tmp_log, err_str=None, no_message=False):
def sleep(interval, stop_event, randomize=True):
def make_pool_file_catalog(jobspec_list):
def calc_adler32(file_name):
def get_output_file_report(jobspec):
def create_shards(input_list, size):
def update_job_attributes_with_workers(map_type, jobspec_list, wor

In [28]:
print('context - only keywords and dosctrings')
print('symm')
print_metrics(symm_preds, gts)
print('assym')
print_metrics(asymm_preds, gts)

context - only keywords and dosctrings
symm
accuracy@1: 0.72368
accuracy@3: 0.89186
accuracy@5: 0.96084
assym
accuracy@1: 0.14054
accuracy@3: 0.45268
accuracy@5: 0.78271


In [30]:
del model

In [31]:
del symm_dist, asymm_dist