### Extract Code Structure

In [1]:
import ast
from collections import defaultdict
from typing import List, Dict, Tuple

def extract_code_structure(code: str, max_chars: int = 2000) -> Tuple[List[str], Dict[str, List[str]], Dict[str, List[str]]]:
    code_chunks = []
    call_graph = defaultdict(list)
    class_hierarchy = defaultdict(list)

    try:
        tree = ast.parse(code)
        for node in ast.walk(tree):

            if isinstance(node, ast.FunctionDef):
                func_code = ast.get_source_segment(code, node)
                if func_code and len(func_code) <= max_chars:
                    code_chunks.append((node.name, func_code))
                for sub in ast.walk(node):
                    if isinstance(sub, ast.Call) and isinstance(sub.func, ast.Name):
                        call_graph[node.name].append(sub.func.id)

            elif isinstance(node, ast.ClassDef):
                for item in node.body:
                    if isinstance(item, ast.FunctionDef):
                        class_hierarchy[node.name].append(item.name)
                        method_code = ast.get_source_segment(code, item)
                        if method_code and len(method_code) <= max_chars:
                            code_chunks.append((item.name, method_code))
    except:
        pass

    return code_chunks, dict(call_graph), dict(class_hierarchy)

In [2]:
import ast

def extract_functions_from_code(code: str):
    try:
        tree = ast.parse(code)
        return {
            node.name: ast.get_source_segment(code, node)
            for node in ast.walk(tree)
            if isinstance(node, ast.FunctionDef)
        }
    except:
        return {}

def extract_call_graph(code: str):
    call_graph = {}
    try:
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.FunctionDef):
                callers = []
                for child in ast.walk(node):
                    if isinstance(child, ast.Call) and hasattr(child.func, "id"):
                        callers.append(child.func.id)
                call_graph[node.name] = list(set(callers))
    except:
        pass
    return call_graph

def extract_class_hierarchy(code: str):
    class_map = {}
    try:
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.ClassDef):
                methods = [n.name for n in node.body if isinstance(n, ast.FunctionDef)]
                class_map[node.name] = methods
    except:
        pass
    return class_map


### Summarization Utils

In [3]:
from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import torch

# Setup device
device = 0 if torch.cuda.is_available() else -1

# ========== FUNCTION-LEVEL MODEL: CodeT5 ========== #
tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-sum")
model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-sum")

func_summarizer = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    device=device,
    batch_size=8,
)

# ========== EMBEDDING MODEL: SBERT ========== #
embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# ========== FILE/REPO-LEVEL MODEL: Longformer (LED) ========== #
file_summarizer = pipeline(
    "summarization",
    model="allenai/led-base-16384",
    tokenizer="allenai/led-base-16384",
    device=device,
    truncation=True,
    max_length=128,
    min_length=64
)

# ========== GRAPH-AWARE PROMPT BUILDER ========== #
def format_graph_context(func_summaries, call_graph=None, class_hierarchy=None):
    """
    Format function summaries and graph information into a readable structured input.
    """
    lines = ["You are summarizing a Python module."]

    if func_summaries:
        lines.append("Function Summaries:")
        for name, summary in func_summaries.items():
            lines.append(f"- {name}: {summary}")

    if call_graph:
        lines.append("\nCall Graph:")
        for caller, callees in call_graph.items():
            if callees:
                lines.append(f"- {caller} → {', '.join(callees)}")

    if class_hierarchy:
        lines.append("\nClass Hierarchy:")
        for cls, methods in class_hierarchy.items():
            lines.append(f"- {cls}: [{', '.join(methods)}]")

    return "\n".join(lines)


Device set to use cpu
Device set to use cpu


### Graph Ranking

In [4]:
import torch
from sentence_transformers import util

def rank_by_graph_and_embedding(summaries: Dict[str, str], call_graph: Dict[str, List[str]], top_k: int = 5):
    # Graph degree as rough importance
    graph_score = {name: len(callees) for name, callees in call_graph.items()}
    
    # Embedding richness
    texts = list(summaries.values())
    names = list(summaries.keys())
    embeddings = embed_model.encode(texts, convert_to_tensor=True)
    norms = embeddings.norm(dim=1).cpu().tolist()

    # Combine scores
    combined = []
    for i, name in enumerate(names):
        score = 0.5 * graph_score.get(name, 0) + 0.5 * norms[i]
        combined.append((name, summaries[name], score))

    # Select top-k
    combined.sort(key=lambda x: x[2], reverse=True)
    return [summary for _, summary, _ in combined[:top_k]]


