In [None]:
!git clone https://github.com/huydsai02/project-2.git
!pip install -q -r /kaggle/working/project-2/requirements.txt

In [None]:
%cd /kaggle/working/project-2/model/apl_update
from model import APL
%cd /kaggle/input

In [None]:
from jiwer import wer
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch, json, os, time, librosa, transformers, gc
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
import torch.optim as optim
from torch.utils.data import DataLoader
from python_speech_features import fbank
# from multiprocessing import get_context
from pyctcdecode import build_ctcdecoder
import pandas as pd
import numpy as np
from tqdm import tqdm
import warnings

In [None]:
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")
model_wav2vec = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self")

In [None]:
with open('/kaggle/input/model-project-2/vocab_39.json') as f:
    dict_vocab = json.load(f)

def text_to_tensor(text):
    text = text.lower()
    text = text.split(" ")
    text_list = []
    for idex in text:
        text_list.append(dict_vocab[idex])
    return text_list

In [None]:
import torch
from torch.utils.data import Dataset
import numpy as np

class MDD_Dataset(Dataset):

    def __init__(self, data):
        self.len_data           = len(data)
        self.path               = list(data['Path'])
        self.canonical          = list(data['Canonical'])
        self.transcript         = list(data['Transcript'])

    def __getitem__(self, index):
        waveform, _ = librosa.load(self.path[index], sr=16000)
        linguistic  = text_to_tensor(self.canonical[index])
        transcript  = self.transcript[index]
        label       = text_to_tensor(transcript)
        return waveform, linguistic, label, transcript

    def __len__(self):
        return self.len_data

In [None]:
df_train = pd.read_csv('/kaggle/input/project-2-new-data/csv/train.csv')
df_dev = pd.read_csv('/kaggle/input/project-2-new-data/csv/dev.csv')

In [None]:
audio_folder = '/kaggle/input/project-2-new-data/audio'

In [None]:
df_train['Path'] = df_train['Path'].apply(lambda x : os.path.join(audio_folder, str(x).zfill(5), 'audio.wav'))
df_dev['Path'] = df_dev['Path'].apply(lambda x : os.path.join(audio_folder, str(x).zfill(5), 'audio.wav'))

In [None]:
df_train

In [None]:
min_wer = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
num_epoch = 150

In [None]:
train_dataset = MDD_Dataset(df_train)
dev_dataset = MDD_Dataset(df_dev)

In [None]:
model_embed = torch.nn.Sequential(*(list(model_wav2vec.children())[:-2])).to(device)
model_embed.eval()
del model_wav2vec
gc.collect()

for param in model_embed.parameters():
    param.requires_grad = False

In [None]:
def collate_fn(batch):
    with torch.no_grad():
        sr = 16000
        max_col = [-1] * 3
        target_length = []
        for row in batch:
            if row[0].shape[0] > max_col[0]:
                max_col[0] = row[0].shape[0]
            if len(row[1]) > max_col[1]:
                max_col[1] = len(row[1])
            if len(row[2]) > max_col[2]:
                max_col[2] = len(row[2])
            target_length.append(len(row[2]))

        cols = {'fbank':[], 'linguistic':[], 'labels':[], 'waveform':[], 'transcript':[]}
        
        for row in batch:
            pad_wav = np.concatenate([row[0], np.zeros(max_col[0] - row[0].shape[0])])
            cols['waveform'].append(pad_wav)
            melfbank, energy = fbank(pad_wav, sr, winlen=0.02, winstep = 0.02, nfilt=80)
            cols['fbank'].append(np.concatenate([melfbank, energy.reshape(-1, 1)], axis=1).tolist())
            row[1].extend([len(dict_vocab)] * (max_col[1] - len(row[1])))
            cols['linguistic'].append(row[1])
            row[2].extend([len(dict_vocab)] * (max_col[2] - len(row[2])))
            cols['labels'].append(row[2])
            cols['transcript'].append(row[3])
            
        inputs = processor(cols['waveform'], return_tensors="pt",sampling_rate=sr, padding="longest")
        input_values = inputs.input_values.to(device)
        phonetic = model_embed(input_values).last_hidden_state
        
        cols['fbank'] = torch.tensor(cols['fbank'], dtype=torch.float, device=device)[:, :phonetic.shape[1], :]
        cols['linguistic'] = torch.tensor(cols['linguistic'], dtype=torch.long, device=device)
        cols['labels'] = torch.tensor(cols['labels'], dtype=torch.long, device=device)
        targets_length = torch.tensor(target_length, dtype=torch.long, device=device)
    
    return cols['fbank'], phonetic, cols['linguistic'], cols['labels'], targets_length, cols['transcript']

