In [None]:
# %load_ext autoreload
# %autoreload 2

In [None]:
import os
from collections import defaultdict

import pandas as pd
import torch
import torchaudio as ta
from data_utils import SpeechDataset, filter_yt_df, yt_data_to_df
from eval_utils import _normalize_text, _wer
from IPython.display import Audio
from tqdm import tqdm

In [None]:
target_sr = 16000

In [None]:
video_dir = r'F:/BIG_FILES/AI_DATA/2024_STT'
if not os.path.exists(video_dir):
    print('Video directory not found')
    raise FileNotFoundError(video_dir)
    
cache_dir = './cache'
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)

In [None]:
yt_df, video_df, segment_df = yt_data_to_df(video_dir, do_load_transcripts=True)
display(yt_df.head(3), video_df.head(3), segment_df.head(3))
print(f'Number of videos: {len(video_df)}')
print(f'Number of segments: {len(segment_df)}')

In [None]:
print(f"Original number of segments: {len(yt_df)}")
print(f'Number of german segments: {len(yt_df[yt_df.language == "de"])}')
print(f'Number of english segments: {len(yt_df[yt_df.language == "en"])}')
print(f"Number of auto generated segments: {len(yt_df[yt_df.is_generated == True])}")
print(f"Number of manual segments: {len(yt_df[yt_df.is_generated == False])}")
print(f"Total duration: {yt_df.segment_duration.sum() / 3600:.2f}h")
print(f"Total duration (manual): {yt_df[yt_df.is_generated == False].segment_duration.sum() / 3600:.2f}h")
print(f"Valid audio segments: {yt_df[yt_df.valid_audio].count().segment_id} / {len(yt_df)}")
min_segment_length = None
max_segment_length = 30
target_language = "de"
use_auto_generated = False
min_words = None
max_words = None
drop_columns = ["language", "is_generated", "num_segments", "segment_durations", "segment_id", "valid_audio"]
yt_df_filtered = filter_yt_df(
    yt_df,
    min_segment_length=min_segment_length,
    max_segment_length=max_segment_length,
    language=target_language,
    use_auto_generated=use_auto_generated,
    min_words=min_words,
    max_words=max_words,
)
yt_df_filtered = yt_df_filtered[yt_df_filtered.valid_audio].reset_index(drop=True)
print(f"Filtered number of segments: {len(yt_df_filtered)}")
print(f"Total duration: {yt_df_filtered.segment_duration.sum() / 3600:.2f}h")
yt_df_filtered = yt_df_filtered.drop(columns=drop_columns)
yt_df_filtered.head(3)

## EDA

In [None]:
# histogram of segment lengths
yt_df_filtered.hist(column='segment_duration', bins=30, figsize=(10, 5), grid=False, color='#86bf91', zorder=2, rwidth=0.9)

In [None]:
# histogram of number of words
yt_df_filtered.hist(column='num_words', bins=30, figsize=(10, 5), grid=False, color='#fe3e12', zorder=2, rwidth=0.9)

In [None]:
# load a random audio file
sample_segment = yt_df_filtered.sample(1).iloc[0]
audio_file_path = sample_segment['segment_path']
transcript = sample_segment['transcript']
wave, sr = ta.load(audio_file_path)
print(f'Loaded audio file: {audio_file_path}')
print(f'Wave shape: {wave.shape}')
print(f'Sample rate: {sr}')
display(Audio(wave.numpy(), rate=sr))

res_wave = ta.transforms.Resample(sr, target_sr)(wave)
print(f'Wave shape after resampling: {res_wave.shape}')
display(Audio(res_wave.numpy(), rate=target_sr))

# show the transcript but only x words per line
words = transcript.split()
words_per_line = 15
for i in range(0, len(words), words_per_line):
    print(' '.join(words[i:i+words_per_line]))

# Model evaluation

In [None]:
def show_output(wave, sr, gt, decoded_output):
    display(Audio(wave.numpy(), rate=sr))
    print(f"Ground Truth: {gt}")
    print(f"Predicted:    {decoded_output.strip()}")
    print(f"---")
    print(f"Normalized Ground Truth: {_normalize_text(gt)}")
    print(f"Normalized Predicted:    {_normalize_text(decoded_output)}")
    print(f'WER: {_wer(gt, decoded_output, _normalize_text)}')

In [None]:
def custom_collate_fn(x):
    return x

## Whisper

