# Code-switching Pipeline POC

This is heavily based on <a href="https://github.com/openai/whisper/blob/main/notebooks/Multilingual_ASR.ipynb">this notebook</a>, and aims to combine the different tasks whisper is trained on to gather multilingual transcriptions.



Key idea:
For each frame,

In [None]:
from IPython.display import display, Audio, HTML

In [None]:
import torch
import transformers
import datasets

In [None]:
device = 'CUDA' if torch.cuda.is_available() else 'CPU'

In [None]:
# Load code switching dataset (e.g. ASCEND)
dataset = datasets.load_dataset('CAiRE/ASCEND')

In [None]:
# Give example
ex = dataset['train'][2]
SAMPLING_RATE = ex['audio']['sampling_rate']

display(HTML('<h1> Example Audio Segment</h1><hr>'))
display(Audio(ex['audio']['array'], rate=SAMPLING_RATE))
display(HTML(f"Transcription: {ex['transcription']}"))

In this example we have an example of true code switching, where the utterance goes from chinese language to an english phrase and back to chinese particles.

In a zero-shot setting, we achieve the following results:

In [None]:
from transformers import WhisperForConditionalGeneration, WhisperProcessor

In [None]:
processor = WhisperProcessor.from_pretrained('openai/whisper-medium')
model = WhisperForConditionalGeneration.from_pretrained('openai/whisper-medium')

In [None]:
input_features = processor(ex['audio']['array'], sampling_rate=SAMPLING_RATE, return_tensors='pt').input_features

In [None]:
pred = model.generate(input_features)

In [None]:
pred

In [None]:
transcription = processor.batch_decode(pred)

In [None]:
transcription

# NOTES

To do:
Test how WhisperModel works shape wise
Get frame/timestamp level language identification
Get transcription for each language
Collapse function

https://arxiv.org/abs/2006.11477

https://arxiv.org/pdf/2112.06223

https://arxiv.org/abs/2409.09543

# Baseline Evaluation and Fine-tuned evaluation

In [1]:
# Baseline Evaluation
import torch
from transformers import AutoModelForSpeechSeq2Seq, WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
from datasets import Audio, load_dataset
import evaluate

dataset = load_dataset("CAiRE/ASCEND")

common_voice = dataset.cast_column("audio", Audio(sampling_rate=16000))
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-large-v3-turbo")
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3-turbo", task="transcribe")
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-large-v3-turbo", task="transcribe")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
torch.set_default_dtype(torch_dtype)

model_id = "openai/whisper-large-v3-turbo"

model = AutoModelForSpeechSeq2Seq.from_pretrained(
    model_id, torch_dtype=torch_dtype, use_safetensors=True
)
model = model.to(device)
model = model.float()



  from .autonotebook import tqdm as notebook_tqdm


In [2]:

def prepare_dataset(batch, extractor, tokenizer):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array 
    batch["input_features"] = extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids 
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch


common_voice = common_voice.map(prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=4, 
                                fn_kwargs={'extractor':feature_extractor, 'tokenizer':tokenizer})
model.generation_config.task = "transcribe"
model.generation_config.forced_decoder_ids = None


In [3]:

def compute_metrics(pred, metric = evaluate.load("wer")):

    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [4]:
from dataclasses import dataclass
from typing import Any, Dict, List, Union

@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch


data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

In [5]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./finetune",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,
    learning_rate=1e-5,
    warmup_steps=10,
    max_steps=500,
    gradient_checkpointing=True,
    eval_strategy="steps",
    per_device_eval_batch_size=2,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=100,
    eval_steps=100,
    logging_steps=25,
    report_to=["tensorboard"],
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    fp16=False,     # MPS에서 fp16 비활성화
    remove_unused_columns=False
)

train_dataset = common_voice["train"].select(range(2))
eval_dataset = common_voice["test"].select(range(2))
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    processing_class=processor.feature_extractor,
)

