<a href="https://colab.research.google.com/github/kevin89887634/ChatGPT-wechat-bot/blob/master/Whisper_%E6%89%B9%E9%87%8F%E8%BD%AC%E5%86%99_v4_0_(%E6%9C%80%E7%BB%88%E7%89%88).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-
"""
Whisper 批量转写 v4.0 (最终版)

核心特性：
1.  【一站式路径】：只需配置一个项目总文件夹，输入输出路径全自动管理。
2.  【智能依赖安装】：自动检查并安装所需库，实现开箱即用。
3.  【企业级健壮性】：包含断点续传、错误重试、文件验证和详细日志。
4.  【GPT文本校对】：可选的GPT-4文本润色功能，提升转写质量。
"""
import os
import sys
import subprocess
import importlib
import re
import json
import time
import hashlib
from datetime import datetime
import logging
from typing import Dict, Any, Tuple

# ==============================================================================
# 1.【 您唯一需要配置的地方 】
# ==============================================================================

# 设置您在 Google Drive 中的总项目文件夹路径。
# 脚本将自动在此文件夹内寻找 "_source_audios" 子文件夹作为输入，
# 并自动创建 "srt", "txt_corrected" 等文件夹用于存放输出。
PROJECT_FOLDER = "/content/drive/MyDrive/MyWhisperProject"

# --- 可选配置 ---

# 模型选择: "tiny", "base", "small", "medium", "large"
MODEL_SIZE = "medium"

# 是否启用 GPT 校对（需要配置好 API Key）
ENABLE_GPT_CORRECTION = True

# 临时性错误（如网络问题）的重试次数
RETRY_COUNT = 3

# ==============================================================================
# 2.【 动态路径生成 (无需修改) 】
# ==============================================================================

SOURCE_SUBFOLDER = "_source_audios"  # 存放源文件的子文件夹名称

# 根据唯一的项目文件夹，自动生成所有其他路径
INPUT_FOLDER = os.path.join(PROJECT_FOLDER, SOURCE_SUBFOLDER)
OUTPUT_SRT = os.path.join(PROJECT_FOLDER, "srt")
OUTPUT_TXT = os.path.join(PROJECT_FOLDER, "txt_corrected")
OUTPUT_JSON = os.path.join(PROJECT_FOLDER, "json")
LOG_FOLDER = os.path.join(PROJECT_FOLDER, "logs")
PROGRESS_FILE = os.path.join(LOG_FOLDER, "progress.json")
SESSION_LOG_FILE = os.path.join(LOG_FOLDER, f"session_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")

# ==============================================================================
# 3.【 初始化与环境设置 (全自动) 】
# ==============================================================================

# 全局日志器
logger = logging.getLogger(__name__)

def auto_install_dependencies():
    """
    智能检查并自动安装所有必需的依赖库和工具。
    """
    print("📋 正在检查环境与依赖库...")

    # 检查 Python 库 (whisper, openai)
    required_libs = {"whisper": "openai-whisper", "openai": "openai"}
    for lib_name, install_name in required_libs.items():
        try:
            importlib.import_module(lib_name)
            print(f"  ✅ 依赖库 '{lib_name}' 已安装。")
        except ImportError:
            print(f"  ⚠️ 依赖库 '{lib_name}' 未找到，正在自动安装...")
            try:
                subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", install_name])
                print(f"  ✅ 成功安装 '{install_name}'。")
            except subprocess.CalledProcessError as e:
                print(f"  ❌ 安装 '{install_name}' 失败: {e}")
                print("  请尝试手动运行 'pip install -U openai-whisper openai' 后再试。")
                raise

    # 检查系统工具 (ffmpeg/ffprobe)，主要针对 Colab 环境
    try:
        result = subprocess.run(['ffprobe', '-version'], capture_output=True, text=True, check=True)
        print("  ✅ 工具 'ffmpeg' 已安装。")
    except (FileNotFoundError, subprocess.CalledProcessError):
        print("  ⚠️ 工具 'ffmpeg' 未找到，正在自动安装 (需要 sudo 权限)...")
        try:
            # 使用 apt-get 安装，适用于 Debian/Ubuntu/Colab
            subprocess.check_call(['apt-get', 'install', '-y', 'ffmpeg'])
            print("  ✅ 成功安装 'ffmpeg'。")
        except Exception as e:
            print(f"  ❌ 自动安装 'ffmpeg' 失败: {e}")
            print("  如果不在 Colab 环境，请根据您的操作系统手动安装 ffmpeg。")
            raise

