In [1]:
from datasets import load_dataset
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import Audio
from scipy.signal import resample
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader


from transformers import WhisperTokenizer
from transformers import WhisperFeatureExtractor
from transformers import WhisperForConditionalGeneration

import evaluate

wer  = evaluate.load('wer')

from scipy.signal import resample

def down_sample_audio(audio_original, original_sample_rate):
    target_sample_rate = 16000

    # Calculate the number of samples for the target sample rate
    num_samples = int(len(audio_original) * target_sample_rate / original_sample_rate)

    # Resample the audio array to the target sample rate
    downsampled_audio = resample(audio_original, num_samples)

    return downsampled_audio

In [None]:
atco_asr_data = load_dataset('parquet',data_files="train-00000-of-00005-c6681348ac8543dc.parquet")

In [None]:
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-small",language='english',task='transcribe')
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-small",language='english',task='transcribe')
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small").to('cuda')

In [None]:

class whisper_training_dataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len):#daatset is huggingface dataset object
        self.dataset = dataset
        self.max_len = max_len
        self.bos_token = model.config.decoder_start_token_id

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]


        input_features = feature_extractor(item['audio']["array"], sampling_rate=16000,return_tensors='pt').input_features[0]

        # Process the transcription
        transcription = item["text"]

        # Create labels
        labels = tokenizer(transcription, padding="max_length", max_length=self.max_len, truncation=True, return_tensors="pt")
        labels = labels["input_ids"].masked_fill(labels['attention_mask'].ne(1), -100)
        labels = labels[0][1:]


        return {
            "input_features": input_features,
            "labels": labels
        }

In [None]:
dataset = whisper_training_dataset(dataset=atco_asr_data['train'], max_len=60)

train_dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=8,  # Adjust batch size as needed
    shuffle=True,  # Shuffle data during training
)

In [None]:
from IPython.display import clear_output

model.train()  # Set model to training mode

device='cuda'

optimizer=torch.optim.AdamW(model.parameters(), lr=1e-4)

max_steps=100
step=0

running_loss=[]

progress_bar = tqdm(range(max_steps),leave=False)


while True:

    for batch in train_dataloader:

        optimizer.zero_grad()  # Reset gradients
        input_features, labels = batch["input_features"].to(device), batch["labels"].to(device)

        # Forward pass
        outputs = model(input_features, labels=labels)  # Assuming your model takes these inputs
        loss = outputs.loss
        
        running_loss.append(loss.item())

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        progress_bar.update(1)
        progress_bar.set_postfix(loss=loss.item())
        
        step=step+1
        
        if step%10==0:
        
            plt.plot(running_loss)
            clear_output(wait=True)
            plt.show()
            
        if step==max_steps:

            break
model.eval()

In [None]:
torch.save(model,'finetuned_atco.pt')

In [None]:
model.eval()

atco_asr_data = load_dataset('parquet',data_files="validation-00000-of-00002-7a5ea3756991bf72.parquet")

predictions=[]
references=[]

for sample in tqdm(atco_asr_data['train'],total=len(atco_asr_data['train'])):
    audio=sample['audio']['array']
    sample_rate=sample['audio']['sampling_rate']
    text=sample['text']

    audio = down_sample_audio(audio, sample_rate) # downsample the audio to 16000Hz for WHISPER

    input_features = feature_extractor(
    raw_speech=audio,
    sampling_rate=16000,
    return_tensors='pt',
    padding=True).input_features

    # Generate predictions with no gradient computation
    with torch.no_grad():
        op = model.generate(input_features.to('cuda'), language='english', task='transcribe')

    # Decode predictions
    text_preds = tokenizer.batch_decode(op, skip_special_tokens=True)

    # Append batch predictions and references to the respective lists
    predictions.extend(text_preds)
    references.extend([text])

print(f'The WER after training is {wer.compute(predictions=predictions, references=references) * 100}%')