In [1]:
from data import file_to_lists
from model import get_code_embeddings, get_cos_similarity, retrieve_topN

from tqdm import tqdm

repo = 'keon/algorithms'
repo_info = {}
function_list = file_to_lists(f"content/output/{repo}/directory_info.json")
repo_info["funcs"] = function_list


print("Generating code embeddings for dataset ... ")
code_embeddings = []
for func in tqdm(repo_info["funcs"]):
    code_embeddings.append(get_code_embeddings(func))
    
print("Dataset code embeddings generated!")

  from .autonotebook import tqdm as notebook_tqdm


Generating code embeddings for dataset ... 


100%|███████████████████████████████████████████████████████████████████████████████| 1171/1171 [01:02<00:00, 18.63it/s]

Dataset code embeddings generated!





In [7]:
class CodeSearchEngine:
    def __init__(self, repo_info, code_embeddings):
        self.repo_info = repo_info
        self.code_embeddings = code_embeddings

    def search(self, query, n):
        code_embeddings = self.code_embeddings
        input_embedding = get_code_embeddings(query)
        similarities = get_cos_similarity(input_embedding, code_embeddings)
        similar_func_names = retrieve_topN(self.repo_info, similarities, n)
        
        print('The most similiar {n} code snippets:')
        for func_name in similar_func_names:
            print(f'\n------------------------------------------------------------------\n {func_name}')

# Instantiate the CodeSearchEngine
se = CodeSearchEngine(repo_info, code_embeddings)


In [10]:
from IPython.core.magic import (register_line_magic, register_cell_magic)

@register_line_magic
def search(line):
    query = line.strip()
    n = int(input("How many similar code snippets you want to retrieve: \n"))
    se.search(query, n)

In [11]:
%search
"""
def test_topsort(self):
    res_recursive = top_sort_recursive(self.depGraph)
    self.assertTrue(res_recursive.index('g') < res_recursive.index('e'))
    
    res_iterative = top_sort(self.depGraph)
    self.assertTrue(res_iterative.index('g') < res_iterative.index('e'))

"""

How many similar code snippets you want to retrieve: 
 5


The most similiar {n} code snippets:

------------------------------------------------------------------
 def main():
    (m, n) = map(int, input('Enter two positive integers: ').split())
    count_paths(m, n)

------------------------------------------------------------------
 def as_list(self):
    """ Return interval as list. """
    return list(self)

------------------------------------------------------------------
 def __repr__(self):
    return 'Interval ({}, {})'.format(self.start, self.end)

------------------------------------------------------------------
 def decrypt(data, d, n):
    return pow(int(data), int(d), int(n))

------------------------------------------------------------------
 def multiply(multiplicand: list, multiplier: list) -> list:
    """
    :type A: List[List[int]]
    :type B: List[List[int]]
    :rtype: List[List[int]]
    """
    (multiplicand_row, multiplicand_col) = (len(multiplicand), len(multiplicand[0]))
    (multiplier_row, multiplier_col) = (le

"\ndef test_topsort(self):\n    res_recursive = top_sort_recursive(self.depGraph)\n    self.assertTrue(res_recursive.index('g') < res_recursive.index('e'))\n    \n    res_iterative = top_sort(self.depGraph)\n    self.assertTrue(res_iterative.index('g') < res_iterative.index('e'))\n\n"