def setup_environment():
    """设置日志、挂载驱动器并创建目录结构"""
    # 设置日志
    if not logger.hasHandlers():
        os.makedirs(LOG_FOLDER, exist_ok=True)
        logger.setLevel(logging.INFO)
        formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
        # 文件日志
        file_handler = logging.FileHandler(SESSION_LOG_FILE, encoding='utf-8')
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)
        # 控制台日志
        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(formatter)
        logger.addHandler(stream_handler)

    logger.info("日志系统初始化完成。")

    # 挂载 Google Drive 并配置 OpenAI API Key
    try:
        from google.colab import drive, userdata
        if not os.path.ismount("/content/drive"):
            logger.info("🔗 正在挂载 Google Drive...")
            drive.mount("/content/drive")
        else:
            logger.info("🔗 Google Drive 已挂载。")

        # 动态导入 openai 并设置key
        if ENABLE_GPT_CORRECTION:
            global openai
            import openai
            try:
                openai.api_key = userdata.get('OPENAI_API_KEY')
                if openai.api_key:
                    logger.info("🔑 成功加载 OpenAI API Key。")
                else:
                    logger.warning("⚠️ 未在 Colab userdata 中找到'OPENAI_API_KEY'，GPT 校对功能将禁用。")
            except Exception:
                openai.api_key = None
                logger.error("⚠️ 加载 OpenAI API Key 失败，GPT 校对功能将禁用。")
    except ImportError:
        logger.warning("⚠️ 未在 Google Colab 环境中运行，跳过驱动器挂载和 userdata 配置。")
        if ENABLE_GPT_CORRECTION:
            # 尝试从环境变量加载
            import openai
            openai.api_key = os.environ.get('OPENAI_API_KEY')
            if openai.api_key:
                logger.info("🔑 成功从环境变量加载 OpenAI API Key。")
            else:
                logger.warning("⚠️ 未在环境变量中找到 OpenAI API Key，GPT 校对功能将禁用。")


    # 创建所有必要的目录
    logger.info("📁 正在创建项目目录结构...")
    directories_to_create = [PROJECT_FOLDER, INPUT_FOLDER, OUTPUT_SRT, OUTPUT_TXT, OUTPUT_JSON, LOG_FOLDER]
    for dir_path in directories_to_create:
        try:
            os.makedirs(dir_path, exist_ok=True)
        except OSError as e:
            logger.error(f"创建目录失败: {dir_path} - {e}")
            raise
    logger.info("✅ 项目目录结构准备就绪。")


# ==============================================================================
# 4.【 核心功能模块 (无需修改) 】
# ==============================================================================

# --- 进度管理 ---
class ProgressManager:
    """管理处理进度，实现断点续传和状态记录"""
    def __init__(self, filepath: str):
        self.filepath = filepath
        self.data = self._load()

    def _load(self) -> Dict[str, Any]:
        if os.path.exists(self.filepath):
            try:
                with open(self.filepath, 'r', encoding='utf-8') as f:
                    return json.load(f)
            except (json.JSONDecodeError, IOError):
                return {"processed_files": {}, "failed_files": {}}
        return {"processed_files": {}, "failed_files": {}}

    def save(self):
        try:
            with open(self.filepath, 'w', encoding='utf-8') as f:
                json.dump(self.data, f, ensure_ascii=False, indent=4)
        except IOError as e:
            logger.error(f"无法保存进度文件: {e}")

    def is_completed(self, file_hash: str, base_name: str) -> bool:
        return file_hash in self.data.get("processed_files", {}) and \
               all(os.path.exists(os.path.join(p, f"{base_name}{ext}")) for p, ext in
                   [(OUTPUT_SRT, ".srt"), (OUTPUT_TXT, ".txt"), (OUTPUT_JSON, ".json")])

    def add_success(self, file_hash: str, file_path: str, info: Dict[str, Any]):
        self.data.setdefault("processed_files", {})[file_hash] = {"file_path": file_path, **info}
        self.data.setdefault("failed_files", {}).pop(file_hash, None)
        self.save()

    def add_failure(self, file_hash: str, file_path: str, reason: str):
        self.data.setdefault("failed_files", {})[file_hash] = {"file_path": file_path, "reason": reason, "timestamp": datetime.now().isoformat()}
        self.save()

    def get_summary(self) -> Tuple[int, int]:
        return len(self.data.get("processed_files", {})), len(self.data.get("failed_files", {}))

