### Create Module Graph of Repository

In [3]:
import pickle
import os
from collections import defaultdict
import networkx as nx
import json
from tqdm.auto import tqdm
import ast
import pathlib

modules_classes_file = 'dataset/module_imports_and_classes.json'



def get_all_imports(module_imports_data):
    all_imports = list(set(
        module_import for node in module_imports_data.values()\
            for module_imports in node.values() \
            for module_import in module_imports['imports'])
    )
    return all_imports


def extract_file_paths(module_imports_data):
    node_file_paths = set()
    for file_name, file_contents in module_imports_data.items():
        file_name = file_name.replace('/', '.').replace('.py', '')
        for node_name, _ in file_contents.items():
            node_file_path = f"{file_name}.{node_name}"
            node_file_paths.add(node_file_path)

    node_file_paths = list(node_file_paths)
    return node_file_paths


def get_imports_file_map(module_imports, node_file_paths):
    import_packages_map = defaultdict(list)
    for module_import in tqdm(module_imports, desc='Creating import file map'):
        for fp in node_file_paths:
            if fp.endswith(module_import) and fp.split('.')[-1] == module_import.split('.')[-1]:
                import_packages_map[module_import].append(fp)
    
    return import_packages_map


def get_used_aliases(tree, aliases):
    used_aliases = set()
    for node in ast.walk(tree):
        if isinstance(node, ast.Name):
            if node.id in aliases:
                import_name = aliases[node.id]
                used_aliases.add(import_name)
    return {u: [] for u in used_aliases}


def create_aliases_dict(tree):
    all_aliases = dict()
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                import_as_name = alias.asname if alias.asname else alias.name
                all_aliases[import_as_name] = alias.name

        elif isinstance(node, ast.ImportFrom):
            module_name = node.module
            for alias in node.names:
                import_as_name = alias.asname if alias.asname else alias.name
                all_aliases[import_as_name] = f"{module_name}.{alias.name}"
    return all_aliases


def parse_file(file_path):
    file_contents = open(file_path).read()
    tree = ast.parse(file_contents)
    all_aliases = create_aliases_dict(tree)
    all_nodes_contents = dict()
    visited = set()

    def get_function_details(function_node):
        function_imports = get_used_aliases(function_node, all_aliases)
        details = {
            'type': 'function',
            'imports': function_imports,
            'docstring': ast.get_docstring(function_node),
            'body': ast.get_source_segment(file_contents, function_node),
        }
        return details

    for node in ast.walk(tree):
        if isinstance(node, ast.ClassDef) and node not in visited:
            class_imports = get_used_aliases(node, all_aliases)
            all_nodes_contents[node.name] = {
                'type': 'class',
                'imports': class_imports,
                'docstring': ast.get_docstring(node),
                'body': ast.get_source_segment(file_contents, node),
                'functions': dict()
            }
            visited.add(node)
            for subnode in ast.walk(node):
                if isinstance(subnode, ast.FunctionDef) and subnode not in visited:
                    all_nodes_contents[node.name]['functions'][subnode.name] = get_function_details(subnode)
                    visited.add(subnode)

        elif isinstance(node, ast.FunctionDef) and node not in visited:
            all_nodes_contents[node.name] = get_function_details(node)
            visited.add(node)

    return all_nodes_contents


def parse_files_in_dir(dir_path):
    pkg_paths = pathlib.Path(dir_path).glob('**/*.py')
    modules = [str(p) for p in pkg_paths]
    all_module_imports = dict()
    for i, module in tqdm(enumerate(modules), total=len(modules), desc='Parsing files'):
        module_imports_and_classes = parse_file(module)
        all_module_imports[module] = module_imports_and_classes
        

    return all_module_imports

def get_init_module_paths(repository):
    pkg_paths = pathlib.Path(repository).glob('**/__init__.py')
    modules = [str(p).replace('/__init__', '') for p in pkg_paths]
    return modules


def get_module_to_file_imports(module_imports_data, init_files):
    module_node_imports = get_all_imports(module_imports_data)
    all_files = list(module_imports_data.keys())
    module_to_file = defaultdict(set)

    for module_imports_node in tqdm(module_node_imports, desc='Creating module to file map'):
        module_imports_node = module_imports_node.split('.')
        node, module_import = module_imports_node[-1], '.'.join(module_imports_node[:-1])
        
        if module_import == '':
            continue
        
        for file in all_files:
            file_name = file.replace('/', '.').replace('.py', '')
            file_nodes = module_imports_data[file]
            if file_name.endswith(module_import) and node in file_nodes:
                module_to_file[module_import].add(file)
        
        for file in init_files:
            file_name = file.replace('/', '.').replace('.py', '')
            if file_name.endswith(module_import):
                module_to_file[module_import].add(file)
        
        # if module_import == 'llama_index.experimental.param_tuner':
        #     print(module_to_file[module_import])
        #     break

    module_to_file = {k: list(v) for k, v in module_to_file.items()}
    with open('module_to_file.json', 'w') as f:
        json.dump(module_to_file, f, indent=4)
        
    return module_to_file


