## Goal : Creating a new dataset that can load data. 

1. Download data from DALI.
2. Load it into Huggingface Datasets.
3. Break down the entire audios into smaller six second chunks. Make this parametrizable and create another metadata.csv file that has this information to be passed into the Wav2vec models.
4. Finetune the Wav2vec model with this dataset.
5. Evaluation procedure to make sure the Wav2vec works.



***Note:** If the **.wav** files are too big and taking time, we may have to revert to **.mp3** files. However, in the initial experimentations done, mp3 files are not playable by Ipython.display and hence i omitted using them in the model*


In [None]:
!pip freeze

In [None]:
!pip install dali-dataset
#!pip install sox
#!pip install torchaudio==0.11.0
! pip install torchaudio --upgrade
! pip install datasets --upgrade
#!pip install datasets==1.18.3
#!pip install soundfile==0.12.1
#!pip install --force-reinstall soundfile
# ! pip install spacy
# ! pip install spacy-langdetect
# !python3 -m spacy download en_core_web_sm
!pip install fasttext
! pip install fasttext-langdetect


In [1]:
import datasets
datasets.__version__


'2.12.0'

In [2]:
import os
import DALI as dali_code
import logging
import soundfile
from typing import Dict, Optional, List
import pandas as pd
import numpy as np

