# Installing and Downloading

In [1]:
# !apt update
# !apt -y install ffmpeg
# !apt install espeak -y
!pip install wandb
!pip install transformers datasets phonemizer
!pip install pydub
#!pip install transformers --upgrade
#!pip install torchaudio
!pip install tqdm --upgrade
#!pip install torchaudio --upgrade
!pip install gdown
!pip install abydos

Collecting phonemizer
  Downloading phonemizer-3.2.1-py3-none-any.whl (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Collecting segments
  Downloading segments-2.2.1-py2.py3-none-any.whl (15 kB)
Collecting dlinfo
  Downloading dlinfo-1.2.1-py3-none-any.whl (3.6 kB)
Collecting clldutils>=1.7.3
  Downloading clldutils-3.12.0-py2.py3-none-any.whl (197 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m197.6/197.6 kB[0m [31m8.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting csvw>=1.5.6
  Downloading csvw-3.1.1-py2.py3-none-any.whl (56 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 kB[0m [31m5.2 MB/s[0m eta [36m0:00:00[0m
Collecting rfc3986<2
  Downloading rfc3986-1.5.0-py2.py3-none-any.whl (31 kB)
Collecting isodate
  Downloading isodate-0.6.1-py2.py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.7/41.7 kB[0m [31m

In [2]:
%%capture
!gdown https://drive.google.com/uc?id=16EnTGOzIgSJT69pDAruvbkB35NBJ5R0q
!unzip "base_audio (1).zip" -d base_audio
!gdown https://drive.google.com/uc?id=1dAJZyLpXHS2y-WCCMXw8m9_j0t1j70Cj
!unzip ai4talk_tokenizer.zip

# Imports

In [57]:
import pandas as pd
import numpy as np
import os
import re
from tqdm.auto import tqdm
import torch
import torchaudio
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from abydos import distance
import torch.nn as nn
from transformers import AutoModelForCTC, Wav2Vec2Processor, Wav2Vec2FeatureExtractor, Wav2Vec2Tokenizer, Wav2Vec2PhonemeCTCTokenizer, Wav2Vec2Model
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import json
import wandb
wandb.login(key='')
wandb.init(project="ASR", entity="ai_4_talk")

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [4]:
class CFG:
    device = 'cuda'
    train_batch_size = 2
    valid_batch_size = 2
    alpha_ctc = 1
    num_epoch = 5
    lr = 5e-5
    wd = 2e-6
    eta_min = 0
    tokenizer_path = 'ai4talk_tokenizer/'
    checkpoint_path = 'checkpoints/'

# Data

In [5]:
df = pd.read_csv('base_audio/ars.csv').rename(columns={'new_path': 'audio_path'})
df = df.drop_duplicates(subset=['transcription', 'start', 'end', 'fpath']).reset_index(drop=True)

In [6]:
df['length_audio'] = df.audio_path.apply(lambda x: torchaudio.load(x)[0].shape[-1])
df['length_text'] = df.transcription.apply(lambda x: len(x))

In [7]:
df = df[(df.length_audio < 100000) & (df.length_audio > 400)].reset_index(drop=True)

In [8]:
def create_folds(data, target, num_splits=5):
    if num_splits > 1:
        data.loc[:,'kfold'] = -1
        X = data.drop(columns=[target])
        y = data[target]
        mskf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)
        for fold, (trn_, val_) in enumerate(mskf.split(X, y)):
            data.loc[val_,'kfold'] = fold
    else:
        data.loc[:,'kfold'] = 0
    return data

In [9]:
df = create_folds(df, 'length_text')



In [29]:
train_df, valid_df = df[df.kfold != 0].sort_values(['length_audio'], ignore_index=True), df[df.kfold == 0].sort_values(['length_audio'], ignore_index=True)

# Dataset

In [12]:
class Tokenizer:
    def __init__(self, vocab_path):
        self.vocab = json.load(open(vocab_path))
        self.token_id = sorted(list(self.vocab.items()), key=lambda x: -len(x[0]))
        self.id_token = {id_: token for token, id_ in self.vocab.items()}
    
    def __call__(self, sentence):
        for token, id_ in self.token_id:
            sentence = sentence.replace(token, f' {id_} ')
        return torch.LongTensor(list(map(int, sentence.split())))
    
    def decode(self, sequence):
        result = ''
        for id_ in sequence:
            if self.id_token[id_] != '<pad>':
                result += self.id_token[id_]
        return result
    
    def batch_decode(self, batch):
        result = []
        for sequence in batch:
            result.append(self.decode(sequence))
        return result
    
    def pad(self, batch):
        return pad_sequence(batch, batch_first=True, padding_value=self.vocab['<pad>'])
    
    @property
    def vocab_size(self):
        return len(self.vocab)

In [13]:
class TalkDataset(Dataset):
    def __init__(self, df, feature_extractor, tokenizer, augmentations=None):
        super().__init__()
        self.df = df
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.augmentations = augmentations
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        path = self.df.iloc[idx]['audio_path']
        text = self.df.iloc[idx]['transcription']
        waveform, sample_rate = torchaudio.load(path)
        if self.augmentations:
            waveform = augmentations(waveform)
        sample = {}
        sample['input_values'] = self.feature_extractor(waveform[0], sampling_rate=sample_rate, return_tensors="pt",)['input_values'][0]
        sample['input_ids'] = self.tokenizer(text)
        sample['text'] = text.lower()
        return sample

In [14]:
class WaveTextCollator:
    def __init__(self, feature_extractor, tokenizer):
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        
    def __call__(self, batch):
        input_values = [{'input_values': feature['input_values']} for feature in batch]
        input_values = self.feature_extractor.pad(
            input_values,
            padding=True,
            return_tensors="pt",
        )
        input_ids = [feature['input_ids'] for feature in batch]
        text_lens = torch.LongTensor([text.size(0) for text in input_ids])
        input_ids = self.tokenizer.pad(
            input_ids,
        )
        texts = [feature['text'] for feature in batch]
        
        return input_values, texts, input_ids, text_lens

In [15]:
class AverageMeter:
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [16]:
class CTCrossEntropyLoss:
    def __init__(self, alpha):
        self.ctc_criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)
        self.ce_criterion = nn.CrossEntropyLoss(ignore_index=0)
        self.alpha = alpha
        
    def __call__(self, output, enc_pad_texts, output_lenghts, text_lens):
        log_probs = nn.functional.log_softmax(output, dim=-1, dtype=torch.float32).transpose(0, 1)
        ctc_loss = self.ctc_criterion(log_probs, enc_pad_texts, output_lenghts, text_lens)
        
        #ce_loss = self.ce_criterion(output.view(output.size(0), -1), enc_pad_texts.view(-1))
        return self.alpha * ctc_loss

In [17]:
def metric(y_true, y_pred):
    phonetic = distance.PhoneticEditDistance()
    return sum(phonetic.dist(t, p) for t, p in zip(y_true, y_pred)) / len(y_true)

In [30]:
def train_epoch(model, loader, criterion, optimizer, scheduler):
    loss_avg = AverageMeter()
    model.train()
    for input_values, texts, input_ids, text_lens in tqdm(loader):
        optimizer.zero_grad()
        output = model(input_values['input_values'].to(CFG.device), input_values['attention_mask'].to(CFG.device))
        output_lenghts = torch.full(
            size=(output.size(0),),
            fill_value=output.size(1),
            dtype=torch.long
        )
        loss = criterion(output, input_ids.to(CFG.device), output_lenghts, text_lens)
        loss_avg.update(loss.item(), len(texts))
        loss.backward()
        optimizer.step()
        scheduler.step()
    return loss_avg.avg

In [53]:
def valid_epoch(model, loader, criterion):
    model.eval()
    loss_avg = AverageMeter()
    metric_avg = AverageMeter()
    with torch.no_grad():
        for input_values, texts, input_ids, text_lens in tqdm(loader):
            output = model(input_values['input_values'].to(CFG.device), input_values['attention_mask'].to(CFG.device))
            output_lenghts = torch.full(
                size=(output.size(0),),
                fill_value=output.size(1),
                dtype=torch.long
            )
            loss = criterion(output, input_ids.to(CFG.device), output_lenghts, text_lens, )
            loss_avg.update(loss.item(), len(texts))
            pred_ids = torch.argmax(output.detach().cpu(), dim=-1).numpy()
            pred_str = tokenizer.batch_decode(pred_ids)
            try:
                metric_avg.update(metric(texts, pred_str), len(texts))
            except:
                pass
        print(texts, pred_str)
    return loss_avg.avg, metric_avg.avg

In [32]:
def fit(model, train_loader, valid_loader, criterion, optimizer, scheduler, num_epoch, checkpoint_path):
    for epoch in tqdm(range(num_epoch)):
        train_loss = train_epoch(model, train_loader, criterion, optimizer, scheduler)
        valid_loss, valid_ped = valid_epoch(model, valid_loader, criterion)
        wandb.log({'Epoch': epoch+1, 'Train loss': train_loss, 'Valid loss': valid_loss, 'Valid PED': valid_ped, 'LR': scheduler.get_last_lr()[0]})
        torch.save({'model': model.state_dict(), checkpoint_path + f'epoch{epoch+1}_ped{valid_ped}.pt'})

In [33]:
class Wav2VecCTC(nn.Module):
    def __init__(self, vocab_size, dropout=0.0):
        super().__init__()
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
        self.encoder.config.mask_time_length = 1
        self.dropout = nn.Dropout(dropout)
        output_hidden_size = self.encoder.config.hidden_size
        self.lm_head = nn.Linear(output_hidden_size, vocab_size)
        
    def forward(self, input_values, attention_mask):
        outputs = self.encoder(
            input_values,
            attention_mask=attention_mask,
        )

        hidden_states = outputs[0]
        hidden_states = self.dropout(hidden_states)

        logits = self.lm_head(hidden_states)
        return logits

In [34]:
tokenizer = Tokenizer(CFG.tokenizer_path + 'vocab.json')
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("facebook/wav2vec2-xlsr-53-espeak-cv-ft")
model = Wav2VecCTC(tokenizer.vocab_size).to(CFG.device)

Some weights of the model checkpoint at facebook/wav2vec2-xlsr-53-espeak-cv-ft were not used when initializing Wav2Vec2Model: ['lm_head.bias', 'lm_head.weight']
- This IS expected if you are initializing Wav2Vec2Model from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing Wav2Vec2Model from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [35]:
train_dataset = TalkDataset(train_df, feature_extractor, tokenizer)
valid_dataset = TalkDataset(valid_df, feature_extractor, tokenizer)
collator = WaveTextCollator(feature_extractor, tokenizer)
train_loader = DataLoader(train_dataset, batch_size=CFG.train_batch_size, shuffle=False, collate_fn=collator, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=CFG.valid_batch_size, shuffle=False, collate_fn=collator, pin_memory=True)
criterion = CTCrossEntropyLoss(CFG.alpha_ctc)
optimizer = AdamW(model.parameters(), lr=CFG.lr, weight_decay=CFG.wd)
scheduler = CosineAnnealingLR(optimizer, T_max=len(train_loader) * CFG.num_epoch, eta_min=CFG.eta_min)

In [36]:
fit(model, train_loader, valid_loader, criterion, optimizer, scheduler, CFG.num_epoch, CFG.checkpoint_path)
wandb.finish()

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/22812 [00:00<?, ?it/s]

  0%|          | 0/5704 [00:00<?, ?it/s]

IndexError: list index out of range

In [52]:
phonetic = distance.PhoneticEditDistance()
phonetic.dist('pnt͡ʃ', 'ant͡ʃ')

IndexError: list index out of range

In [3]:
'ant͡ʃ'[3]

'͡'