In [None]:
import os
import torch
import requests
import re
import ast
import astunparse
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from rank_bm25 import BM25Okapi
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

# GitHub API 토큰 (반드시 유효한 토큰으로 변경)
github_token = "use your own token"

# Fine-tuned 모델이 저장된 디렉토리
model_dir = "./codet5-finetuned-swe-bench-optimized"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSeq2SeqLM.from_pretrained(model_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

In [None]:
# GitHub API를 통해 변경된 파일 목록을 가져오는 함수
def get_changed_files_api(repo_owner, repo_name, base_commit, github_token):
    api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/compare/{base_commit}^...{base_commit}"
    headers = {"Authorization": f"token {github_token}", "Accept": "application/vnd.github.v3+json"}
    try:
        response = requests.get(api_url, headers=headers)
        response.raise_for_status()
        comparison = response.json()
        changed_files = [file["filename"] for file in comparison["files"]]
        return changed_files
    except requests.exceptions.RequestException as e:
        print(f"Error retrieving changed files: {e}")
        return None

# GitHub API를 통해 특정 commit의 파일 코드를 가져오는 함수
def get_code_from_commit_api(repo_owner, repo_name, commit_hash, changed_files, github_token):
    headers = {"Authorization": f"token {github_token}", "Accept": "application/vnd.github.v3.raw"}
    base_code_dict = {}
    for file_path in changed_files:
        api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/contents/{file_path}?ref={commit_hash}"
        try:
            response = requests.get(api_url, headers=headers)
            response.raise_for_status()
            base_code_dict[file_path] = response.text
        except requests.exceptions.RequestException as e:
            print(f"Error retrieving code for {file_path}: {e}")
            return None
    return base_code_dict

In [None]:
# 간단한 전처리: 긴 주석 제거
def remove_long_comments(code):
    code = re.sub(r'""".*?"""', '', code, flags=re.DOTALL)
    code = re.sub(r"'''.*?'''", '', code, flags=re.DOTALL)
    return code

# AST를 이용해 함수 단위로 코드를 분리하는 클래스 및 함수
class FunctionVisitor(ast.NodeVisitor):
    def __init__(self, file_path):
        super().__init__()
        self.functions = []
        self.file_path = file_path

    def visit_FunctionDef(self, node):
        self.functions.append({
            "name": node.name,
            "code": astunparse.unparse(node),
            "start_line": node.lineno,
            "end_line": node.end_lineno,
            "filepath": self.file_path
        })
        self.generic_visit(node)

    def visit_AsyncFunctionDef(self, node):
        self.functions.append({
            "name": node.name,
            "code": astunparse.unparse(node),
            "start_line": node.lineno,
            "end_line": node.end_lineno,
            "filepath": self.file_path
        })
        self.generic_visit(node)

def split_code_into_functions_ast(code_text, file_path):
    code_text = remove_long_comments(code_text)
    try:
        tree = ast.parse(code_text)
        visitor = FunctionVisitor(file_path)
        visitor.visit(tree)
        return visitor.functions
    except Exception as e:
        print(f"Error parsing code from {file_path}: {e}")
        return []

In [None]:
def find_relevant_functions(problem_statement, all_functions):
    func_texts = [func["code"] for func in all_functions]
    tokenized_functions = [tokenizer.tokenize(text) for text in func_texts]
    bm25 = BM25Okapi(tokenized_functions)
    tokenized_query = tokenizer.tokenize(problem_statement)
    doc_scores = bm25.get_scores(tokenized_query)
    
    top_n = 3  # 상위 3개 함수 선택
    top_indices = np.argsort(doc_scores)[::-1][:top_n]
    return [all_functions[i] for i in top_indices]

In [None]:
def generate_patch(example):
    repo = example["repo"]
    base_commit = example["base_commit"]
    problem_statement = example["problem_statement"]

    # GitHub에서 변경된 파일 가져오기
    repo_owner, repo_name = repo.split("/")
    changed_files = get_changed_files_api(repo_owner, repo_name, base_commit, github_token)
    if not changed_files:
        print(f"No changed files found for commit {base_commit}.")
        return None

    changed_files = [f for f in changed_files if f.endswith(".py")]
    if not changed_files:
        print(f"No .py files found for commit {base_commit}.")
        return None

    # 코드 가져오기
    base_code_dict = get_code_from_commit_api(repo_owner, repo_name, base_commit, changed_files, github_token)
    if base_code_dict is None:
        print(f"Failed to retrieve code for commit {base_commit}.")
        return None

    # 함수 단위로 분석
    all_functions = []
    for file_path, base_code in base_code_dict.items():
        funcs = split_code_into_functions_ast(base_code, file_path)
        all_functions.extend(funcs)

    if not all_functions:
        print(f"No functions found in commit {base_commit}.")
        return None

    # 관련 코드 선택
    selected_functions = find_relevant_functions(problem_statement, all_functions)

    # Prompt 생성
    selected_code = "".join([
        f"File: {func['filepath']}, Function: {func['name']}\nCode:\n{func['code']}\n\n"
        for func in selected_functions
    ])

    prompt_input = (
        f"Issue: {problem_statement}\n"
        f"Code Context:\n{selected_code}\n"
        "Task: Provide a patch to fix the issue."
    )

    inputs = tokenizer(
        prompt_input, return_tensors="pt", max_length=512, truncation=True, padding="max_length"
    ).to(device)

    # 패치 생성
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            pad_token_id=tokenizer.eos_token_id,
            num_beams=5,
            early_stopping=True
        )

    return tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
# 데이터셋 로드
inference_dataset = load_dataset("princeton-nlp/SWE-bench_Verified", split="test")

# 저장 디렉토리 생성
output_dir = "patches_verified"
os.makedirs(output_dir, exist_ok=True)

In [None]:
# 추론 루프: 각 예제에 대해 패치를 생성하여 파일로 저장
for idx, example in enumerate(tqdm(inference_dataset, desc="Inference on Verified Dataset")):
    patch = generate_patch(example)
    if patch:
        with open(os.path.join(output_dir, f"patch_{idx}.txt"), "w", encoding="utf-8") as f:
            f.write(patch)
        print(f"Example {idx} generated patch:\n{patch}\n")
    else:
        print(f"Example {idx}: Patch generation skipped.")