In [10]:
from github import Github
import requests
import time
import json
import ast
import os
from tqdm import tqdm

# 使用您的GitHub访问令牌
ACCESS_TOKEN = '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'

# 初始化GitHub对象
g = Github(ACCESS_TOKEN)

def extract_functions_from_content(file_content):
    """
    从代码内容中提取所有函数定义，忽略非 Python 3 语法的错误
    """
    try:
        tree = ast.parse(file_content)
    except SyntaxError as e:
        print(f"解析代码文件失败，错误：{e}")
        return {}
    functions = {}
    for node in ast.walk(tree):
        if isinstance(node, ast.FunctionDef):
            func_name = node.name
            func_code = ast.get_source_segment(file_content, node)
            functions[func_name] = func_code
    return functions

def get_python_files_in_repo(repo):
    """
    获取仓库中所有的Python文件路径列表
    """
    python_files = []
    contents = repo.get_contents("")
    while contents:
        file_content = contents.pop(0)
        if file_content.type == "dir":
            try:
                contents.extend(repo.get_contents(file_content.path))
            except Exception as e:
                print(f"无法获取目录内容：{file_content.path}, 错误：{e}")
                continue
        elif file_content.type == "file" and file_content.path.endswith('.py'):
            if 'test' not in file_content.path.lower():
                python_files.append(file_content.path)
    return python_files

def save_data_pairs(all_data_pairs):
    """
    将数据对保存到文件
    """
    with open('data_pairs.json', 'w', encoding='utf-8') as f:
        json.dump(all_data_pairs, f, ensure_ascii=False, indent=4)
    print(f"数据对已保存到 data_pairs.json，当前总数：{len(all_data_pairs)}")

def find_and_save_functions():
    # 从 test_code_files.json 中读取测试文件的信息
    with open('test_code_files.json', 'r', encoding='utf-8') as f:
        test_code_files = json.load(f)

    all_data_pairs = []
    
    # 使用 tqdm 包来显示进度
    for idx, file_info in enumerate(tqdm(test_code_files, desc="处理文件")):
        repo_name = file_info['repository_full_name']
        test_file_path = file_info['file_path']
        download_url = file_info['download_url']

        print(f"\n处理仓库：{repo_name}, 测试文件：{test_file_path} [{idx+1}/{len(test_code_files)}]")

        # 下载测试文件内容
        response = requests.get(download_url)
        if response.status_code != 200:
            print(f"无法下载文件：{download_url}")
            continue
        test_file_content = response.text

        # 提取测试函数，捕获并忽略解析失败的文件
        test_functions = extract_functions_from_content(test_file_content)
        test_functions = {name: code for name, code in test_functions.items() if name.startswith('test_')}

        if not test_functions:
            print(f"测试文件 {test_file_path} 中未找到测试函数")
            continue

        # 获取仓库对象
        try:
            repo = g.get_repo(repo_name)
        except Exception as e:
            print(f"无法获取仓库：{repo_name}, 错误：{e}")
            continue

        # 获取仓库中的Python文件
        python_files = get_python_files_in_repo(repo)

        # 提取仓库中所有的函数
        repo_functions = {}
        for py_file in python_files:
            try:
                file_contents = repo.get_contents(py_file)
                code_content = file_contents.decoded_content.decode('utf-8', errors='ignore')
                funcs = extract_functions_from_content(code_content)
                repo_functions[py_file] = funcs
            except Exception as e:
                print(f"无法处理文件：{py_file}, 错误：{e}")
                continue

        # 对于每个测试函数，查找对应的原始函数
        for test_func_name, test_func_code in test_functions.items():
            original_func_name = test_func_name[5:]  # 去除 test_ 前缀
            found = False
            for code_file, funcs in repo_functions.items():
                if original_func_name in funcs:
                    original_func_code = funcs[original_func_name]
                    data_pair = {
                        'input_code': original_func_code,
                        'test_code': test_func_code,
                        'repo_name': repo_name,
                        'code_file': code_file,
                        'test_file': test_file_path
                    }
                    all_data_pairs.append(data_pair)
                    print(f"匹配成功：{original_func_name} 在文件 {code_file} 中找到")
                    found = True
                    break  # 假设一个函数只在一个文件中定义
            if not found:
                print(f"在仓库 {repo_name} 中未找到函数 {original_func_name}")

        # 每处理一个仓库保存一次数据
        save_data_pairs(all_data_pairs)

        # 检查 GitHub API 速率限制
        rate_limit = g.get_rate_limit().core
        if rate_limit.remaining < 10:  # 如果剩余请求数较少，增加延迟
            reset_timestamp = rate_limit.reset.timestamp()
            sleep_time = max(0, reset_timestamp - time.time()) + 10
            print(f"速率限制接近上限，等待 {sleep_time} 秒...")
            time.sleep(sleep_time)

if __name__ == "__main__":
    find_and_save_functions()


处理文件:   0%|          | 0/8576 [00:00<?, ?it/s]


处理仓库：hzy46/Deep-Learning-21-Examples, 测试文件：chapter_11/reader.py [1/8576]


处理文件:   0%|          | 0/8576 [02:16<?, ?it/s]


KeyboardInterrupt: 