# Whisperの動作確認

In [37]:
DATASET_DIR = "./data/speakers/aqua"
RAW_DATASET_DIR = f"{DATASET_DIR}/raw"
SPLIT_AUDIO_DIR = f"{DATASET_DIR}/split_audio"
SPLIT_TEXT_DIR = f"{DATASET_DIR}/split_text"

In [2]:
import os
from pathlib import Path
if Path(os.getcwd()).stem == "whisper":
    %cd ../../
!pwd

/workspace
/workspace


In [38]:
import whisper
import math
import numpy as np
import torch
import torchaudio
from dataclasses import dataclass

In [13]:
from src.whisper.stable_whisper import (
    modify_model,
    stabilize_timestamps
)

In [5]:
WHISPER_MODEL = "large"
WHISPER_LANG = "ja"
model = whisper.load_model(WHISPER_MODEL, "cuda")

In [6]:
modify_model(model)

In [7]:
dataset_dir = Path(RAW_DATASET_DIR)
fp_list = list(dataset_dir.glob("*.wav")) + list(dataset_dir.glob("*.mp3"))
fp_list = sorted(fp_list)
print(len(fp_list))

59


In [55]:
@dataclass
class AudioData():
    audio: np.ndarray
    sr: int
    
    def save_audio(self, fp):
        torchaudio.save(str(fp), torch.Tensor(self.audio).unsqueeze(0), self.sr)
    
    @classmethod
    def pick_audio(cls, audio, start_time, end_time):
        start_ind = max(0, math.floor(audio.sr*start_time))
        end_ind = min(len(audio.audio), math.ceil(audio.sr*end_time))
        return cls(
            audio.audio[start_ind:end_ind],
            audio.sr
        )

In [60]:
@dataclass
class SegmentAudio():
    id: int
    audio: AudioData
    text: str
    start: float
    end: float
    stop: False
    
    def save_text(self, fp):
        with open(str(fp), "w") as f:
            f.write(self.text)
            
    @classmethod
    def from_whisper_result(cls, segment_dic):
        """whisperの検出結果から音声を抽出"""
        id = segment_dic["id"]
        start = segment_dic["start"]
        end = segment_dic["end"]
        text = segment_dic["text"]
        #split_audio = audio.pick_audio(start, end)
        stop = False
        stop_word_list = ("。", "?", "!")
        for stop_word in stop_word_list:
            if len(stop_word) < len(text) and stop_word == text[-len(stop_word):]:
                stop = True
        return cls(
            id,
            AudioData(np.array([]), 0),
            text,
            start,
            end,
            stop
        )
    def update_pick_audio(self, audio: AudioData, end_time_room=0.005):
        """whisperの検出結果から音声を抽出"""
        
        self.audio = AudioData.pick_audio(audio, self.start, self.end+end_time_room)
    
    def is_error(self):
        if self.start == self.end:
            return True
    
    def skip(self, skip_time=0.5):
        if self.end - self.start < skip_time:
            return True
        return False
    
    def is_continuos_segment(self, segment, continuos_time=0.3):
        if self.stop:
            return False
        
        if self.end + continuos_time < segment.start:
            return True
        return False
    
    def merge_segment(self, segment):
        if not self.is_continuos_segment(segment):
            raise ValueError(f"連続したセグメントではありません。")
        self.text += segment.text
        self.end = segment.end
        self.stop = segment.stop
        
    def __repr__(self):
        return f"ID:{self.id:05d} {self.start:3f}-{self.end:3f}: {self.text}"

In [62]:
def transcribe(model: whisper, fp, lang="ja"):
    fp = Path(fp)
    if not fp.is_file():
        raise FileExistsError(f"{fp}は、存在しません。")
    audio = whisper.load_audio(str(fp))
    audio = whisper.pad_or_trim(audio)   
    res = model.transcribe(audio, language=lang,  verbose=False)
    segment_result = stabilize_timestamps(res, top_focus=True)
    raw_audio = AudioData(audio, whisper.audio.SAMPLE_RATE)
    segment_list = []
    for seg in segment_result:
        seg = SegmentAudio.from_whisper_result(seg)
        if len(segment_list) == 0:
            segment_list.append(
                seg
            )
            continue
        if segment_list[-1].is_continuos_segment(seg):
            segment_list[-1].merge_segment(seg)
        else:
            segment_list.append(seg)
    output_list = []
    for o in segment_list:
        o.update_pick_audio(raw_audio)
        output_list.append(o)
    return output_list

count = 0
for fp in fp_list[:15]:
    print(fp)
    output_list = transcribe(model, fp, lang=WHISPER_LANG)
    print("#", output_list)
    for o in output_list:
        count += 1
        o.save_text(f"{SPLIT_TEXT_DIR}/{count:010d}.txt")
        o.audio.save_audio(f"{SPLIT_AUDIO_DIR}/{count:010d}.wav")

data/speakers/aqua/raw/10時ですけど.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.960000-2.240000: 十時ですけど…, ID:00001 2.240000-3.280000: ご主人、何してるんですか?]
data/speakers/aqua/raw/11時〜….wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.000000-5.620000: 11時お腹すいたよご主人]
data/speakers/aqua/raw/12時！.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.800000-6.000000: 12時お昼〜ご飯ご飯ご主人, ID:00004 6.000000-7.000000: ご飯まだ〜?]
data/speakers/aqua/raw/13時です！.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.980000-2.980000: 13時です, ID:00001 2.980000-4.980000: 昼からも頑張っていきましょうね]
data/speakers/aqua/raw/14時.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.640000-2.440000: 14時…, ID:00001 2.440000-4.120000: ふぅ…, ID:00002 4.120000-5.060000: なんだかお昼寝したくなってきた…]
data/speakers/aqua/raw/15時です.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.800000-2.480000: 15時です。, ID:00001 2.480000-5.280000: ご主人様、そろそろお茶にしませんか?]
data/speakers/aqua/raw/16時！.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.640000-1.320000: 16時!, ID:00001 1.320000-1.340000: ゲームしてたらいつの間にかこんな時間に]
data/speakers/aqua/raw/17時ですよ.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.680000-2.520000: 17時ですよご主人, ID:00001 2.520000-4.520000: 今日のご飯は何かな?]
data/speakers/aqua/raw/18時です.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.640000-2.160000: 18時です、ご主人!, ID:00001 2.160000-3.300000: 今日、寿司が食べたいです。]
data/speakers/aqua/raw/19時になりました！.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.000000-5.300000: 19時になりました。ご主人、遊んでください!]
data/speakers/aqua/raw/1時ですよ.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 1.040000-3.060000: 一時ですよご主人, ID:00001 3.060000-3.080000: まだまだ夜はこれからですね]
data/speakers/aqua/raw/20時！.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.400000-2.400000: 二十時!, ID:00001 2.400000-2.900000: ご主人、今何考えてました?]
data/speakers/aqua/raw/21時.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.880000-4.880000: 21時もうすぐ1日が終わっちゃう, ID:00002 4.880000-5.880000: やることちゃんと終わった?]
data/speakers/aqua/raw/22時です。.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.560000-5.560000: 22時ですご主人様眠れないなら, ID:00003 5.560000-6.560000: それでも, ID:00004 6.560000-7.960000: 眠れして差し上げますよ]
data/speakers/aqua/raw/23時だよ！.wav


Use audio_for_mask for transcribe() to provide the original audio track as the path or bytes of the audio file.
  wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))


# [ID:00000 0.560000-2.480000: 23時だよ!, ID:00001 2.480000-4.320000: ご主人ゲームしよ!]
