In [19]:
from concurrent.futures import ThreadPoolExecutor

In [1]:
from api_models import set_llm_and_embed
set_llm_and_embed()



In [3]:
import os


base_dir = 'data_repos/ftlr/datasets/eTour'
all_code_files_path = os.path.join(base_dir, 'all_code_filenames.txt')
all_code_files = [f_name.strip() for f_name in open(all_code_files_path)]

all_req_files_path = os.path.join(base_dir, 'all_req_filenames.txt')


all_req_contents = {f_name.strip(): open(os.path.join(base_dir, 'req', f_name.strip())).read() for f_name in open(all_req_files_path)}
# print(all_req_contents[list(all_req_contents.keys())[0]])

In [6]:
import json
import javalang

from tqdm.auto import tqdm

FILE_NAME_LABEL = "File Name"
CLASS_NAME_LABEL = "Class Name"
DOCSTRING_LABEL = "Docstring"
ATTRIBUTES_LABEL = "Attributes"
ATTRIBUTE_NAME_LABEL = "Attribute Name"
ATTRIBUTES_TYPE_LABEL = "Attribute Type"
METHODS_LABEL = "Methods"
METHOD_NAME_LABEL = "Method Name"
METHOD_SIGNATURE_LABEL = "Signature"
METHOD_DOCSTRING = "Method Docstring"
METHOD_PARAMETERS_LABEL = "Method Parameters"
METHOD_RETURN_LABEL = "Method Return"
PARAM_NAME_LABEL = "Parameter Name"
PARAM_TYPE_LABEL = "Parameter Type"
PARAM_DESCRIPTION_LABEL = "Description"
CALLS_LABEL = "calls"
CALLED_BY_LABEL = "called_by"


def parse_java_file(file_path):
    with open(file_path, 'r', encoding='utf-8-sig') as file:
        content = file.read()

    # Parse the file into an AST
    try:
        tree = javalang.parse.parse(content)
    except javalang.parser.JavaSyntaxError as e:
        # raise e
        return {
            FILE_NAME_LABEL: file_path.split('/')[-1],
            CLASS_NAME_LABEL: file_path.split('/')[-1].split('.')[0],
        }
        

    # Find the main class declaration
    class_decl = next(
        (type_decl for type_decl in tree.types \
         if isinstance(type_decl, javalang.tree.ClassDeclaration) or \
            isinstance(type_decl, javalang.tree.InterfaceDeclaration)), None
    )

    class_info = {
        CLASS_NAME_LABEL: class_decl.name,
        DOCSTRING_LABEL: class_decl.documentation if class_decl.documentation else "",
        ATTRIBUTES_LABEL: [],
        METHODS_LABEL: [],
        FILE_NAME_LABEL: file_path.split('/')[-1]
    }

    # Extract attributes and methods
    for _, node in class_decl.filter(javalang.tree.FieldDeclaration):
        for field in node.declarators:
            class_info[ATTRIBUTES_LABEL].append({
                ATTRIBUTE_NAME_LABEL: field.name,
                ATTRIBUTES_TYPE_LABEL: node.type.name,
            })

    for _, node in class_decl.filter(javalang.tree.MethodDeclaration):
        method_info = {
            METHOD_NAME_LABEL: node.name,
            METHOD_DOCSTRING: node.documentation if node.documentation else None,
            METHOD_PARAMETERS_LABEL: [],
            METHOD_RETURN_LABEL: node.return_type.name if node.return_type else "void"
        }
        for param in node.parameters:
            method_info[METHOD_PARAMETERS_LABEL].append({
                PARAM_NAME_LABEL: param.name,
                PARAM_TYPE_LABEL: param.type.name,
            })
        class_info[METHODS_LABEL].append(method_info)

    return class_info

class_info_objects = list()
for f_name in tqdm(all_code_files):
    class_info = parse_java_file(os.path.join(base_dir, 'code', f_name))
    class_info_objects.append(class_info)

  0%|          | 0/116 [00:00<?, ?it/s]

