# **Fine-tuning XLSR-Wav2Vec2 for Multi-Lingual ASR with Transformers**

In [None]:
!nvidia-smi

Wed Apr 10 10:03:51 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  Tesla T4                       Off | 00000000:00:04.0 Off |                    0 |
| N/A   50C    P8              12W /  70W |      0MiB / 15360MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [None]:
%%capture
!pip install torchaudio
!pip install librosa
!pip install jiwer
!pip install datasets
!pip install pyarrow==12.0.0
!pip install samplerate
!pip install resampy
!pip install transformers[torch]
!pip install accelerate -U
!pip install sagemaker
!pip install ffmpeg-python
!pip install Js2Py
!pip install textdistance
!pip install audio_similarity==1.0.0

In [None]:
import os
import pandas as pd
import zipfile
from datasets import Dataset
from google.colab import drive

In [None]:
drive.mount('/content/drive/')

Mounted at /content/drive/


In [None]:
path_to_dir = '/content/drive/My Drive/Unisys/Unisys/container_0'
path_to_dir_male = '/content/drive/My Drive/Unisys/Unisys/container_0/kn_male'
path_to_dir_female = '/content/drive/My Drive/Unisys/Unisys/container_0/kn_female'
paths = []
texts = []

In [None]:
line_df_male = pd.read_csv(os.path.join(path_to_dir,'/content/drive/My Drive/Unisys/Unisys/container_0/line_index_male.tsv'),sep="\t", header=None)
fils = os.listdir(path_to_dir_male)
for fil in fils:
    if '.wav' in fil:
        paths.append(os.path.join(path_to_dir_male,fil))
        texts.append(line_df_male[line_df_male[0]==fil.split('.')[0]][1].values[0])

In [None]:
line_df_female = pd.read_csv(os.path.join(path_to_dir,'/content/drive/My Drive/Unisys/Unisys/container_0/line_index_female.tsv'),sep="\t", header=None)
fils = os.listdir(path_to_dir_female)
for fil in fils:
    if '.wav' in fil:
        paths.append(os.path.join(path_to_dir_female,fil))
        texts.append(line_df_female[line_df_female[0]==fil.split('.')[0]][1].values[0])

In [None]:
data = Dataset.from_dict({"path":paths,"text":texts})

In [None]:
data = data.shuffle(seed=42)

In [None]:
dataset = data.train_test_split(test_size=0.2,seed=42)

In [None]:
train_dataset = dataset['train']
test_dataset = dataset['test']

In [None]:
import re
chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\–\…]'

def remove_special_characters(batch):
    batch["text"] = re.sub(chars_to_ignore_regex, '', batch["text"]).lower() + " "
    return batch

In [None]:
train_dataset = train_dataset.map(remove_special_characters)
test_dataset = test_dataset.map(remove_special_characters)

Map:   0%|          | 0/3520 [00:00<?, ? examples/s]

Map:   0%|          | 0/880 [00:00<?, ? examples/s]

In [None]:
import torch
import torchaudio
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor

processor = Wav2Vec2Processor.from_pretrained("/content/drive/My Drive/Unisys/Unisys/container_0/wav2vec2-large-xlsr-kn")
model = Wav2Vec2ForCTC.from_pretrained("/content/drive/My Drive/Unisys/Unisys/container_0/wav2vec2-large-xlsr-kn")

resampler = torchaudio.transforms.Resample(48_000, 16_000)

def speech_file_to_array_fn(batch):
    speech_array, sampling_rate = torchaudio.load(batch["path"])
    batch["speech"] = resampler(speech_array).squeeze().numpy()
    return batch


test_dataset = test_dataset.map(speech_file_to_array_fn)
inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
with torch.no_grad():
    logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits

predicted_ids = torch.argmax(logits, dim=-1)
print("Prediction:", processor.batch_decode(predicted_ids))
print("Reference:", test_dataset["text"][:2])

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Map:   0%|          | 0/880 [00:00<?, ? examples/s]

Prediction: ['ಇದವಳೆಗೆ ಸರಿಯಾಗಿ ಬರುತ್ತದೆ', 'ನಿವು ತಲುಪಿದ್ದೆರಿ']
Reference: ['ಇದೇ ವೇಳೆಗೆ ಸರಿಯಾಗಿ ಬರುತ್ತದೆ ', 'ನೀವು ತಲುಪಿದ್ದೀರಿ ']


