In [1]:
import os
import json
import numpy as np
import pandas as pd
from tqdm.auto import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from transformers import Trainer
from datasets import load_metric, Audio
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    Wav2Vec2CTCTokenizer,
    Wav2Vec2FeatureExtractor
) 
import re
import librosa

import warnings
warnings.simplefilter('ignore')
from transformers.trainer_utils import get_last_checkpoint, is_main_process
from transformers.utils import check_min_version, send_example_telemetry
from transformers.utils.versions import require_version
from transformers import TrainingArguments


wer_metric = load_metric("wer")


2023-08-30 01:15:47.681662: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class BengaliDataset(Dataset):
    def __init__(self, df):
        self.df = df
        self.processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, 
    tokenizer=tokenizer)
    
    def __getitem__(self, idx):
        # First read and pre-process the audio file
        audio = self.read_audio(self.df.loc[idx]['path'])
        audio = processor(
            audio, 
            sampling_rate=16000
        ).input_values[0]
        
        with processor.as_target_processor():
            labels = processor(self.df.loc[idx]['sentence']).input_ids
        return {'input_values': audio, 'labels': labels}
        
    def __len__(self):
        return len(self.df)
    
    def read_audio(self, mp3_path):
        target_sr = 16000  # Set the target sampling rate
        
        audio, sr = librosa.load(mp3_path, sr=None) 
        audio_array = librosa.resample(audio, orig_sr=sr, target_sr=target_sr)
        
        return audio_array

In [3]:
def save_vocab(dataframe):
    """
    Saves the processed vocab file as 'vocab.json', to be ingested by tokenizer
    """
    vocab = construct_vocab(dataframe['sentence'].tolist())
    vocab_dict = {v: k for k, v in enumerate(vocab)}
    vocab_dict["__"] = vocab_dict[" "]
    _ = vocab_dict.pop(" ")
    vocab_dict["[UNK]"] = len(vocab_dict)
    vocab_dict["[PAD]"] = len(vocab_dict)
    
    print(vocab_dict)
    
    new_vocab_dict = {"ben": vocab_dict}


    with open('vocab.json', 'w') as fl:
        json.dump(new_vocab_dict, fl)


def ctc_data_collator(batch):
    """
    data collator function to dynamically pad the data
    """
    input_features = [{"input_values": sample["input_values"]} for sample in batch]
    label_features = [{"input_ids": sample["labels"]} for sample in batch]
    batch = processor.pad(
        input_features,
        padding=True,
        return_tensors="pt",
    )
    with processor.as_target_processor():
        labels_batch = processor.pad(
            label_features,
            padding=True,
            return_tensors="pt",
        )
        
    labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
    batch["labels"] = labels
    return batch

def construct_vocab(texts):
    """
    Get unique characters from all the text in a list
    """
    all_text = " ".join(texts)
    vocab = list(set(all_text))
    return vocab
    
### Data cleaning, remove punctuations and lowercase
def remove_special_characters(string):

    chars_to_ignore_regex = ', ? . ! - \; \: \" “ % ” �'
    
    clean_text = re.sub(chars_to_ignore_regex, "", string).lower() + " "
  
    return clean_text


### Word Error Rate (Evaluation Metrics)
def compute_metrics(pred):


    pred_logits = pred.predictions
    pred_ids = np.argmax(pred_logits, axis=-1)

    pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id

    pred_str = processor.batch_decode(pred_ids)
    label_str = processor.batch_decode(pred.label_ids, group_tokens=False)

    wer = wer_metric.compute(predictions=pred_str, references=label_str)

    return {"wer": wer}

In [4]:
output_dir = './checkpoint-mms'
model_name = 'facebook/mms-1b-all'

# Load the CSV file
df = pd.read_csv("/home/ubuntu/bengali/data/train.csv")

df['sentence'] = df['sentence'].apply(lambda x: remove_special_characters(x))

# Add the full path to the audio files
df['path'] = df['id'].apply(lambda x: os.path.join('/home/ubuntu/bengali/data/train_mp3s', x+'.mp3'))

# Filter the dataset
train = df[df['split'] == 'train'].sample(frac=0.0008, random_state=10).reset_index(drop=True)
val = df[df['split'] == 'valid'].sample(frac=0.001, random_state=10).reset_index(drop=True)

print(f"Training on samples: {len(train)}, Validation on samples: {len(val)}")

Training on samples: 747, Validation on samples: 30


In [5]:
save_vocab(df)