# --- 文件处理 ---

def get_file_hash(file_path: str) -> str:
    """计算文件的 MD5 哈希值"""
    hash_md5 = hashlib.md5()
    with open(file_path, "rb") as f:
        for chunk in iter(lambda: f.read(4096), b""):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()

def validate_audio_file(file_path: str) -> Tuple[bool, Dict[str, Any]]:
    """验证文件是否存在、可读、包含音频流且时长有效"""
    if not os.path.exists(file_path): return False, {"error": "文件不存在"}
    try:
        result = subprocess.run(
            ["ffprobe", "-v", "quiet", "-print_format", "json", "-show_format", "-show_streams", file_path],
            capture_output=True, text=True, check=True
        )
        info = json.loads(result.stdout)
        if not any(s.get("codec_type") == "audio" for s in info.get("streams", [])):
            return False, {"error": "文件不包含音频流"}
        duration = float(info.get("format", {}).get("duration", 0))
        if duration < 0.1: return False, {"error": "音频时长过短"}
        return True, {"duration": duration}
    except Exception:
        return False, {"error": "文件已损坏或格式无法识别"}

def correct_text_with_gpt(text: str) -> str:
    """使用 GPT-3.5-turbo 校对文本"""
    if not ENABLE_GPT_CORRECTION or not hasattr(sys.modules[__name__], 'openai') or not openai.api_key:
        return text

    logger.info("  调用 GPT 进行文本校对...")
    try:
        response = openai.chat.completions.create(
            model="gpt-3.5-turbo",
            messages=[
                {"role": "system", "content": "你是一个专业的中文文本校对编辑。请修正以下文本中的错别字、标点和语法问题，使其更流畅、准确，但必须保持原意。直接返回修正后的文本。"},
                {"role": "user", "content": text}
            ],
            temperature=0.2, max_tokens=4000
        )
        return response.choices[0].message.content.strip()
    except Exception as e:
        logger.error(f"  GPT 校对失败: {e}。将返回原始文本。")
        raise IOError("GPT API call failed") from e # 抛出可重试的错误

def process_single_file(file_path: str, model: Any, progress_manager: ProgressManager) -> str:
    """处理单个文件的完整流程"""
    base_name = os.path.splitext(os.path.basename(file_path))[0]

    try:
        file_hash = get_file_hash(file_path)
    except FileNotFoundError:
        return "error_file_not_found"

    if progress_manager.is_completed(file_hash, base_name):
        return "skipped"

    is_valid, validation_info = validate_audio_file(file_path)
    if not is_valid:
        progress_manager.add_failure(file_hash, file_path, f"文件验证失败: {validation_info.get('error')}")
        return "error_validation"

    try:
        result = model.transcribe(file_path, verbose=False)
        raw_text = "\n".join([seg["text"].strip() for seg in result["segments"]])
        corrected_text = correct_text_with_gpt(raw_text) if "zh" in result.get("language", "") else raw_text

        # 保存所有输出文件
        with open(os.path.join(OUTPUT_SRT, base_name + ".srt"), "w", encoding="utf-8") as f:
            for idx, seg in enumerate(result["segments"]):
                start_t = datetime.utcfromtimestamp(seg['start']).strftime('%H:%M:%S,%f')[:-3]
                end_t = datetime.utcfromtimestamp(seg['end']).strftime('%H:%M:%S,%f')[:-3]
                f.write(f"{idx+1}\n{start_t} --> {end_t}\n{seg['text'].strip()}\n\n")

        with open(os.path.join(OUTPUT_TXT, base_name + ".txt"), "w", encoding="utf-8") as f:
            f.write(corrected_text)

        result['corrected_text'] = corrected_text
        with open(os.path.join(OUTPUT_JSON, base_name + ".json"), "w", encoding="utf-8") as f:
            json.dump(result, f, ensure_ascii=False, indent=4)

        progress_manager.add_success(file_hash, file_path, {"language": result.get("language"), "duration": validation_info.get('duration'), "timestamp": datetime.now().isoformat()})
        return "success"
    except IOError as e: # GPT 调用失败
        progress_manager.add_failure(file_hash, file_path, f"可恢复错误: {e}")
        return "error_retryable"
    except Exception as e:
        logger.error(f"处理文件 {file_path} 时发生不可恢复的严重错误: {e}", exc_info=True)
        progress_manager.add_failure(file_hash, file_path, f"严重错误: {e}")
        return "error_critical"