## Test with your own Audio

In [None]:
AUDIO_HTML = """
<script>
var my_div = document.createElement("DIV");
var my_p = document.createElement("P");
var my_btn = document.createElement("BUTTON");
var t = document.createTextNode("Press to start recording");

my_btn.appendChild(t);
//my_p.appendChild(my_btn);
my_div.appendChild(my_btn);
document.body.appendChild(my_div);

var base64data = 0;
var reader;
var recorder, gumStream;
var recordButton = my_btn;

var handleSuccess = function(stream) {
  gumStream = stream;
  var options = {
    //bitsPerSecond: 8000, //chrome seems to ignore, always 48k
    mimeType : 'audio/webm;codecs=opus'
    //mimeType : 'audio/webm;codecs=pcm'
  };
  //recorder = new MediaRecorder(stream, options);
  recorder = new MediaRecorder(stream);
  recorder.ondataavailable = function(e) {
    var url = URL.createObjectURL(e.data);
    var preview = document.createElement('audio');
    preview.controls = true;
    preview.src = url;
    document.body.appendChild(preview);

    reader = new FileReader();
    reader.readAsDataURL(e.data);
    reader.onloadend = function() {
      base64data = reader.result;
      //console.log("Inside FileReader:" + base64data);
    }
  };
  recorder.start();
  };

recordButton.innerText = "Recording... press to stop";

navigator.mediaDevices.getUserMedia({audio: true}).then(handleSuccess);

function toggleRecording() {
  if (recorder && recorder.state == "recording") {
      recorder.stop();
      gumStream.getAudioTracks()[0].stop();
      recordButton.innerText = "Saving the recording... pls wait!"
  }
}

// https://stackoverflow.com/a/951057
function sleep(ms) {
  return new Promise(resolve => setTimeout(resolve, ms));
}

var data = new Promise(resolve=>{
//recordButton.addEventListener("click", toggleRecording);
recordButton.onclick = ()=>{
toggleRecording()

sleep(2000).then(() => {
  // wait 2000ms for the data to be available...
  // ideally this should use something like await...
  //console.log("Inside data:" + base64data)
  resolve(base64data.toString())

});

}
});

</script>
"""

In [None]:
from IPython.display import HTML, Audio
from js2py import eval_js
from base64 import b64decode
import wave
from scipy.io.wavfile import read as wav_read
import io
import numpy as np
import ffmpeg

def write_wav(f, sr, x, normalized=False):
    f = wave.open(f, "wb")
    f.setnchannels(1)
    f.setsampwidth(2)
    f.setframerate(sr)

    wave_data = x.astype(np.short)
    f.writeframes(wave_data.tobytes())
    f.close()

def get_audio():
  global hnum

  # call microphone
  display(HTML(AUDIO_HTML))
  data = eval_js('data')
  binary = b64decode(data.split(',')[1])

  process = (ffmpeg
      .input('pipe:0')
      .output('pipe:1', format='wav')
      .run_async(pipe_stdin=True, pipe_stdout=True, pipe_stderr=True, quiet=True, overwrite_output=True)
  )
  output, err = process.communicate(input=binary)

  riff_chunk_size = len(output) - 8
  # Break up the chunk size into four bytes, held in b.
  q = riff_chunk_size
  b = []
  for i in range(4):
      q, r = divmod(q, 256)
      b.append(r)

  # Replace bytes 4:8 in proc.stdout with the actual size of the RIFF chunk.
  riff = output[:4] + bytes(b) + output[8:]
  sr, audio = wav_read(io.BytesIO(riff))
  # save
  human_sound_file = "/content/drive/My Drive/Unisys/demo1.wav"
  write_wav(human_sound_file, sr, audio)

  return human_sound_file

In [None]:
import torchaudio
from datasets import load_dataset, load_metric
from transformers import (
    Wav2Vec2ForCTC,
    Wav2Vec2Processor,
    AutoTokenizer,
    AutoModelWithLMHead
)
import torch
import re
import sys

model_name = "/content/drive/My Drive/Unisys/Unisys/container_0/wav2vec2-large-xlsr-kn"
device = "cuda"
processor_name = "/content/drive/My Drive/Unisys/Unisys/container_0/wav2vec2-large-xlsr-kn"