logging.basicConfig(filename='app.log', filemode='w', format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', level=logging.INFO)
__all__ = ["DALIDataset"]

class DALIDataset():
    def __init__(self, data_path: str , file_path: Optional[str] = None):
        self._data_path = data_path
        if file_path is None:
            self._file_path = self._data_path + 'audio/'
        else:
            self._file_path = file_path

    @property
    def data_path(self):
        logging.info("Setting the data_path")
        return self._data_path

    @data_path.setter
    def data_path(self, data_path: str):
        logging.info("Setting the data_path")
        self._data_path = data_path

    @property
    def file_path(self):
        logging.info("Setting the data_path")
        return self._file_path

    @file_path.setter
    def file_path(self, file_path: str):
        logging.info("Setting the data_path")
        self._file_path = file_path

    def get_data(self) -> Dict:
        logging.info("Getting the data_path")
        if self._data_path is not None:
            dali_dataset = dali_code.get_the_DALI_dataset(self._data_path, skip=[], keep=[])
            logging.info(f"The DALI dataset has been downloaded")
            return dali_dataset
        else:
            raise TypeError(f"Set the data_path for the location of the DALI datasets; data_path = {self._data_path}")

    def download_data(self) -> NotImplementedError:
        # dali_data = self.get_data()
        # logging.info(f"Downloading the data into the file path = {self._data_path}data/")
        raise NotImplementedError

    def get_info(self) -> pd.DataFrame:
        logging.info(f"Getting the info related to the data from the data_path = {self._data_path}")
        if self._data_path is not None:
            dali_info = dali_code.get_info(self._data_path + 'info/DALI_DATA_INFO.gz')
            dali_df = pd.DataFrame(dali_info)[1:]
            dali_df.columns = dali_info[0]
            logging.info(f"The DALI dataset has {len(dali_info)} rows in it")
            return dali_df
        else:
            raise TypeError(f"Set the data_path for the location of the DALI datasets; data_path = {self._data_path}")

    def download_info(self) -> None:
        dali_df = self.get_info()
        logging.info(f"Downloading to the file path = {self._data_path}info/ ")
        dali_df.to_csv(self._data_path + 'info/dali_info.csv')
        logging.info(f"Download complete in the file path = {self._data_path}info/ ")

    def download_audio(self) -> List:
        logging.info("Downloading audio from youtube URLs associated with the info file")
        if self._data_path is not None or self._file_path is not None:
            dali_info = self.get_info()
            logging.info(f"The DALI Audio download has {len(dali_info)} errors in it")
            return dali_code.get_audio(dali_info, self._file_path, skip=[], keep=[])
        else:
            raise TypeError(f"Set the data_path & file_path for the location of the DALI datasets; "
                            f"data_path = {self._data_path}, file_path = {self._file_path}")






In [9]:
print(soundfile.__libsndfile_version__)


1.2.0


In [7]:
dali = DALIDataset(data_path="/home/users/gmenon/dali/DALI_v1.0/")

In [9]:
dali_dataset = dali.get_data()

## Explore the DALI Dataset

In this particular section, I will be going through what the DALI dataset has to offer and understand how it relates with one particular song that is being played. Initial experiment i ran took the song in audacity and me running through the lyrics and information from the following info and annotations. It helped me to understand what is happening and what the data within the particular cells mean.

**Goal** -  *I would like to have a spectogram with the entire song chunks that can then be used to break down the song into lines and then have the lyrics associated with them in them*

In [None]:
dali_dataset["e186227bb7474fa5a7738c9108f11972"].info



In [None]:
for k in dali_dataset["e186227bb7474fa5a7738c9108f11972"].annotations["annot"]["lines"][:5]:
    print(f"frequency = {k['freq']}, time = {k['time']}, text = {k['text']}")

In [None]:
lyrics = []
for k in dali_dataset["e186227bb7474fa5a7738c9108f11972"].annotations["annot"]["paragraphs"]:
    lyrics.append(k["text"])
"".join(lyrics)

In [None]:
BASE_AUDIO_PATH = "/home/users/gmenon/dali/DALI_v1.0/audio/wav_data/"
    
AUDIO_DOWNLOADED_FROM_YOUTUBE = [files.split('.')[0] for files in os.listdir(BASE_AUDIO_PATH)]
AUDIO_DOWNLOADED_FROM_YOUTUBE[:5]

In [None]:
# import json
# with open("/home/users/gmenon/dali/DALI_v1.0/info/dali_dataset.json", "w") as outfile:
#     json.dumps(dali_dataset, indent = 4)

In [None]:
dali_info = pd.read_csv("/home/users/gmenon/dali/DALI_v1.0/info/dali_info.csv") \
              .filter(["DALI_ID","NAME","YOUTUBE","WORKING"])

In [None]:
_replace_dict = {True: "Y", False: "N"}  
dali_info["AUDIO_DWNLD"] = dali_info["DALI_ID"].isin(AUDIO_DOWNLOADED_FROM_YOUTUBE).replace(_replace_dict)
dali_info["AUDIO_PATH"] = BASE_AUDIO_PATH + dali_info["DALI_ID"] + ".wav"
dali_info["AUDIO_DWNLD"].value_counts()

In [None]:
dali_info = dali_info.loc[dali_info['AUDIO_DWNLD'] == 'Y']

In [None]:
# Check whether the path looks good . It looks good
dali_info['AUDIO_PATH'][9]

In [None]:
dali_info

In [None]:
song_metadata = pd.read_csv("/home/users/gmenon/dali/DALI_v1.0/audio/wav_data/metadata.csv",delimiter= ",")
song_metadata.transcription[:10]

In [None]:
import os
import librosa
import warnings
BASE_AUDIO_PATH = "/home/users/gmenon/dali/DALI_v1.0/audio/wav_clips/"
corrupt_files = []
file_list = os.listdir(BASE_AUDIO_PATH)
for wav_file in file_list:
    try:
        librosa.load(BASE_AUDIO_PATH + wav_file)
    except:
        corrupt_files.append(BASE_AUDIO_PATH + wav_file)

len(corrupt_files)




In [None]:
song_metadata = pd.read_csv("/home/users/gmenon/dali/DALI_v1.0/audio/wav_clips/metadata.csv",delimiter= ",")
song_metadata['file_name'] = "/home/users/gmenon/dali/DALI_v1.0/audio/wav_clips/" + song_metadata["file_name"]
song_metadata["transcription"] = song_metadata["transcription"].astype(str)
song_metadata.shape

In [None]:
#https://www.adamsmith.haus/python/answers/how-to-filter-a-pandas-dataframe-with-a-list-by-%60in%60-or-%60not-in%60-in-python

song_metadata = song_metadata[~song_metadata.file_name.isin(corrupt_files)]
song_metadata.shape

In [11]:
song_metadata = pd.read_csv("home/users/gmenon/notebooks/song_metadata_cleaned.csv")
from ftlangdetect import detect

def is_english(text: str) -> str:
    """Simple Function to detect whether the language of the text is english or not.
    Returns a Boolean output for the same.Input the text column"""
    return detect(text, low_memory=True)["lang"]

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\%\$\&\^\*\@\#\<\>\/\+\\=\_\\}\{\)\(\]\[\`1234567890]'
song_metadata["transcription"] = song_metadata["transcription"].replace(chars_to_ignore_regex, '', regex=True)
song_metadata = song_metadata[song_metadata.transcription.str.len()>8]
song_metadata["language"] = song_metadata["transcription"].apply(is_english)
song_metadata = song_metadata.loc[song_metadata["language"] == "en"]
song_metadata = song_metadata.head(20000)
song_metadata



Unnamed: 0.1,Unnamed: 0,file_name,transcription,language
90,90,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,betrogen kluge worte was,en
115,115,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,boy my life,en
116,116,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,ain't what it used to be,en
118,118,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,since you went out the door,en
119,119,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,all the times,en
...,...,...,...,...
29793,29862,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,don't blame this,en
29794,29863,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,sleeping satellite,en
29795,29864,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,did we fly,en
29796,29865,/home/users/gmenon/dali/DALI_v1.0/audio/wav_cl...,to the moon too soon,en


In [12]:
# import spacy
# from spacy.language import Language
# from spacy_langdetect import LanguageDetector

# def is_english(text: str) -> bool:
#     """Simple Function to detect whether the language of the text is english or not.
#     Returns a Boolean output for the same.Input the text column"""
    
#     def get_lang_detector(nlp, name):
#         return LanguageDetector(seed=42)  # We use the seed 42
    
#     nlp = spacy.load("en_core_web_sm")
#     import en_core_web_sm
#     nlp = en_core_web_sm.load()
#     nlp.add_pipe(factory_name="language_detector")
#     doc = nlp(text)
#     return doc._.language    

In [13]:
from datasets import load_dataset,Dataset, Audio
#audio_dataset = Dataset.from_dict({"audio": list(dali_info['AUDIO_PATH'])}).cast_column("audio", Audio())

audio_dataset = Dataset.from_dict({"audio": list(song_metadata["file_name"]), "transcription": list(song_metadata["transcription"])}).cast_column("audio", Audio(sampling_rate=16000))
audio_dataset = audio_dataset.train_test_split(test_size=0.2, shuffle=True)

In [14]:
import torch
torch.cuda.empty_cache()

In [15]:
# from datasets import load_dataset,Dataset, Audio
# #audio_dataset = Dataset.from_dict({"audio": list(dali_info['AUDIO_PATH'])}).cast_column("audio", Audio())

# audio_dataset = Dataset.from_dict({"audio": ["/home/users/gmenon/0a0c413b5290497c96d5327e2ef2ad8d.wav"]}).cast_column("audio", Audio(sampling_rate=16000))
# #audio_dataset = audio_dataset.train_test_split(test_size=0.2, shuffle=True)
# audio_dataset

In [16]:
#np.asarray("/home/users/gmenon/0a0c413b5290497c96d5327e2ef2ad8d.wav")

In [17]:
audio_dataset["train"][:6]["audio"]

[{'path': '/home/users/gmenon/dali/DALI_v1.0/audio/wav_clips/d94ade3cb26f40f28b6aa5fbccace99c.wav',
  'array': array([0.15964612, 0.24435514, 0.11763921, ..., 0.07339463, 0.10085566,
         0.164987  ]),
  'sampling_rate': 16000},
 {'path': '/home/users/gmenon/dali/DALI_v1.0/audio/wav_clips/830431a2f2504bc2abebe6df72b2f4c5.wav',
  'array': array([ 0.1501548 ,  0.12344176,  0.118747  , ..., -0.02714896,
         -0.09638148, -0.15420352]),
  'sampling_rate': 16000},
 {'path': '/home/users/gmenon/dali/DALI_v1.0/audio/wav_clips/946f56527d314bd9874358591888562f.wav',
  'array': array([ 0.56786728,  0.57429814,  0.22644949, ..., -0.35814488,
         -0.66134751, -0.57367814]),
  'sampling_rate': 16000},
 {'path': '/home/users/gmenon/dali/DALI_v1.0/audio/wav_clips/d8cc99e52f63412c9043a24b62d5c878.wav',
  'array': array([0.32682917, 0.62510657, 0.58776152, ..., 0.23291638, 0.15553989,
         0.04170347]),
  'sampling_rate': 16000},
 {'path': '/home/users/gmenon/dali/DALI_v1.0/audio/wav_c

In [18]:
import IPython.display as ipd
ipd.Audio(data=np.asarray(audio_dataset["train"][0]["audio"]["array"]), autoplay=True, rate=16000)

In [19]:
# np.asarray(audio_dataset[0]["audio"])

In [20]:
# import IPython.display as ipd
# import numpy as np
# import random

# print(audio_dataset["train"][rand_int]["transcription"])
# ipd.Audio(data=np.asarray(audio_dataset["train"][rand_int]["audio"]["array"]), autoplay=True, rate=16000)


### FINETUNE WAV2VEC USING THIS NEW DATASET

This will be interesting as we will be finetuning the dataset of Wav2vec2.0 speech recognition using the new work we are doing for Songs to Lyrics generation.

In [21]:
from datasets import ClassLabel
import random
import pandas as pd
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    display(HTML(df.to_html()))

In [22]:
show_random_elements(audio_dataset["train"].remove_columns("audio"))

Unnamed: 0,transcription
0,we're laughing way too proud
1,i've been down onto my knees
2,you greedy little bastard
3,ding dong ding dong ding dong ding
4,and high above the stars
5,but if it helps you to mend
6,through all the summers
7,ripped all the flowers in the garden
8,it's hard to let your children go
9,i hate these thoughts i can't deny


In [23]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\%\$\&\^\*\@\#\<\>\/\+\\=\_\\}\{\)\(\]\[\`1234567890]'

def remove_special_characters(batch):
  #  batch["transcription"] = re.sub(chars_to_ignore_regex, '', batch["transcription"]).lower() + ' '
    batch["transcription"] = re.sub(chars_to_ignore_regex, '', batch["transcription"]).upper()
    return batch

In [24]:
audio_dataset = audio_dataset.map(remove_special_characters)

Map:   0%|          | 0/16000 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

In [25]:
def extract_all_chars(batch):
    all_text = " ".join(batch["transcription"])
    vocab = list(set(all_text))
    return {"vocab": [vocab], "all_text": [all_text]}

In [26]:
vocabs = audio_dataset.map(extract_all_chars, batched=True, batch_size=-1, keep_in_memory=True, remove_columns=audio_dataset.column_names["train"])

Map:   0%|          | 0/16000 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

In [27]:
vocab_list = list(set(vocabs["train"]["vocab"][0]) | set(vocabs["test"]["vocab"][0]))

In [28]:
vocab_dict = {v: k for k, v in enumerate(vocab_list)}
vocab_dict

{'C': 0,
 'R': 1,
 'K': 2,
 'P': 3,
 'O': 4,
 'Y': 5,
 'M': 6,
 'J': 7,
 'H': 8,
 'T': 9,
 'D': 10,
 'U': 11,
 'X': 12,
 'L': 13,
 ' ': 14,
 'I': 15,
 'B': 16,
 'Q': 17,
 'E': 18,
 "'": 19,
 'W': 20,
 'F': 21,
 'V': 22,
 'S': 23,
 'Z': 24,
 'N': 25,
 'A': 26,
 'G': 27}

In [29]:

vocab_dict["|"] = vocab_dict[" "]
del vocab_dict[" "]

In [30]:
vocab_dict["[UNK]"] = len(vocab_dict)
vocab_dict["[PAD]"] = len(vocab_dict)
len(vocab_dict)

30

In [31]:
import json
with open('vocab.json', 'w') as vocab_file:
    json.dump(vocab_dict, vocab_file)

In [32]:
from transformers import Wav2Vec2CTCTokenizer

tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")

In [33]:
from transformers import Wav2Vec2FeatureExtractor

feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)


In [34]:
from transformers import Wav2Vec2Processor

processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [35]:
import IPython.display as ipd
import numpy as np
import random

rand_int = random.randint(0, len(audio_dataset["train"]))

print(audio_dataset["train"][rand_int]["transcription"])
ipd.Audio(data=np.asarray(audio_dataset["train"][rand_int]["audio"]["array"]), autoplay=True, rate=16000)

FOR THE CHANCE TO LEAVE ME


In [36]:
rand_int = random.randint(0, len(audio_dataset["train"]))

print("Target text:", audio_dataset["train"][rand_int]["transcription"])
print("Input array shape:", np.asarray(audio_dataset["train"][rand_int]["audio"]["array"]).shape)
print("Sampling rate:", audio_dataset["train"][rand_int]["audio"]["sampling_rate"])

Target text: REACH UP FOR THE SUNRISE
Input array shape: (41575,)
Sampling rate: 16000


In [37]:
def prepare_dataset(batch):
    audio = batch["audio"]

    # batched output is "un-batched" to ensure mapping is correct
    batch["input_values"] = processor(audio["array"], sampling_rate=audio["sampling_rate"]).input_values[0]
    batch["input_length"] = len(batch["input_values"])
    
    with processor.as_target_processor():
        batch["labels"] = processor(batch["transcription"]).input_ids
    return batch

In [38]:
audio_dataset = audio_dataset.map(prepare_dataset, \
                                           remove_columns=audio_dataset.column_names["train"], \
                                           num_proc=1)


Map:   0%|          | 0/16000 [00:00<?, ? examples/s]

Map:   0%|          | 0/4000 [00:00<?, ? examples/s]

In [39]:
audio_dataset

DatasetDict({
    train: Dataset({
        features: ['input_values', 'input_length', 'labels'],
        num_rows: 16000
    })
    test: Dataset({
        features: ['input_values', 'input_length', 'labels'],
        num_rows: 4000
    })
})

In [40]:
max_input_length_in_sec = 6
audio_dataset["train"] = audio_dataset["train"].filter(lambda x: x < max_input_length_in_sec * processor.feature_extractor.sampling_rate, input_columns=["input_length"])

Filter:   0%|          | 0/16000 [00:00<?, ? examples/s]

In [41]:
import torch

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Union

@dataclass
class DataCollatorCTCWithPadding:
    """
    Data collator that will dynamically pad the inputs received.
    Args:
        processor (:class:`~transformers.Wav2Vec2Processor`)
            The processor used for proccessing the data.
        padding (:obj:`bool`, :obj:`str` or :class:`~transformers.tokenization_utils_base.PaddingStrategy`, `optional`, defaults to :obj:`True`):
            Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
            among:
            * :obj:`True` or :obj:`'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
              sequence if provided).
            * :obj:`'max_length'`: Pad to a maximum length specified with the argument :obj:`max_length` or to the
              maximum acceptable input length for the model if that argument is not provided.
            * :obj:`False` or :obj:`'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of
              different lengths).
    """

    processor: Wav2Vec2Processor
    padding: Union[bool, str] = True

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        # split inputs and labels since they have to be of different lenghts and need
        # different padding methods
        input_features = [{"input_values": feature["input_values"]} for feature in features]
        label_features = [{"input_ids": feature["labels"]} for feature in features]

        batch = self.processor.pad(
            input_features,
            padding=self.padding,
            return_tensors="pt",
        )
        with self.processor.as_target_processor():
            labels_batch = self.processor.pad(
                label_features,
                padding=self.padding,
                return_tensors="pt",
            )

        # replace padding with -100 to ignore loss correctly
        labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)

        batch["labels"] = labels

        return batch

In [42]:
data_collator = DataCollatorCTCWithPadding(processor=processor, padding=True)

In [43]:
from datasets import load_dataset, load_metric

wer_metric = load_metric("wer")

  wer_metric = load_metric("wer")


In [44]:
def compute_metrics(pred):
    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    # we do not want to group tokens when computing the metrics
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [45]:
from transformers import Wav2Vec2ForCTC

model = Wav2Vec2ForCTC.from_pretrained(
    "facebook/wav2vec2-large-960h-lv60-self",
    ctc_loss_reduction="mean", 
    pad_token_id=processor.tokenizer.pad_token_id,
)

Some weights of Wav2Vec2ForCTC were not initialized from the model checkpoint at facebook/wav2vec2-large-960h-lv60-self and are newly initialized: ['wav2vec2.masked_spec_embed']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [46]:
model.freeze_feature_encoder()

In [47]:
from transformers import TrainingArguments

training_args = TrainingArguments(
  output_dir="songstolyrics_wav2vec",
  group_by_length=True,
  per_device_train_batch_size=8,
  evaluation_strategy="steps",
  num_train_epochs=30,
  fp16=True,
  gradient_checkpointing=True,
  save_steps=500,
  eval_steps=500,
  logging_steps=500,
  learning_rate=1e-4,
  weight_decay=0.005,
  warmup_steps=1000,
  save_total_limit=2,
)

In [48]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    data_collator=data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=audio_dataset["train"],
    eval_dataset=audio_dataset["test"],
    tokenizer=processor.feature_extractor,
)