# ==============================================================================
# 5.【 主程序入口 】
# ==============================================================================

def main():
    """主函数，负责编排整个批量处理流程。"""

    # 步骤 1: 智能安装依赖
    try:
        auto_install_dependencies()
        import whisper # 确保 whisper 在安装后可用
    except Exception:
        print("❌ 依赖安装失败，程序无法继续。请检查错误信息。")
        return

    # 步骤 2: 初始化环境
    setup_environment()
    logger.info("🚀 启动 Whisper 批量转写增强版 v4.0")
    logger.info(f"项目文件夹: {PROJECT_FOLDER}")

    # 步骤 3: 文件扫描
    logger.info(f"🔍 正在扫描输入文件夹: {INPUT_FOLDER}")
    VALID_EXT = [".mp3", ".mp4", ".m4a", ".wav", ".webm", ".aac", ".mov", ".wmv", ".flac", ".ogg"]
    all_files = [os.path.join(root, f) for root, _, files in os.walk(INPUT_FOLDER) for f in files if os.path.splitext(f)[1].lower() in VALID_EXT]

    if not all_files:
        logger.warning(f"⚠️ 在 '{INPUT_FOLDER}' 中没有找到任何有效的音视频文件。")
        logger.warning("请确保您的音频文件已放入正确的子文件夹中。")
        return
    logger.info(f"📦 发现 {len(all_files)} 个有效文件。")

    # 步骤 4: 加载模型和进度管理器
    progress_manager = ProgressManager(PROGRESS_FILE)
    logger.info(f"⏳ 正在加载 Whisper 模型: {MODEL_SIZE} (此过程可能需要几分钟)...")
    try:
        model = whisper.load_model(MODEL_SIZE)
        logger.info("✅ Whisper 模型加载完成。")
    except Exception as e:
        logger.error(f"❌ 加载 Whisper 模型失败: {e}。程序已终止。")
        return

    # 步骤 5: 批量处理
    success_count, failed_count = progress_manager.get_summary()
    with tqdm(total=len(all_files), desc="总体进度", unit="个") as pbar:
        pbar.update(success_count + failed_count)
        for file_path in all_files:
            filename = os.path.basename(file_path)
            pbar.set_description(f"处理: {filename[:25]:<25}")

            for attempt in range(RETRY_COUNT + 1):
                status = process_single_file(file_path, model, progress_manager)
                if status != "error_retryable":
                    break
                logger.warning(f"文件 {filename} 出现可恢复错误 (尝试 {attempt + 1}/{RETRY_COUNT})，将在5秒后重试...")
                time.sleep(5)

            if status == "success" or status == "skipped":
                pbar.update(1)
            else: # 各种失败情况
                pbar.update(1) # 失败也算处理了一个

            current_success, current_failed = progress_manager.get_summary()
            pbar.set_postfix({"成功": current_success, "失败": current_failed})

    # 步骤 6: 结束汇总
    logger.info("🎉 批量处理全部完成！")
    final_success, final_failed = progress_manager.get_summary()
    logger.info(f"📊 最终统计: 成功 {final_success} 个, 失败 {final_failed} 个。")
    if final_failed > 0:
        logger.warning("--- 以下文件处理失败 ---")
        for details in progress_manager.data.get('failed_files', {}).values():
            logger.warning(f"  - 文件: {details['file_path']}")
            logger.warning(f"    原因: {details['reason']}")
        logger.warning("----------------------")


if __name__ == "__main__":
    main()