In [68]:
import torch
import torch.nn as nn
import torchaudio.transforms as T
import numpy as np
from datasets import load_dataset
from transformers import (
    Wav2Vec2ForCTC, Wav2Vec2Processor,
    HubertForCTC, 
    WhisperProcessor, WhisperForConditionalGeneration
)

In [69]:
print("1. Loading Real German Data (flozi00/asr-german-mixed-evals)...")

# Load dataset (streaming mode)
dataset_stream = load_dataset(
    "flozi00/asr-german-mixed-evals", 
    split="train", 
    streaming=True
)

1. Loading Real German Data (flozi00/asr-german-mixed-evals)...


In [70]:
def get_next_sample():
    sample = next(iter(dataset_stream))
    
    # 1. Get the raw list of numbers
    audio_data = sample["audio"]["array"]
    orig_sr = sample["audio"]["sampling_rate"]
    
    # 2. CRITICAL FIX: Convert List -> NumPy Array -> Tensor
    # The error happened because 'audio_data' was a list. 
    # np.array() fixes it.
    audio_tensor = torch.from_numpy(np.array(audio_data)).float()
    
    # 3. Manual Resampling to 16000 Hz (Standard for these models)
    if orig_sr != 16000:
        resampler = T.Resample(orig_sr, 16000)
        audio_tensor = resampler(audio_tensor)
    
    # 4. Extract Text
    text = sample["references"]
    
    return audio_tensor, text

In [71]:
audio_check, text_check = get_next_sample()
print(f"   Data Loaded Successfully.")
print(f"   Sample Text: {text_check}")
print(f"   Audio Shape: {audio_check.shape}")

   Data Loaded Successfully.
   Sample Text: Sie hätten jedenfalls sogleich die sicherste Kontrolle für meine Darstellung an ihr, auf der anderen Seite gewinnt aber diese vielleicht an Unbefangenheit und historischer Treue.
   Audio Shape: torch.Size([193280])


In [72]:
class BaselineCNN(nn.Module):
    def __init__(self, n_classes=32):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 32, 3, padding=1), nn.ReLU(),
            nn.Conv1d(32, 64, 3, padding=1), nn.ReLU(),
            nn.Conv1d(64, n_classes, 3, padding=1)
        )
    def forward(self, x):
        # FIX: Handle 1D input [Time] -> 3D [1, 1, Time]
        if x.ndim == 1:
            x = x.unsqueeze(0).unsqueeze(0)
        # FIX: Handle 2D input [Batch, Time] -> 3D [Batch, 1, Time]
        elif x.ndim == 2:
            x = x.unsqueeze(1)
            
        return self.cnn(x).permute(0, 2, 1) 

In [73]:
def manual_ctc_decode(logits, vocab):
    probs = torch.softmax(logits, dim=-1)
    best_path = torch.argmax(probs, dim=-1)[0]
    
    decoded_chars = []
    prev_idx = -1
    
    for idx in best_path:
        idx = idx.item()
        if idx != prev_idx and idx != 0:
            char = vocab.get(idx, "")
            decoded_chars.append(char)
        prev_idx = idx
        
    return "".join(decoded_chars)

In [74]:
print("   Loading HuBERT (German)...")
processor_3 = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
model_3 = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")

   Loading HuBERT (German)...


In [75]:
print("   Loading Whisper (German)...")
processor_4 = WhisperProcessor.from_pretrained("openai/whisper-small")
model_4 = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

   Loading Whisper (German)...


In [76]:
print("   Loading Wav2Vec2 XLS-R (German)...")
w2v_id = "facebook/wav2vec2-large-xlsr-53-german"
processor_5 = Wav2Vec2Processor.from_pretrained(w2v_id)
model_5 = Wav2Vec2ForCTC.from_pretrained(w2v_id)

   Loading Wav2Vec2 XLS-R (German)...


In [77]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size=1, hidden_size=128, n_classes=32):
        super().__init__()
        self.rnn = nn.GRU(
            input_size=input_size, 
            hidden_size=hidden_size, 
            batch_first=True, 
            bidirectional=True
        )
        self.fc = nn.Linear(hidden_size * 2, n_classes)

    def forward(self, x):
        if x.ndim == 1:
            x = x.unsqueeze(0).unsqueeze(2)
        
        output, _ = self.rnn(x)
        return self.fc(output)
    
model_6 = SimpleRNN()

In [78]:
import torch
import torch.nn as nn
from transformers import (
    Wav2Vec2ForCTC, Wav2Vec2Processor,
    HubertForCTC, 
    WhisperProcessor, WhisperForConditionalGeneration
)
from jiwer import wer

In [79]:
def manual_ctc_decode(logits, processor):
    # Rule-based decoding (Greedy Search) without using the library's .decode()
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)[0]
    return transcription

In [80]:
class BaselineCNN(nn.Module):
    def __init__(self, n_classes=32):
        super().__init__()
        self.cnn = nn.Sequential(
            nn.Conv1d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(64, n_classes, kernel_size=3, padding=1)
        )

    def forward(self, x):
        # FIX 1: If input is just 1D audio [Time], make it [1, 1, Time]
        if x.ndim == 1:
            x = x.unsqueeze(0).unsqueeze(0)
        # FIX 2: If input is 2D batch [Batch, Time], make it [Batch, 1, Time]
        elif x.ndim == 2:
            x = x.unsqueeze(1)
            
        return self.cnn(x).permute(0, 2, 1)