### Summarize File

In [5]:
# def summarize_file_with_graph(code: str, top_k: int = 5):
#     chunks, call_graph, _ = extract_code_structure(code)
#     if not chunks:
#         return "No valid chunks"

#     names, funcs = zip(*chunks)
#     results = func_summarizer(list(funcs), max_length=64, truncation=True, do_sample=False)
#     summaries = {name: res["generated_text"].strip() for name, res in zip(names, results)}

#     top_summaries = rank_by_graph_and_embedding(summaries, call_graph, top_k)
#     return file_summarizer(" ".join(top_summaries), max_length=128, min_length=64, do_sample=False)[0]["summary_text"]

def summarize_file_with_graph(code_text: str, top_k: int = 5):
    """
    Summarize a file by:
    - Extracting function-level summaries using CodeT5
    - Building call/class hierarchy (graph_utils)
    - Creating structured prompt for LED summarizer
    """
    functions = extract_functions_from_code(code_text)
    if not functions:
        return "No functions found."

    func_names = list(functions.keys())
    func_bodies = list(functions.values())

    # Summarize functions
    func_summaries_raw = func_summarizer(func_bodies, max_length=64, do_sample=False)
    func_summaries = {
        func_names[i]: func_summaries_raw[i]["generated_text"].strip()
        for i in range(len(func_names))
    }

    # Select top-k longest functions as proxy for importance
    top_funcs = sorted(func_summaries.items(), key=lambda x: len(functions[x[0]]), reverse=True)[:top_k]
    top_func_summaries = {k: v for k, v in top_funcs}

    # Build graph context
    call_graph = extract_call_graph(code_text)
    class_hierarchy = extract_class_hierarchy(code_text)
    input_text = format_graph_context(top_func_summaries, call_graph, class_hierarchy)

    # Summarize with LED
    summary = file_summarizer(
        input_text,
        max_length=128,
        min_length=64,
        no_repeat_ngram_size=3,
        do_sample=False,
    )[0]["summary_text"]

    return summary



### Summarize Repo

In [10]:
# def summarize_repo_with_graph(file_contents: Dict[str, str], top_files: int = 5, top_k_funcs: int = 5):
#     file_summaries = []
#     for file_path, code in list(file_contents.items())[:top_files]:
#         try:
#             summary = summarize_file_with_graph(code, top_k=top_k_funcs)
#             file_summaries.append(summary)
#         except Exception:
#             continue

#     if not file_summaries:
#         return "No summary generated"

#     repo_input = " ".join(" ".join(s.split()[:1024]) for s in file_summaries)
#     return file_summarizer(repo_input, max_length=256, min_length=100, do_sample=False)[0]["summary_text"]


def summarize_repo_with_graph(file_dict: dict, top_files=5, top_k_funcs=5):
    """
    Summarize a repository:
    - Summarize each file using summarize_file_with_graph
    - Combine top-k summaries
    - Feed to LED summarizer
    """
    file_summaries = []

    for file_path, code_text in list(file_dict.items())[:top_files]:
        try:
            summary = summarize_file_with_graph(code_text, top_k=top_k_funcs)
            file_summaries.append(summary)
        except Exception as e:
            print(f"Skipped file {file_path} due to: {e}")

    if not file_summaries:
        return "No valid summaries found."

    combined_input = "\n\n".join(file_summaries)
    final_summary = file_summarizer(
        combined_input,
        max_length=256,
        min_length=100,
        no_repeat_ngram_size=3,
        do_sample=False,
    )[0]["summary_text"]

    return final_summary


### Evaluate Unsupervised

In [7]:
import evaluate
from sentence_transformers import util
from sentence_transformers import SentenceTransformer

embed_model = SentenceTransformer("all-MiniLM-L6-v2")

# Load HF metrics once
rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")

