In [None]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from datasets import load_dataset
import requests
from tqdm import tqdm
import time
import ast
import astunparse
from rank_bm25 import BM25Okapi

# GitHub API 토큰 (반드시 실제 토큰으로 변경)
github_token = "use your own token"
if github_token is None:
    raise ValueError("GitHub API token not set.")

# 사전 학습된 CodeT5-small 모델 및 토크나이저 로드 (파인튜닝 없이 사용)
model_name = "Salesforce/codet5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# GPU 설정: GPU가 있으면 CUDA를 사용
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print("Using device:", device)

In [None]:
# SWE-bench Lite 데이터셋의 test split을 로드
test_dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split="test")

In [None]:
# 변경된 파일 목록을 가져오는 함수
def get_changed_files_api(repo_owner, repo_name, base_commit, github_token):
    api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/compare/{base_commit}^...{base_commit}"
    headers = {
        "Authorization": f"token {github_token}",
        "Accept": "application/vnd.github.v3+json"
    }
    try:
        response = requests.get(api_url, headers=headers)
        response.raise_for_status()
        comparison = response.json()
        changed_files = [file["filename"] for file in comparison["files"]]
        return changed_files
    except requests.exceptions.RequestException as e:
        print(f"Error getting changed files for commit {base_commit}: {e}")
        return None

# 변경된 파일의 코드를 가져오는 함수
def get_code_from_commit_api(repo_owner, repo_name, commit_hash, changed_files, github_token):
    headers = {
        "Authorization": f"token {github_token}",
        "Accept": "application/vnd.github.v3.raw"
    }
    base_code_dict = {}
    for file_path in changed_files:
        api_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/contents/{file_path}?ref={commit_hash}"
        try:
            response = requests.get(api_url, headers=headers)
            response.raise_for_status()
            base_code_dict[file_path] = response.text
        except requests.exceptions.RequestException as e:
            print(f"Error retrieving code for file {file_path} in commit {commit_hash}: {e}")
            return None
    return base_code_dict

In [None]:
# AST를 사용하여 함수 단위로 코드를 분할하는 클래스
class FunctionVisitor(ast.NodeVisitor):
    def __init__(self, file_path):
        super().__init__()
        self.functions = []
        self.file_path = file_path

    def visit_FunctionDef(self, node):
        self.functions.append({
            "name": node.name,
            "code": astunparse.unparse(node),
            "start_line": node.lineno,
            "end_line": node.end_lineno,
            "filepath": self.file_path
        })
        self.generic_visit(node)

    def visit_AsyncFunctionDef(self, node):
        self.functions.append({
            "name": node.name,
            "code": astunparse.unparse(node),
            "start_line": node.lineno,
            "end_line": node.end_lineno,
            "filepath": self.file_path
        })
        self.generic_visit(node)

# 함수 단위로 AST 파싱하는 함수
def split_code_into_functions_ast(code_text, file_path):
    try:
        tree = ast.parse(code_text)
        visitor = FunctionVisitor(file_path)
        visitor.visit(tree)
        return visitor.functions
    except Exception as e:
        print(f"AST parsing error in {file_path}: {e}")
        return []

In [None]:
def generate_patch(example):
    try:
        # repo, commit, 문제 설명 정보 추출
        repo_owner, repo_name = example["repo"].split("/")
        commit_hash = example["base_commit"]
        problem_statement = example["problem_statement"]

        # GitHub API를 통해 변경된 파일 목록 및 코드 가져오기
        changed_files = get_changed_files_api(repo_owner, repo_name, commit_hash, github_token)
        if not changed_files:
            print(f"No changed files for commit {commit_hash}.")
            return None
        # .py 파일만 사용
        changed_files = [f for f in changed_files if f.endswith(".py")]
        if not changed_files:
            print(f"No Python files for commit {commit_hash}.")
            return None

        base_code_dict = get_code_from_commit_api(repo_owner, repo_name, commit_hash, changed_files, github_token)
        if base_code_dict is None:
            print(f"Failed to retrieve code for commit {commit_hash}.")
            return None

        # AST 파싱: 모든 파일에서 함수 단위로 분할
        all_functions = []
        for file_path, code in base_code_dict.items():
            funcs = split_code_into_functions_ast(code, file_path)
            all_functions.extend(funcs)
        if not all_functions:
            print(f"No functions found in commit {commit_hash}.")
            return None

        # BM25를 이용해 문제 설명과 관련 있는 상위 3개 함수 선택
        tokenized_functions = [tokenizer.tokenize(func["code"]) for func in all_functions]
        bm25 = BM25Okapi(tokenized_functions)
        tokenized_query = tokenizer.tokenize(problem_statement)
        doc_scores = bm25.get_scores(tokenized_query)
        top_n = 3
        top_indices = sorted(range(len(doc_scores)), key=lambda i: doc_scores[i], reverse=True)[:top_n]
        selected_functions = [all_functions[i] for i in top_indices]

        # 선택된 함수 정보를 결합하여 prompt 구성
        selected_code = ""
        for func in selected_functions:
            selected_code += f"File: {func['filepath']}, Function: {func['name']}\n"
            selected_code += f"Start Line: {func['start_line']}, End Line: {func['end_line']}\n"
            selected_code += f"Code:\n{func['code']}\n\n"

        prompt_input = f"Issue: {problem_statement}\nCode Context:\n{selected_code}\nTask: Provide a patch to fix the issue."

        # 입력 텍스트를 512 토큰 이내로 제한
        input_tokens = tokenizer.encode(prompt_input)
        if len(input_tokens) > 512:
            input_tokens = input_tokens[:512]
        input_text = tokenizer.decode(input_tokens, skip_special_tokens=True)
        inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True, padding="max_length").to(device)

        # CodeT5를 사용하여 패치 생성 (beam search 사용)
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=256,
                pad_token_id=tokenizer.eos_token_id,
                num_beams=5,
                early_stopping=True
            )
        patch = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return patch

    except Exception as e:
        print(f"[Error in generate_patch] {e}")
        return None

In [None]:
# 생성된 패치를 저장할 폴더 생성
output_dir = "initial_patches"
os.makedirs(output_dir, exist_ok=True)

# 테스트 데이터셋의 각 예제에 대해 패치 생성 및 저장
for idx, example in enumerate(tqdm(test_dataset, desc="Generating Initial Patches")):
    patch = generate_patch(example)
    if patch:
        # UTF-8 인코딩으로 저장 (유니코드 인코딩 에러 방지)
        with open(os.path.join(output_dir, f"patch_{idx}.txt"), "w", encoding="utf-8") as f:
            f.write(patch)
        print(f"--- Example {idx} Patch ---")
        print(patch)
    else:
        print(f"--- Example {idx} Skipped ---")
    time.sleep(1)