In [5]:
# from huggingface_hub import notebook_login

# notebook_login() #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

In [6]:
#! pip install accelerate -U

In [1]:
import os
os. chdir('../')

In [2]:
import torch
from dataclasses import dataclass
from typing import Any, Dict, List, Union
from constants.mir_constants import WAV2VEC2_ARGS
from datasets import Dataset,Audio,load_metric,load_dataset, DatasetDict
from transformers import Seq2SeqTrainingArguments,Seq2SeqTrainer
from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor, WhisperForConditionalGeneration
from training.wav2vec2_finetune import Wav2Vec2SpeechRecognition, SpeechRecognitionData
import json
from dataclasses import dataclass, asdict
import pandas as pd
from jiwer import wer

In [3]:
WAV2VEC2_ARGS.MODEL_BACKBONE = "openai/whisper-large"
WAV2VEC2_ARGS.MODEL_SAVE_PATH="/scratch/users/gmenon/model_artefacts/whisper-base.pt"
print(json.dumps(asdict(WAV2VEC2_ARGS), indent = 4))

{
    "TRAIN_FILE_PATH": "/scratch/users/gmenon/train_song_metadata_en_demucs_cleaned_filtered_095.csv",
    "TEST_FILE_PATH": "/scratch/users/gmenon/validation_song_metadata_en_demucs_cleaned_filtered_005.csv",
    "MODEL_BACKBONE": "openai/whisper-large",
    "BATCH_SIZE": 1,
    "NUM_EPOCHS": 15,
    "MODEL_SAVE_PATH": "/scratch/users/gmenon/model_artefacts/whisper-base.pt",
    "FINETUNE_STRATEGY": [
        "freeze_unfreeze",
        10
    ],
    "LR_SCHEDULER": "reduce_on_plateau_schedule"
}


In [4]:
train_data = pd.read_csv(WAV2VEC2_ARGS.TRAIN_FILE_PATH)
test_data = pd.read_csv(WAV2VEC2_ARGS.TEST_FILE_PATH)
train_audio_dataset = Dataset.from_dict({"audio": list(train_data["consolidated_file_path"]), "transcription": list(train_data["transcription"])}).cast_column("audio", Audio(sampling_rate=16000))
test_audio_dataset = Dataset.from_dict({"audio": list(test_data["consolidated_file_path"]), "transcription": list(test_data["transcription"])}).cast_column("audio", Audio(sampling_rate=16000))

train_audio_dataset[0]

{'audio': {'path': '/scratch/users/gmenon/wav_clips/separated/htdemucs/01cef35811fd4a3fa63a3ab8bba5430c/vocals.wav',
  'array': array([-0.01242626, -0.06679466, -0.00535162, ...,  0.19485795,
          0.20158702,  0.        ]),
  'sampling_rate': 16000},
 'transcription': "right about now i'm fifty fifty"}

In [7]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lengths and need different padding methods
        # first treat the audio inputs by simply returning torch tensors
        input_features = [{"input_features": feature["input_features"]} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")

        # get the tokenized label sequences
        label_features = [{"input_ids": feature["labels"]} for feature in features]
        # pad the labels to max length
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        # if bos token is appended in previous tokenization step,
        # cut bos token here as it's append later anyways
        if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch["labels"] = labels

        return batch

In [9]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(WAV2VEC2_ARGS.MODEL_BACKBONE)
tokenizer = WhisperTokenizer.from_pretrained(WAV2VEC2_ARGS.MODEL_BACKBONE, language="English", task="transcribe")
processor = WhisperProcessor.from_pretrained(WAV2VEC2_ARGS.MODEL_BACKBONE, language="English", task="transcribe")
model = WhisperForConditionalGeneration.from_pretrained(WAV2VEC2_ARGS.MODEL_BACKBONE)
data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)

# Update some model config
model.config.forced_decoder_ids = None
model.config.suppress_tokens = []

