# Create Vector Stores for each repository from Modified CodeSearchNet dataset

In [1]:
from pathlib import Path
from tqdm import tqdm
from transformers import logging
import warnings
import shutil
import pandas as pd
from langchain.text_splitter import Language
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import GitLoader
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import logging

logging.set_verbosity_error()
warnings.filterwarnings("ignore")

In [2]:
DATA_DIR = Path("../../data")

REPOS_DIR = DATA_DIR / "repos"
VECTOR_STORES_DIR = DATA_DIR / "vector-stores"
PREPROCESSED_DATA_DIR = DATA_DIR / "preprocessed"

DATASET = "mcsn"
file_path = PREPROCESSED_DATA_DIR / f"method-level-{DATASET}.jsonl"
df = pd.read_json(file_path, lines=True)
df.head(1)

Unnamed: 0,repo_name,method_name,method_code,method_summary,original_method_code,method_path
0,apache/airflow,HttpHook.run,"def run(self, endpoint, data=None, headers=Non...",Performs the request,"def run(self, endpoint, data=None, headers=Non...",airflow/hooks/http_hook.py


In [3]:
EMBEDDINGS = HuggingFaceBgeEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    model_kwargs={'device': 'cuda:2'},
    encode_kwargs={'normalize_embeddings': True},
)

In [4]:
repo_list = df["repo_name"].unique().tolist()
repo_list

['apache/airflow',
 'Azure/azure-sdk-for-python',
 'streamlink/streamlink',
 'open-mmlab/mmcv']

## Index repos

It is done only once

In [9]:
# # Load and index repositories
# for repo_name in tqdm(repo_list[1:]):
#     clone_url = f"https://github.com/{repo_name}.git"
#     repo_path = REPOS_DIR / repo_name.replace("/", "_")
#     persist_path = VECTOR_STORES_DIR / repo_name.replace("/", "_")
#     shutil.rmtree(repo_path, ignore_errors=True)
#     shutil.rmtree(persist_path, ignore_errors=True)
#     loader = GitLoader(
#         clone_url=None if repo_path.exists() else clone_url,
#         repo_path=repo_path.absolute(),
#         branch="master" if repo_name == "streamlink/streamlink" else "main",
#         file_filter=lambda file_path: file_path.endswith(".py"),
#     )
#     print(f"Loading {repo_name}...")
#     documents = loader.load()
#     print(f"Loaded {len(documents)} documents from {repo_name} into {repo_path}.")
#     python_splitter = RecursiveCharacterTextSplitter.from_language(
#         language=Language.PYTHON,
#         chunk_size=1000,
#         chunk_overlap=100,
#     )  # Splits code by definitions of classes and functions, then by lines
#     print("Splitting documents...")
#     chunks = python_splitter.split_documents(documents)
#     print(f"Splitted documents into {len(chunks)} chunks from {repo_name}.")
#     print(f"Persisting {repo_name}...")
#     vector_store = await FAISS.afrom_documents(
#         chunks, EMBEDDINGS
#     )
#     vector_store.save_local(persist_path)
#     print(f"Persisted {repo_name} into {persist_path}.\n")

In [10]:
# # See repo statistics
# for repo_name in tqdm(repo_list):
#     repo_path = REPOS_DIR / repo_name.replace("/", "_")
#     loader = GitLoader(
#         clone_url=None,
#         repo_path=repo_path.absolute(),
#         branch="master" if repo_name == "streamlink/streamlink" else "main",
#         file_filter=lambda file_path: file_path.endswith(".py"),
#     )
#     documents = loader.load()
#     python_splitter = RecursiveCharacterTextSplitter.from_language(
#         language=Language.PYTHON,
#         chunk_size=1000,
#         chunk_overlap=100,
#     )
#     chunks = python_splitter.split_documents(documents)
#     print(repo_name, len(documents), len(chunks))

## Load indexed repos and test retrive repo context

In [5]:
def load_vector_store(repo_name):
    return FAISS.load_local(
        VECTOR_STORES_DIR / repo_name.replace("/", "_"),
        EMBEDDINGS,
        allow_dangerous_deserialization=True,
    )


VECTOR_STORES = {repo_name: load_vector_store(repo_name) for repo_name in repo_list}

