In [1]:
import os
import datasets 
from datasets import Dataset
from collections import defaultdict
import re
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch

from dataclasses import dataclass
from typing import Any, Dict, List, Union
from safetensors import safe_open
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Optional, Tuple

In [2]:
dataset_path = '/vision/u/eatang/ml_superb/eighth_version/'
languages = []

sources = os.listdir(dataset_path)

for source in sources:
    if source[0] != ".":
        languages.extend(os.listdir(os.path.join(dataset_path, source)))
        
languages = set([x for x in languages if '.' not in x])

In [3]:
def remove_punctuation(text):
    return re.sub(r'[^\w\s]', '', text)


all_paths = {}
all_sentences = {}
for duration in ["10min", "1h"]:
    for split in ["train", "test"]:
        language_to_paths = defaultdict(list)
        language_to_sentences = defaultdict(list)
        for language in languages:
            for source in sources:
                source_lang_path = os.path.join(dataset_path, source, language)
                if os.path.exists(os.path.join(source_lang_path, f'transcript_{duration}_{split}.txt')):
                    with open(os.path.join(source_lang_path, f'transcript_{duration}_{split}.txt'), 'r') as file:
                        lines = [line.rstrip() for line in file]
                        sentences = []
                        paths = []
                        for line in lines:
                            sentence = " ".join(re.split(r'[ \t]+', line)[2:])
                            sentence = remove_punctuation(sentence).lower().strip()
                            if len(sentence) <= 1:
                                continue
                            if len(re.split(r'[ \t]+', line)[0]) > 0:
                                sentences.append(sentence)
                                paths.append(os.path.join(source_lang_path, 'wav', re.split(r'[ \t]+', line)[0] + '.wav'))

                        language_to_paths[language].extend(paths)
                        language_to_sentences[language].extend(sentences)
        all_paths[duration + split] = language_to_paths
        all_sentences[duration + split] = language_to_sentences

In [4]:
all_paths['10mintrain'].keys()

