This notebook demonstrates the full process of `SemanticCodeSearch` using fine-tuned GraphCodeBERT model, which implement the code-to-code search.

### Download test repositories and run `inspect4py` on them

In [1]:
# Repository picked from https://github.com as an example
repo = 'keon/algorithms'

In [2]:
!inspect4py --version

inspect4py, version 0.0.6


In [3]:
!mkdir -p content/output
%cd content/

!mkdir -p {repo} && git clone {f"https://github.com/{repo}.git"} {repo}
!inspect4py -i {repo} -o output/{repo} -sc -rm
%cd ..

/cs/home/cd271/Documents/Project/Examples/RepoAnalysis/SemanticCodeSearch/Text2code/content
Cloning into 'keon/algorithms'...
remote: Enumerating objects: 5162, done.[K
remote: Counting objects: 100% (26/26), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 5162 (delta 11), reused 16 (delta 3), pack-reused 5136[K
Receiving objects: 100% (5162/5162), 1.42 MiB | 10.99 MiB/s, done.
Resolving deltas: 100% (3231/3231), done.
Updating files: 100% (477/477), done.
Creating jsDir:output/keon/algorithms/algorithms/json_files
Creating jsDir:output/keon/algorithms/algorithms/algorithms/json_files
Creating jsDir:output/keon/algorithms/algorithms/algorithms/streaming/json_files
Creating jsDir:output/keon/algorithms/algorithms/algorithms/map/json_files
Error when processing separate_chaining_hashtable.py:  <class 'AttributeError'>
Error when processing hashtable.py:  <class 'AttributeError'>
Creating jsDir:output/keon/algorithms/algorithms/algorithms/stack/json_files
Erro

### Extract docstrings and functions from repositories.

In [4]:
import json

def funcs_to_lists(funcs, func_codes, docs):
    for func_name, func_info in funcs.items():
        if func_info.get("source_code") is not None:
            func_codes.append(func_info["source_code"])
        if func_info.get("doc") is None:
            continue
        for key in ["full", "long_description", "short_description"]:
            if func_info["doc"].get(key) is not None:
                docs.append(f"{func_name} {func_info['doc'].get(key)}")
                break

def file_to_lists(filename):
    func_codes = []
    docs = []
    with open(filename, "r") as f:
        dic = json.load(f)
    dic.pop("readme_files", None)
    for dir_name, files in dic.items():
        for file in files:
            if file.get("functions") is not None:
                funcs_to_lists(file["functions"], func_codes, docs)
            if file.get("classes") is not None:
                for class_name, class_info in file["classes"].items():
                    if class_info.get("methods") is not None:
                        funcs_to_lists(class_info["methods"], func_codes, docs)
    return func_codes, docs

In [5]:
repo_info = {}
function_list, _ = file_to_lists(f"content/output/{repo}/directory_info.json")

### Download UniXCoder, fine-tuned model and install requirements

In [6]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "Salesforce/codet5-base-multi-sum"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def get_code_sum(funcs):
    inputs = tokenizer.batch_encode_plus(
        funcs,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    # Perform inference to get code similarity
    with torch.no_grad():
        inputs = {k: v.to(model.device) for k, v in inputs.items()}
        outputs = model.generate(**inputs)
        
    similar_code_snippets = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return similar_code_snippets

In [8]:
from tqdm import tqdm

# Obtain function_list summarizations 
code_sum = []
with tqdm(total=len(function_list), desc="Generating code summaries") as pbar:
    for funcs in function_list:
        code_snippets = get_code_sum([funcs])
        code_sum.extend(code_snippets)
        pbar.update(1)

Generating code summaries: 100%|████████████████████████████████████████████████████| 1171/1171 [06:05<00:00,  3.21it/s]


In [9]:
# Example for input query
query = ['check if the number in bitsum is vilad']

In [11]:
from sentence_transformers import SentenceTransformer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sum_model = SentenceTransformer("all-mpnet-base-v2", device=device)

def get_embedding(code_sum, query_sum):
    return sum_model.encode(code_sum, convert_to_tensor=True), sum_model.encode(query, convert_to_tensor=True)
    
code_sum_embeddings, query_sum_embedding = get_embedding(code_sum, query)

In [13]:
from torch.nn import CosineSimilarity

cosine_sim = CosineSimilarity(dim=1)
similarities = cosine_sim(query_sum_embedding, code_sum_embeddings).tolist()

In [14]:
def find_top_n_index(lst, n):
    largest_indices = []
    for i in range(n):
        max_value = max(lst)
        max_index = lst.index(max_value)
        largest_indices.append(max_index)
        lst[max_index] = float('-inf')
    return largest_indices

In [15]:
sim = similarities.copy()
index = find_top_n_index(sim,5)
print('Similiar code snippet:\n')
for i in index:
    print(f'Similarity: {similarities[i]}, \n{function_list[i]} \n--------------------------------------------------------------')

Similiar code snippet:

Similarity: 0.8362308740615845, 
def _check_every_number_in_bitsum(bitsum, sum_signs):
    for val in bitsum:
        if val != 0 and val != sum_signs:
            return False
    return True 
--------------------------------------------------------------
Similarity: 0.7320685982704163, 
def test_get_bit(self):
    self.assertEqual(1, get_bit(22, 2))
    self.assertEqual(0, get_bit(22, 3)) 
--------------------------------------------------------------
Similarity: 0.653519868850708, 
def test_has_alternative_bit_fast(self):
    self.assertTrue(has_alternative_bit_fast(5))
    self.assertFalse(has_alternative_bit_fast(7))
    self.assertFalse(has_alternative_bit_fast(11))
    self.assertTrue(has_alternative_bit_fast(10)) 
--------------------------------------------------------------
Similarity: 0.649571418762207, 
def test_has_alternative_bit(self):
    self.assertTrue(has_alternative_bit(5))
    self.assertFalse(has_alternative_bit(7))
    self.assertFalse(has