model_1 = BaselineCNN()

In [81]:
class SimpleRNN(nn.Module):
    def __init__(self, input_size=1, hidden_size=128, n_classes=32):
        super().__init__()
        self.rnn = nn.GRU(input_size, hidden_size, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_size*2, n_classes)
    def forward(self, x):
        if x.ndim == 1: x = x.unsqueeze(0).unsqueeze(2) 
        x, _ = self.rnn(x)
        return self.fc(x)

In [82]:
print("   Loading Pre-trained Brains (HuBERT, Whisper, Wav2Vec2)...")
# HuBERT
p_hubert = Wav2Vec2Processor.from_pretrained("facebook/hubert-large-ls960-ft")
m_hubert = HubertForCTC.from_pretrained("facebook/hubert-large-ls960-ft")

   Loading Pre-trained Brains (HuBERT, Whisper, Wav2Vec2)...


In [83]:
# Whisper
p_whisper = WhisperProcessor.from_pretrained("openai/whisper-small")
m_whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")

In [84]:
# Wav2Vec2 XLS-R (German)
p_w2v = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-german")
m_w2v = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53-german")

In [85]:
results = []
ground_truth = text_check

In [86]:
# We just need to check if it runs. Output will be garbage.
cnn = BaselineCNN()
with torch.no_grad():
    out = cnn(audio_check)
    pred_1 = "random_init_output" # Placeholder as it's untrained

In [None]:
# 2. Manual CTC (Applied to Wav2Vec2 Logits)
with torch.no_grad():
    logits = m_w2v(audio_check.unsqueeze(0)).logits
    pred_2 = manual_ctc_decode(logits, p_w2v)

In [88]:
# 3. HuBERT
with torch.no_grad():
    logits = m_hubert(audio_check.unsqueeze(0)).logits
    pred_ids = torch.argmax(logits, dim=-1)
    pred_3 = p_hubert.batch_decode(pred_ids)[0]

In [None]:
# 4. Whisper
with torch.no_grad():
    input_features = p_whisper(audio_check, sampling_rate=16000, return_tensors="pt").input_features
    gen_ids = m_whisper.generate(input_features)
    pred_4 = p_whisper.batch_decode(gen_ids, skip_special_tokens=True)[0]

Using custom `forced_decoder_ids` from the (generation) config. This is deprecated in favor of the `task` and `language` flags/config options.
Transcription using a multilingual Whisper will default to language detection followed by transcription instead of translation to English. This might be a breaking change for your use case. If you want to instead always translate your audio to English, make sure to pass `language='en'`. See https://github.com/huggingface/transformers/pull/28687 for more details.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [90]:
# 5. Wav2Vec2 XLS-R
with torch.no_grad():
    logits = m_w2v(audio_check.unsqueeze(0)).logits
    pred_ids = torch.argmax(logits, dim=-1)
    pred_5 = p_w2v.batch_decode(pred_ids)[0]

In [91]:
# 6. RNN (Random)
rnn = SimpleRNN()
with torch.no_grad():
    out = rnn(audio_check)
    pred_6 = "random_init_output"

In [92]:
print("\n" + "="*60)
print(f"FINAL PROJECT REPORT (German Sample)")
print("="*60)
print(f"GROUND TRUTH: \"{ground_truth}\"")
print("-" * 60)
print(f"{'METHOD':<25} | {'WER':<10} | {'TRANSCRIPTION (First 50 chars)'}")
print("-" * 60)

predictions = [
    ("1. Baseline CNN", pred_1),
    ("2. Manual CTC (Rule)", pred_2),
    ("3. HuBERT (Transf)", pred_3),
    ("4. Whisper (Enc-Dec)", pred_4),
    ("5. Wav2Vec2 (XLS-R)", pred_5),
    ("6. Simple RNN", pred_6)
]


FINAL PROJECT REPORT (German Sample)
GROUND TRUTH: "Sie hätten jedenfalls sogleich die sicherste Kontrolle für meine Darstellung an ihr, auf der anderen Seite gewinnt aber diese vielleicht an Unbefangenheit und historischer Treue."
------------------------------------------------------------
METHOD                    | WER        | TRANSCRIPTION (First 50 chars)
------------------------------------------------------------


In [93]:
for name, pred in predictions:
    # Calculate Word Error Rate (WER)
    # If pred is garbage, WER is high (1.0 or more)
    if "random" in pred:
        error_rate = 1.0
    else:
        error_rate = wer(ground_truth, pred)
        
    print(f"{name:<25} | {error_rate:.4f}     | {pred[:50]}...")

1. Baseline CNN           | 1.0000     | random_init_output...
2. Manual CTC (Rule)      | 0.2800     | sie hätten jedenfalls sogleich die sicherste kontr...
3. HuBERT (Transf)        | 1.1200     | ZHE HADN YEDEN FIL SOGLAICH DI ZE HESTO CONTROLLEF...
4. Whisper (Enc-Dec)      | 0.1200     |  Sie hätten jedenfalls zugleich die sicherste Kont...
5. Wav2Vec2 (XLS-R)       | 0.2800     | sie hätten jedenfalls sogleich die sicherste kontr...
6. Simple RNN             | 1.0000     | random_init_output...