{'‘': 0, ';': 1, '।': 2, 'ধ': 3, '!': 4, 'গ': 5, 'ঝ': 6, 'ঐ': 7, 'আ': 8, 'ঃ': 9, 'ফ': 10, '"': 11, 'ু': 12, 'ো': 13, 'ম': 14, 'স': 15, '-': 16, 'ে': 17, 'ষ': 18, '/': 19, 'ঢ়': 20, 'ঠ': 21, 'ঊ': 22, ':': 23, 'ন': 24, 'ড়': 25, 'ও': 26, 'ং': 27, 'ঙ': 28, 'ী': 29, '”': 30, 'ত': 31, '্': 32, 'ঞ': 33, '.': 34, 'ৰ': 35, 'উ': 36, 'া': 37, 'ণ': 38, 'ব': 39, 'থ': 40, 'ভ': 41, 'ূ': 42, 'ঈ': 43, "'": 44, 'ঢ': 45, 'ছ': 46, 'ই': 47, 'ৎ': 48, '“': 49, 'শ': 51, 'জ': 52, '॥': 53, 'ট': 54, 'য': 55, 'ল': 56, 'ড': 57, 'ি': 58, 'দ': 59, 'ক': 60, '–': 61, 'এ': 62, 'হ': 63, 'ৗ': 64, 'ৌ': 65, '—': 66, '‚': 67, 'র': 68, 'ঘ': 69, '়': 70, ',': 71, '৵': 72, 'ঋ': 73, '’': 74, '…': 75, 'ৈ': 76, 'ৃ': 77, 'খ': 78, '৷': 79, 'প': 80, 'য়': 81, 'অ': 82, 'ঁ': 83, 'চ': 84, 'ঔ': 85, '?': 86, '__': 50, '[UNK]': 87, '[PAD]': 88}


In [6]:
from transformers import Wav2Vec2ForCTC, AutoProcessor
import torch

# processor = AutoProcessor.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

tokenizer = Wav2Vec2CTCTokenizer(
    "vocab.json", 
    unk_token="[UNK]",
    pad_token="[PAD]",
    word_delimiter_token="__",
    target_lang = "ben"
)
feature_extractor = Wav2Vec2FeatureExtractor(
    feature_size=1, 
    sampling_rate=16000, 
    padding_value=0.0, 
    do_normalize=True, 
    return_attention_mask=False
)
processor = Wav2Vec2Processor(
    feature_extractor=feature_extractor, 
    tokenizer=tokenizer
)

# model = Wav2Vec2ForCTC.from_pretrained(
#     model_name,
#     ctc_loss_reduction="mean", 
#     pad_token_id=processor.tokenizer.pad_token_id,
#     vocab_size = len(tokenizer),
# )
processor.tokenizer.set_target_lang("ben")
model.load_adapter("ben")

In [7]:
with open('vocab.json', 'w') as fopen:
    json.dump(processor.tokenizer.vocab['ben'], fopen)

In [None]:
train_ds = BengaliDataset(train)
valid_ds = BengaliDataset(val)


model.freeze_feature_encoder()

training_args = TrainingArguments(
        output_dir=output_dir,
        group_by_length=True,
        per_device_train_batch_size=1,
        evaluation_strategy="steps",
        num_train_epochs=1,
        gradient_checkpointing=True,
        fp16=True,
        save_steps=100,
        eval_steps=100,
        logging_steps=10,
        learning_rate=3e-4,
        warmup_steps=500,
        save_total_limit=2,
        do_eval=False
)


trainer = Trainer(
    model=model,
    data_collator=ctc_data_collator,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,   
    eval_dataset=valid_ds,    
    tokenizer=processor.feature_extractor,
)

In [None]:
trainer.train()

In [None]:
python3 run_speech_recognition_ctc.py \
	--model_name_or_path="facebook/mms-1b-all" \
	--output_dir="./mms-1b" \
	--num_train_epochs="1" \
	--per_device_train_batch_size="5" \
	--learning_rate="2e-5" \
	--evaluation_strategy="steps" \
	--save_steps="400" \
	--eval_steps="100" \
	--layerdrop="0.0" \
	--save_total_limit="2" \
	--gradient_checkpointing \
	--fp16 \
    --do_train


In [None]:
~/.local/bin/deepspeed run_speech_recognition.py --deepspeed ds_config_zero3.json

In [None]:
~/.local/bin/deepspeed run_speech_recognition_ctc.py \
--deepspeed ds_config_zero3.json \
--model_name_or_path="./mms-1b-all" \
	--output_dir="./mms-1b" \
	--num_train_epochs="1" \
	--per_device_train_batch_size="16" \
	--learning_rate="2e-5" \
	--evaluation_strategy="steps" \
	--save_steps="400" \
	--eval_steps="400" \
    --logging_steps="50"\
	--layerdrop="0.0" \
	--save_total_limit="2" \
	--gradient_checkpointing \
--do_train 