def add_files_to_module_imports(module_imports_data, module_to_file):
    for _, file_contents in module_imports_data.items():
        for _, node_content in file_contents.items():
            imports = node_content['imports']
            new_imports = dict()
            for module_import_node in imports:
                node_module_import = module_import_node.split('.')
                _, module_import = node_module_import[-1], '.'.join(node_module_import[:-1])
                new_imports[module_import_node] = module_to_file[module_import]\
                      if module_import in module_to_file else []
            node_content['imports'] = new_imports


def add_class_node_to_graph(graph, full_class_name, class_content):
    class_name = full_class_name.split('.')[-1]
    if not graph.has_node(full_class_name):
        graph.add_node(full_class_name, type='class')
        graph.nodes[full_class_name]['docstring'] = class_content['docstring']
        graph.nodes[full_class_name]['body'] = class_content['body']
        graph.nodes[full_class_name]['name'] = class_name
        

def add_function_node_to_graph(graph, full_function_name, function_content):
    function_name = full_function_name.split('.')[-1]
    if not graph.has_node(full_function_name):
        graph.add_node(full_function_name, type='function')
        graph.nodes[full_function_name]['docstring'] = function_content['docstring']
        graph.nodes[full_function_name]['body'] = function_content['body']
        graph.nodes[full_function_name]['name'] = function_name


def clean_file_name(name):
    f_name = name.replace('/', '.').replace('.py', '')
    f_name = f_name.replace('.__init__', '') if f_name.endswith('.__init__') else f_name
    return f_name


def create_graph_module_nodes(nxg, file_name):
    fp_units = file_name.split('.')

    ### Create module nodes

    for i in range(1, len(fp_units)):
        u = '.'.join(fp_units[:i])
        v = '.'.join(fp_units[:i + 1])

        if not nxg.has_node(u):
            nxg.add_node(u, name=fp_units[i-1], type='module')
        
        if not nxg.has_node(v):
            nxg.add_node(v, name=fp_units[i], type='module')

        if not nxg.has_edge(u, v):
            nxg.add_edge(u, v, type='module2module')


def create_graph_nodes(nxg, module_imports_data):
    for file_name, file_contents in tqdm(module_imports_data.items(), desc='Creating graph nodes'):
        f_name = clean_file_name(file_name)
        create_graph_module_nodes(nxg, f_name)

        ### Create class and function nodes

        for node_name, node_content in file_contents.items():
            full_node_name = f"{f_name}.{node_name}"
            if node_content['type'] == 'class':
                add_class_node_to_graph(nxg, full_node_name, node_content)
                
                for func_name, func_content in node_content['functions'].items():
                    full_func_name = f"{full_node_name}.{func_name}"
                    add_function_node_to_graph(nxg, full_func_name, func_content)
                    nxg.add_edge(full_node_name, full_func_name, type='class2function')

            elif node_content['type'] == 'function':
                add_function_node_to_graph(nxg, full_node_name, node_content)
            
            nxg.add_edge(f_name, full_node_name, type=f'module2{node_content["type"]}')
            
        

def create_graph_edges(nxg, module_imports_data, module_to_file):
    for file_name, file_contents in tqdm(module_imports_data.items(), desc='Creating graph edges'):
        f_name = clean_file_name(file_name)

        for node_name, node_content in file_contents.items():
            node_module = f"{f_name}.{node_name}"
            imports = node_content['imports']

            for module_imports_node, _ in imports.items():
                module_imports_node = module_imports_node.split('.')
                _, module_import = module_imports_node[-1], '.'.join(module_imports_node[:-1])
                if module_import == '' or module_import not in module_to_file:
                    continue

                module_file = module_to_file[module_import]
                module_file_names = [f.replace('/', '.').replace('.py', '') for f in module_file]
            
                for module_file_name in module_file_names:
                    assert nxg.has_node(node_module), f"Node not found: {node_module}, {file_name}"
                    assert nxg.has_node(module_file_name), f"Module not found: {module_file_name}"
                    nxg.add_edge(node_module, module_file_name, type='module2module')
                    
                    
def create_nxg(module_imports_data, module_to_file):
    nxg = nx.DiGraph()
    # print('Creating graph nodes')
    create_graph_nodes(nxg, module_imports_data)
    # print('Creating graph edges')
    create_graph_edges(nxg, module_imports_data, module_to_file)
    add_parent_attribute(nxg)
    return nxg