In [None]:
from transformers import (AutoProcessor, WhisperConfig,
                          WhisperForConditionalGeneration)

MODEL_ID = "openai/whisper-small"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

In [None]:
whisper_config = WhisperConfig.from_pretrained(MODEL_ID, cache_dir=cache_dir)
whisper_model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, config=whisper_config, cache_dir=cache_dir, torch_dtype=torch_dtype)
whisper_model.eval().to('cuda')
whisper_processor = AutoProcessor.from_pretrained(MODEL_ID, cache_dir=cache_dir)

In [None]:
processor_args = {
    "return_tensors": "pt",
    "sampling_rate": target_sr,
}

whisper_dataset = SpeechDataset(yt_df_filtered, whisper_processor, processor_args, target_sr)
whisper_sample = whisper_dataset[0]
whisper_loader = torch.utils.data.DataLoader(whisper_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

In [None]:
def whisper_inferece(model, batch, processor):
    input_features = [b["input_features"] for b in batch]
    input_features = torch.stack(input_features).squeeze(1).to('cuda').to(torch_dtype)
    with torch.no_grad():
        output = model.generate(input_features, language="de")
        decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
        del output, input_features
        return decoded_outputs

In [None]:
whisper_results = defaultdict(list)
for batch in tqdm(whisper_loader):
    decoded_outputs = whisper_inferece(whisper_model, batch, whisper_processor)
    for i, decoded_output in enumerate(decoded_outputs):
        whisper_results['decoded_output'].append(decoded_output.strip())
        whisper_results['normalized_decoded_output'].append(_normalize_text(decoded_output))
        whisper_results['gt'].append(batch[i]['transcript'].strip())
        whisper_results['normalized_gt'].append(_normalize_text(batch[i]['transcript']))
        whisper_results['audio_path'].append(batch[i]['audio_path'])
    del decoded_outputs
whisper_results_df = pd.DataFrame(whisper_results)
# filter all rows where the gt is empty
whisper_results_df['valid'] = whisper_results_df['normalized_gt'].apply(lambda x: len(x) > 0)
whisper_results_df = whisper_results_df[whisper_results_df['valid']].reset_index(drop=True)
whisper_results_df['wer'] = whisper_results_df.apply(lambda x: _wer(x['gt'], x['decoded_output'], _normalize_text), axis=1)
whisper_results_df.head()

In [None]:
# save the results and yt_df_filtered
whisper_results_df.to_csv('whisper-small_results.csv', index=False)

In [None]:
whisper_wer = whisper_results_df['wer'].mean()*100
print(f'WHISPER WER: {whisper_wer:.2f}%')

# whisper_results_df["wer"].describe()
# print mean wer, median wer, std wer, min wer, max wer, 25%, 50%, 75% percentile
print(
    f"{MODEL_ID}: WER: {whisper_results_df['wer'].mean():.2f}, Median: {whisper_results_df['wer'].median():.2f}, Std: {whisper_results_df['wer'].std():.2f}, Min: {whisper_results_df['wer'].min():.2f}, Max: {whisper_results_df['wer'].max():.2f}, 25%: {whisper_results_df['wer'].quantile(0.25):.2f}, 50%: {whisper_results_df['wer'].quantile(0.50):.2f}, 75%: {whisper_results_df['wer'].quantile(0.75):.2f}"
)

# show example with the highest WER
worst_wer_idx = whisper_results_df['wer'].idxmax()
worst_wer_row = whisper_results_df.loc[worst_wer_idx]
worst_wer_wave, sr = ta.load(worst_wer_row['audio_path'])
show_output(worst_wer_wave, sr, worst_wer_row['gt'], worst_wer_row['decoded_output'])

## Whisper-Large-Dist

In [None]:
from transformers import (AutoProcessor, WhisperConfig,
                          WhisperForConditionalGeneration)

MODEL_ID = "primeline/distil-whisper-large-v3-german"
torch_dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
whisper_dist_config = WhisperConfig.from_pretrained(MODEL_ID, cache_dir=cache_dir)
whisper_dist_model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID, config=whisper_dist_config, cache_dir=cache_dir, torch_dtype=torch_dtype)
whisper_dist_model.eval().to('cuda')
whisper_dist_processor = AutoProcessor.from_pretrained(MODEL_ID, cache_dir=cache_dir)

In [None]:
processor_args = {
    "return_tensors": "pt",
    "sampling_rate": target_sr,
}