Using amp half precision backend


In [49]:
audio_dataset["test"]

Dataset({
    features: ['input_values', 'input_length', 'labels'],
    num_rows: 4000
})

In [None]:
trainer.train(resume_from_checkpoint = False)

The following columns in the training set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 15634
  Num Epochs = 30
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 58650


Step,Training Loss,Validation Loss


The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 4000
  Batch size = 8
Saving model checkpoint to songstolyrics_wav2vec/checkpoint-500
Configuration saved in songstolyrics_wav2vec/checkpoint-500/config.json
Model weights saved in songstolyrics_wav2vec/checkpoint-500/pytorch_model.bin
Feature extractor saved in songstolyrics_wav2vec/checkpoint-500/preprocessor_config.json
Deleting older checkpoint [songstolyrics_wav2vec/checkpoint-50500] due to args.save_total_limit
The following columns in the evaluation set  don't have a corresponding argument in `Wav2Vec2ForCTC.forward` and have been ignored: input_length. If input_length are not expected by `Wav2Vec2ForCTC.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num exam

In [None]:
trainer.save_state()

In [None]:
trainer.save_model("home/users/gmenon/songstolyrics_wav2vec_finetuned/")

In [None]:
processor.save_pretrained("home/users/gmenon/songstolyrics_wav2vec_processor/")