def add_parent_attribute(nxg):
    visited = set()
    def dfs_util(node):
        visited.add(node)
        for neighbour in nxg.neighbors(node):
            nxg.nodes[neighbour]['parent'] = node
            if neighbour not in visited:
                dfs_util(neighbour)
    
    for node in nxg.nodes:
        if node not in visited:
            dfs_util(node)

def create_module_graph(repository, f_name=modules_classes_file):
    if os.path.exists(f_name):
        with open(f_name, 'r') as f:
            module_imports_data = json.load(f)
            

    module_imports_data = parse_files_in_dir(repository)
    init_files = get_init_module_paths(repository)
    module_to_file = get_module_to_file_imports(module_imports_data, init_files)
    add_files_to_module_imports(module_imports_data, module_to_file)

    with open(f_name, 'w') as f:
        json.dump(module_imports_data, f, indent=4)
    nxg = create_nxg(module_imports_data, module_to_file)
    write_nxg(f'{repository}_module_graph.gpickle')
    
    return nxg

def load_nxg(repository):
    with open(f'{repository}_module_graph.gpickle', 'rb') as f:
        nxg = pickle.load(f)
    return nxg

def write_nxg(nxg, f_name):
    with open(f_name, 'wb') as f:
        pickle.dump(nxg, f, pickle.HIGHEST_PROTOCOL)

## Create a Module Graph

In [None]:
repository = 'llama_index_local'
nxg = create_module_graph(repository)

In [7]:
nxg = load_nxg(repository)

In [8]:
nxg.number_of_nodes(), nxg.number_of_edges()

(20604, 32824)

In [10]:
from collections import deque

def find_nodes_within_distance(graph, start_node, distance):
    q, visited = deque(), dict()
    q.append((start_node, 0))
    
    while q:
        n, d = q.popleft()
        if d <= distance:
            visited[n] = d
            neighbours = [neighbor for neighbor in graph.neighbors(n) if neighbor != n and neighbor not in visited]
            for neighbour in neighbours:
                if neighbour not in visited:
                    q.append((neighbour, d + 1))
    
    sorted_list = sorted(visited.items(), key=lambda x: x[1])
    return sorted_list

In [None]:
test_node = 'llama_index_local.llama-index-core.llama_index.core.selectors'
distance = 2
nodes = find_nodes_within_distance(nxg, test_node, distance)
print(f"Nodes within distance {distance} from {test_node}: {len(nodes)}")
for node in nodes:
    print(node)

#### Creating LLM-Generated Docstrings

#### Configure LLMs

In [81]:
import openai
import streamlit as st
import requests
from prompt_templates.prompts import SYSTEM_PROMPT


openai_apikey = st.secrets["OPENAI_API_KEY"]
hf_api_key = st.secrets['HUGGINGFACEHUB_API_TOKEN']
any_scale_api_key = st.secrets['ANY_SCALE_API_TOKEN']


def get_llm_response(client, model_name, prompt, system_prompt):
    chat_completion = client.chat.completions.create(
        model=f"{model_name}",
        messages=[{"role": "system", "content": f"{system_prompt}"},
                {"role": "user", "content": prompt}],
        temperature=0.7
    )

    try:
        response = chat_completion.choices[0].message.content
    except Exception as e:
        response = "Error while generating summary"
        print(e)
        print(prompt)

    return response


def get_any_scale_response(mode_name, user_prompt, system_prompt):
    client = openai.OpenAI(
        base_url = "https://api.endpoints.anyscale.com/v1",
        api_key=f"{any_scale_api_key}"
    )
    response = get_llm_response(client, mode_name, user_prompt, system_prompt)
    return response


def get_gpt_response(model_name, user_prompt, system_prompt):
    client = openai.OpenAI(api_key=f"{openai_apikey}")
    response = get_llm_response(client, model_name, user_prompt, system_prompt)
    return response


def get_hf_response(model_name, prompt):
    headers = {"Authorization": f"Bearer {hf_api_key}"}
    API_URL = f"https://api-inference.huggingface.co/models/{model_name}"
    summary = requests.post(API_URL, headers=headers, json=prompt)
    return summary.json()


class LLM():
    def __init__(self, llm_config):
        self.model_name = llm_config['model_id']
        self.model_type = llm_config['type']

    def get_response(self, prompt, system_prompt=SYSTEM_PROMPT):
        if self.model_type == 'openai':
            summary = get_gpt_response(self.model_name, prompt, system_prompt)
        elif self.model_type == 'hf':
            summary = get_hf_response(self.model_name, prompt, system_prompt)
        elif self.model_type == 'anyscale':
            summary = get_any_scale_response(self.model_name, prompt, system_prompt)
        else:
            raise NotImplementedError
        return summary