In [8]:
def get_graph_nodes(class_info_objects):
    graph_nodes = dict()
    for class_info in class_info_objects:
        node = {
            CLASS_NAME_LABEL: class_info[CLASS_NAME_LABEL],
            FILE_NAME_LABEL: class_info[FILE_NAME_LABEL],
            "type": "Class"
        }
        if DOCSTRING_LABEL in class_info:
            node[DOCSTRING_LABEL] = class_info[DOCSTRING_LABEL]
        if ATTRIBUTES_LABEL in class_info:
            node[ATTRIBUTES_LABEL] = class_info[ATTRIBUTES_LABEL]
        graph_nodes[class_info[CLASS_NAME_LABEL]] = node

        class_name = class_info[CLASS_NAME_LABEL]
        if METHODS_LABEL not in class_info:
            continue
        for method_info in class_info[METHODS_LABEL]:
            method_name = method_info[METHOD_NAME_LABEL]
            method_key = f'{class_name}.{method_name}'
            params_str = f"({','.join([param[PARAM_TYPE_LABEL] for param in method_info[METHOD_PARAMETERS_LABEL]])})"
            method_key_str = f'{method_key}{params_str}'

            node = {
                CLASS_NAME_LABEL: class_name,
                FILE_NAME_LABEL: class_info[FILE_NAME_LABEL],
                METHOD_NAME_LABEL: method_info[METHOD_NAME_LABEL],
                "type": "Method",
                METHOD_SIGNATURE_LABEL: method_key_str
            }

            if METHOD_DOCSTRING in method_info:
                node[METHOD_DOCSTRING] = method_info[METHOD_DOCSTRING]
            
            graph_nodes[method_key_str] = node

    print("Number of nodes in the graph:", len(graph_nodes))
    return graph_nodes

graph_nodes = get_graph_nodes(class_info_objects)


Number of nodes in the graph: 1060


In [11]:
def add_method_calls(graph_nodes, cdg_file_path):
    cdg = json.load(open(cdg_file_path))
    present, absent = 0, 0
    for node in list(graph_nodes.values()):
        if node["type"] == "Method":
            method_key = node[METHOD_SIGNATURE_LABEL]
            if method_key in cdg:
                graph_nodes[method_key][CALLS_LABEL] = cdg[method_key]["calls"]
                graph_nodes[method_key][CALLED_BY_LABEL] = cdg[method_key]["called_by"]

                for node_call in cdg[method_key]["calls"] + cdg[method_key]["called_by"]:
                    if node_call not in cdg:
                        absent += 1
                        print(f"Node {node_call} not found")
                    else:
                        present += 1
                        cdg_call_node = cdg[node_call]
                        if node_call not in graph_nodes:
                            graph_nodes[node_call] = {
                                "type": "Method",
                                METHOD_SIGNATURE_LABEL: node_call,
                                CLASS_NAME_LABEL: cdg_call_node['class_name'],
                                METHOD_NAME_LABEL: cdg_call_node['method_name'],
                                FILE_NAME_LABEL: node[FILE_NAME_LABEL],
                            }
                        graph_nodes[node_call][CALLED_BY_LABEL] = cdg_call_node["called_by"]
                        graph_nodes[node_call][CALLS_LABEL] = cdg_call_node["calls"]

    print(f"Present: {present}, Absent: {absent}")
    print("Number of nodes in the graph:", len(graph_nodes))

cdg_path = os.path.join(base_dir, 'etour_method_callgraph.json')
add_method_calls(graph_nodes, cdg_path)

Present: 1704, Absent: 0
Number of nodes in the graph: 1143


In [9]:
from llama_index.core import Document
from tqdm.auto import tqdm