In [None]:
train_loader = DataLoader(dataset=train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
dev_loader = DataLoader(dataset=dev_dataset, batch_size=32, collate_fn=collate_fn)

In [None]:
model = APL(len(dict_vocab)).to(device)

In [None]:
ctc_loss = nn.CTCLoss(blank = len(dict_vocab))
optimizer = optim.AdamW(model.parameters(), lr = 1.5e-5)
scheduler = ExponentialLR(optimizer, gamma=0.97)

In [None]:
list_vocab = sorted([w for w in list(dict_vocab.keys())], key=lambda x : dict_vocab[x])
list_vocab = [f'{w} ' for w in list_vocab]
loss_file = open('/kaggle/working/loss_file.txt', 'w')
print("Start Training")

for epoch in range(num_epoch):
    model.train().to(device)
    running_loss = []
    print(f'EPOCH {epoch}:')
#     print("TRAIN:")
    t1 = time.time()
    for i, data in enumerate(train_loader):
        acoustic, phonetic, linguistic, labels, target_lengths, transcripts  = data
        
        optimizer.zero_grad()
        outputs = model(acoustic, phonetic, linguistic)
        outputs = outputs.transpose(0, 1)
        
        out_shape = outputs.shape
        
        input_lengths = torch.full(size=(out_shape[1],), fill_value=out_shape[0], dtype=torch.long, device=device)
        outputs = (F.log_softmax(outputs, dim=2))
        
#         print(outputs.shape)
#         print(labels.shape)
#         print(input_lengths)
#         print(target_lengths)
            
#         target_lengths = torch.full(size=(out_shape[1],), fill_value=labels.shape[1], dtype=torch.long, device=device)
        loss = ctc_loss(outputs, labels, input_lengths, target_lengths)
        running_loss.append(loss.item())
        loss_file.write(f"Epoch {epoch}, train {i} loss: {loss.item()}\n")
        loss.backward()
        optimizer.step()
        
    scheduler.step()
    print(f"Training loss: {sum(running_loss) / len(running_loss)}")
    
    running_loss = []
    with torch.no_grad():
        model.eval().to(device)
        worderrorrate = []
#         print("EVAL:")
        for i, data in enumerate(dev_loader):
            acoustic, phonetic, linguistic, labels, target_lengths, transcripts = data
            outputs = model(acoustic, phonetic, linguistic)
            
            outputs = outputs.transpose(0, 1)
            out_shape = outputs.shape
            input_lengths = torch.full(size=(out_shape[1],), fill_value=out_shape[0], dtype=torch.long, device=device)
            outputs = F.log_softmax(outputs, dim=2)
            loss = ctc_loss(outputs, labels, input_lengths, target_lengths)
            running_loss.append(loss.item())
            
            loss_file.write(f"Epoch {epoch}, test {i} loss: {loss.item()}\n")
            
            if epoch > 7:
                x = outputs.transpose(0, 1)
                x = x.detach().cpu().numpy()
                decoder_ctc = build_ctcdecoder(
                    labels = list_vocab,
                )

                for n_transcript in range(len(transcripts)):
                    ground_truth = transcripts[n_transcript]
                    hypothesis = str(decoder_ctc.decode(x[n_transcript])).strip()
                    error = wer(ground_truth, hypothesis)
                    worderrorrate.append(error)

        if epoch > 7:
            epoch_wer = sum(worderrorrate)/len(worderrorrate)

            with open('/kaggle/working/wer_file.txt', 'a') as wer_file:
                wer_file.write(f"Epoch {epoch}: {epoch_wer}\n")

            if (epoch_wer < min_wer):
                min_wer = epoch_wer
                torch.save(model, '/kaggle/working/checkpoint_BaseMHA_Linguistic50.pth')
            print("wer checkpoint " + str(epoch) + ": " + str(epoch_wer))
            print("min_wer: " + str(min_wer))

        print(f"Eval loss: {sum(running_loss) / len(running_loss)}")

    print(f"FINISH EPOCH {epoch} IN: {time.time() - t1}")

loss_file.close()
print('Finished Training')