It's better to have Whisper in a separate environment from ESPnet. Also, we use `faster-whisper` as it is not only faster but more stable than the original.
```shell
conda create -n whisper python=3.10
conda activate whisper
pip install faster-whisper
```
before running this.

You also need to set the env variable LD_LIBRARY_PATH:

```shell
conda activate whisper
env_loc=$(conda env list | grep '*' | awk '{print $3}' | tr -d '\n')
activator="${env_loc}/etc/conda/activate.d/cuda_activate.sh"
echo export LD_LIBRARY_PATH=`python3 -c 'import os; import nvidia.cublas.lib; import nvidia.cudnn.lib; import torch; print(os.path.dirname(nvidia.cublas.lib.__file__) + ":" + os.path.dirname(nvidia.cudnn.lib.__file__) + ":" + os.path.dirname(torch.__file__) +"/lib")'` > "${activator}"
chmod +x "${activator}"
conda deactivate
conda activate whisper
```

In [1]:
from pathlib import Path
import os
import faster_whisper
from tqdm import tqdm
import jiwer
import re

PWD = %pwd
PWD = Path(PWD)
prosody_dir = PWD.parent
outputs_dir = PWD / 'outputs'
os.makedirs(outputs_dir, exist_ok=True)
asr_dir = outputs_dir / 'CSS10' / 'german'
os.makedirs(asr_dir, exist_ok=True)
jets_dir = prosody_dir / 'outputs' / 'tts_train_jets_raw_phn_tacotron_g2p_en_no_space/CSS10/german'
data_dir = (prosody_dir / '../../datasets/CSS10/german/').resolve()

In [2]:
faster_whisper.available_models()

['tiny.en',
 'tiny',
 'base.en',
 'base',
 'small.en',
 'small',
 'medium.en',
 'medium',
 'large-v1',
 'large-v2',
 'large-v3',
 'large']

In [2]:
import os
ld_lib_path = os.environ['LD_LIBRARY_PATH']
assert 'cublas' in ld_lib_path and 'cudnn' in ld_lib_path

In [4]:
model = faster_whisper.WhisperModel("large-v2", device='cuda', compute_type='float16')
asr_dir = asr_dir / 'large-v2'
os.makedirs(asr_dir, exist_ok=True)

In [25]:
german_letters = set()
with open(data_dir / 'transcript.txt') as f:
    for line in f:
        german_letters |= set(line.split('|')[2].lower())

In [26]:
''.join(sorted(german_letters))

" !',-.:;?abcdefghijklmnopqrstuvwxyzßàäéöü–"

In [6]:
def normalize_german(text):
    text = text.lower()
    text = re.sub(r"[!',-.:;?–]", ' ', text)
    text = text.replace('é', 'e')
    text = text.replace('à', 'a')
    text = re.sub(r'[^abcdefghijklmnopqrstuvwxyzßäöü ]', '', text)
    text = ' '.join(text.strip().split())
    # text = text.replace('ß', 'ss')
    # normalize umlauts?
    return text

In [2]:
transcript_file = data_dir / 'transcript_normalized.txt'
if not transcript_file.exists():
    with open(data_dir / 'transcript.txt') as f:
        with open(transcript_file, 'w') as norm_f:
            for line in f:
                filename, _, transcript, _ = line.split('|')
                transcript = normalize_german(transcript)
                norm_f.write(f'{filename}|{transcript}\n')


def get_transcripts():
    transcripts = {}
    with open(transcript_file) as f:
        for line in f:
            filename, transcript = line.strip().split('|', maxsplit=1)
            transcripts[filename] = transcript
    return transcripts

transcripts = get_transcripts()

In [8]:
filenames = list(transcripts.keys())

In [9]:
whisper_kwargs = {
    # 'suppress_tokens': suppress_tokens,
    # 'temperature': 0.0,
    # 'condition_on_previous_text': False,
    'prepend_punctuations': '',
    'append_punctuations': '',
}

In [20]:
from num2words import num2words
import numpy as np
def whisper_transcribe(filepath, kwargs=whisper_kwargs):
    segments, _ = model.transcribe(filepath, language='de', **kwargs)
    text = ' '.join(segment.text for segment in segments)
    pads = 1
    while not text:  # whisper sometimes randomly fails to produce anything
        audio = faster_whisper.audio.decode_audio(filepath)
        audio_pad = np.pad(audio, (pads * 100, 0), mode='constant', constant_values=0)
        segments, _ = model.transcribe(audio_pad, language='de', **kwargs)
        text = ' '.join(segment.text for segment in segments)
        pads += 1
        if pads == 10:
            break
    splits = re.split(r'(\d+)', text)
    for i in range(len(splits)):
        integer = splits[i] 
        if re.fullmatch(r'\d+', integer):
            word = num2words(integer, lang='de')
            if re.fullmatch(r'100+', integer) and word.startswith('ein'):
                word = word[3:]  # remove ein
            splits[i] = word 
    text = ''.join(splits)
    return text

In [21]:
whisper_transcribe(data_dir / filenames[176])

' Aber die hundert Segel, die jetzt von Jabarze kommen, zeigen in den Segelfalten keine Schriftzeichen mehr.'

In [22]:
def run_asr(filenames, audio_dir, asr_result_path):
    with open(asr_result_path, 'w') as f:
        for filename in tqdm(filenames):
            wav_path = audio_dir / filename
            text = whisper_transcribe(wav_path)
            text = normalize_german(text)
            f.write(f'{filename}|{text}\n')

In [13]:
import logging
logging.basicConfig(level=logging.WARNING)

In [5]:
gt_dir = data_dir
gt_asr_path = asr_dir / 'gt_result.txt'

jets_asr_path = asr_dir / 'jets_result.txt'

In [24]:
run_asr(filenames, gt_dir, gt_asr_path)

100%|██████████| 7427/7427 [1:07:54<00:00,  1.82it/s]


In [25]:
run_asr(filenames, jets_dir, jets_asr_path)

100%|██████████| 7427/7427 [1:07:34<00:00,  1.83it/s]


In [7]:
def eval_cer(transcripts, asr_result_path, cer_path):
    with open(cer_path, 'w') as cer_file:
        cer_file.write('wav_file,gt_len,cer\n')
        with open(asr_result_path) as f:
            for line in f:
                wav_file, asr_output = line.strip('\n').split('|', maxsplit=1)
                transcript = transcripts[wav_file]
                transcript_nospace = transcript.replace(' ', '')
                asr_nospace = asr_output.replace(' ', '')
                gt_len = len(transcript)
                cer = jiwer.cer(truth=transcript_nospace, hypothesis=asr_nospace)
                cer_file.write(f'{wav_file},{gt_len},{cer}\n')

In [27]:
gt_cer_path = asr_dir / 'gt_cer.csv'
eval_cer(transcripts=transcripts, asr_result_path=gt_asr_path, cer_path=gt_cer_path)

In [8]:
jets_cer_path = asr_dir / 'jets_cer.csv'
eval_cer(transcripts=transcripts, asr_result_path=jets_asr_path, cer_path=jets_cer_path)