In [1]:
import os
import datasets 
from datasets import Dataset
from collections import defaultdict
import re
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union
from safetensors import safe_open


In [2]:
dataset_path = '/vision/u/eatang/ml_superb/eighth_version/'
languages = []

sources = os.listdir(dataset_path)

for source in sources:
    if source[0] != ".":
        languages.extend(os.listdir(os.path.join(dataset_path, source)))
        
languages = set([x for x in languages if '.' not in x])

In [82]:
def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)


all_paths = {}
all_sentences = {}
for duration in ["10min", "1h"]:
    for split in ["train", "test"]:
        language_to_paths = defaultdict(list)
        language_to_sentences = defaultdict(list)
        for language in languages:
            for source in sources:
                source_lang_path = os.path.join(dataset_path, source, language)
                if os.path.exists(os.path.join(source_lang_path, f'transcript_{duration}_{split}.txt')):
                    with open(os.path.join(source_lang_path, f'transcript_{duration}_{split}.txt'), 'r') as file:
                        lines = [line.rstrip() for line in file]
                        sentences = []
                        paths = []
                        for line in lines:
                            sentence = " ".join(re.split(r'[ \t]+', line)[2:])
                            sentence = remove_punctuation(sentence).lower().strip()
                            if len(sentence) <= 1:
                                continue
                            if len(re.split(r'[ \t]+', line)[0]) > 0:
                                sentences.append(sentence)
                                paths.append(os.path.join(source_lang_path, 'wav', re.split(r'[ \t]+', line)[0] + '.wav'))

                        language_to_paths[language].extend(paths)
                        language_to_sentences[language].extend(sentences)
        all_paths[duration + split] = language_to_paths
        all_sentences[duration + split] = language_to_sentences

In [55]:
from transformers import WhisperProcessor,WhisperForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC
import torchaudio
from datasets import Dataset
import torch
import os
import numpy as np
from jiwer import wer

# Function to load and preprocess audio files
def load_audio(path):
    speech, _ = torchaudio.load(path)
    return speech.squeeze().numpy()

# Preprocess the dataset
def preprocess(batch):
    audio = [load_audio(path) for path in batch["audio"]]
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt").to("cuda")
    labels = processor.tokenizer(text=batch["sentence"], return_tensors="pt", padding=True).input_ids
    inputs["labels"] = labels.to("cuda")
    return inputs

# Function to decode model predictions
def decode_predictions(pred_ids):
    pred_ids = pred_ids.cpu().numpy()
    pred_str = processor.batch_decode(pred_ids)
    return pred_str

# Evaluate WER
def compute_metrics(batch):
    label_ids = batch.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_ids = batch.predictions
    
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    pred_str = [remove_punctuation(x).lower().strip() for x in pred_str]
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    
    wer_score = wer(label_str, pred_str)
    return {"wer": wer_score}

# Evaluate WER
def compute_wer(batch, language=LANG):
    inputs = {key: batch[key].to("cuda") for key in batch if key != "audio" and key != "sentence"}

    with torch.no_grad():
        pred_ids = model.generate(inputs["input_features"], language=language)
    
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    pred_str = [remove_punctuation(x).lower().strip() for x in pred_str]
    label_str = processor.batch_decode(batch["labels"].cpu().numpy(), skip_special_tokens=True)
    
    wer_score = wer(label_str, pred_str)
    return wer_score



In [48]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


In [101]:
# LANG = "swahili"
# MODEL_ID = "openai/whisper-tiny"
# processor = WhisperProcessor.from_pretrained(MODEL_ID, language=LANG)
# model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to("cuda")

languages = ['ssw', 'swa', 'xho']

# Prepare the data for Hugging Face datasets
train_data = {
    "audio": [],
    "sentence": []
}

for language in languages: 
    test_data = {
        "audio": [],
        "sentence": []
    }
