In [1]:
import IPython.display
from pathlib import Path

import os
import numpy as np

import torch
from torch import nn
import pandas as pd
import whisper
import torchaudio
import torchaudio.transforms as at

from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

from tqdm.notebook import tqdm
import evaluate

from transformers import (
    AdamW,
    get_linear_schedule_with_warmup
)

In [None]:
import gdown
gdown.download('https://drive.google.com/u/0/uc?id=19oAw8wWn3Y7z6CKChRdAyGOB9yupL_Xt', 'jvs.zip', quiet=False)
!unzip jvs.zip -d ./jvs

In [2]:
DATASET_DIR = "jvs/jvs_ver1"
SAMPLE_RATE = 16000
BATCH_SIZE = 2
TRAIN_RATE = 0.8

AUDIO_MAX_LENGTH = 480000
TEXT_MAX_LENGTH = 120
SEED = 3407
DEVICE = "gpu" if torch.cuda.is_available() else "cpu"
seed_everything(SEED, workers=True)

Global seed set to 3407


3407

In [3]:
def load_wave(wave_path, sample_rate:int=16000) -> torch.Tensor:
    waveform, sr = torchaudio.load(wave_path, normalize=True)
    if sample_rate != sr:
        waveform = at.Resample(sr, sample_rate)(waveform)
    return waveform

In [4]:
dataset_dir = Path(DATASET_DIR)
transcripts_path_list = list(dataset_dir.glob("*/*/transcripts_utf8.txt"))
print(len(transcripts_path_list))

400


In [6]:
def get_audio_file_list(transcripts_path_list, text_max_length=120, audio_max_sample_length=480000, sample_rate=16000):
    audio_transcript_pair_list = []
    for transcripts_path in tqdm(transcripts_path_list):
        # audioファイルのディレクトリ確認
        audio_dir = transcripts_path.parent / "wav24kHz16bit"
        if not audio_dir.exists():
            print(f"{audio_dir}は存在しません。")
            continue

        # 翻訳テキストからAudioIdとテキストを取得
        with open(transcripts_path, "r") as f:
            text_list = f.readlines()
        for text in text_list:
            audio_id, text = text.replace("\n", "").split(":")
            #print(audio_id, text)

            audio_path = audio_dir / f"{audio_id}.wav"
            if audio_path.exists():
                # データのチェック
                audio = load_wave(audio_path, sample_rate=sample_rate)[0]
                if len(text) > text_max_length or len(audio) > audio_max_sample_length:
                    print(len(text), len(audio))
                    continue
                audio_transcript_pair_list.append((audio_id, str(audio_path), text))
    return audio_transcript_pair_list