whisper_dist_dataset = SpeechDataset(yt_df_filtered, whisper_dist_processor, processor_args, target_sr)
whisper_dist_sample = whisper_dist_dataset[0]
whisper_dist_loader = torch.utils.data.DataLoader(whisper_dist_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

In [None]:
def whisper_dist_inferece(model, batch, processor):
    input_features = [b["input_features"] for b in batch]
    input_features = torch.stack(input_features).squeeze(1).to('cuda').to(torch_dtype)
    with torch.no_grad():
        output = model.generate(input_features, language="de")
        decoded_outputs = processor.batch_decode(output, skip_special_tokens=True)
        del output, input_features
        return decoded_outputs

In [None]:
whisper_dist_results = defaultdict(list)
for batch in tqdm(whisper_dist_loader):
    decoded_outputs = whisper_dist_inferece(whisper_dist_model, batch, whisper_dist_processor)
    for i, decoded_output in enumerate(decoded_outputs):
        whisper_dist_results['decoded_output'].append(decoded_output.strip())
        whisper_dist_results['normalized_decoded_output'].append(_normalize_text(decoded_output))
        whisper_dist_results['gt'].append(batch[i]['transcript'].strip())
        whisper_dist_results['normalized_gt'].append(_normalize_text(batch[i]['transcript']))
        whisper_dist_results['audio_path'].append(batch[i]['audio_path'])
    del decoded_outputs
    break
whisper_dist_results_df = pd.DataFrame(whisper_dist_results)
# filter all rows where the gt is empty
whisper_dist_results_df['valid'] = whisper_dist_results_df['normalized_gt'].apply(lambda x: len(x) > 0)
whisper_dist_results_df = whisper_dist_results_df[whisper_dist_results_df['valid']].reset_index(drop=True)
whisper_dist_results_df['wer'] = whisper_dist_results_df.apply(lambda x: _wer(x['gt'], x['decoded_output'], _normalize_text), axis=1)
whisper_dist_results_df.head()

In [None]:
whisper_dist_results_df.to_csv('distil-large-v3_results.csv', index=False)

In [None]:
whisper_dist_wer = whisper_dist_results_df['wer'].mean()*100
print(f'WHISPER WER: {whisper_dist_wer:.2f}%')

# whisper_dist_results_df["wer"].describe()
# print mean wer, median wer, std wer, min wer, max wer, 25%, 50%, 75% percentile
print(
    f"{MODEL_ID}: WER: {whisper_dist_results_df['wer'].mean():.2f}, Median: {whisper_dist_results_df['wer'].median():.2f}, Std: {whisper_dist_results_df['wer'].std():.2f}, Min: {whisper_dist_results_df['wer'].min():.2f}, Max: {whisper_dist_results_df['wer'].max():.2f}, 25%: {whisper_dist_results_df['wer'].quantile(0.25):.2f}, 50%: {whisper_dist_results_df['wer'].quantile(0.50):.2f}, 75%: {whisper_dist_results_df['wer'].quantile(0.75):.2f}"
)

# show example with the highest WER
worst_wer_idx = whisper_dist_results_df['wer'].idxmax()
worst_wer_row = whisper_dist_results_df.loc[worst_wer_idx]
worst_wer_wave, sr = ta.load(worst_wer_row['audio_path'])
show_output(worst_wer_wave, sr, worst_wer_row['gt'], worst_wer_row['decoded_output'])

In [None]:
# show sample output
sample_idx = 14
sample_row = whisper_dist_results_df.loc[sample_idx]
sample_wave, sr = ta.load(sample_row['audio_path'])
show_output(sample_wave, sr, sample_row['gt'], sample_row['decoded_output'])

## Wav2Vec2ForCTC

In [None]:
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

MODEL_ID = "jonatasgrosman/wav2vec2-large-xlsr-53-german"

wav2vec_processor = Wav2Vec2Processor.from_pretrained(MODEL_ID, cache_dir=cache_dir)
wav2vec_model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID, cache_dir=cache_dir)
_= wav2vec_model.eval().to('cuda')

In [None]:
processor_args = {
    "return_tensors": "pt",
    # "padding":True,
    "sampling_rate": target_sr,
}

wav2vec_dataset = SpeechDataset(yt_df_filtered, wav2vec_processor, processor_args, target_sr)
wav2vec_sample = wav2vec_dataset[0]
wav2vec_loader = torch.utils.data.DataLoader(wav2vec_dataset, batch_size=16, shuffle=False, collate_fn=custom_collate_fn)