def create_class_node_doc(graph_node):
    content = f"{graph_node['type']} Name: " + graph_node[CLASS_NAME_LABEL] + "\n"
    
    if ATTRIBUTES_LABEL in graph_node and len(graph_node[ATTRIBUTES_LABEL]):
        content += f"Attributes: \n"
        for attr in graph_node[ATTRIBUTES_LABEL]:
            content += f"{attr[ATTRIBUTE_NAME_LABEL]}: {attr[ATTRIBUTES_TYPE_LABEL]}\n"
    
    if DOCSTRING_LABEL in graph_node:
        content += f"\n{graph_node[DOCSTRING_LABEL]}\n"
    
    doc = Document(
        text=content,
        metadata = {
            FILE_NAME_LABEL: graph_node[FILE_NAME_LABEL],
            "type": "Class"
        },
        excluded_embed_metadata_keys=["type"],
        excluded_llm_metadata_keys=["type"]
    )
    return doc


def create_method_node_doc(graph_node, show_calls=False):
    content = f"Class Name: {graph_node[CLASS_NAME_LABEL]}\n"
    content += f"{graph_node['type']} Name: {graph_node[METHOD_NAME_LABEL]}\n"
    content += f"Signature: {graph_node[METHOD_SIGNATURE_LABEL]}\n"
    
    
    if METHOD_DOCSTRING in graph_node:
        content += f"\n{graph_node[METHOD_DOCSTRING]}\n"
    
    if show_calls:
        if CALLS_LABEL in graph_node:
            content += f"\nCalls: \n"
            for call in graph_node[CALLS_LABEL]:
                content += f"{call}\n"
        
        if CALLED_BY_LABEL in graph_node:
            content += f"\nCalled By: \n"
            for called_by in graph_node[CALLED_BY_LABEL]:
                content += f"{called_by}\n"
    
    doc = Document(
        text=content,
        metadata = {
            FILE_NAME_LABEL: graph_node[FILE_NAME_LABEL],
            CALLS_LABEL: ", ".join(graph_node[CALLS_LABEL]) if CALLS_LABEL in graph_node else None,
            CALLED_BY_LABEL: ", ".join(graph_node[CALLED_BY_LABEL]) if CALLED_BY_LABEL in graph_node else None,
            "type": "Method",
            METHOD_SIGNATURE_LABEL: graph_node[METHOD_SIGNATURE_LABEL],
        },
        excluded_embed_metadata_keys=[METHOD_SIGNATURE_LABEL, CALLS_LABEL, CALLED_BY_LABEL, "type"],
        excluded_llm_metadata_keys=[METHOD_SIGNATURE_LABEL, CALLS_LABEL, CALLED_BY_LABEL, "type"],
    )
    return doc

docs = [
    create_class_node_doc(graph_node) \
    if graph_node["type"] == "Class" else create_method_node_doc(graph_node) \
    for graph_node in tqdm(graph_nodes.values(), desc="Creating Documents")
]

Creating Documents:   0%|          | 0/1060 [00:00<?, ?it/s]

In [None]:
from llama_index.core.indices import KnowledgeGraphIndex
import kuzu
from llama_index.graph_stores.kuzu import KuzuGraphStore
from llama_index.core import StorageContext


def get_kuzu_graph_store(collection_name):
    db = kuzu.Database(collection_name)
    graph_store = KuzuGraphStore(db)

    storage_context = StorageContext.from_defaults(graph_store=graph_store)
    return storage_context

