In [44]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import os
from pyannote.audio import Pipeline
from pydub import AudioSegment
from dotenv import load_dotenv
from openai import OpenAI
from tempfile import TemporaryDirectory
from utils import load_yaml

load_dotenv()


class STT:
    def __init__(self, config, diarization_file=None, use_api=False):
        self.diarization_file = diarization_file
        self.use_api = use_api
        self.device = config["settings"]["device"]
        self.category_id = config["settings"]["category_id"]
        self.base_path = config["settings"]["base_path"]
        self.folder_path = os.path.join(self.base_path, self.category_id)

        self.torch_dtype = (
            torch.float16
            if self.device == "mps" or self.device == "cuda"
            else torch.float32
        )
        self.model_id = "openai/whisper-large-v3"
        self.model = AutoModelForSpeechSeq2Seq.from_pretrained(
            self.model_id,
            torch_dtype=self.torch_dtype,
            low_cpu_mem_usage=True,
            use_safetensors=True,
        )

    def transcribe_with_speaker_diarization(self):
        diarization_pipeline = Pipeline.from_pretrained(
            "pyannote/speaker-diarization", use_auth_token="YOUR_HF_TOKEN"
        )
        diarization = diarization_pipeline(self.diarization_file)

        transcription_pipeline = pipeline(
            "automatic-speech-recognition", model="openai/whisper-large-v3"
        )
        audio = AudioSegment.from_file(self.diarization_file)

        results = []
        for turn, _, speaker in diarization.itertracks(yield_label=True):
            # 해당 부분의 오디오 추출
            start_time = int(turn.start * 1000)  # ms로 변환
            end_time = int(turn.end * 1000)  # ms로 변환
            segment = audio[start_time:end_time]

            # 임시 파일로 저장 (Whisper가 파일 경로를 요구하므로)
            segment.export("temp_segment.wav", format="wav")

            # Whisper로 전사
            transcript = transcription_pipeline("temp_segment.wav")["text"]

            # 결과 저장
            results.append(
                {
                    "speaker": speaker,
                    "start": turn.start,
                    "end": turn.end,
                    "text": transcript,
                }
            )

        return results

    def whisper_api(self, file_path, segment_length_ms):
        client = OpenAI()
        audio = AudioSegment.from_file(file_path)

        audio_length_ms = len(audio)

        if len(audio) < segment_length_ms:
            return audio

        txt_list = []
        with TemporaryDirectory() as tempfile:
            for i, start in enumerate(range(0, audio_length_ms, segment_length_ms)):
                end = min(start + segment_length_ms, audio_length_ms)
                segment = audio[start:end]
                segment.export(f"{tempfile}/audio_{i}.mp3", format="mp3")
                audio_file = open(f"{tempfile}/audio_{i}.mp3", "rb")
                transcription = client.audio.transcriptions.create(
                    model="whisper-1", file=audio_file, response_format="text"
                )

                txt_list.append(transcription)

        return " ".join(txt_list)

    def run_stt(self):
        self.model.to(self.device)
        processor = AutoProcessor.from_pretrained(self.model_id)

        pipe = pipeline(
            "automatic-speech-recognition",
            model=self.model,
            tokenizer=processor.tokenizer,
            feature_extractor=processor.feature_extractor,
            max_new_tokens=128,
            chunk_length_s=30,
            batch_size=16,
            torch_dtype=self.torch_dtype,
            device=self.device,
        )

        audio_list = [
            x for x in os.listdir(f"{self.folder_path}/audio") if not x.startswith(".")
        ]

        print(f"총 오디오파일의 개수는 {len(audio_list)}개 입니다.")

        for audio in audio_list:
            if f"{audio.split('.')[0]}.txt" in os.listdir(
                f"{self.folder_path}/origin_txt"
            ):
                print(f"{audio}는 이미 처리된 파일입니다")
                continue

            audio_path = os.path.join(self.folder_path, "audio", audio)

            if self.use_api:
                segment_length_ms = 5 * 60 * 1000
                result = self.whisper_api(audio_path, segment_length_ms)

            else:
                result = pipe(audio_path)["text"]

            with open(
                f"{self.folder_path}/origin_txt/{audio.split('.')[0]}.txt",
                "w",
                encoding="utf-8",
            ) as file:
                file.write(result)
            print(f"{audio}가 txt로 변환되었습니다")

In [None]:
config = load_yaml("../config/stt.yaml")
stt = STT(config, use_api=True)

# stt 수행
stt.run_stt()