In [None]:
# pip install pandas requests tqdm

# Main - API

In [None]:
import os
import requests

import pandas as pd

from tqdm import tqdm

# Define the path to the CSV file and the folder containing MP3 files
print(os.getcwd())

CSV_PATH = r'C:\Users\Featherine\Downloads\HTX xData Assignment\common_voice\cv-valid-dev.csv'
SAVE_PATH = r'C:\Users\Featherine\Downloads\HTX xData Assignment\asr\cv-valid-dev-saved.csv'
AUDIO_FOLDER = r'C:\Users\Featherine\Downloads\HTX xData Assignment\common_voice\cv-valid-dev'
API_URL = 'http://localhost:8001/asr'  # Update this if your API is hosted elsewhere

def transcribe_audio(file_path):
    """
    Sends an audio file to the ASR API and returns the transcription.
    """
    try:
        with open(file_path, 'rb') as audio_file:
            files = {'file': audio_file}
            response = requests.post(API_URL, files=files)
            response.raise_for_status()
            return response.json().get('transcription', '')
    except requests.exceptions.RequestException as e:
        print(f"Error transcribing {file_path}: {e}")
        return ''

def main():
    # Load the CSV file into a DataFrame
    df = pd.read_csv(CSV_PATH)

    # Ensure there's a 'generated_text' column
    if 'generated_text' not in df.columns:
        df['generated_text'] = ''

    # Iterate over each row and transcribe the corresponding audio file
    for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Transcribing audio files"):
        mp3_filename = row.get('filename')
        if not mp3_filename:
            print(f"No filename found for row {index}. Skipping.")
            continue

        audio_path = os.path.join(AUDIO_FOLDER, mp3_filename)
        # print(audio_path)
        if not os.path.isfile(audio_path):
            print(f"File {audio_path} does not exist. Skipping.")
            continue

        # Transcribe the audio file
        transcription = transcribe_audio(audio_path)
        # print(transcription)
        df.at[index, 'generated_text'] = transcription

    # Save the updated DataFrame back to the CSV file
    df.to_csv(SAVE_PATH, index=False)
    print(f"Transcription complete. Updated CSV saved to {SAVE_PATH}.")

if __name__ == '__main__':
    main()


# Main - Manual (to speed things up)

In [None]:
import os
import requests

import pandas as pd
import torchaudio

from tqdm import tqdm
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from jiwer import wer

CSV_PATH = r'C:\Users\Featherine\Downloads\HTX xData Assignment\common_voice\cv-valid-dev.csv'
SAVE_PATH = r'C:\Users\Featherine\Downloads\HTX xData Assignment\asr\cv-valid-dev-saved.csv'
AUDIO_FOLDER = r'C:\Users\Featherine\Downloads\HTX xData Assignment\common_voice\cv-valid-dev'

# Load pre-trained model and tokenizer
model_name = "facebook/wav2vec2-base-960h"
# model_name = "facebook/wav2vec2-large-960h"
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_name)
model = Wav2Vec2ForCTC.from_pretrained(model_name)

# Load the CSV file into a DataFrame
df = pd.read_csv(CSV_PATH)

# Ensure there's a 'generated_text' column
if 'generated_text' not in df.columns:
    df['generated_text'] = ''

# Iterate over each row and transcribe the corresponding audio file
for index, row in tqdm(df.iterrows(), total=df.shape[0], desc="Transcribing audio files"):
    mp3_filename = row.get('filename')
    if not mp3_filename:
        print(f"No filename found for row {index}. Skipping.")
        continue

    audio_path = os.path.join(AUDIO_FOLDER, mp3_filename)
    # print(audio_path)
    if not os.path.isfile(audio_path):
        print(f"File {audio_path} does not exist. Skipping.")
        continue

    # Transcribe the audio file
    # Load and process audio
    # waveform, sample_rate = torchaudio.load(file_path)
    waveform, sample_rate = torchaudio.load(audio_path)
    # Resample if necessary
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000)
        waveform = resampler(waveform)
    input_values = tokenizer(waveform.squeeze().numpy(), return_tensors="pt").input_values

    # Perform inference
    logits = model(input_values).logits
    predicted_ids = logits.argmax(dim=-1)
    transcription = tokenizer.decode(predicted_ids[0])

    # Calculate duration
    duration = waveform.shape[1] / 16000.0
    
    # print(transcription)
    df.at[index, 'generated_text'] = transcription

# Save the updated DataFrame back to the CSV file
df.to_csv(SAVE_PATH, index=False)
print(f"Transcription complete. Updated CSV saved to {SAVE_PATH}.")

In [1]:
import pandas as pd
from jiwer import wer
SAVE_PATH = r'C:\Users\Featherine\Downloads\HTX xData Assignment\asr\cv-valid-dev-saved.csv'

test_df = pd.read_csv(SAVE_PATH)

# Calculate WER
text = [str(x) for x in test_df['text'].tolist()]
text_gen = [str(x).lower() for x in test_df['generated_text'].tolist()]
wer_score = wer(text, text_gen)
print(f"Test WER: {wer_score}")

Test WER: 0.14384739667483193


In [None]:
test_df

In [None]:
df