def create_kg_index(docs, storage_context=None):
    class_docs = [doc for doc in docs if doc.metadata["type"] == "Class"]
    method_docs = [doc for doc in docs if doc.metadata["type"] == "Method"]

    print("Number of Class Nodes:", len(class_docs))
    index = KnowledgeGraphIndex.from_documents(
        class_docs,
        max_triplets_per_chunk=5,
        show_progress=True,
        storage_context=storage_context
    )

    for doc in tqdm(method_docs, desc="Adding Method Node triples"):
        calls = doc.metadata.get(CALLS_LABEL, [])
        called_by = doc.metadata.get(CALLED_BY_LABEL, [])

        calls = calls.split(", ") if isinstance(calls, str) else []
        called_by = called_by.split(", ") if isinstance(called_by, str) else []
        for call in calls:
            triple = (doc.metadata[METHOD_SIGNATURE_LABEL], "calls", call)
            index.upsert_triplet_and_node(triple, doc)
        for called_by_node in called_by:
            triple = (doc.metadata[METHOD_SIGNATURE_LABEL], "called_by", called_by_node)
            index.upsert_triplet_and_node(triple, doc)
    return index

kuzu_store = get_kuzu_graph_store("etour_kg")
kg_index = create_kg_index(docs)

In [13]:
from indexing.utils import get_parser
parser = get_parser()

In [19]:
# code_nodes = code_parser.create_nodes_with_fltr_setting(use_docstring=False)


Creating Requirement Docs:   0%|          | 0/58 [00:00<?, ?it/s]

58

In [10]:
import chromadb
from llama_index.core import StorageContext
from llama_index.vector_stores.chroma import ChromaVectorStore


def get_vector_storage_context(chroma_db_path, collection_name):
    db = chromadb.PersistentClient(path=f"{chroma_db_path}")
    chroma_collection = db.get_or_create_collection(collection_name)
    vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
    storage_context = StorageContext.from_defaults(vector_store=vector_store)
    return storage_context, vector_store

In [11]:
dataset_name = base_dir.split('/')[-1]
indices_dir = 'indices'
storage_context, vector_store = get_vector_storage_context(
    f"{indices_dir}/{dataset_name}", f"{dataset_name}"
)

In [25]:
from indexing.utils import run_pipeline_multithreaded

# transformations = get_transformations(
#     llm=llm,
#     summary_extractor=True,
#     summary_template=summary_template,
# )

nodes = run_pipeline_multithreaded(docs, num_threads=4, show_progress=True)

In [None]:
from indexing.indices import create_vector_index

indexed_nodes, vector_index = create_vector_index(
    nodes=docs[:1], 
    storage_context=storage_context, 
    num_threads=8,
    pickle_dir=f'indices/embedded_{dataset_name}',
    show_progress=True
)

In [None]:
import pickle


vector_index = pickle.load(open('vector_index.pkl', 'rb'))

In [None]:
kg_index = pickle.load(open('kg_index.pkl', 'rb'))
kg_qe = kg_index.as_query_engine()

In [50]:
vi_qe = vector_index.as_query_engine()

In [54]:
keyword_index= ''
ki_qe = keyword_index.as_query_engine()

In [46]:
from retrievers import custom_query_engine

custom_kg_qe = custom_query_engine(vector_index, kg_index)

In [47]:
custom_ki_qe = custom_query_engine(vector_index, keyword_index)

In [59]:
query_engines = {
    # "Vector Index": vi_qe,
    "Knowledge Graph Index": kg_qe,
    "Keyword Index": ki_qe,
    # "Custom KG": custom_kg_qe,
    # "Custom Keyword": custom_ki_qe
}

In [17]:
import pickle
with open('keyword_index.pkl', 'wb') as f:
    pickle.dump(keyword_index, f)

In [15]:
query_template = \
"""
What are the names of the classes that are related to the following use case?
{requirement}

Provide the answer in a list format and provide ONLY the list of class names as a JSON list.
[<"Class 1 Name">, <"Class 2 Name">, ... <"Class N Name">] where N can be up to 10.
"""

In [16]:
class RequirementsNodesCreator:
    def __init__(self, base_dir: str, req_contents: dict[str, str]):
        self.base_dir = base_dir
        self.req_contents = req_contents
        

    def create_nodes(self):
        docs = list()
        for file_name, content in tqdm(self.req_contents.items(), desc="Creating Requirement Docs"):
            doc = Document(
                text=content,
                metadata={
                    "file_name": file_name
                }
            )
            docs += [doc]
        
        return docs