dict_keys(['glg', 'kab', 'lav', 'pan', 'ind', 'deu', 'tgk', 'ben', 'slv', 'yor', 'vie', 'orm', 'slk', 'msa', 'hin', 'ven', 'uig', 'nld', 'ast', 'kan', 'mon', 'khm', 'rus', 'mlt', 'srp', 'fin', 'sun', 'por', 'hun', 'est', 'tso', 'tur', 'tos', 'urd', 'uzb', 'yue', 'sna', 'kir', 'bul', 'org_jpn', 'cnh', 'ina', 'fil', 'ara', 'heb', 'wol', 'ceb', 'mya', 'fas', 'zul', 'aze', 'ltz', 'ori', 'xty', 'eus', 'ron', 'ell', 'pus', 'swa', 'nep', 'tha', 'nno', 'bre', 'sin', 'hsb', 'isl', 'hye', 'umb', 'epo', 'kmr', 'nan', 'pol', 'ssw', 'mhr', 'kor', 'sah', 'afr', 'amh', 'fra', 'cat', 'nya', 'snd', 'hau', 'asm', 'lug', 'kaz', 'mrj', 'lit', 'tsn', 'myv', 'ibo', 'hrv', 'luo', 'ckb', 'grn', 'sot', 'kat', 'bel', 'tam', 'lao', 'guj', 'div', 'mal', 'xho', 'ful', 'kea', 'ita', 'frr', 'cym', 'lga', 'tok', 'ces', 'cmn', 'eng', 'gle', 'bos', 'mri', 'kin', 'jpn', 'tel', 'chv', 'dan', 'ukr', 'bas', 'oci', 'jav', 'som', 'lin', 'nso', 'skr', 'ory', 'swe', 'bak', 'abk', 'mar', 'tat', 'mkd', 'nob', 'nbl', 'spa', 'kam'

In [4]:
from transformers import WhisperProcessor,WhisperForConditionalGeneration, Wav2Vec2Processor, Wav2Vec2ForCTC
import torchaudio
from datasets import Dataset
import torch
import os
import numpy as np
from jiwer import wer

# Function to load and preprocess audio files
def load_audio(path):
    speech, _ = torchaudio.load(path)
    return speech.squeeze().numpy()

# Preprocess the dataset
def preprocess(batch):
    audio = [load_audio(path) for path in batch["audio"]]
    inputs = processor(audio, sampling_rate=16000, return_tensors="pt").to("cuda")
    labels = processor.tokenizer(text=batch["sentence"], return_tensors="pt", padding=True).input_ids
    inputs["labels"] = labels.to("cuda")
    return inputs

# Function to decode model predictions
def decode_predictions(pred_ids):
    pred_ids = pred_ids.cpu().numpy()
    pred_str = processor.batch_decode(pred_ids)
    return pred_str

# Evaluate WER
def compute_metrics(batch):
    label_ids = batch.label_ids
    label_ids[label_ids == -100] = processor.tokenizer.pad_token_id
    pred_ids = batch.predictions
    
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    pred_str = [remove_punctuation(x).lower().strip() for x in pred_str]
    label_str = processor.batch_decode(label_ids, skip_special_tokens=True)
    
    wer_score = wer(label_str, pred_str)
    return {"wer": wer_score}

# Evaluate WER
def compute_wer(batch, language="swahili"):
    inputs = {key: batch[key].to("cuda") for key in batch if key != "audio" and key != "sentence"}

    with torch.no_grad():
        pred_ids = model.generate(inputs["input_features"], language=language)
    
    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    pred_str = [remove_punctuation(x).lower().strip() for x in pred_str]
    label_str = processor.batch_decode(batch["labels"].cpu().numpy(), skip_special_tokens=True)
    
    wer_score = wer(label_str, pred_str)
    return wer_score



In [5]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels
        batch["group_idx"] = torch.LongTensor([feature["group_idx"] for feature in features])
        return batch


In [6]:
LANG = "swahili"
MODEL_ID = "openai/whisper-tiny"
processor = WhisperProcessor.from_pretrained(MODEL_ID, language=LANG)
model = WhisperForConditionalGeneration.from_pretrained(MODEL_ID).to("cuda")

languages = ['ssw', 'xho', 'swa']

# Prepare the data for Hugging Face datasets
train_data = {
    "audio": [],
    "sentence": [],
    "group_idx": []
}

test_data = {
    "audio": [],
    "sentence": [],
    "group_idx": []
}

group_str_to_idx = {
    "ssw" : 0,
    "xho" : 1,
    "swa" : 2,
}

for language in languages: 
    for path, sentence in zip(all_paths["10mintrain"][language], all_sentences["10mintrain"][language]):
        train_data["audio"].append(path)
        train_data["sentence"].append(sentence)
        train_data["group_idx"].append(group_str_to_idx[language])
    for path, sentence in zip(all_paths["10mintest"][language], all_sentences["10mintest"][language]):
        test_data["audio"].append(path)
        test_data["sentence"].append(sentence)
        test_data["group_idx"].append(group_str_to_idx[language])

# Create a Hugging Face dataset
train_dataset = Dataset.from_dict(train_data)
test_dataset = Dataset.from_dict(test_data)

train_set = train_dataset.map(preprocess, batched=True, batch_size=32).with_format("torch")
test_set = test_dataset.map(preprocess, batched=True, batch_size=32).with_format("torch")
train_dataloader = DataLoader(train_set, batch_size=32)
test_dataloader = DataLoader(test_set, batch_size=32)


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/755 [00:00<?, ? examples/s]

Map:   0%|          | 0/747 [00:00<?, ? examples/s]

In [7]:
class LossComputer:
    def __init__(self, criterion, is_robust, n_groups, group_counts, group_str_fn, alpha=None, gamma=0.1, adj=None, min_var_weight=0, step_size=0.01, normalize_loss=False, btl=False):
        self.criterion = criterion
        self.is_robust = is_robust
        self.gamma = gamma
        self.alpha = alpha
        self.min_var_weight = min_var_weight
        self.step_size = step_size
        self.normalize_loss = normalize_loss
        self.btl = btl

        self.n_groups = n_groups
        self.group_counts = group_counts.cuda()
        self.group_frac = self.group_counts/self.group_counts.sum()
        self.group_str = group_str_fn

        if adj is not None:
            self.adj = torch.from_numpy(adj).float().cuda()
        else:
            self.adj = torch.zeros(self.n_groups).float().cuda()

        if is_robust:
            assert alpha, 'alpha must be specified'

        # quantities maintained throughout training
        self.adv_probs = torch.ones(self.n_groups).cuda()/self.n_groups
        self.exp_avg_loss = torch.zeros(self.n_groups).cuda()
        self.exp_avg_initialized = torch.zeros(self.n_groups).byte().cuda()

        self.reset_stats()

    def loss(self, yhat, y, group_idx=None, is_training=False):
        # compute per-sample and per-group losses
#         print(yhat.shape, y.shape)
        per_sample_losses = self.criterion(yhat, y).mean(1)
        group_loss, group_count = self.compute_group_avg(per_sample_losses, group_idx)
        group_acc, group_count = self.compute_group_avg((torch.argmax(yhat,1)==y).float().mean(1), group_idx)

        # update historical losses
        self.update_exp_avg_loss(group_loss, group_count)

        # compute overall loss
        if self.is_robust and not self.btl:
            actual_loss, weights = self.compute_robust_loss(group_loss, group_count)
        elif self.is_robust and self.btl:
             actual_loss, weights = self.compute_robust_loss_btl(group_loss, group_count)
        else:
            actual_loss = per_sample_losses.mean()
            weights = None

        # update stats
        self.update_stats(actual_loss, group_loss, group_acc, group_count, weights)

        return actual_loss

    def compute_robust_loss(self, group_loss, group_count):
        adjusted_loss = group_loss
        if torch.all(self.adj>0):
            adjusted_loss += self.adj/torch.sqrt(self.group_counts)
        if self.normalize_loss:
            adjusted_loss = adjusted_loss/(adjusted_loss.sum())
        self.adv_probs = self.adv_probs * torch.exp(self.step_size*adjusted_loss.data)
        self.adv_probs = self.adv_probs/(self.adv_probs.sum())

        robust_loss = group_loss @ self.adv_probs
        return robust_loss, self.adv_probs

    def compute_robust_loss_btl(self, group_loss, group_count):
        adjusted_loss = self.exp_avg_loss + self.adj/torch.sqrt(self.group_counts)
        return self.compute_robust_loss_greedy(group_loss, adjusted_loss)

    def compute_robust_loss_greedy(self, group_loss, ref_loss):
        sorted_idx = ref_loss.sort(descending=True)[1]
        sorted_loss = group_loss[sorted_idx]
        sorted_frac = self.group_frac[sorted_idx]

        mask = torch.cumsum(sorted_frac, dim=0)<=self.alpha
        weights = mask.float() * sorted_frac /self.alpha
        last_idx = mask.sum()
        weights[last_idx] = 1 - weights.sum()
        weights = sorted_frac*self.min_var_weight + weights*(1-self.min_var_weight)

        robust_loss = sorted_loss @ weights

        # sort the weights back
        _, unsort_idx = sorted_idx.sort()
        unsorted_weights = weights[unsort_idx]
        return robust_loss, unsorted_weights

    def compute_group_avg(self, losses, group_idx):
        # compute observed counts and mean loss for each group
#         print(losses.shape, group_idx.shape)
        group_map = (group_idx == torch.arange(self.n_groups).unsqueeze(1).long().cuda()).float()
        group_count = group_map.sum(1)
        group_denom = group_count + (group_count==0).float() # avoid nans
        group_loss = (group_map @ losses.view(-1))/group_denom
        return group_loss, group_count

    def update_exp_avg_loss(self, group_loss, group_count):
        prev_weights = (1 - self.gamma*(group_count>0).float()) * (self.exp_avg_initialized>0).float()
        curr_weights = 1 - prev_weights
        self.exp_avg_loss = self.exp_avg_loss * prev_weights + group_loss*curr_weights
        self.exp_avg_initialized = (self.exp_avg_initialized>0) + (group_count>0)

    def reset_stats(self):
        self.processed_data_counts = torch.zeros(self.n_groups).cuda()
        self.update_data_counts = torch.zeros(self.n_groups).cuda()
        self.update_batch_counts = torch.zeros(self.n_groups).cuda()
        self.avg_group_loss = torch.zeros(self.n_groups).cuda()
        self.avg_group_acc = torch.zeros(self.n_groups).cuda()
        self.avg_per_sample_loss = 0.
        self.avg_actual_loss = 0.
        self.avg_acc = 0.
        self.batch_count = 0.

    def update_stats(self, actual_loss, group_loss, group_acc, group_count, weights=None):
        # avg group loss
        denom = self.processed_data_counts + group_count
        denom += (denom==0).float()
        prev_weight = self.processed_data_counts/denom
        curr_weight = group_count/denom
        self.avg_group_loss = prev_weight*self.avg_group_loss + curr_weight*group_loss

        # avg group acc
        self.avg_group_acc = prev_weight*self.avg_group_acc + curr_weight*group_acc

        # batch-wise average actual loss
        denom = self.batch_count + 1
        self.avg_actual_loss = (self.batch_count/denom)*self.avg_actual_loss + (1/denom)*actual_loss

        # counts
        self.processed_data_counts += group_count
        if self.is_robust:
            self.update_data_counts += group_count*((weights>0).float())
            self.update_batch_counts += ((group_count*weights)>0).float()
        else:
            self.update_data_counts += group_count
            self.update_batch_counts += (group_count>0).float()
        self.batch_count+=1

        # avg per-sample quantities
        group_frac = self.processed_data_counts/(self.processed_data_counts.sum())
        self.avg_per_sample_loss = group_frac @ self.avg_group_loss
        self.avg_acc = group_frac @ self.avg_group_acc

    def get_model_stats(self, model, args, stats_dict):
        model_norm_sq = 0.
        for param in model.parameters():
            model_norm_sq += torch.norm(param) ** 2
        stats_dict['model_norm_sq'] = model_norm_sq.item()
        stats_dict['reg_loss'] = args.weight_decay / 2 * model_norm_sq.item()
        return stats_dict

    def get_stats(self, model=None, args=None):
        stats_dict = {}
        for idx in range(self.n_groups):
            stats_dict[f'avg_loss_group:{idx}'] = self.avg_group_loss[idx].item()
            stats_dict[f'exp_avg_loss_group:{idx}'] = self.exp_avg_loss[idx].item()
            stats_dict[f'avg_acc_group:{idx}'] = self.avg_group_acc[idx].item()
            stats_dict[f'processed_data_count_group:{idx}'] = self.processed_data_counts[idx].item()
            stats_dict[f'update_data_count_group:{idx}'] = self.update_data_counts[idx].item()
            stats_dict[f'update_batch_count_group:{idx}'] = self.update_batch_counts[idx].item()

        stats_dict['avg_actual_loss'] = self.avg_actual_loss.item()
        stats_dict['avg_per_sample_loss'] = self.avg_per_sample_loss.item()
        stats_dict['avg_acc'] = self.avg_acc.item()

        # Model stats
        if model is not None:
            assert args is not None
            stats_dict = self.get_model_stats(model, args, stats_dict)

        return stats_dict

    def log_stats(self, logger, is_training):
        if logger is None:
            return

        logger.write(f'Average incurred loss: {self.avg_per_sample_loss.item():.3f}  \n')
        logger.write(f'Average sample loss: {self.avg_actual_loss.item():.3f}  \n')
        logger.write(f'Average acc: {self.avg_acc.item():.3f}  \n')
        for group_idx in range(self.n_groups):
            logger.write(
                f'  {self.group_str(group_idx)}  '
                f'[n = {int(self.processed_data_counts[group_idx])}]:\t'
                f'loss = {self.avg_group_loss[group_idx]:.3f}  '
                f'exp loss = {self.exp_avg_loss[group_idx]:.3f}  '
                f'adjusted loss = {self.exp_avg_loss[group_idx] + self.adj[group_idx]/torch.sqrt(self.group_counts)[group_idx]:.3f}  '
                f'adv prob = {self.adv_probs[group_idx]:3f}   '
                f'acc = {self.avg_group_acc[group_idx]:.3f}\n')
        logger.flush()

In [8]:
def group_str_fn(group_idx):
    x = {
        0: "ssw",
        1: "xho",
        2: "tsn"
    }

    return x[group_idx]


train_loss_computer = LossComputer(torch.nn.CrossEntropyLoss(reduction='none'), True, 3, torch.Tensor([5,5,5]), group_str_fn, alpha=0.2)
val_loss_computer = LossComputer(torch.nn.CrossEntropyLoss(reduction='none'), True, 3, torch.Tensor([5,5,5]), group_str_fn, alpha=0.2)

def compute_loss(model, inputs, return_outputs=False):
    inputs_no_idx = {k: v for k, v in inputs.items() if k != "group_idx"}
    outputs = model(**inputs_no_idx).logits
    if model.training:
        return train_loss_computer.loss(outputs.permute(0,2,1), inputs["labels"], inputs["group_idx"])
    else:
        return val_loss_computer.loss(outputs.permute(0,2,1), inputs["labels"], inputs["group_idx"])




In [25]:
def prediction_step(
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
        **gen_kwargs,
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on `model` using `inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (`nn.Module`):
                The model to evaluate.
            inputs (`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument `labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (`bool`):
                Whether or not to return the loss only.
            gen_kwargs:
                Additional `generate` specific kwargs.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """

        if not self.args.predict_with_generate or prediction_loss_only:
            return super().prediction_step(
                model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
            )

        has_labels = "labels" in inputs
        inputs = self._prepare_inputs(inputs)

        # Priority (handled in generate):
        # non-`None` gen_kwargs > model.generation_config > default GenerationConfig()
        if len(gen_kwargs) == 0 and hasattr(self, "_gen_kwargs"):
            gen_kwargs = self._gen_kwargs.copy()
        if "num_beams" in gen_kwargs and gen_kwargs["num_beams"] is None:
            gen_kwargs.pop("num_beams")
        if "max_length" in gen_kwargs and gen_kwargs["max_length"] is None:
            gen_kwargs.pop("max_length")

        default_synced_gpus = True if is_deepspeed_zero3_enabled() else False
        gen_kwargs["synced_gpus"] = (
            gen_kwargs["synced_gpus"] if gen_kwargs.get("synced_gpus") is not None else default_synced_gpus
        )

        generation_inputs = inputs.copy()
        # If the `decoder_input_ids` was created from `labels`, evict the former, so that the model can freely generate
        # (otherwise, it would continue generating from the padded `decoder_input_ids`)
        if (
            "labels" in generation_inputs
            and "decoder_input_ids" in generation_inputs
            and generation_inputs["labels"].shape == generation_inputs["decoder_input_ids"].shape
        ):
            generation_inputs = {
                k: v for k, v in inputs.items() if k not in ("decoder_input_ids", "decoder_attention_mask")
            }
        generated_tokens = self.model.generate(**generation_inputs, **gen_kwargs)

        # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
        # TODO: remove this hack when the legacy code that initializes generation_config from a model config is
        # removed in https://github.com/huggingface/transformers/blob/98d88b23f54e5a23e741833f1e973fdf600cc2c5/src/transformers/generation/utils.py#L1183
        if self.model.generation_config._from_model_config:
            self.model.generation_config._from_model_config = False

        # Retrieves GenerationConfig from model.generation_config
        gen_config = self.model.generation_config
        # in case the batch is shorter than max length, the output should be padded
        if generated_tokens.shape[-1] < gen_config.max_length:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_length)
        elif gen_config.max_new_tokens is not None and generated_tokens.shape[-1] < gen_config.max_new_tokens + 1:
            generated_tokens = self._pad_tensors_to_max_len(generated_tokens, gen_config.max_new_tokens + 1)

        with torch.no_grad():
            if has_labels:
                with self.compute_loss_context_manager():
                    inputs_no_idx = {k: v for k, v in inputs.items() if k != "group_idx"}
                    outputs = model(**inputs_no_idx).logits
                loss = val_loss_computer.loss(outputs.permute(0,2,1), inputs["labels"], inputs["group_idx"])
#                 elif self.label_smoother is not None:
#                     loss = self.label_smoother(outputs, inputs["labels"]).mean().detach()
#                 else:
#                     loss = (outputs["loss"] if isinstance(outputs, dict) else outputs[0]).mean().detach()
            else:
                loss = None

        if self.args.prediction_loss_only:
            return loss, None, None

        if has_labels:
            labels = inputs["labels"]
            if labels.shape[-1] < gen_config.max_length:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_length)
            elif gen_config.max_new_tokens is not None and labels.shape[-1] < gen_config.max_new_tokens + 1:
                labels = self._pad_tensors_to_max_len(labels, gen_config.max_new_tokens + 1)
        else:
            labels = None

        return loss, generated_tokens, labels

In [26]:
from transformers import Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    output_dir="./",  # change to a repo name of your choice
    per_device_train_batch_size=64,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-4,
    warmup_steps=50,
    max_steps=300,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=64,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1,
    logging_steps=100,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)



In [27]:
from transformers import Seq2SeqTrainer

data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_set,
    eval_dataset=test_set,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
max_steps is given, it will override any value given in num_train_epochs


In [28]:
trainer._signature_columns =['input_features',
 'attention_mask',
 'decoder_input_ids',
 'decoder_attention_mask',
 'head_mask',
 'decoder_head_mask',
 'cross_attn_head_mask',
 'encoder_outputs',
 'past_key_values',
 'decoder_inputs_embeds',
 'decoder_position_ids',
 'labels',
 'use_cache',
 'output_attentions',
 'output_hidden_states',
 'return_dict',
 'labels',
 'label_ids',
 'label',
  'group_idx']

# /viscam/u/eatang/miniconda3/lib/python3.10/site-packages/transformers/generation/utils.py:1542
# comment out this line of code

In [29]:
trainer.compute_loss = compute_loss
trainer.prediction_step = prediction_step
trainer.train()

Step,Training Loss,Validation Loss


NameError: name 'self' is not defined

In [None]:
model.generation_config.language = LANG
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None

# Compute WER for the entire dataset
wer_scores = []

# Use tqdm to wrap your dataloader to show a progress bar
for batch in tqdm(test_dataloader, desc="Processing batches"):
    wer_scores.append(compute_wer(batch, language=LANG))

average_wer = np.mean(wer_scores)
print(f"{language} Average WER: {average_wer}")

In [106]:
for batch in dataloader:
    inputs = {key: batch[key].to("cuda") for key in batch if key != "audio" and key != "sentence"}
    model.generation_config.language = "swahili"
    model.generation_config.task = "transcribe"


    with torch.no_grad():
        pred_ids = model.generate(inputs["input_features"], language="swahili", )

    pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    pred_str = [remove_punctuation(x).lower().strip() for x in pred_str]
    label_str = processor.batch_decode(batch["labels"].cpu().numpy(), skip_special_tokens=True)

    wer_score = wer(label_str, pred_str)
    break

In [131]:
i = 6

In [132]:
pred_str[i]

'mbunji inum blonganjin'

In [133]:
label_str[i]

'mbônji i nnumb loñge njiñ'

In [134]:
wer([label_str[i]], [pred_str[i]])

1.0

In [18]:
def group_str_fn(group_idx):
    x = {
        0: "ssw",
        1: "xho",
        2: "tsn"
    }

    return x[group_idx]

# torch.manual_seed(42)

# y_hat = torch.rand((8, 32)).cuda()
# y = torch.randint(8,32,(8,)).cuda().long()
# group_idxs = torch.Tensor([0,0,1,2,1,0,1,0]).cuda()
# # group_idxs = torch.Tensor([0,0,0,0,0,0,0,0]).cuda()


# print(loss_computer.loss(y_hat, y, group_idxs))

# y_hat = torch.rand((8, 32)).cuda()
# y = torch.randint(8,32,(8,)).cuda().long()
# group_idxs = torch.Tensor([0,0,1,0,1,0,1,0]).cuda()
# # group_idxs = torch.Tensor([0,0,0,0,0,0,0,0]).cuda()


# print(loss_computer.loss(y_hat, y, group_idxs))

train_loss_computer = LossComputer(torch.nn.CrossEntropyLoss(reduction='none'), True, 3, torch.Tensor([5,5,5]), group_str_fn, alpha=0.2)
val_loss_computer = LossComputer(torch.nn.CrossEntropyLoss(reduction='none'), True, 3, torch.Tensor([5,5,5]), group_str_fn, alpha=0.2)

def compute_loss(model, inputs, return_outputs=False):
    inputs_no_idx = {k: v for k, v in inputs.items() if k != "group_idx"}
    outputs = model(**inputs_no_idx).logits
    if model.training:
        return train_loss_computer.loss(outputs.permute(0,2,1), inputs["labels"], inputs["group_idx"])
    else:
        return val_loss_computer.loss(outputs.permute(0,2,1), inputs["labels"], inputs["group_idx"])


