In [120]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'
import numpy as np

try:
    import tensorflow  # required in Colab to avoid protobuf compatibility issues
except ImportError:
    pass

import torch
import pandas as pd
import whisper
import torchaudio
from torchaudio.transforms import Resample,MelSpectrogram


from tqdm.notebook import tqdm


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [121]:
class LibriSpeech(torch.utils.data.Dataset):
    """
    A simple class to wrap LibriSpeech and trim/pad the audio to 30 seconds.
    It will drop the last few seconds of a very small portion of the utterances.
    """
    def __init__(self, split="test-clean", device=DEVICE):
        self.dataset = torchaudio.datasets.LIBRISPEECH(
            root=os.path.expanduser("~/.cache"),
            url=split,
            download=True,
        )
        self.device = device

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, item):
        audio, sample_rate, text, _, _, _ = self.dataset[item]
        
        assert sample_rate == 16000
        audio = whisper.pad_or_trim(audio.flatten()).to(self.device)
        
        mel = whisper.log_mel_spectrogram(audio)
        
        return (mel, text)
    

# 예시 사용
# dataset = MyAudioDataset(device='cuda')
# loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True)
# for waveform, filename in loader:
#     print(filename)

In [122]:
default_dataset = LibriSpeech("test-clean")

In [123]:
default_dataset[0][0].shape

torch.Size([80, 3000])

In [124]:
default_dataset[1][0].shape

torch.Size([80, 3000])

In [125]:
class MyAudioDataset(torch.utils.data.Dataset):
    def __init__(self, directory="bigdata", device='cpu', sample_rate=16000, n_mels=80):
        self.directory = directory
        self.device = device
        self.sample_rate = sample_rate
        self.n_mels = n_mels
        self.filenames = [f for f in os.listdir(directory) if f.endswith('.mp3')]
        self.mel_transformer = MelSpectrogram(sample_rate=self.sample_rate, n_mels=self.n_mels)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        path = os.path.join(self.directory, filename)
        waveform, sample_rate = torchaudio.load(path)
        
        # 오디오를 모노로 변환
        waveform = waveform.mean(dim=0, keepdim=True)
        
        # 리샘플링
        if sample_rate != self.sample_rate:
            resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.sample_rate)
            waveform = resampler(waveform)

#         # 멜 스펙트로그램 변환
#         mel = self.mel_transformer(waveform)

        audio = whisper.pad_or_trim(waveform.flatten()).to(self.device)
        
        mel = whisper.log_mel_spectrogram(audio)

        return mel, filename

In [126]:
dataset = MyAudioDataset(directory='./bigdata/processed_audio', device='cuda:0')
loader = torch.utils.data.DataLoader(dataset, batch_size=16)

In [127]:
dataset[0][0].shape

torch.Size([80, 3000])

In [128]:
dataset[1][0].shape

torch.Size([80, 3000])

In [129]:
model = whisper.load_model("base.en")
print(
    f"Model is {'multilingual' if model.is_multilingual else 'English-only'} "
    f"and has {sum(np.prod(p.shape) for p in model.parameters()):,} parameters."
)

Model is English-only and has 71,825,408 parameters.


In [130]:
# predict without timestamps for short-form transcription
options = whisper.DecodingOptions(language="en", without_timestamps=True)

In [131]:
hypotheses = []
references = []

for mels, texts in tqdm(loader):
    results = model.decode(mels, options)
    hypotheses.extend([result.text for result in results])
    references.extend(texts)

  0%|          | 0/1448 [00:00<?, ?it/s]

In [132]:
data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references))
data


Unnamed: 0,hypothesis,reference
0,"Officer, to come at 6 today with a tight drive...",call_1_1.mp3
1,"59-13, carefully outside unit, and you have th...",call_1_2.mp3
2,Let me film you 980-9930,call_1_3.mp3
3,the column over side drive 15 A-13 or her neck.,call_1_4.mp3
4,"So,",call_1_5.mp3
...,...,...
23151,I'm going to get a right. Get out of the house...,call_99_37.mp3
23152,"I'm going to get out of that house, man. Okay,...",call_99_38.mp3
23153,I'm going to need to get out of there.,call_99_39.mp3
23154,Okay.,call_99_40.mp3


In [133]:
data.to_csv('./bigdata/result/15sec_whisper.csv', index=False)

In [None]:
# import jiwer
# from whisper.normalizers import EnglishTextNormalizer

# normalizer = EnglishTextNormalizer()

In [None]:
# data["hypothesis_clean"] = [normalizer(text) for text in data["hypothesis"]]
# data["reference_clean"] = [normalizer(text) for text in data["reference"]]
# data


In [None]:
# wer = jiwer.wer(list(data["reference_clean"]), list(data["hypothesis_clean"]))

# print(f"WER: {wer * 100:.2f} %")