In [1]:
import random
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from huggingface_hub import notebook_login
from datasets import load_dataset, DatasetDict, Audio, Dataset, concatenate_datasets
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration, Seq2SeqTrainingArguments, Seq2SeqTrainer, pipeline
import evaluate
import jiwer
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
import os
from tqdm import tqdm
import gradio as gr
import seaborn as sns
import time
import matplotlib.pyplot as plt

In [2]:
directory = "Prepared_Datasets"
all_dirs = os.listdir(directory)
test_batch_dirs = [f"{directory}/{dir_name}" for dir_name in all_dirs if dir_name.startswith('processed_test_batch')]
test_batch_datasets = [Dataset.load_from_disk(batch_dir) for batch_dir in test_batch_dirs]
test_dataset = concatenate_datasets(test_batch_datasets)

metric = evaluate.load("wer")
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
model.generation_config.language = "polish"
model.generation_config.task = "transcribe"
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")

def compute_metrics(pred):
    pred_ids = pred["predictions"]
    label_ids = pred["label_ids"]
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    pred_str = [p.lower() for p in pred_str]
    label_str = [l.lower() for l in label_str]

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)
    return {"eval_wer": wer}



In [5]:

def get_random_samples(dataset, num_samples=100):
    random_indices = random.sample(range(len(dataset)), num_samples)
    return dataset.select(random_indices)

def transcribe_samples(model, processor, samples):
    results = []
    times = []
    predictions = []
    references = []
    wer_list = []

    for sample in tqdm(samples, desc="Transcribing Samples", unit="sample"):
        input_features = torch.tensor([sample["input_features"]])
        reference_str = processor.tokenizer.decode(sample["labels"], skip_special_tokens=True)

        attention_mask = torch.ones(input_features.shape[:2], dtype=torch.long)

        start_time = time.time()
        pred_ids = model.generate(input_features, attention_mask=attention_mask)
        pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)[0]
        end_time = time.time()

        predictions.append(pred_str)
        references.append(reference_str)
        times.append(end_time - start_time)

        # WER computation for each sample
        pred = {
            "predictions": processor.tokenizer(pred_str, return_tensors="pt", padding=True, truncation=True).input_ids,
            "label_ids": processor.tokenizer(reference_str, return_tensors="pt", padding=True, truncation=True).input_ids,
        }
        metrics = compute_metrics(pred)
        wer_list.append(metrics["eval_wer"])

        results.append({"reference": reference_str, "transcription": pred_str})

    return results, times, wer_list, predictions, references

def evaluate_models(model_paths, samples, metric):
    results = {}
    times_all = {}
    wer_all = {}

    for model_path in model_paths:
        processor = WhisperProcessor.from_pretrained(model_path, language="pl", task="transcribe")
        model = WhisperForConditionalGeneration.from_pretrained(model_path)
        model.config.forced_decoder_ids = processor.tokenizer.convert_tokens_to_ids(["<|pl|>"])
        model.config.language = "pl"

        transcriptions, times, wer_list, predictions, references = transcribe_samples(model, processor, samples)

        # Calculate overall WER
        overall_pred = {
            "predictions": processor.tokenizer(predictions, padding=True, truncation=True, return_tensors="pt").input_ids,
            "label_ids": processor.tokenizer(references, padding=True, truncation=True, return_tensors="pt").input_ids,
        }
        overall_metrics = compute_metrics(overall_pred)

        results[model_path] = overall_metrics["eval_wer"]
        times_all[model_path] = times
        wer_all[model_path] = wer_list

    return results, times_all, wer_all

def plot_inference_times_and_wer(times_all, wer_all):
    # Plot for Inference Times (Boxplot)
    plt.figure(figsize=(8, 5))
    all_times = [times_all[model_path] for model_path in times_all]
    plt.boxplot(all_times, vert=True, patch_artist=True, boxprops=dict(facecolor="skyblue"))
    plt.title("Inference Time Comparison for All Models")
    plt.xlabel("Model")
    plt.ylabel("Inference Time (s)")
    plt.xticks(range(1, len(times_all) + 1), list(times_all.keys()))
    plt.tight_layout()
    plt.show()

    # Plot for WER (Bar chart)
    plt.figure(figsize=(8, 5))
    wer_values = [wer_all[model_path] for model_path in wer_all]
    model_names = list(wer_all.keys())
    avg_wer = [sum(wer) / len(wer) for wer in wer_values]  # Average WER per model
    plt.bar(model_names, avg_wer, color='orange')
    plt.title("WER Comparison for All Models")
    plt.xlabel("Model")
    plt.ylabel("WER (%)")
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    model_folders = ["openai/whisper-tiny", "openai/whisper-large-v3-turbo"]
    model_folders.extend(folder_name for folder_name in os.listdir() if folder_name.startswith("whisper"))
    num_samples = 1

    samples = get_random_samples(test_dataset, num_samples=num_samples)
    wer_results, times_results, wer_sample_results = evaluate_models(model_folders, samples, metric)

    print("\nWER Results:")
    for model, wer in wer_results.items():
        print(f"{model}: {wer:.2f}%")

    plot_inference_times_and_wer(times_results, wer_sample_results)

Transcribing Samples: 100%|██████████| 1/1 [00:00<00:00,  1.45sample/s]
Transcribing Samples:   0%|          | 0/1 [00:00<?, ?sample/s]


RuntimeError: Given groups=1, weight of size [1280, 128, 3], expected input[1, 80, 3000] to have 128 channels, but got 80 channels instead