# Start training
trainer.evaluate()
trainer.train()
trainer.evaluate()

max_steps is given, it will override any value given in num_train_epochs
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.43.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
100%|██████████| 1/1 [00:00<00:00,  8.71it/s]
  0%|          | 0/500 [00:00<?, ?it/s]`use_cache = True` is incompatib

{'loss': 0.0658, 'grad_norm': 0.0028760205022990704, 'learning_rate': 9.693877551020408e-06, 'epoch': 25.0}


 10%|█         | 50/500 [1:10:38<7:57:23, 63.65s/it]  

{'loss': 0.0, 'grad_norm': 0.0002822221431415528, 'learning_rate': 9.183673469387756e-06, 'epoch': 50.0}


 15%|█▌        | 75/500 [1:24:00<3:40:14, 31.09s/it]

{'loss': 0.0, 'grad_norm': 0.0001588374434504658, 'learning_rate': 8.673469387755103e-06, 'epoch': 75.0}


 20%|██        | 100/500 [1:36:58<3:27:13, 31.08s/it]

{'loss': 0.0, 'grad_norm': 0.00012207365944050252, 'learning_rate': 8.16326530612245e-06, 'epoch': 100.0}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
                                                     
 20%|██        | 100/500 [1:37:09<3:27:13, 31.08s/it]

{'eval_loss': 1.45229172706604, 'eval_model_preparation_time': 0.0071, 'eval_wer': 100.0, 'eval_runtime': 11.8075, 'eval_samples_per_second': 0.169, 'eval_steps_per_second': 0.085, 'epoch': 100.0}


 25%|██▌       | 125/500 [1:50:21<3:14:17, 31.09s/it]

{'loss': 0.0, 'grad_norm': 0.0001010647029033862, 'learning_rate': 7.653061224489796e-06, 'epoch': 125.0}


 30%|███       | 150/500 [2:03:18<3:01:21, 31.09s/it]

{'loss': 0.0, 'grad_norm': 8.74739489518106e-05, 'learning_rate': 7.1428571428571436e-06, 'epoch': 150.0}


 35%|███▌      | 175/500 [2:16:16<2:48:21, 31.08s/it]

{'loss': 0.0, 'grad_norm': 7.836943404981866e-05, 'learning_rate': 6.63265306122449e-06, 'epoch': 175.0}


 40%|████      | 200/500 [2:29:14<2:35:24, 31.08s/it]

{'loss': 0.0, 'grad_norm': 7.061767246341333e-05, 'learning_rate': 6.122448979591837e-06, 'epoch': 200.0}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
                                                     
 40%|████      | 200/500 [2:29:25<2:35:24, 31.08s/it]

{'eval_loss': 1.45560622215271, 'eval_model_preparation_time': 0.0071, 'eval_wer': 100.0, 'eval_runtime': 11.8079, 'eval_samples_per_second': 0.169, 'eval_steps_per_second': 0.085, 'epoch': 200.0}


 45%|████▌     | 225/500 [2:42:37<2:22:40, 31.13s/it]

{'loss': 0.0, 'grad_norm': 6.453368405345827e-05, 'learning_rate': 5.6122448979591834e-06, 'epoch': 225.0}


 50%|█████     | 250/500 [2:55:34<2:09:32, 31.09s/it]

{'loss': 0.0, 'grad_norm': 6.028938514646143e-05, 'learning_rate': 5.1020408163265315e-06, 'epoch': 250.0}


 55%|█████▌    | 275/500 [3:08:32<1:56:34, 31.09s/it]

{'loss': 0.0, 'grad_norm': 5.703227361664176e-05, 'learning_rate': 4.591836734693878e-06, 'epoch': 275.0}


 60%|██████    | 300/500 [3:21:30<1:43:36, 31.08s/it]

{'loss': 0.0, 'grad_norm': 5.411282108980231e-05, 'learning_rate': 4.081632653061225e-06, 'epoch': 300.0}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
                                                     
 60%|██████    | 300/500 [3:21:41<1:43:36, 31.08s/it]