def cosine_similarity(preds, refs, batch_size=2):
    pred_embeds = embed_model.encode(preds, convert_to_tensor=True, batch_size=batch_size)
    ref_embeds = embed_model.encode(refs, convert_to_tensor=True, batch_size=batch_size)
    return util.cos_sim(pred_embeds, ref_embeds).diag().mean().item()

# ---------------- FUNCTION LEVEL ---------------- #
def evaluate_function_level(preds: list, refs: list):
    rouge_scores = rouge.compute(predictions=preds, references=refs)
    bert = bertscore.compute(predictions=preds, references=refs, lang="en")
    bert_f1 = sum(bert["f1"]) / len(bert["f1"])

    print("\nFunction-Level ROUGE:")
    for metric in ["rouge1", "rouge2", "rougeL", "rougeLsum"]:
        print(f"{metric}: {rouge_scores[metric]:.4f}")
    
    print(f"Function-Level BERTScore: {bert_f1:.4f}")
    return rouge_scores, bert_f1

# ---------------- FILE/REPO LEVEL ---------------- #
def evaluate_unsupervised_level(pred: str, ref: str, label: str = "Repo"):
    bert = bertscore.compute(predictions=[pred], references=[ref], lang="en")["f1"][0]
    cos = cosine_similarity([pred], [ref])

    print(f"\n{label}-Level Evaluation:")
    print(f"BERTScore: {bert:.4f}")
    print(f"Cosine Similarity: {cos:.4f}")

    return bert, cos

In [8]:
# from datasets import load_dataset

# dataset = load_dataset("code_search_net", "python", split="train[:1%]")
# dataset = dataset.filter(lambda x: x["func_code_string"])

# # Build repo → file → code
# from collections import defaultdict
# repo_map = defaultdict(lambda: defaultdict(str))
# for item in dataset:
#     repo = item["repository_name"]
#     path = item["func_path_in_repository"]
#     repo_map[repo][path] += "\n" + item["func_code_string"]

# # Pick top repo
# repo_name = list(repo_map.keys())[0]
# files = repo_map[repo_name]
# print(f"Summarizing {repo_name}...")

# # Run summarization
# summary = summarize_repo_with_graph(files, top_files=5, top_k_funcs=5)
# raw_code = "\n".join(files.values())[:5000]

# # Evaluate
# bert, cos = evaluate_summary(summary, raw_code)
# print("\nREPO SUMMARY:\n", summary)
# print(f"\nBERTScore: {bert:.4f} | Cosine Similarity: {cos:.4f}")


### Main Pipeline

In [11]:
from datasets import load_dataset
from collections import defaultdict

# Load and group data
dataset = load_dataset("code_search_net", "python", split="train[:1%]")
dataset = dataset.filter(lambda x: x["func_code_string"] and x["func_documentation_string"])

# Group repo → file → code
repo_map = defaultdict(lambda: defaultdict(list))
for item in dataset:
    repo_map[item["repository_name"]][item["func_path_in_repository"]].append(item["func_code_string"])

# ========== FUNCTION-LEVEL (First 100 functions) ==========

print("\n================ FUNCTION LEVEL EVALUATION ================\n")

func_codes = [item["func_code_string"] for item in dataset.select(range(100))]
func_refs = [item["func_documentation_string"] for item in dataset.select(range(100))]
func_preds = [out["generated_text"].strip() for out in func_summarizer(func_codes, max_length=64, do_sample=False)]

evaluate_function_level(func_preds, func_refs)

# ========== FILE-LEVEL (Top 3 repos × 3 files each) ==========

print("\n================ FILE LEVEL EVALUATION ================\n")

file_bert_scores = []
file_cos_scores = []

for repo_name in list(repo_map.keys())[:3]:
    for file_path, func_list in list(repo_map[repo_name].items())[:3]:
        raw_code = "\n".join(func_list)
        try:
            file_summary = summarize_file_with_graph(raw_code, top_k=5)
            bert, cos = evaluate_unsupervised_level(file_summary, raw_code, label="File")
            file_bert_scores.append(bert)
            file_cos_scores.append(cos)
        except Exception as e:
            print(f"Skipped file {file_path} due to: {e}")