#     for path, sentence in zip(all_paths["1htrain"][language], all_sentences["1htrain"][language]):
#         train_data["audio"].append(path)
#         train_data["sentence"].append(sentence)
    for path, sentence in zip(all_paths["10mintest"][language], all_sentences["10mintest"][language]):
        test_data["audio"].append(path)
        test_data["sentence"].append(sentence)

    # test_data["audio"] = test_data["audio"][:32]
    # test_data["sentence"] = test_data["sentence"][:32]

    # Create a Hugging Face dataset
#     train_dataset = Dataset.from_dict(train_data)
    test_dataset = Dataset.from_dict(test_data)

#     train_set = train_dataset.map(preprocess, batched=True, batch_size=32).with_format("torch")
    test_set = test_dataset.map(preprocess, batched=True, batch_size=32).with_format("torch")
#     train_dataloader = DataLoader(train_set, batch_size=32)
    test_dataloader = DataLoader(test_set, batch_size=32)
    
    model.generation_config.language = LANG
    model.generation_config.task = "transcribe"
    model.generation_config.forced_decoder_ids = None

    # Compute WER for the entire dataset
    wer_scores = []

    # Use tqdm to wrap your dataloader to show a progress bar
    for batch in tqdm(test_dataloader, desc="Processing batches"):
        wer_scores.append(compute_wer(batch, language=LANG))

    average_wer = np.mean(wer_scores)
    print(f"{language} Average WER: {average_wer}")


Map:   0%|          | 0/116 [00:00<?, ? examples/s]

Processing batches: 100%|████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:06<00:00,  1.68s/it]


ssw Average WER: 0.5841399899541198


Map:   0%|          | 0/317 [00:00<?, ? examples/s]

Processing batches: 100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:15<00:00,  1.55s/it]


swa Average WER: 0.5750793521499591


Map:   0%|          | 0/314 [00:00<?, ? examples/s]

Processing batches: 100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:14<00:00,  1.41s/it]

xho Average WER: 0.6668938540877771





In [91]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./",  # change to a repo name of your choice
    per_device_train_batch_size=64,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-4,
    warmup_steps=50,
    max_steps=300,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=64,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=100,
    logging_steps=10,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)



In [92]:
from transformers import Seq2SeqTrainer

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_set,
    eval_dataset=test_set,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs


In [93]:
trainer.train()

Step,Training Loss,Validation Loss,Wer
100,0.5065,0.662768,0.725646
200,0.2743,0.570663,0.610344
300,0.0933,0.559565,0.605384


TrainOutput(global_step=300, training_loss=0.6795074792702993, metrics={'train_runtime': 582.7457, 'train_samples_per_second': 32.947, 'train_steps_per_second': 0.515, 'total_flos': 4.692359503872e+17, 'train_loss': 0.6795074792702993, 'epoch': 4.285714285714286})

In [95]:
model.generation_config.language = LANG
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

# Compute WER for the entire dataset
wer_scores = []

# Use tqdm to wrap your dataloader to show a progress bar
for batch in tqdm(test_dataloader, desc="Processing batches"):
    wer_scores.append(compute_wer(batch, language=LANG))

average_wer = np.mean(wer_scores)
print(f"{language} Average WER: {average_wer}")

Processing batches: 100%|██████████████████████████████████████████████████████████████████████████████████| 24/24 [00:38<00:00,  1.59s/it]

xho Average WER: 0.6248699750350295





In [106]:
for batch in dataloader:
    inputs = {key: batch[key].to("cuda") for key in batch if key != "audio" and key != "sentence"}
    model.generation_config.language = "swahili"
    model.generation_config.task = "transcribe"


    with torch.no_grad():
        pred_ids = model.generate(inputs["input_features"], language="swahili", )

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    pred_str = [remove_punctuation(x).lower().strip() for x in pred_str]
    label_str = processor.batch_decode(batch["labels"].cpu().numpy(), skip_special_tokens=True)

    wer_score = wer(label_str, pred_str)
    break

In [131]:
i = 6

In [132]:
pred_str[i]

'mbunji inum blonganjin'

In [133]:
label_str[i]

'mbônji i nnumb loñge njiñ'

In [134]:
wer([label_str[i]], [pred_str[i]])

1.0