{'eval_loss': 1.467933177947998, 'eval_model_preparation_time': 0.0071, 'eval_wer': 100.0, 'eval_runtime': 11.7894, 'eval_samples_per_second': 0.17, 'eval_steps_per_second': 0.085, 'epoch': 300.0}


 65%|██████▌   | 325/500 [3:34:52<1:30:40, 31.09s/it]

{'loss': 0.0, 'grad_norm': 5.1301278290338814e-05, 'learning_rate': 3.5714285714285718e-06, 'epoch': 325.0}


 70%|███████   | 350/500 [3:47:50<1:17:44, 31.10s/it]

{'loss': 0.0, 'grad_norm': 4.96740140079055e-05, 'learning_rate': 3.0612244897959185e-06, 'epoch': 350.0}


 75%|███████▌  | 375/500 [4:00:48<1:04:44, 31.08s/it]

{'loss': 0.0, 'grad_norm': 4.7977020585676655e-05, 'learning_rate': 2.5510204081632657e-06, 'epoch': 375.0}


 80%|████████  | 400/500 [4:13:45<51:49, 31.09s/it]  

{'loss': 0.0, 'grad_norm': 4.667815665015951e-05, 'learning_rate': 2.0408163265306125e-06, 'epoch': 400.0}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
                                                   
 80%|████████  | 400/500 [4:13:57<51:49, 31.09s/it]

{'eval_loss': 1.4767768383026123, 'eval_model_preparation_time': 0.0071, 'eval_wer': 100.0, 'eval_runtime': 11.8142, 'eval_samples_per_second': 0.169, 'eval_steps_per_second': 0.085, 'epoch': 400.0}


 85%|████████▌ | 425/500 [4:27:08<38:51, 31.09s/it]  

{'loss': 0.0, 'grad_norm': 4.547726348391734e-05, 'learning_rate': 1.5306122448979593e-06, 'epoch': 425.0}


 90%|█████████ | 450/500 [4:40:06<25:54, 31.09s/it]

{'loss': 0.0, 'grad_norm': 4.490283390623517e-05, 'learning_rate': 1.0204081632653063e-06, 'epoch': 450.0}


 95%|█████████▌| 475/500 [4:53:03<12:58, 31.13s/it]

{'loss': 0.0, 'grad_norm': 4.449974949238822e-05, 'learning_rate': 5.102040816326531e-07, 'epoch': 475.0}


100%|██████████| 500/500 [5:06:01<00:00, 31.11s/it]

{'loss': 0.0, 'grad_norm': 4.429456384968944e-05, 'learning_rate': 0.0, 'epoch': 500.0}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
                                                   
100%|██████████| 500/500 [5:06:13<00:00, 31.11s/it]

{'eval_loss': 1.4796991348266602, 'eval_model_preparation_time': 0.0071, 'eval_wer': 100.0, 'eval_runtime': 11.7988, 'eval_samples_per_second': 0.17, 'eval_steps_per_second': 0.085, 'epoch': 500.0}


There were missing keys in the checkpoint model loaded: ['proj_out.weight'].
100%|██████████| 500/500 [5:06:31<00:00, 36.78s/it]


{'train_runtime': 18391.2999, 'train_samples_per_second': 0.109, 'train_steps_per_second': 0.027, 'train_loss': 0.003293265566131595, 'epoch': 500.0}


Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
Trainer.tokenizer is now deprecated. You should use Trainer.processing_class instead.
100%|██████████| 1/1 [00:00<00:00,  8.68it/s]


{'eval_loss': 1.45229172706604,
 'eval_model_preparation_time': 0.0071,
 'eval_wer': 100.0,
 'eval_runtime': 12.0499,
 'eval_samples_per_second': 0.166,
 'eval_steps_per_second': 0.083,
 'epoch': 500.0}