In [7]:
def prepare_dataset(batch):
    # load and resample audio data from 48 to 16kHz
    audio = batch["audio"]

    # compute log-Mel input features from input audio array
    batch["input_features"] = feature_extractor(audio["array"], sampling_rate=audio["sampling_rate"]).input_features[0]

    # encode target text to label ids
    batch["labels"] = tokenizer(batch["transcription"]).input_ids
    return batch

In [8]:
tokenizer

WhisperTokenizer(name_or_path='openai/whisper-large', vocab_size=50258, model_max_length=1024, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': '<|endoftext|>', 'additional_special_tokens': ['<|endoftext|>', '<|startoftranscript|>', '<|en|>', '<|zh|>', '<|de|>', '<|es|>', '<|ru|>', '<|ko|>', '<|fr|>', '<|ja|>', '<|pt|>', '<|tr|>', '<|pl|>', '<|ca|>', '<|nl|>', '<|ar|>', '<|sv|>', '<|it|>', '<|id|>', '<|hi|>', '<|fi|>', '<|vi|>', '<|he|>', '<|uk|>', '<|el|>', '<|ms|>', '<|cs|>', '<|ro|>', '<|da|>', '<|hu|>', '<|ta|>', '<|no|>', '<|th|>', '<|ur|>', '<|hr|>', '<|bg|>', '<|lt|>', '<|la|>', '<|mi|>', '<|ml|>', '<|cy|>', '<|sk|>', '<|te|>

In [42]:
train_audio_dataset = train_audio_dataset.map(prepare_dataset, num_proc=4)
test_audio_dataset = test_audio_dataset.map(prepare_dataset, num_proc=4)

Map (num_proc=4):   0%|          | 0/9538 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/507 [00:00<?, ? examples/s]

In [43]:
metric = load_metric("wer")

In [9]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [10]:
training_args = Seq2SeqTrainingArguments(
    output_dir="/scratch/users/gmenon/whisper-large-dali",  # change to a repo name of your choice
    per_device_train_batch_size=4,
    gradient_accumulation_steps=1,  # increase by 2x for every 2x decrease in batch size
    learning_rate=1e-5,
    warmup_steps=500,
    max_steps=4000,
    gradient_checkpointing=True,
    fp16=True,
    evaluation_strategy="steps",
    per_device_eval_batch_size=4,
    predict_with_generate=True,
    generation_max_length=225,
    save_steps=1000,
    eval_steps=1000,
    logging_steps=25,
    load_best_model_at_end=True,
    metric_for_best_model="wer",
    greater_is_better=False,
    push_to_hub=False,
)

In [11]:

trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=train_audio_dataset,
    eval_dataset=test_audio_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

In [47]:
processor.save_pretrained(training_args.output_dir)

In [48]:
#trainer.train(resume_from_checkpoint="/scratch/users/gmenon/whisper-large-dali/checkpoint-2000")
trainer.train()



Step,Training Loss,Validation Loss,Wer
1000,1.0435,0.977644,85.845718
2000,0.9779,0.882608,60.721868
3000,0.4158,0.847412,122.505308
4000,0.387,0.822984,75.194621


TrainOutput(global_step=4000, training_loss=0.8582798521518707, metrics={'train_runtime': 9988.2756, 'train_samples_per_second': 1.602, 'train_steps_per_second': 0.4, 'total_flos': 3.39664899907584e+19, 'train_loss': 0.8582798521518707, 'epoch': 1.68})

In [68]:
for sample_num in range(20):
    sample = test_audio_dataset[sample_num]["audio"]
    input_features = processor(sample["array"], sampling_rate=16000, return_tensors="pt").input_features 
    predicted_ids = model.generate(input_features.cuda())
    #transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    print(f"Predicted Transcript = {transcription[0]},Actual Transcript = {test_audio_dataset[sample_num]['transcription']}")


Predicted Transcript =  toso learn from your mistakes,Actual Transcript = so learn from your mistakes
Predicted Transcript = to the right line,Actual Transcript = i've been connected to the right line
Predicted Transcript =  tothe truth to be found,Actual Transcript = the truth to be found
Predicted Transcript =  toshe said the way my blue eyes shine,Actual Transcript = he said the way myblue eyes shined
Predicted Transcript =  tothe only one taking all my own,Actual Transcript = you leave me once again home alone
Predicted Transcript =  towhile they are in command,Actual Transcript = while they are in commend
Predicted Transcript =  tomy love of mine,Actual Transcript = a life all mine
Predicted Transcript =  toso i never went back,Actual Transcript = so i never went back
Predicted Transcript =  toto keep you alive,Actual Transcript = in you i taste god
Predicted Transcript =  toto stay a while,Actual Transcript = could stay a while
Predicted Transcript =  toi love my life,Actual Tran

In [18]:
model = WhisperForConditionalGeneration.from_pretrained("/scratch/users/gmenon/whisper-large-dali/checkpoint-2000").cuda()
for sample_num in range(20):
    sample = test_audio_dataset[sample_num]["audio"]
    input_features = processor(sample["array"], sampling_rate=16000, return_tensors="pt").input_features 
    predicted_ids = model.generate(input_features.cuda())
    #transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    print(f"Predicted Transcript = {transcription[0]},Actual Transcript = {test_audio_dataset[sample_num]['transcription']}")


Predicted Transcript =  toso learn from your mistakes,Actual Transcript = so learn from your mistakes
Predicted Transcript = to the right line,Actual Transcript = i've been connected to the right line
Predicted Transcript =  tothe truth to be found,Actual Transcript = the truth to be found
Predicted Transcript =  toshe said the way my blue eyes shine,Actual Transcript = he said the way myblue eyes shined
Predicted Transcript =  tothe only one taking all my own,Actual Transcript = you leave me once again home alone
Predicted Transcript =  towhile they are in command,Actual Transcript = while they are in commend
Predicted Transcript =  tomy love of mine,Actual Transcript = a life all mine
Predicted Transcript =  toso i never went back,Actual Transcript = so i never went back
Predicted Transcript =  toto keep you alive,Actual Transcript = in you i taste god
Predicted Transcript =  toto stay a while,Actual Transcript = could stay a while
Predicted Transcript =  toi love my life,Actual Tran

In [19]:
model = WhisperForConditionalGeneration.from_pretrained("/scratch/users/gmenon/whisper-large-dali/checkpoint-1000").cuda()
for sample_num in range(20):
    sample = test_audio_dataset[sample_num]["audio"]
    input_features = processor(sample["array"], sampling_rate=16000, return_tensors="pt").input_features 
    predicted_ids = model.generate(input_features.cuda())
    #transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    print(f"Predicted Transcript = {transcription[0]},Actual Transcript = {test_audio_dataset[sample_num]['transcription']}")


Predicted Transcript =  toso learn from your mistakes,Actual Transcript = so learn from your mistakes
Predicted Transcript = you through the right line,Actual Transcript = i've been connected to the right line
Predicted Transcript =  tonot true to be found,Actual Transcript = the truth to be found
Predicted Transcript =  toyou said the way my blue eyes shine,Actual Transcript = he said the way myblue eyes shined
Predicted Transcript =  toyou're really wanting to get on my own,Actual Transcript = you leave me once again home alone
Predicted Transcript =  towhile they are incomming,Actual Transcript = while they are in commend
Predicted Transcript = oh i'm all alone,Actual Transcript = a life all mine
Predicted Transcript =  toso i never went back,Actual Transcript = so i never went back
Predicted Transcript = so we're gonna make it round to round,Actual Transcript = in you i taste god
Predicted Transcript = you say you're wild,Actual Transcript = could stay a while
Predicted Transcript 

In [20]:
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").cuda()
for sample_num in range(20):
    sample = test_audio_dataset[sample_num]["audio"]
    input_features = processor(sample["array"], sampling_rate=16000, return_tensors="pt").input_features 
    predicted_ids = model.generate(input_features.cuda())
    #transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    print(f"Predicted Transcript = {transcription[0]},Actual Transcript = {test_audio_dataset[sample_num]['transcription']}")


Predicted Transcript =  So learn from your mistakes,Actual Transcript = so learn from your mistakes
Predicted Transcript =  Through the right line,Actual Transcript = i've been connected to the right line
Predicted Transcript =  The truth is bound,Actual Transcript = the truth to be found
Predicted Transcript =  You set the way my blue eyes shine,Actual Transcript = he said the way myblue eyes shined
Predicted Transcript =  I'm a fool,Actual Transcript = you leave me once again home alone
Predicted Transcript =  While they are in command,Actual Transcript = while they are in commend
Predicted Transcript =  My love...,Actual Transcript = a life all mine
Predicted Transcript =  So I never went back,Actual Transcript = so i never went back
Predicted Transcript =  I'm going to make a hole in the wall.,Actual Transcript = in you i taste god
Predicted Transcript =  To say I was,Actual Transcript = could stay a while
Predicted Transcript =  I love my heart,Actual Transcript = my love my life


In [21]:
model = WhisperForConditionalGeneration.from_pretrained("/scratch/users/gmenon/whisper-large-dali/checkpoint-3000").cuda()
for sample_num in range(20):
    sample = test_audio_dataset[sample_num]["audio"]
    input_features = processor(sample["array"], sampling_rate=16000, return_tensors="pt").input_features 
    predicted_ids = model.generate(input_features.cuda())
    #transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
    print(f"Predicted Transcript = {transcription[0]},Actual Transcript = {test_audio_dataset[sample_num]['transcription']}")


Predicted Transcript =  toso learn from your mistakes,Actual Transcript = so learn from your mistakes
Predicted Transcript =  toto the right line,Actual Transcript = i've been connected to the right line
Predicted Transcript =  tothe truth to the sound,Actual Transcript = the truth to be found
Predicted Transcript =  toyou said the way my blue eyes shine,Actual Transcript = he said the way myblue eyes shined
Predicted Transcript =  tothe rain it once again on my own,Actual Transcript = you leave me once again home alone
Predicted Transcript =  towhile they are in command,Actual Transcript = while they are in commend
Predicted Transcript =  tomy momma,Actual Transcript = a life all mine
Predicted Transcript =  toso i never went back,Actual Transcript = so i never went back
Predicted Transcript =  toyou know i love you,Actual Transcript = in you i taste god
Predicted Transcript =  tocan stay a while,Actual Transcript = could stay a while
Predicted Transcript =  toi lost my life,Actual Tr

In [51]:
model = WhisperForConditionalGeneration.from_pretrained("/scratch/users/gmenon/whisper-large-dali/checkpoint-2000").cpu()

whisper_outputs = model(processor(test_audio_dataset["audio"][0]["array"], sampling_rate=16000, return_tensors="pt").input_features.cpu() ,decoder_input_ids=tokenizer(test_audio_dataset["transcription"][0],return_tensors="pt").input_ids.cpu())

In [68]:
whisper_outputs[0].shape,whisper_outputs[0]

(torch.Size([1, 10, 51865]),
 tensor([[[ 3.2344e+00,  4.6969e-01, -3.3156e-02,  ..., -3.0436e+00,
           -2.5017e+00, -5.5387e+00],
          [ 8.0211e+00,  8.1208e+00,  3.3465e+00,  ...,  2.3013e+00,
            1.6144e+00, -2.7801e+00],
          [ 2.6590e+01,  3.0826e+01,  2.9286e+01,  ...,  2.7998e+01,
            2.7803e+01,  2.6338e+01],
          ...,
          [ 9.7797e+00,  1.0055e+01,  1.0368e+01,  ...,  4.7854e+00,
            5.2226e+00,  6.1670e+00],
          [ 3.4403e+01,  3.4782e+01,  2.8662e+01,  ...,  2.3286e+01,
            2.2944e+01,  2.0881e+01],
          [ 7.1971e+00,  6.3629e+00,  5.8428e+00,  ..., -3.8433e-01,
           -1.0985e+00, -2.6368e+00]]], grad_fn=<UnsafeViewBackward0>))

In [65]:
whisper_outputs.encoder_last_hidden_state.shape

torch.Size([1, 1500, 1280])

In [67]:
whisper_outputs.encoder_last_hidden_state

tensor([[[-7.0498e-01, -8.8163e-01,  4.4531e-01,  ...,  2.7841e-01,
          -4.3470e-01, -1.8660e+00],
         [-3.6542e-01, -5.9699e-01,  7.9428e-01,  ...,  4.4049e-01,
           4.8307e-01, -1.9065e+00],
         [-7.4569e-01,  8.6501e-01,  4.6751e-01,  ...,  3.3038e-01,
           9.0526e-01, -2.4736e+00],
         ...,
         [ 5.6839e-02, -5.0746e-01,  1.7135e-01,  ...,  8.8492e-01,
          -1.2692e-01,  2.3672e-01],
         [-6.4457e-02, -2.1929e-01, -1.1233e-03,  ...,  9.9467e-01,
          -2.6303e-01,  1.6834e-01],
         [-1.8695e-01,  3.2099e-01, -1.0913e-01,  ...,  1.0678e+00,
          -2.9235e-01,  6.2301e-02]]], grad_fn=<NativeLayerNormBackward0>)

In [11]:
from jiwer import wer

reference = []
hypothesis = []
#model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large").cuda()
model = WhisperForConditionalGeneration.from_pretrained("/scratch/users/gmenon/whisper-large-dali/checkpoint-2000").cuda()
for sample_num in range(507):
    transcription = ""
    sample = test_audio_dataset[sample_num]["audio"]
    input_features = processor(sample["array"], sampling_rate=16000, return_tensors="pt").input_features.cuda()
    predicted_ids = model.generate(input_features)
    #transcription = processor.batch_decode(predicted_ids, skip_special_tokens=False)
    transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
    transcription = str.lstrip(transcription)
    if transcription.startswith('to'):
        transcription= transcription[-len(transcription) +2:]
    print(transcription)
    reference.append(test_audio_dataset[sample_num]['transcription'])
    hypothesis.append(transcription)
    #print(f"Predicted Transcript = {transcription[0]},Actual Transcript = {test_audio_dataset[sample_num]['transcription']}")

error = wer(reference, hypothesis)
print(f"Word Error Rate = {error}")

so learn from your mistakes
 the right line
the truth to be found
she said the way my blue eyes shine
the only one taking all my own
while they are in command
my love of mine
so i never went back
to keep you alive
to stay a while
i love my life
so what you want to know
how can you treat me like a child
it's hard to put the fire out
i'd walk out of my life
open your eyes and wake up
oh what you do to me
love the sorrow it's infall
is where she lies broken inside
closing time open all the doors
many places i have been
the peace of love
i wanna be a hero
i see your face i feel your love
you can make the most of the distance
i'm free and true
is the reason it ain't me
we believe in love
i can feel her on my skin
well before i leave
you make me do so good
we don't wanna fight
so you friends try to tell me
you still can touch my heart
and both of us must try
when i thought that i spoke with the score
won't you die alone in the open rain
i'm just trying to say
i wish i could be there for you


In [91]:
list(pd.read_csv(WAV2VEC2_ARGS.TEST_FILE_PATH)[506:].consolidated_file_path)

['/scratch/users/gmenon/wav_clips/separated/htdemucs/13aa830f017b4f1b9649db51a7cbf7bf/vocals.wav']

In [92]:
WhisperForConditionalGeneration.from_pretrained('openai/whisper-large').model.encoder


WhisperEncoder(
  (conv1): Conv1d(80, 1280, kernel_size=(3,), stride=(1,), padding=(1,))
  (conv2): Conv1d(1280, 1280, kernel_size=(3,), stride=(2,), padding=(1,))
  (embed_positions): Embedding(1500, 1280)
  (layers): ModuleList(
    (0-31): 32 x WhisperEncoderLayer(
      (self_attn): WhisperAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=False)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (activation_fn): GELUActivation()
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (layer_norm): LayerNorm((1280,), eps=