In [6]:
samples = [df[df["repo_name"] == repo_name].sample(1, random_state=42) for repo_name in repo_list]
tdf = pd.concat(samples, ignore_index=True)
tdf

Unnamed: 0,repo_name,method_name,method_code,method_summary,original_method_code,method_path
0,apache/airflow,DbApiHook.insert_rows,"def insert_rows(self, table, rows, target_fiel...",A generic way to insert a set of tuples into a...,"def insert_rows(self, table, rows, target_fiel...",airflow/hooks/dbapi_hook.py
1,Azure/azure-sdk-for-python,_MinidomXmlToObject.get_entry_properties_from_...,"def get_entry_properties_from_node(entry, incl...",get properties from entry xml,"def get_entry_properties_from_node(entry, incl...",azure-servicemanagement-legacy/azure/servicema...
2,streamlink/streamlink,format_time,def format_time(elapsed):\n hours = int(ela...,Formats elapsed seconds into a human readable ...,"def format_time(elapsed):\n """"""Formats elap...",src/streamlink_cli/utils/progress.py
3,open-mmlab/mmcv,frames2video,"def frames2video(frame_dir,\n ...",Read the frame images from a directory and joi...,"def frames2video(frame_dir,\n ...",mmcv/video/io.py


In [7]:
K = 70


def retrieve_repo_context(method_code, repo_name):
    context = VECTOR_STORES[repo_name].similarity_search(method_code, k=K)
    return "\n\n".join(
        [f"File path: {d.metadata['file_path']}\nFile content:\n```{d.page_content}```" for d in context]
    )

In [8]:
tqdm.pandas()
tdf["repo_context"] = tdf.progress_apply(
    lambda x: retrieve_repo_context(x.get("method_code"), x.get("repo_name")), axis=1
)
for c in tdf.repo_context:
    print(len(c))
    break

100%|██████████| 4/4 [00:04<00:00,  1.16s/it]

59036





In [14]:
model_names = [
    "deepseek-ai/deepseek-coder-1.3b-instruct",
    "deepseek-ai/deepseek-coder-6.7b-instruct",
    "deepseek-ai/deepseek-coder-33b-instruct",
    "bigcode/starcoder2-15b-instruct-v0.1",
    "gradientai/Llama-3-8B-Instruct-Gradient-1048k"
]

idx = 1
MODEL_NAME = model_names[idx]

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

In [20]:
%%time
for row in tdf.itertuples():
    print(row.repo_name)
    overall_context_length = 0
    relevance_scores = []
    context = VECTOR_STORES[row.repo_name].similarity_search_with_relevance_scores(row.method_code, k=70)
    # context = VECTOR_STORES[row.repo_name].max_marginal_relevance_search(
    #     row.method_code, k=K, fetch_k=K * 5)
    for c in context:
        if c[0].page_content not in row.original_method_code:
            overall_context_length += len(tokenizer.encode(c[0].page_content))
            relevance_scores.append(c[1])
    print(overall_context_length)
    # print(max(relevance_scores), sum(relevance_scores) / len(relevance_scores), min(relevance_scores))


apache/airflow
13779
Azure/azure-sdk-for-python
12557
streamlink/streamlink
14144
open-mmlab/mmcv
14821
CPU times: user 1min 21s, sys: 9.33 s, total: 1min 30s
Wall time: 11.1 s


In [51]:
dict(c[0])

{'page_content': 'else:\n                    lst.append(str(cell))\n            values = tuple(lst)\n            sql = f"INSERT /*+ APPEND */ INTO {table} {target_fields} VALUES ({\',\'.join(values)})"\n            cur.execute(sql)\n            if i % commit_every == 0:\n                conn.commit()  # type: ignore[attr-defined]\n                self.log.info("Loaded %s into %s rows so far", i, table)\n        conn.commit()  # type: ignore[attr-defined]\n        cur.close()\n        conn.close()  # type: ignore[attr-defined]\n        self.log.info("Done loading. Loaded a total of %s rows", i)',
 'metadata': {'source': 'airflow/providers/oracle/hooks/oracle.py',
  'file_path': 'airflow/providers/oracle/hooks/oracle.py',
  'file_name': 'oracle.py',
  'file_type': '.py'},
 'type': 'Document'}