print(f"\nAvg File BERTScore: {sum(file_bert_scores)/len(file_bert_scores):.4f}")
print(f"Avg File Cosine Similarity: {sum(file_cos_scores)/len(file_cos_scores):.4f}")

# ========== REPO-LEVEL (Top 5 repos) ==========

print("\n================ REPO LEVEL EVALUATION ================\n")

repo_bert_scores = []
repo_cos_scores = []

for repo_name in list(repo_map.keys())[:5]:
    file_contents = {
        path: "\n".join(funcs)
        for path, funcs in list(repo_map[repo_name].items())[:5]
    }

    try:
        repo_summary = summarize_repo_with_graph(file_contents, top_files=5, top_k_funcs=5)
        raw_repo_code = "\n".join(file_contents.values())[:5000]
        bert, cos = evaluate_unsupervised_level(repo_summary, raw_repo_code, label="Repo")
        repo_bert_scores.append(bert)
        repo_cos_scores.append(cos)
    except Exception as e:
        print(f"Skipped repo {repo_name} due to: {e}")

print(f"\nAvg Repo BERTScore: {sum(repo_bert_scores)/len(repo_bert_scores):.4f}")
print(f"Avg Repo Cosine Similarity: {sum(repo_cos_scores)/len(repo_cos_scores):.4f}")





🔍 Function-Level ROUGE:
rouge1: 0.4376
rouge2: 0.3668
rougeL: 0.4244
rougeLsum: 0.4346
🔍 Function-Level BERTScore: 0.8576




Your max_length is set to 128, but your input_length is only 109. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=54)



🔍 File-Level Evaluation:
BERTScore: 0.7810
Cosine Similarity: 0.6971

🔍 File-Level Evaluation:
BERTScore: 0.7890
Cosine Similarity: 0.3769

🔍 File-Level Evaluation:
BERTScore: 0.7721
Cosine Similarity: 0.4991


Your max_length is set to 128, but your input_length is only 89. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=44)



🔍 File-Level Evaluation:
BERTScore: 0.7643
Cosine Similarity: 0.5989

🔍 File-Level Evaluation:
BERTScore: 0.7748
Cosine Similarity: 0.6304

🔍 File-Level Evaluation:
BERTScore: 0.7658
Cosine Similarity: 0.5190


Your max_length is set to 128, but your input_length is only 105. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=52)



🔍 File-Level Evaluation:
BERTScore: 0.7733
Cosine Similarity: 0.5520

🔍 File-Level Evaluation:
BERTScore: 0.7504
Cosine Similarity: 0.5895

✅ Avg File BERTScore: 0.7714
✅ Avg File Cosine Similarity: 0.5579




Your max_length is set to 128, but your input_length is only 109. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=54)
Your max_length is set to 256, but your input_length is only 253. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=126)



🔍 Repo-Level Evaluation:
BERTScore: 0.7812
Cosine Similarity: 0.7010


Your max_length is set to 128, but your input_length is only 89. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=44)
Your max_length is set to 128, but your input_length is only 62. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=31)



🔍 Repo-Level Evaluation:
BERTScore: 0.7665
Cosine Similarity: 0.5092


Your max_length is set to 128, but your input_length is only 105. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=52)
Your max_length is set to 128, but your input_length is only 74. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=37)



🔍 Repo-Level Evaluation:
BERTScore: 0.7678
Cosine Similarity: 0.4784


Your max_length is set to 128, but your input_length is only 47. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=23)
Your max_length is set to 128, but your input_length is only 88. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=44)



🔍 Repo-Level Evaluation:
BERTScore: 0.7601
Cosine Similarity: 0.4561


Your max_length is set to 128, but your input_length is only 70. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=35)
Your max_length is set to 128, but your input_length is only 111. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=55)
Your max_length is set to 128, but your input_length is only 51. Since this is a summarization task, where outputs shorter than the input are typically wanted, you might consider decreasing max_length manually, e.g. summarizer('...', max_length=25)



🔍 Repo-Level Evaluation:
BERTScore: 0.7798
Cosine Similarity: 0.4907

✅ Avg Repo BERTScore: 0.7711
✅ Avg Repo Cosine Similarity: 0.5271