## LOAD FROM ABOVE TRAINED MODELS AND EVALUATE RESULTS

In [None]:
processor = Wav2Vec2Processor.from_pretrained("home/users/gmenon/songstolyrics_wav2vec_processor/",local_files_only=True)

In [None]:
model = Wav2Vec2ForCTC.from_pretrained("home/users/gmenon/songstolyrics_wav2vec_finetuned/",local_files_only=True).cuda()

In [None]:
# from transformers import Wav2Vec2FeatureExtractor
# from transformers import Wav2Vec2Processor

# tokenizer = Wav2Vec2CTCTokenizer("./vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|")
# feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
# processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

In [None]:
def map_to_result(batch):
    with torch.no_grad():
        input_values = torch.tensor(batch["input_values"], device="cuda").unsqueeze(0)
        logits = model(input_values).logits
    pred_ids = torch.argmax(logits, dim=-1)
    batch["pred_str"] = processor.batch_decode(pred_ids)[0]
    batch["transcription"] = processor.decode(batch["labels"], group_tokens=False)
    return batch


In [None]:
import torch, torchaudio
results = audio_dataset["test"].map(map_to_result, remove_columns=audio_dataset["test"].column_names)
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["transcription"])))


----

In [None]:
pd.DataFrame(results[0:40])

In [None]:
import torch, torchaudio
results = audio_dataset["train"].map(map_to_result, remove_columns=audio_dataset["train"].column_names)
print("Test WER: {:.3f}".format(wer_metric.compute(predictions=results["pred_str"], references=results["transcription"])))