req_parser = RequirementsNodesCreator(base_dir, all_req_contents)
req_nodes = req_parser.create_nodes()
len(req_nodes)

Creating Requirement Docs:   0%|          | 0/58 [00:00<?, ?it/s]

58

In [48]:
from collections import defaultdict
from typing import List
from llama_index.core.query_engine import BaseQueryEngine

def query_parallel(
        query_engine: BaseQueryEngine,
        query_template: str, 
        req_nodes: List[Document], 
        num_threads=8
    ):
    progress_bar = tqdm(total=len(req_nodes), desc="Processing", unit="Requirement")
    futures = list()
    with ThreadPoolExecutor(max_workers=num_threads) as executor:
        for req_node in req_nodes:
            future = executor.submit(
                query_engine.query,
                query_template.format(requirement=req_node.text)
            )
            futures.append((req_node.metadata["file_name"], future))

        results = list()
        for file_name, future in futures:
            results.append((file_name, future.result()))
            progress_bar.update(1)
    return results

In [27]:
def get_post_processing_results(req_results):
    ok = True
    llm_results = defaultdict(set)
    for file_name, result in req_results:
        try:
            class_names_list = json.loads(result.response)
            for class_name in class_names_list:
                llm_results[file_name].add(class_name)
        except Exception as e:
            print(file_name, result.response)
            ok = False

    for k, value in llm_results.items():
        llm_results[k] = list(value)
    
    if not ok:
        print("Some results are invalid")
    else:
        print("All results are valid")

    return llm_results

UC14.txt Empty Response
Some results are invalid


In [None]:
# responses = list()
# for req_node in tqdm(req_nodes, desc="Querying Requirements"):
#     query = query_template.format(requirement=req_node.text)
#     results = query_engine.query(query)
    
#     print(f"Query: {query}")
#     print(f"Results: {results}")
#     print("=========================================")
#     responses.append((query, results))
#     break

In [56]:
from collections import defaultdict

def get_solutions(file_name):
    gts = [line for line in open(file_name).read().split('\n') if line]
    solutions = defaultdict(list)
    for gt in gts:
        gt_split = gt.split(': ')
        file_name = gt_split[0]
        class_name = gt_split[1].split('.java')[0]
        solutions[file_name].append(class_name)
    return solutions

solutions = get_solutions('data_repos/ftlr/datasets/eTour/etour_solution_links_english.txt')

In [57]:
def compare_solutions(solutions, llm_results, result_file_name='results.json'):
    results = list()
    tp, fp = 0, 0
    tn, fn = 0, 0
    
    for file_name, classes in solutions.items():
        if file_name in llm_results:
            tp += len(set(classes).intersection(set(llm_results[file_name])))
            fp += len(set(llm_results[file_name]) - set(classes))
            fn += len(set(classes) - set(llm_results[file_name]))
            
            result = {
                "file_name": file_name,
                "expected_classes": sorted(classes),
                "llm_classes": sorted(llm_results[file_name])
            }
            results.append(result)
    with open(result_file_name, 'w') as f:
        json.dump(results, f, indent=4)
    
    precision = tp / (tp + fp) if tp + fp > 0 else 0
    recall = tp / (tp + fn) if tp + fn > 0 else 0
    f1 = 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0
    print("Precision: ", precision)
    print("Recall: ", recall)
    print("F1 Score: ", f1)

In [None]:
for config, query_engine in query_engines.items():
    print(f"Evaluating for {config}")
    req_results = query_parallel(
        query_engine, 
        query_template, 
        req_nodes, 
        num_threads=8
    )
    llm_results = get_post_processing_results(req_results)
    print(f"Results for {config}")
    compare_solutions(solutions, llm_results)