In [None]:
def wav2vec_inferece(model, batch, processor):
    input_features = [b["input_features"] for b in batch]
    
    # pad the input features
    max_input_length = max([len(x[0]) for x in input_features])
    input_features = [torch.nn.functional.pad(x[0], (0, max_input_length - x[0].shape[-1])) for x in input_features]
    input_features = torch.stack(input_features).to('cuda')
    with torch.no_grad():
        output = model(input_features).logits
        predicted_ids = torch.argmax(output, dim=-1)
        decoded_outputs = processor.batch_decode(predicted_ids, skip_special_tokens=True)
        del output, input_features
        return decoded_outputs

In [None]:
wav2vec_results = defaultdict(list)
for batch in tqdm(wav2vec_loader):
    decoded_outputs = wav2vec_inferece(wav2vec_model, batch, wav2vec_processor)
    for i, decoded_output in enumerate(decoded_outputs):
        wav2vec_results['decoded_output'].append(decoded_output.strip())
        wav2vec_results['normalized_decoded_output'].append(_normalize_text(decoded_output))
        wav2vec_results['gt'].append(batch[i]['transcript'].strip())
        wav2vec_results['normalized_gt'].append(_normalize_text(batch[i]['transcript']))
        wav2vec_results['audio_path'].append(batch[i]['audio_path'])
    del decoded_outputs
wav2vec_results_df = pd.DataFrame(wav2vec_results)
# filter all rows where the gt is empty
wav2vec_results_df['valid'] = wav2vec_results_df['normalized_gt'].apply(lambda x: len(x) > 0)
wav2vec_results_df = wav2vec_results_df[wav2vec_results_df['valid']].reset_index(drop=True)
wav2vec_results_df['wer'] = wav2vec_results_df.apply(lambda x: _wer(x['gt'], x['decoded_output'], _normalize_text), axis=1)
wav2vec_results_df.head()

In [None]:
wav2vec_results_df.to_csv('wav2vec_results.csv', index=False)

In [None]:
wav2vec_wer = wav2vec_results_df['wer'].mean()*100
print(f'WAV2VEC WER: {wav2vec_wer:.2f}%')

# wav2vec_results_df["wer"].describe()
# print mean wer, median wer, std wer, min wer, max wer, 25%, 50%, 75% percentile
print(
    f"{MODEL_ID}: WER: {wav2vec_results_df['wer'].mean():.2f}, Median: {wav2vec_results_df['wer'].median():.2f}, Std: {wav2vec_results_df['wer'].std():.2f}, Min: {wav2vec_results_df['wer'].min():.2f}, Max: {wav2vec_results_df['wer'].max():.2f}, 25%: {wav2vec_results_df['wer'].quantile(0.25):.2f}, 50%: {wav2vec_results_df['wer'].quantile(0.50):.2f}, 75%: {wav2vec_results_df['wer'].quantile(0.75):.2f}"
)

# show example with the highest WER
worst_wer_idx = wav2vec_results_df['wer'].idxmax()
worst_wer_row = wav2vec_results_df.loc[worst_wer_idx]
worst_wer_wave, sr = ta.load(worst_wer_row['audio_path'])
show_output(worst_wer_wave, sr, worst_wer_row['gt'], worst_wer_row['decoded_output'])

In [None]:
# show sample output
sample_idx = 5
sample_row = wav2vec_results_df.loc[sample_idx]
sample_wave, sr = ta.load(sample_row['audio_path'])
show_output(sample_wave, sr, sample_row['gt'], sample_row['decoded_output'])

## SeamlessM4Tv2

In [None]:
from transformers import AutoProcessor, SeamlessM4Tv2Model

MODEL_ID = "facebook/seamless-m4t-v2-large"

seamless_processor = AutoProcessor.from_pretrained(MODEL_ID, cache_dir=cache_dir)
seamless_model = SeamlessM4Tv2Model.from_pretrained(MODEL_ID, cache_dir=cache_dir)
_ = seamless_model.eval().to('cuda')

In [None]:
# use the model for ASR
inputs = seamless_processor(audios=res_wave[0], return_tensors="pt", sampling_rate=target_sr)
inputs = inputs.to("cuda")
model_output = seamless_model.generate(**inputs, tgt_lang="deu", generate_speech=False)
decoded_output = seamless_processor.batch_decode(model_output[0], skip_special_tokens=True)[0]
show_output(res_wave, target_sr, transcript, decoded_output)