chars_to_ignore_regex = '[\,\?\.\!\-\;\:\"\“\%\‘\”\�\–\…]'

model = Wav2Vec2ForCTC.from_pretrained(model_name).to(device)
processor = Wav2Vec2Processor.from_pretrained(processor_name)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
resampler = torchaudio.transforms.Resample(orig_freq=48_000, new_freq=16_000)

def load_file_to_data(file):
    batch = {}
    speech, _ = torchaudio.load(file)
    batch["speech"] = resampler.forward(speech.squeeze(0)).numpy()
    batch["sampling_rate"] = resampler.new_freq
    return batch

def predict(data):
    features = processor(data["speech"], sampling_rate=data["sampling_rate"], padding=True, return_tensors="pt")
    input_values = features.input_values.to(device)
    attention_mask = features.attention_mask.to(device)
    with torch.no_grad():
        logits = model(input_values, attention_mask=attention_mask).logits

    decoded_results = []
    for logit in logits:
        pred_ids = torch.argmax(logit, dim=-1)
        decoded_results.append(processor.decode(pred_ids))
    return decoded_results


In [None]:
get_audio()

JsException: ReferenceError: data is not defined

In [None]:
predicted_output = predict(load_file_to_data('/content/drive/My Drive/Unisys/md.wav'))

In [None]:
print(predicted_output)

['ಪಿರಗಲೆು']


In [None]:
import textdistance
key = predicted_output[0]

limited_test_dataset = line_df_female[:100]

nearest_match = None
max_similarity = 0

In [None]:
#for index, row in limited_test_dataset.iterrows():
#  print(row[1])

In [None]:
max_similarity = 0
nearest_match = None
p = -1

for index, row in limited_test_dataset.iterrows():
    similarity = textdistance.levenshtein.normalized_similarity(key.lower(), row[1].lower())
    if similarity > max_similarity:
        max_similarity = similarity
        nearest_match = row[1]
        p = index

if nearest_match and max_similarity > 0.1:
    print("Nearest match found in text:", nearest_match)
    m = limited_test_dataset.loc[p][0]
    audio_path = "/content/drive/My Drive/Unisys/Unisys/container_0/kn_female/" +m + ".wav"
    print("Audio path:", audio_path)
else:
    print("No sufficiently similar match found.")


Nearest match found in text: ಪ್ರಯೋಗಶೀಲತೆ
Audio path: /content/drive/My Drive/Unisys/Unisys/container_0/kn_female/knf_08476_00919951349.wav


In [None]:
from IPython.display import Audio
Audio(audio_path)


In [None]:
from collections import Counter
import math

def cosine_similarity(str1, str2):
    vec1 = Counter(str1)
    vec2 = Counter(str2)
    intersection = sum(vec1[key] * vec2[key] for key in vec1 if key in vec2)
    norm1 = math.sqrt(sum(val**2 for val in vec1.values()))
    norm2 = math.sqrt(sum(val**2 for val in vec2.values()))
    return intersection / (norm1 * norm2)


cosine_sim = cosine_similarity(key, nearest_match)
print("Cosine similarity:", cosine_sim)


Cosine similarity: 0.7051102404077105


In [None]:
from audio_similarity import AudioSimilarity

# Paths to the original and compariosn audio files/folders

original_path = '/content/drive/My Drive/Unisys/Unisys/container_0/kn_female/knf_05550_01969100159.wav'
generated_path = '/content/drive/My Drive/Unisys/demo.wav'

# Set the sample rate and weights for the metrics

sample_rate = 16000
weights = {
    'zcr_similarity': 0.2,
    'rhythm_similarity': 0.2,
    'chroma_similarity': 0.2,
    'energy_envelope_similarity': 0.1,
    'spectral_contrast_similarity': 0.1,
    'perceptual_similarity': 0.2
}

audio_similarity = AudioSimilarity(original_path, generated_path, sample_rate, weights)

similarity_score = audio_similarity.stent_weighted_audio_similarity()

print(f"Stent Weighted Audio Similarity: {similarity_score}")

Loading original files:: 100%|██████████| 1/1 [00:00<00:00, 164.61it/s]
Loading comparison files:: 100%|██████████| 1/1 [00:00<00:00, 88.10it/s]


Stent Weighted Audio Similarity: 0.6851944976699905