train_num = int(len(transcripts_path_list) * TRAIN_RATE)
train_transcripts_path_list, eval_transcripts_path_list = transcripts_path_list[:train_num], transcripts_path_list[train_num:]
train_audio_transcript_pair_list = get_audio_file_list(train_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
eval_audio_transcript_pair_list = get_audio_file_list(eval_transcripts_path_list, TEXT_MAX_LENGTH, AUDIO_MAX_LENGTH, SAMPLE_RATE)
print("TRAIN AUDIO DATASET NUM: ", len(train_audio_transcript_pair_list))
print("EVAL AUDIO DATASET NUM: ", len(eval_audio_transcript_pair_list))

  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

TRAIN AUDIO DATASET NUM:  11982
EVAL AUDIO DATASET NUM:  3000


In [7]:
train_audio_transcript_pair_list

[('VOICEACTRESS100_001',
  'jvs/jvs_ver1/jvs094/whisper10/wav24kHz16bit/VOICEACTRESS100_001.wav',
  'また、東寺のように、五大明王と呼ばれる、主要な明王の中央に配されることも多い。'),
 ('VOICEACTRESS100_002',
  'jvs/jvs_ver1/jvs094/whisper10/wav24kHz16bit/VOICEACTRESS100_002.wav',
  'ニューイングランド風は、牛乳をベースとした、白いクリームスープであり、ボストンクラムチャウダーとも呼ばれる。'),
 ('VOICEACTRESS100_003',
  'jvs/jvs_ver1/jvs094/whisper10/wav24kHz16bit/VOICEACTRESS100_003.wav',
  'コンピュータゲームのメーカーや、業界団体などに関連する人物のカテゴリ。'),
 ('VOICEACTRESS100_004',
  'jvs/jvs_ver1/jvs094/whisper10/wav24kHz16bit/VOICEACTRESS100_004.wav',
  'サービスマネージャー導入駅のため、大井町駅から、遠隔管理している。'),
 ('VOICEACTRESS100_005',
  'jvs/jvs_ver1/jvs094/whisper10/wav24kHz16bit/VOICEACTRESS100_005.wav',
  'シルバーサーファー襲撃事件までに、リチャーズは、チーム名と共に、国際的にスーパーヒーロー、および、有名人として、認知されている。'),
 ('BASIC5000_0923',
  'jvs/jvs_ver1/jvs094/whisper10/wav24kHz16bit/BASIC5000_0923.wav',
  'ここから、ディズニーランド行きのバスに乗れると思います。'),
 ('BASIC5000_0090',
  'jvs/jvs_ver1/jvs094/whisper10/wav24kHz16bit/BASIC5000_0090.wav',
  'ガリレオは、二つの鉄の玉を、塔のてっぺんから落とした。'),
 ('UT

In [2]:
!pip3 install --upgrade setuptools

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting setuptools
  Downloading setuptools-68.1.2-py3-none-any.whl (805 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m805.1/805.1 KB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: setuptools
  Attempting uninstall: setuptools
    Found existing installation: setuptools 59.6.0
    Uninstalling setuptools-59.6.0:
      Successfully uninstalled setuptools-59.6.0
Successfully installed setuptools-68.1.2


In [None]:
def text_kana_convert(text):
    text = pyopenjtalk.g2p(text, kana=True)
    return text
print(text_kana_convert("こんにちは、私の名前は、田中一郎です。"))

In [5]:
woptions = whisper.DecodingOptions(language="ja", without_timestamps=True)
wmodel = whisper.load_model("base")
wtokenizer = whisper.tokenizer.get_tokenizer(True, language="ja", task=woptions.task)

In [None]:
class JvsSpeechDataset(torch.utils.data.Dataset):
    def __init__(self, audio_info_list, tokenizer, sample_rate) -> None:
        super().__init__()

        self.audio_info_list = audio_info_list
        self.sample_rate = sample_rate
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.audio_info_list)
    
    def __getitem__(self, id):
        audio_id, audio_path, text = self.audio_info_list[id]

        # audio
        audio = load_wave(audio_path, sample_rate=self.sample_rate)
        audio = whisper.pad_or_trim(audio.flatten())
        mel = whisper.log_mel_spectrogram(audio)

        text = text_kana_convert(text)
        text = [*self.tokenizer.sot_sequence_including_notimestamps] + self.tokenizer.encode(text)
        labels = text[1:] + [self.tokenizer.eot]

        return {
            "input_ids": mel,
            "labels": labels,
            "dec_input_ids": text
        }

In [None]:
class WhisperDataCollatorWhithPadding:
    def __call__(sefl, features):
        input_ids, labels, dec_input_ids = [], [], []
        for f in features:
            input_ids.append(f["input_ids"])
            labels.append(f["labels"])
            dec_input_ids.append(f["dec_input_ids"])

        input_ids = torch.concat([input_id[None, :] for input_id in input_ids])
        
        label_lengths = [len(lab) for lab in labels]
        dec_input_ids_length = [len(e) for e in dec_input_ids]
        max_label_len = max(label_lengths+dec_input_ids_length)

        labels = [np.pad(lab, (0, max_label_len - lab_len), 'constant', constant_values=-100) for lab, lab_len in zip(labels, label_lengths)]
        dec_input_ids = [np.pad(e, (0, max_label_len - e_len), 'constant', constant_values=50257) for e, e_len in zip(dec_input_ids, dec_input_ids_length)] # 50257 is eot token id

        batch = {
            "labels": labels,
            "dec_input_ids": dec_input_ids
        }

        batch = {k: torch.tensor(np.array(v), requires_grad=False) for k, v in batch.items()}
        batch["input_ids"] = input_ids

        return batch

In [None]:
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

In [None]:
dataset = JvsSpeechDataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

In [None]:
for b in loader:
    print(b["labels"].shape)
    print(b["input_ids"].shape)
    print(b["dec_input_ids"].shape)

    for token, dec in zip(b["labels"], b["dec_input_ids"]):
        token[token == -100] = wtokenizer.eot
        text = wtokenizer.decode(token, skip_special_tokens=False)
        print(text)

        dec[dec == -100] = wtokenizer.eot
        text = wtokenizer.decode(dec, skip_special_tokens=False)
        print(text)


    break

In [None]:
with torch.no_grad():
    audio_features = wmodel.encoder(b["input_ids"].cuda())
    input_ids = b["input_ids"]
    labels = b["labels"].long()
    dec_input_ids = b["dec_input_ids"].long()

        
    audio_features = wmodel.encoder(input_ids.cuda())
    print(dec_input_ids)
    print(input_ids.shape, dec_input_ids.shape, audio_features.shape)
    print(audio_features.shape)
    print()
out = wmodel.decoder(dec_input_ids.cuda(), audio_features)

In [None]:
print(out.shape)
print(out.view(-1, out.size(-1)).shape)
print(b["labels"].view(-1).shape)

In [None]:
tokens = torch.argmax(out, dim=2)
for token in tokens:
    token[token == -100] = wtokenizer.eot
    text = wtokenizer.decode(token, skip_special_tokens=True)
    print(text)

In [None]:
class Config:
    learning_rate = 0.0005
    weight_decay = 0.01
    adam_epsilon = 1e-8
    warmup_steps = 2
    batch_size = 16
    num_worker = 2
    num_train_epochs = 10
    gradient_accumulation_steps = 1
    sample_rate = SAMPLE_RATE

In [None]:
class WhisperModelModule(LightningModule):
    def __init__(self, cfg:Config, model_name="base", lang="ja", train_dataset=[], eval_dataset=[]) -> None:
        super().__init__()
        self.options = whisper.DecodingOptions(language=lang, without_timestamps=True)
        self.model = whisper.load_model(model_name)
        self.tokenizer = whisper.tokenizer.get_tokenizer(True, language="ja", task=self.options.task)

        # only decoder training
        for p in self.model.encoder.parameters():
            p.requires_grad = False
        
        self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
        self.metrics_wer = evaluate.load("wer")
        self.metrics_cer = evaluate.load("cer")

        self.cfg = cfg
        self.__train_dataset = train_dataset
        self.__eval_dataset = eval_dataset
    
    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_id):
        ###Mel spectogram
        input_ids = batch["input_ids"]
        
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()

        with torch.no_grad():
            audio_features = self.model.encoder(input_ids)

        out = self.model.decoder(dec_input_ids, audio_features)
        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))
        self.log("train/loss", loss, on_step=True, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_id):
        input_ids = batch["input_ids"]
        labels = batch["labels"].long()
        dec_input_ids = batch["dec_input_ids"].long()


        audio_features = self.model.encoder(input_ids)
        out = self.model.decoder(dec_input_ids, audio_features)

        loss = self.loss_fn(out.view(-1, out.size(-1)), labels.view(-1))

        out[out == -100] = self.tokenizer.eot
        labels[labels == -100] = self.tokenizer.eot

        o_list, l_list = [], []
        for o, l in zip(out, labels):
            o = torch.argmax(o, dim=1)
            o_list.append(self.tokenizer.decode(o, skip_special_tokens=True))
            l_list.append(self.tokenizer.decode(l, skip_special_tokens=True))
        cer = self.metrics_cer.compute(references=l_list, predictions=o_list)
        wer = self.metrics_wer.compute(references=l_list, predictions=o_list)

        self.log("val/loss", loss, on_step=True, prog_bar=True, logger=True)
        self.log("val/cer", cer, on_step=True, prog_bar=True, logger=True)
        self.log("val/wer", wer, on_step=True, prog_bar=True, logger=True)

        return {
            "cer": cer,
            "wer": wer,
            "loss": loss
        }

    def configure_optimizers(self):
        """オプティマイザーとスケジューラーを作成する"""
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() 
                            if not any(nd in n for nd in no_decay)],
                "weight_decay": self.cfg.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() 
                            if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, 
                          lr=self.cfg.learning_rate, 
                          eps=self.cfg.adam_epsilon)
        self.optimizer = optimizer

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.cfg.warmup_steps, 
            num_training_steps=self.t_total
        )
        self.scheduler = scheduler

        return [optimizer], [{"scheduler": scheduler, "interval": "step", "frequency": 1}]
    
    def setup(self, stage=None):
        """初期設定（データセットの読み込み）"""

        if stage == 'fit' or stage is None:
            self.t_total = (
                (len(self.__train_dataset) // (self.cfg.batch_size))
                // self.cfg.gradient_accumulation_steps
                * float(self.cfg.num_train_epochs)
            )
    
    def train_dataloader(self):
        """訓練データローダーを作成する"""
        dataset = JvsSpeechDataset(self.__train_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset, 
                          batch_size=self.cfg.batch_size, 
                          drop_last=True, shuffle=True, num_workers=self.cfg.num_worker,
                          collate_fn=WhisperDataCollatorWhithPadding()
                          )

    def val_dataloader(self):
        """バリデーションデータローダーを作成する"""
        dataset = JvsSpeechDataset(self.__eval_dataset, self.tokenizer, self.cfg.sample_rate)
        return torch.utils.data.DataLoader(dataset, 
                          batch_size=self.cfg.batch_size, 
                          num_workers=self.cfg.num_worker,
                          collate_fn=WhisperDataCollatorWhithPadding()
                          )
    
       

In [None]:
state_dict = torch.load(checkpoint_path)
print(state_dict.keys())
state_dict = state_dict['state_dict']

In [None]:
woptions = whisper.DecodingOptions(language="ja", without_timestamps=True)
dataset = JvsSpeechDataset(eval_audio_transcript_pair_list, wtokenizer, SAMPLE_RATE)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, collate_fn=WhisperDataCollatorWhithPadding())

refs = []
res = []
for b in tqdm(loader):
    input_ids = b["input_ids"].half().cuda()
    labels = b["labels"].long().cuda()
    with torch.no_grad():
        #audio_features = whisper_model.model.encoder(input_ids)
        #out = whisper_model.model.decoder(enc_input_ids, audio_features)
        results = whisper_model.model.decode(input_ids, woptions)
        for r in results:
            res.append(r.text)
        
        for l in labels:
            l[l == -100] = wtokenizer.eot
            ref = wtokenizer.decode(l, skip_special_tokens=True)
            refs.append(ref)

In [None]:
cer_metrics = evaluate.load("cer")
cer_metrics.compute(references=refs, predictions=res)

In [None]:
for k, v in zip(refs, res):
    print("-"*10)
    print(k)
    print(v)