#### Generate Summaries

In [85]:
import os
from prompt_templates.prompts import (
    FUNC_SUMMARIZATION_PROMPT, 
    COMBINE_FUNCTION_SUMMARIZATION_PROMPT,
    CLASS_SUMMARIZATION_PROMPT,
    COMBINE_CLASS_SUMMARIZATION_PROMPT,
    COMBINE_MODULE_SUMMARIZATION_PROMPT
)
from langchain_text_splitters import (
    Language,
    RecursiveCharacterTextSplitter,
)

TOKEN_LIMIT = 50000
CHUNK_SIZE = 20000
CHUNK_OVERLAP = 2000
python_splitter = RecursiveCharacterTextSplitter.from_language(
    language=Language.PYTHON, chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP)

summaries_dir = 'summaries'

def summarize_docs(llm, document, node_prompt, combine_prompt):
    # print("Summarizing document:", document)
    docs = python_splitter.create_documents([document])
    split_summaries = list()
    for doc in docs:
        prompt = f"{node_prompt}\n\n{doc.page_content}"
        summary = llm.get_response(prompt)
        # summary = f"Summary of: {prompt}"
        split_summaries.append(summary)
    
    summaries = "\n".join(split_summaries)
    combine_summary_prompt = f"{combine_prompt}\n\n Summaries: \n\n{summaries}"
    response = llm.get_response(combine_summary_prompt)
    # response = f"Summary of: {combine_summary_prompt}"
    return response


def get_prompt_response(llm, prompt, node_prompt, combine_prompt):
    num_tokens = len(prompt.split())
    if num_tokens >= TOKEN_LIMIT:
        response = summarize_docs(llm, prompt, node_prompt, combine_prompt)
    else:
        prompt = f"{node_prompt}\n\n{prompt}"
        response = llm.get_response(prompt)
    return response


def summarize_code_node(llm, node, nxg):
    # print("Summarizing code node: ", node)
    node_type = nxg.nodes[node]['type']
    body = nxg.nodes[node]['body']
    node_prompt = FUNC_SUMMARIZATION_PROMPT \
        if node_type == 'function' else CLASS_SUMMARIZATION_PROMPT
    combine_prompt = COMBINE_FUNCTION_SUMMARIZATION_PROMPT \
        if node_type == 'function' else COMBINE_CLASS_SUMMARIZATION_PROMPT

    summary = get_prompt_response(llm, body, node_prompt, combine_prompt)
    nxg.nodes[node]['summary'] = summary
    return summary


def summarize_module_node(llm, node, nxg):
    # print("Summarising module: ", node)
    node_summaries = list()
    for neighbour in nxg.neighbors(node):
        summary = summarize_node(llm, neighbour, nxg)
        node_summaries.append((neighbour, summary))
    
    combined_summary = "\n".join([f"{n}:\n{s}" for n, s in node_summaries])
    node_prompt = COMBINE_MODULE_SUMMARIZATION_PROMPT
    combine_prompt = COMBINE_MODULE_SUMMARIZATION_PROMPT
    summary = get_prompt_response(llm, combined_summary, node_prompt, combine_prompt)

    nxg.nodes[node]['summary'] = summary
    return summary


def summarize_node(llm, node, nxg):

    os.makedirs(summaries_dir, exist_ok=True)
    if os.path.exists(f"{summaries_dir}/{node}.txt"):
        with open(f"{summaries_dir}/{node}.txt", 'r') as f:
            return f.read()

    # print("Summarising node: ", node)

    node_type = nxg.nodes[node]['type']
    summary = summarize_module_node(llm, node, nxg) if node_type == 'module' else summarize_code_node(llm, node, nxg)

    with open(f"{summaries_dir}/{node}.txt", 'w') as f:
        f.write(summary)

In [75]:
import json
llms = json.load(open('llms.json', 'r'))

In [None]:
from tqdm.auto import tqdm
nxg = load_nxg(repository)

llm = LLM(llms['mistral8x7b'])
for node in tqdm(nxg.nodes, total=nxg.number_of_nodes()):
    summarize_node(llm, node, nxg)
    break

In [None]:
code = \
"""
def load_document(uploaded_files: List[UploadedFile]) -> List[Document]:
    # Read documents
    temp_dir = tempfile.TemporaryDirectory()
    for file in uploaded_files:
        temp_filepath = os.path.join(temp_dir.name, file.name)
        with open(temp_filepath, "wb") as f:
            f.write(file.getvalue())

    reader = SimpleDirectoryReader(input_dir=temp_dir.name)
    return reader.load_data()
"""