In [None]:
# Imports

!pip install transformers[torch] datasets evaluate jiwer soundfile librosa accelerate

import torch
import torchaudio
from transformers import (
    Wav2Vec2Processor,
    Wav2Vec2ForCTC,
    TrainingArguments,
    Trainer
)
from datasets import Dataset, DatasetDict
import evaluate
import pandas as pd
import numpy as np
import librosa
import soundfile as sf
from pathlib import Path
import re
import json
from dataclasses import dataclass
from typing import Dict, List, Union, Any

from google.colab import drive
import pandas as pd
import os
import librosa
import numpy as np

Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (12 kB)
Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading rapidfuzz-3.14.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (3.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m82.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rapidfuzz, jiwer, evaluate
Successfully installed evaluate-0.4.6 jiwer-4.0.0 rapidfuzz-3.14.1


In [None]:
# Mount Google Drive

drive.mount('/content/drive')

print("Google Drive mounted successfully!")
print("\nLet's explore your data structure...")

Mounted at /content/drive
Google Drive mounted successfully!

Let's explore your data structure...


In [None]:
# Defing paths

base_path = '/content/drive/MyDrive/data'
test_audio_path = '/content/drive/MyDrive/data/test'
metadata_path = '/content/drive/MyDrive/data/metadata.csv'

print(f"\nChecking data paths:")
print(f"Base data folder exists: {os.path.exists(base_path)}")
print(f"Test audio folder exists: {os.path.exists(test_audio_path)}")
print(f"Metadata CSV exists: {os.path.exists(metadata_path)}")

if os.path.exists(base_path):
    print(f"\nContents of data folder:")
    for item in os.listdir(base_path):
        item_path = os.path.join(base_path, item)
        if os.path.isdir(item_path):
            print(f"{item}/ ({len(os.listdir(item_path))} items)")
        else:
            print(f"{item}")

if os.path.exists(test_audio_path):
    audio_files = [f for f in os.listdir(test_audio_path) if f.endswith('.wav')]
    print(f"\nFound {len(audio_files)} WAV files in test folder")
    if audio_files:
        print("First few audio files:")
        for i, file in enumerate(audio_files[:5]):
            print(f"  {i+1}. {file}")

# Load and examine metadata
if os.path.exists(metadata_path):
    print(f"\nLoading metadata...")
    metadata_df = pd.read_csv(metadata_path)
    print(f"Metadata shape: {metadata_df.shape}")
    print(f"\nColumns: {list(metadata_df.columns)}")
    print(f"\nFirst few rows:")
    print(metadata_df.head())

    # Basic statistics
    print(f"\nBasic Statistics:")
    print(f"Total records: {len(metadata_df)}")
    print(f"Unique subjects: {metadata_df['subject'].nunique()}")
    print(f"Gender distribution: {metadata_df['gender'].value_counts().to_dict()}")
    print(f"Utterance types: {metadata_df['utterance_type'].value_counts().to_dict()}")

    if 'duration' in metadata_df.columns:
        print(f"Duration stats: min={metadata_df['duration'].min():.2f}s, max={metadata_df['duration'].max():.2f}s, mean={metadata_df['duration'].mean():.2f}s")

else:
    print("Metadata file not found. Please check the path.")


Checking data paths:
Base data folder exists: True
Test audio folder exists: True
Metadata CSV exists: True

Contents of data folder:
metadata.csv
readme.txt
get_stats.ipynb
.DS_Store
test/ (400 items)
metadata.gsheet

Found 400 WAV files in test folder
First few audio files:
  1. video1_M02_159.wav
  2. video1_F01_66.wav
  3. video1_F02_125.wav
  4. video1_M01_3.wav
  5. video2_M03_268.wav

Loading metadata...
Metadata shape: (400, 9)

Columns: ['file_name', 'transcript', 'utterance_type', 'subject', 'gender', 'age', 'diagnosis', 'comment', 'duration']

First few rows:
                    file_name                               transcript  \
0  data/test/video1_M01_1.wav      zmiany zauważyłem od ośmiu miesięcy   
1  data/test/video1_M01_2.wav  i pogarszało się z tygodnia na tydzień.   
2  data/test/video1_M01_3.wav               i teraz już nie mogę mówić   
3  data/test/video1_M01_4.wav                                        a   
4  data/test/video1_M01_5.wav                       

In [None]:
# Check if GPU is available

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Load evaluation metric for later use
wer_metric = evaluate.load("wer")

print("All packages installed and imported successfully!")
print("Ready for data preprocessing...")

Using device: cuda
GPU: Tesla T4
GPU Memory: 15.8 GB


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading builder script: 0.00B [00:00, ?B/s]

All packages installed and imported successfully!
Ready for data preprocessing...


In [None]:
# Load Polish Wav2Vec2 Model and Split Data

# Load pretrained Polish Wav2Vec2 model
model_name = "facebook/wav2vec2-large-xlsr-53-polish"
processor = Wav2Vec2Processor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

print(f"Loaded Polish model: {model_name}")
print(f"Vocabulary size: {len(processor.tokenizer)}")
print(f"Sample rate: {processor.feature_extractor.sampling_rate}")

# Define paths
test_audio_path = '/content/drive/MyDrive'
metadata_path = '/content/drive/MyDrive/data/metadata.csv'

# Load metadata
metadata_df = pd.read_csv(metadata_path)
print(f"\nLoaded metadata with {len(metadata_df)} records")

# Prepare data for HuggingFace Dataset
def prepare_dataset_entry(row):
    """Convert metadata row to dataset entry"""
    audio_path = os.path.join(test_audio_path, row['file_name'])
    return {
        'audio': audio_path,
        'sentence': row['transcript'],
        'file_name': row['file_name'],
        'utterance_type': row.get('utterance_type', ''),
        'subject': row.get('subject', ''),
        'gender': row.get('gender', ''),
        'age': row.get('age', ''),
        'diagnosis': row.get('diagnosis', '')
    }

# Convert to list of dictionaries
dataset_entries = [prepare_dataset_entry(row) for _, row in metadata_df.iterrows()]

# Split data into train/validation/test (80/10/10)
from sklearn.model_selection import train_test_split

# First split: 80% train, 20% temp
train_data, temp_data = train_test_split(dataset_entries, test_size=0.2, random_state=42)

# Second split: 10% val, 10% test from the 20% temp
val_data, test_data = train_test_split(temp_data, test_size=0.5, random_state=42)

print(f"\nData split:")
print(f"Train: {len(train_data)} samples")
print(f"Validation: {len(val_data)} samples")
print(f"Test: {len(test_data)} samples")

# Create HuggingFace datasets
train_dataset = Dataset.from_list(train_data)
val_dataset = Dataset.from_list(val_data)
test_dataset = Dataset.from_list(test_data)

# Create dataset dictionary
dataset_dict = DatasetDict({
    'train': train_dataset,
    'validation': val_dataset,
    'test': test_dataset
})

print(f"\nDataset created successfully!")
print(f"Train dataset: {len(dataset_dict['train'])} samples")
print(f"Validation dataset: {len(dataset_dict['validation'])} samples")
print(f"Test dataset: {len(dataset_dict['test'])} samples")

print("\nReady for audio preprocessing...")

preprocessor_config.json:   0%|          | 0.00/158 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/378 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

vocab.json:   0%|          | 0.00/390 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.26G [00:00<?, ?B/s]

Loaded Polish model: facebook/wav2vec2-large-xlsr-53-polish
Vocabulary size: 42
Sample rate: 16000

Loaded metadata with 400 records

Data split:
Train: 320 samples
Validation: 40 samples
Test: 40 samples

Dataset created successfully!
Train dataset: 320 samples
Validation dataset: 40 samples
Test dataset: 40 samples

Ready for audio preprocessing...


In [None]:
# Audio Preprocessing and Dataset Preparation (Fixed)

# First, let's check and fix the audio paths
print("Checking audio file paths...")

# Check a few samples to see the path structure
for i in range(min(3, len(dataset_dict["train"]))):
    sample_path = dataset_dict["train"][i]["audio"]
    print(f"Sample {i}: {sample_path}")
    print(f"Exists: {os.path.exists(sample_path)}")

# Function to load and preprocess audio with fixed paths
def preprocess_function(examples):
    """Load audio files and prepare inputs for training"""
    audio_arrays = []
    valid_sentences = []

    # Load each audio file
    for i, audio_path in enumerate(examples["audio"]):

        # Check if file exists
        if not os.path.exists(audio_path):
            print(f"Warning: File not found: {audio_path}")
            continue

        try:
            # Load audio at 16kHz (required for Wav2Vec2)
            audio_array, sampling_rate = librosa.load(audio_path, sr=16000)
            audio_arrays.append(audio_array)
            valid_sentences.append(examples["sentence"][i])
        except Exception as e:
            print(f"Error loading {audio_path}: {e}")
            continue

    # Skip batch if no valid audio files
    if not audio_arrays:
        return {"input_values": [], "attention_mask": [], "labels": []}

    # Process audio with the processor
    inputs = processor(
        audio_arrays,
        sampling_rate=16000,
        return_tensors="pt",
        padding=True,
        return_attention_mask=True
    )

    # Process text (transcripts) separately
    with processor.as_target_processor():
        labels = processor.tokenizer(
            valid_sentences,
            return_tensors="pt",
            padding=True
        )

    # Format the batch
    batch = {
        "input_values": inputs.input_values,
        "attention_mask": inputs.attention_mask,
        "labels": labels.input_ids
    }

    return batch

# Apply preprocessing to datasets
print("\nPreprocessing train dataset...")
train_dataset = dataset_dict["train"].map(
    preprocess_function,
    remove_columns=dataset_dict["train"].column_names,
    batched=True,
    batch_size=4,
    num_proc=1
)

print("Preprocessing validation dataset...")
val_dataset = dataset_dict["validation"].map(
    preprocess_function,
    remove_columns=dataset_dict["validation"].column_names,
    batched=True,
    batch_size=4,
    num_proc=1
)

print("Preprocessing test dataset...")
test_dataset = dataset_dict["test"].map(
    preprocess_function,
    remove_columns=dataset_dict["test"].column_names,
    batched=True,
    batch_size=4,
    num_proc=1
)

# Filter out empty samples
def filter_empty(example):
    return len(example["input_values"]) > 0 and len(example["labels"]) > 0

train_dataset = train_dataset.filter(filter_empty)
val_dataset = val_dataset.filter(filter_empty)
test_dataset = test_dataset.filter(filter_empty)

# Update dataset dictionary
dataset_dict = DatasetDict({
    "train": train_dataset,
    "validation": val_dataset,
    "test": test_dataset
})

print("\nPreprocessing complete!")
print(f"Train dataset: {len(dataset_dict['train'])} samples")
print(f"Validation dataset: {len(dataset_dict['validation'])} samples")
print(f"Test dataset: {len(dataset_dict['test'])} samples")

# Data collator for dynamic padding during training
@dataclass
class DataCollatorCTCWithPadding:
    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # Split inputs and labels since they have different lengths
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        # Pad input features
        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # Pad labels using tokenizer directly to avoid conflicts
        labels_batch = self.processor.tokenizer.pad(
            label_features,
            padding=self.padding,
            return_tensors="pt",
        )

        # Replace padding with -100 to ignore in loss computation
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

# Create data collator
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

print("\nData collator created!")
print("Ready for training setup...")

Checking audio file paths...
Sample 0: /content/drive/MyDrive/data/test/video1_M01_4.wav
Exists: True
Sample 1: /content/drive/MyDrive/data/test/video1_M01_19.wav
Exists: True
Sample 2: /content/drive/MyDrive/data/test/video2_F03_213.wav
Exists: True

Preprocessing train dataset...


Map:   0%|          | 0/320 [00:00<?, ? examples/s]

model.safetensors:   0%|          | 0.00/1.26G [00:00<?, ?B/s]



Preprocessing validation dataset...


Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Preprocessing test dataset...


Map:   0%|          | 0/40 [00:00<?, ? examples/s]

Filter:   0%|          | 0/320 [00:00<?, ? examples/s]

Filter:   0%|          | 0/40 [00:00<?, ? examples/s]

Filter:   0%|          | 0/40 [00:00<?, ? examples/s]


Preprocessing complete!
Train dataset: 320 samples
Validation dataset: 40 samples
Test dataset: 40 samples

Data collator created!
Ready for training setup...


In [None]:
# Training Setup and Configuration

# Define evaluation metrics
def compute_metrics(eval_pred):
    """Compute Word Error Rate (WER) for evaluation"""
    predictions, labels = eval_pred

    # Decode predictions
    predicted_ids = torch.argmax(torch.tensor(predictions), dim=-1)

    # Replace -100 with pad token id for decoding
    labels[labels == -100] = processor.tokenizer.pad_token_id

    # Decode both predictions and labels
    predicted_sentences = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    label_sentences = processor.batch_decode(labels, skip_special_tokens=True)

    # Compute WER
    wer = wer_metric.compute(predictions=predicted_sentences, references=label_sentences)

    return {"wer": wer}

# Training arguments
training_args = TrainingArguments(
    output_dir="./wav2vec2-polish-finetuned",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=2,
    eval_strategy="steps",
    eval_steps=100,
    save_steps=200,
    logging_steps=50,
    learning_rate=3e-4,
    warmup_steps=500,
    max_steps=2000,
    fp16=torch.cuda.is_available(),
    push_to_hub=False,
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to=[]
)

print("Training arguments configured:")
print(f"  - Batch size (train): {training_args.per_device_train_batch_size}")
print(f"  - Batch size (eval): {training_args.per_device_eval_batch_size}")
print(f"  - Learning rate: {training_args.learning_rate}")
print(f"  - Max steps: {training_args.max_steps}")
print(f"  - FP16: {training_args.fp16}")
print(f"  - Output directory: {training_args.output_dir}")

# Freeze feature encoder to speed up training
model.freeze_feature_encoder()
print("\nFeature encoder frozen for faster training")

# Create trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["validation"],
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

print("\nTrainer created successfully!")
print(f"Training dataset size: {len(dataset_dict['train'])}")
print(f"Validation dataset size: {len(dataset_dict['validation'])}")
print(f"Test dataset size: {len(dataset_dict['test'])}")

# Check model parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel parameters:")
print(f"  - Total parameters: {total_params:,}")
print(f"  - Trainable parameters: {trainable_params:,}")
print(f"  - Frozen parameters: {total_params - trainable_params:,}")

print("\nReady to start training!")
print("Run trainer.train() in the next step to begin fine-tuning...")

Training arguments configured:
  - Batch size (train): 4
  - Batch size (eval): 4
  - Learning rate: 0.0003
  - Max steps: 2000
  - FP16: True
  - Output directory: ./wav2vec2-polish-finetuned

Feature encoder frozen for faster training


  trainer = Trainer(



Trainer created successfully!
Training dataset size: 320
Validation dataset size: 40
Test dataset size: 40

Model parameters:
  - Total parameters: 315,481,770
  - Trainable parameters: 311,271,594
  - Frozen parameters: 4,210,176

Ready to start training!
Run trainer.train() in the next step to begin fine-tuning...


In [None]:
print("Evaluating baseline performance without interfering with training models...")

# Load baseline model into separate variables (won't affect training)
baseline_model_name = "facebook/wav2vec2-large-xlsr-53-polish"
baseline_processor_eval = Wav2Vec2Processor.from_pretrained(baseline_model_name)
baseline_model_eval = Wav2Vec2ForCTC.from_pretrained(baseline_model_name)

print(f"Loaded separate baseline model for evaluation: {baseline_model_name}")

# Move to device and set to eval mode
baseline_model_eval.to(device)
baseline_model_eval.eval()

# Safe evaluation function that doesn't modify original model/processor
def safe_evaluate_model(eval_model, eval_processor, dataset, dataset_name="dataset"):
    """Safely evaluate model without affecting training variables"""
    print(f"\nEvaluating on {dataset_name} ({len(dataset)} samples)...")

    predictions = []
    references = []

    # Process in small batches
    batch_size = 4
    for i in range(0, len(dataset), batch_size):
        batch_end = min(i + batch_size, len(dataset))
        batch_indices = list(range(i, batch_end))

        # Get batch data
        batch_input_values = []
        batch_attention_mask = []
        batch_labels = []

        for idx in batch_indices:
            item = dataset[idx]
            batch_input_values.append(torch.tensor(item["input_values"]))
            batch_attention_mask.append(torch.tensor(item["attention_mask"]))
            batch_labels.append(item["labels"])

        # Stack tensors
        input_values = torch.stack(batch_input_values).to(device)
        attention_mask = torch.stack(batch_attention_mask).to(device)

        # Model inference (no gradients)
        with torch.no_grad():
            logits = eval_model(input_values, attention_mask=attention_mask).logits
            predicted_ids = torch.argmax(logits, dim=-1)

        # Decode predictions using eval_processor
        batch_predictions = eval_processor.batch_decode(predicted_ids, skip_special_tokens=True)

        # Decode references using eval_processor
        batch_references = []
        for labels in batch_labels:
            label_tensor = torch.tensor(labels)
            # Replace -100 with pad token for decoding
            label_tensor[label_tensor == -100] = eval_processor.tokenizer.pad_token_id
            reference = eval_processor.decode(label_tensor, skip_special_tokens=True)
            batch_references.append(reference)

        predictions.extend(batch_predictions)
        references.extend(batch_references)

        # Show progress
        if (i // batch_size + 1) % 10 == 0 or batch_end == len(dataset):
            print(f"  Processed {batch_end}/{len(dataset)} samples")

    # Compute WER using original wer_metric
    wer = wer_metric.compute(predictions=predictions, references=references)

    # Show some examples
    print(f"\n{dataset_name} Results:")
    print(f"WER: {wer:.4f}")
    print("\nSample predictions vs references:")
    for i in range(min(3, len(predictions))):
        print(f"  {i+1}. Prediction: '{predictions[i]}'")
        print(f"     Reference:  '{references[i]}'")
        print()

    return wer

# Evaluate baseline model on all datasets using separate variables
print("="*60)
print("SAFE BASELINE MODEL EVALUATION")
print("="*60)

# Evaluate on validation set
baseline_val_wer = safe_evaluate_model(
    baseline_model_eval, baseline_processor_eval,
    dataset_dict["validation"], "Validation"
)

# Evaluate on test set
baseline_test_wer = safe_evaluate_model(
    baseline_model_eval, baseline_processor_eval,
    dataset_dict["test"], "Test"
)

# Evaluate on a small sample of training set
train_sample_size = min(50, len(dataset_dict["train"]))
baseline_train_wer = safe_evaluate_model(
    baseline_model_eval, baseline_processor_eval,
    dataset_dict["train"].select(range(train_sample_size)), "Training Sample"
)

# Summary
print("="*60)
print("BASELINE RESULTS SUMMARY")
print("="*60)
print(f"Training Sample WER: {baseline_train_wer:.4f}")
print(f"Validation WER:      {baseline_val_wer:.4f}")
print(f"Test WER:            {baseline_test_wer:.4f}")
print("="*60)

# Clean up evaluation variables to free memory
del baseline_model_eval, baseline_processor_eval
if torch.cuda.is_available():
    torch.cuda.empty_cache()

print("\nBaseline evaluation complete!")
print("Evaluation variables cleaned up - your training models are unaffected.")
print("\nYour original model and processor variables remain unchanged for training.")
print("\nNext: Proceed with fine-tuning using your existing model and processor.")


Evaluating baseline performance without interfering with training models...
Loaded separate baseline model for evaluation: facebook/wav2vec2-large-xlsr-53-polish
SAFE BASELINE MODEL EVALUATION

Evaluating on Validation (40 samples)...
  Processed 40/40 samples

Validation Results:
WER: 0.8069

Sample predictions vs references:
  1. Prediction: 'moci'
     Reference:  'dziewięć'

  2. Prediction: 'rotu się źle czuje dzisiaj po lekach żle się czuje po letach i jak nigdy się tak nie czułam jak siaj dzisiaj źle czy nie wiem cze miałm zmienione te'
     Reference:  'poprostu się źle czuję dzisiaj po lekach źle się czuję po lekach nigdy się tak nie czułam jak się dzisiaj źle czuję nie wiem czy miałam zmienione te leki'

  3. Prediction: 'krakób jest piękne miasto pełno zabytku'
     Reference:  'kraków jest pięknym miastem pełno zabytków'


Evaluating on Test (40 samples)...
  Processed 40/40 samples

Test Results:
WER: 0.8477

Sample predictions vs references:
  1. Prediction: 'krakupti jes

In [None]:
# Training with Early Stopping

from transformers import EarlyStoppingCallback

print("Setting up improved training with early stopping...")

# Training arguments with early stopping and regularization
training_args = TrainingArguments(
    output_dir="./wav2vec2-polish-finetuned-v2",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    eval_strategy="steps",
    eval_steps=50,
    save_steps=50,
    logging_steps=25,
    learning_rate=1e-4,
    warmup_steps=100,
    max_steps=1000,
    fp16=torch.cuda.is_available(),
    push_to_hub=False,
    remove_unused_columns=False,
    dataloader_pin_memory=False,
    save_total_limit=3,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    report_to=[],
    # Regularization
    weight_decay=0.01,
    dataloader_drop_last=False,
)

# Early stopping callback
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3,
    early_stopping_threshold=0.01
)

print("Training configuration:")
print(f"  - Batch size: {training_args.per_device_train_batch_size}")
print(f"  - Learning rate: {training_args.learning_rate}")
print(f"  - Early stopping patience: 3 evaluations")
print(f"  - Weight decay: {training_args.weight_decay}")

# Reload the original model
print("\nReloading fresh model...")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53-polish")
model.freeze_feature_encoder()

# Create new trainer with early stopping
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["validation"],
    tokenizer=processor.feature_extractor,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[early_stopping]  # Add early stopping callback
)

print("\nStarting improved training with early stopping...")

try:
    training_results = trainer.train()

    print("\nTraining completed!")
    print(f"Final training loss: {training_results.training_loss:.4f}")
    print(f"Steps completed: {training_results.global_step}")

    # Evaluate on test set
    print("\nEvaluating on test set...")
    test_results = trainer.evaluate(eval_dataset=dataset_dict["test"])
    print(f"Test WER: {test_results['eval_wer']:.4f}")

except Exception as e:
    print(f"Training failed: {e}")

# Save the improved model
print("\nSaving the improved model...")
trainer.save_model("./wav2vec2-polish-finetuned-v2")
processor.save_pretrained("./wav2vec2-polish-finetuned-v2")
print("Improved model saved!")

Setting up improved training with early stopping...
Training configuration:
  - Batch size: 2
  - Learning rate: 0.0001
  - Early stopping patience: 3 evaluations
  - Weight decay: 0.01

Reloading fresh model...


  trainer = Trainer(



Starting improved training with early stopping...


Step,Training Loss,Validation Loss,Wer
50,-178.4227,-185.974121,0.786207
100,-184.5124,-199.247116,0.737931
150,-227.9012,-197.745209,0.731034
200,-222.6662,-199.675293,0.724138
250,-224.7321,-198.669144,0.724138



Training completed!
Final training loss: -207.9243
Steps completed: 250

Evaluating on test set...


Test WER: 0.7152

Saving the improved model...
Improved model saved!
