In [None]:
import soundfile as sf
import os
import torch
from datasets import load_dataset
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch.optim as optim
from tqdm import tqdm
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
from torch.nn.utils.rnn import pad_sequence

from jiwer import wer
import jiwer
transformation = jiwer.Compose([
    jiwer.ToLowerCase(),
    jiwer.RemoveWhiteSpace(replace_by_space=True),
    jiwer.RemoveMultipleSpaces(),
    jiwer.ReduceToListOfListOfWords(word_delimiter=" ")
    ]) 

In [None]:
# dataset class

audio_path = "primock57/output/audio_utterances/"

class ClinDataset(torch.utils.data.Dataset):

    def __init__(self, device=DEVICE, filename):

        self.device = device
        
        self.transcripts = {}
        with open(filename,'r') as vf:
            trans = vf.readlines()
        for t in trans:
            self.transcripts[t.split('|')[1].strip()] = t.split('|')[0]  
        
        self.wavpaths = []
        for wp in os.listdir(audio_path):
            basename = wp[:-4]
            try:
                t = self.transcripts[basename]
                self.wavpaths.append(wp)
            except:
                pass
        
        print('num files',len(self.wavpaths))
        
               
    def __len__(self):
        return len(self.wavpaths)

    def __getitem__(self, idx):

        wav_path = os.path.join(audio_path, self.wavpaths[idx])
        audio, sample_rate = sf.read(wav_path)
        audio = processor(audio, sampling_rate=sample_rate, return_tensors="pt").input_values
        
        basename = wav_path.split('/')[-1][:-4]
        transcript = self.transcripts[basename]
        transcript = transcript.upper()
        with processor.as_target_processor():
            labels = processor(transcript, return_tensors="pt").input_ids
        sample = { "lab":labels.squeeze(0), "aud":audio.squeeze(0), "trans":transcript}
        return sample


    def collate(self, batch):
        audios = [item["aud"] for item in batch]
        labels = [item["lab"] for item in batch]
        trans = [item["trans"] for item in batch]

        # Pad audio sequences
        audio_batch = pad_sequence(audios, batch_first=True)

        # Pad label sequences
        labels_batch = pad_sequence(labels, batch_first=True, padding_value=-100)

        # Attention masks for labels
        attention_masks = labels_batch != -100

        # Create batch dictionary
        batch = {
            "input_values": audio_batch,
            "labels": labels_batch,
            "trans": trans
        }

        return batch

transcripts_file = "primock57/output/train_transcript.ref.txt"
dataset = ClinDataset(transcripts_file)
loader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=dataset.collate)

In [None]:
# load pretrained model
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
optimizer = optim.Adam(model.parameters(), lr=0.000001)

model = model.train()

In [None]:
grad_acc_step = 1

# FINE-TUNE
# 1 epoch
step=0
for batch in tqdm(loader):
    input_values = batch["input_values"]
    labels = batch["labels"]

    # compute loos and update by passing labels

    loss = model(input_values, labels=labels).loss
    if torch.isnan(loss):
        #loss = torch.zeros_like(loss)
        optimizer.zero_grad()
        continue
    loss = loss/len(input_values)
    print(loss)
    loss.backward()
    if step % grad_acc_step == 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        optimizer.zero_grad()
    step+=1


In [None]:
# wer of fine-tuned model

def print_wer(references, hypotheses):
    
    data = pd.DataFrame(dict(hypothesis=hypotheses, reference=references)) #, phonemes=gtphonms))
    print("WER:", wer(list(data["reference"]), list(data["hypothesis"]), truth_transform=transformation, hypothesis_transform=transformation ))

    for hyp, ref in zip(data["hypothesis"],data["reference"]):
        if hyp!="" and ref!="":
            data["hypothesis_clean"] = normalizer(hyp)
            data["reference_clean"] = normalizer(ref)                              
    print("WER without fillers:", wer(list(data["reference_clean"]), list(data["hypothesis_clean"]), truth_transform=transformation, hypothesis_transform=transformation ))



transcripts_file = "primock57/output/test_transcript.ref.txt"
dataset = ClinDataset(transcripts_file)
loader = torch.utils.data.DataLoader(dataset, batch_size=1)

for batch in loader:
    input_values = batch["input_values"]
    text = batch["text"]
    
    logits = model(input_values).logits
    predicted_ids = torch.argmax(logits, dim=-1)
    # transcribe
    trans = processor.decode(predicted_ids[0])
    
    hypotheses.extend(trans)
    references.extend(text)

print_wer(references, hypotheses)