In [1]:
!pip install transformers datasets jiwer librosa

Collecting jiwer
  Downloading jiwer-3.0.4-py3-none-any.whl.metadata (2.6 kB)
Collecting rapidfuzz<4,>=3 (from jiwer)
  Downloading rapidfuzz-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Downloading jiwer-3.0.4-py3-none-any.whl (21 kB)
Downloading rapidfuzz-3.10.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m31.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hInstalling collected packages: rapidfuzz, jiwer
Successfully installed jiwer-3.0.4 rapidfuzz-3.10.0


In [5]:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
import os
import csv
from jiwer import cer
import warnings

# 경고 무시 설정
warnings.filterwarnings("ignore", category=FutureWarning)

device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id = "openai/whisper-large-v3"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True, use_safetensors=True
)
model.to(device)

processor = AutoProcessor.from_pretrained(model_id)

pipe = pipeline(
    "automatic-speech-recognition",
    model=model,
    tokenizer=processor.tokenizer,
    feature_extractor=processor.feature_extractor,
    torch_dtype=torch_dtype,
    device=device,
)

def load_reference_from_csv(csv_file):
    references = {}
    with open(csv_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            if row['audio_file'][-1:]=='p':
                row['audio_file']= row['audio_file'][:-1]
            elif row['audio_file'][-3:]=='WAV':
                row['audio_file']= row['audio_file'][:-3]+'wav'
            references[row['audio_file']] = row['stt_text']
    return references

def process_wav_folders(folder_paths, reference_csv):
    all_hypotheses = []
    all_references = []

    # 참조 텍스트 로드
    references = load_reference_from_csv(reference_csv)
    print(len(references))
    
    file_count = 0
    for folder_path in folder_paths:
        for root, _, files in os.walk(folder_path):
            for file in files:
                if file.endswith('.wav'):
                    wav_path = os.path.join(root, file)
                    
                    # WAV 파일 처리
                    result = pipe(wav_path,generate_kwargs={"language": "korean"})
                    hypothesis_text = result["text"]
#                     print(file)
                    # 참조 텍스트 추가
                    if file in references:
                        reference_text = references[file]
                        all_references.append(reference_text)
                        all_hypotheses.append(hypothesis_text)
                    
                        file_count += 1
                        
                        # 진행 상황 출력
                        if file_count % 10 == 0:
                            print(f"{file_count}개 파일 처리 완료")
                            
                        if file_count % 100 == 0:
                            current_cer = cer(" ".join(all_references), " ".join(all_hypotheses))
                            print(f"{file_count}개 파일 처리 완료. 현재 CER: {current_cer:.4f}")
                    else:
                        print(f"경고: {file}에 대한 참조 텍스트를 찾을 수 없습니다")

    return all_references, all_hypotheses

# 메인 실행 부분
if __name__ == "__main__":
    folder_paths = [
#         "/kaggle/input/val-wav/Validation/AI스피커",
#         "/kaggle/input/val-wav/Validation/AI챗봇",
        "/kaggle/input/val-wav/Validation/스튜디오",
#         "/kaggle/input/val-wav/Validation/음성수집도구"
    ]
    reference_csv = "/kaggle/input/val-wav/audio_text_pairs_with_labels.csv"  # 참조 텍스트가 있는 CSV 파일 경로
    
    references, hypotheses = process_wav_folders(folder_paths, reference_csv)
    
    # 전체 CER 계산
    overall_cer = cer(" ".join(references), " ".join(hypotheses))
    
    print(f"전체 CER: {overall_cer:.4f}")
    
    # 결과를 파일로 저장
    with open("transcription_results.txt", "w", encoding="utf-8") as f:
        for ref, hyp in zip(references, hypotheses):
            f.write(f"참조: {ref}\n")
            f.write(f"가설: {hyp}\n")
            f.write("\n")

    print("결과가 transcription_results.txt 파일에 저장되었습니다.")

159970
노인남여_노인대화07_F_LSS00_65_수도권_녹음실_08369.wav
노인남여_노인대화07_M_HHS00_60_강원_녹음실_08421.wav
노인남여_노인대화07_M_KMS00_64_수도권_녹음실_08484.wav
노인남여_노인대화07_F_HJS00_64_수도권_녹음실_08634.wav
노인남여_노인대화07_F_JML00_61_수도권_녹음실_08545.wav
노인남여_노인대화07_M_KMS00_64_수도권_녹음실_08493.wav
노인남여_노인대화07_F_JML00_61_수도권_녹음실_08021.wav
노인남여_노인대화07_F_JML00_61_수도권_녹음실_07611.wav
노인남여_노인대화07_F_LAJ00_63_수도권_녹음실_08250.wav
노인남여_노인대화07_F_CSO00_62_수도권_녹음실_08103.wav
10개 파일 처리 완료
노인남여_노인대화07_F_YEL00_62_수도권_녹음실_07720.wav
노인남여_노인대화07_F_JYJ00_61_수도권_녹음실_07712.wav
노인남여_노인대화07_F_KMK00_61_수도권_녹음실_07575.wav
노인남여_노인대화07_F_PYJ00_67_수도권_녹음실_07432.wav
노인남여_노인대화07_F_HJS00_64_수도권_녹음실_07410.wav
노인남여_노인대화07_F_HJS00_64_수도권_녹음실_08506.wav
노인남여_노인대화07_F_PYJ00_67_수도권_녹음실_07621.wav
노인남여_노인대화07_F_LSS00_65_수도권_녹음실_08472.wav
노인남여_노인대화07_M_LK000_61_수도권_녹음실_07933.wav
노인남여_노인대화07_F_JML00_61_수도권_녹음실_08650.wav
20개 파일 처리 완료
노인남여_노인대화07_F_CSO00_62_수도권_녹음실_07486.wav
노인남여_노인대화07_M_HHS00_60_강원_녹음실_08233.wav
노인남여_노인대화07_F_LSS00_65_수도권_녹음실_08390.wav
노인남여_노인대화07_M_KMS00_64_수도권

KeyboardInterrupt: 