<a href="https://colab.research.google.com/github/nicolas-dufour/self-unsupervised-low-res-speech/blob/master/ASR_project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ASR_project

## Install & setup


In [None]:
# Load Git folder
import os
from getpass import getpass
import urllib
repo_user = 'nicolas-dufour'
user = input('Github Username: ')
password = getpass('Password: ')
repo_name = 'self-unsupervised-low-res-speech'
# your password is converted into url format
password = urllib.parse.quote(password)
cmd_string = 'git clone https://{0}:{1}@github.com/{2}/{3}.git'.format(user, password, repo_user, repo_name)
os.system(cmd_string)
cmd_string, password = "", "" # removing the password from the variable
# Bad password fails silently so make sure the repo was copied
assert os.path.exists(f"/content/{repo_name}"), "Incorrect Password or Repo Not Found, please try again"

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

In [None]:
%%capture
!pip install transformers
!pip install datasets
!sudo apt-get install festival espeak-ng mbrola
!pip install torchaudio
!pip install phonemizer
!pip install pytorch_lightning
!pip install wandb

In [None]:
%cd /content/self-unsupervised-low-res-speech/

In [None]:
ls

In [None]:
%load_ext autoreload
%autoreload 2
import urllib
from phonemize import phonemize_labels
from transformers import Wav2Vec2Tokenizer, Wav2Vec2ForCTC , Wav2Vec2FeatureExtractor


from datasets import load_dataset
import pytorch_lightning as pl
import torch
import numpy as np
import torchaudio
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import Audio

from dataloader import CommonVoiceDataModule
from metrics import PER

## Create the data_module (instance of CommonVoiceDataModule)

### Take the url from https://commonvoice.mozilla.org/fr/datasets

In [None]:
url = input('Url:')

### Choose a language from https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md

In [None]:
data_module = CommonVoiceDataModule(
    url,
    'el',
    labels_folder=None,
    phonemize=True,
    label_type='phonemes',
    batch_size= 4
    )
data_module.prepare_data()
data_module.setup()

### Data visualization

In [None]:
train_loader=data_module.train_dataloader()

In [None]:
sounds, tokens = next(iter(train_loader))
sound = np.array(sounds[0])   #(sound,label), [idx_batch,data]
token = tokens[0]
freq = 16000   #Hz
plt.plot(sound)
print("raw label: {}".format(np.array(token)))
print("phonetic label: {}".format(data_module.tokenizer.decode(np.array(token))))
Audio(sound, rate=freq)

## Useful functions for CTC loss

In [None]:
def len_phoneme(phonms):
  bs,max_length=phonms.shape
  input_lengths = torch.zeros(bs ,dtype=torch.long)
  for idx in range(bs):
    input_lengths[idx]= max_length-int(sum(phonms[idx,:]==0))  #correspond to the number of non-zero labels in phonms
  return input_lengths

def len_logits(logts):
  max_length,bs,vocab_size=logts.shape
  if vocab_size!=48:
    raise Exception("Vocab size is not consistent")

  return torch.full(size=(bs,), fill_value=max_length, dtype=torch.long)

def recover_tokens(output_tokens):
    recovered_tokens = []
    for list_tokens in  output_tokens:
        list_decoded = [list_tokens[0].item()]
        j=0
        for i in range(len(list_tokens)):
            if list_decoded[j]!=list_tokens[i]:
                list_decoded.append(list_tokens[i].item())
                j+=1
        recovered_tokens.append(list(filter(lambda a: a != 0, list_decoded)))
    return recovered_tokens

## Construction of the CTC network with Wav2Vec2

In [None]:
#with Pytorch Lightning | complete version
class CTCNetwork(pl.LightningModule):

    def __init__(self):
        super(CTCNetwork, self).__init__()
        
        self.phonemeSizeAlphabet=48   #the size of the phonetic alphabet being 48
        self.criterion = torch.nn.CTCLoss(blank=0, reduction='mean', zero_infinity=True)

        self.feature_extractor = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
        #self.feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=True).from_pretrained("facebook/wav2vec2-base-960h")
        for param in self.feature_extractor.parameters():
             param.requires_grad = False
        self.feature_extractor.lm_head=torch.nn.Linear(in_features=768, out_features=self.phonemeSizeAlphabet, bias=True)

        self.val_per = PER() 
        self.test_per = PER() 

    def forward(self, x_audio):
        x_logits = self.feature_extractor(x_audio).logits.permute(1,0,2)
        log_prob = torch.nn.functional.log_softmax(x_logits, dim=2) #logarithmized probabilities of the outputs
        return log_prob

    def relaxation(self,type_relax):
        if type_relax=="soft":
            for name,param in self.named_parameters():
                if name.startswith('feature_extractor.wav2vec2.encoder.layers.11') or name.startswith('feature_extractor.wav2vec2.encoder.layers.10') or name.startswith('feature_extractor.lm_head'):
                    param.requires_grad = True
        elif type_relax=="hard":
            for name,param in self.named_parameters():
                param.requires_grad = True

    def training_step(self, batch, batch_nb):
        # REQUIRED
        x_audio, phonemes = batch
        log_prob = self(x_audio)
        loss = self.criterion(log_prob,phonemes,len_logits(log_prob),len_phoneme(phonemes))
        self.log('train_loss', loss, on_epoch=True, on_step=True)
        return loss
    
    def validation_step(self, batch, batch_nb):
        # REQUIRED
        x_audio, phonemes = batch
        log_prob = self(x_audio)
        loss = self.criterion(log_prob,phonemes,len_logits(log_prob),len_phoneme(phonemes))
        decoded_tokens = recover_tokens(log_prob.argmax(dim=2).permute(1,0))
        self.log('val_loss', loss, on_epoch=True, on_step=False)
        self.val_per(decoded_tokens, phonemes)
        return loss
    
    def validation_epoch_end(self, losses):
        self.log('val_per',self.val_per.compute())
        self.val_per.reset()

    def test_step(self, batch, batch_nb):
        # REQUIRED
        x_audio, phonemes = batch
        log_prob = self(x_audio)
        decoded_tokens = recover_tokens(log_prob.argmax(dim=2).permute(1,0))
        self.test_per(decoded_tokens, phonemes)

    def test_epoch_end(self, losses):
        self.log('test_per',self.test_per.compute())
        self.test_per.reset()

    def configure_optimizers(self):
        # REQUIRED
        # can return multiple optimizers and learning_rate schedulers
        # (LBFGS it is automatically supported, no need for closure function)
        return torch.optim.Adam(self.parameters(), lr=2e-3)

In [None]:
model = CTCNetwork()
model.relaxation("soft")

In [None]:
wandb_logger = pl.loggers.WandbLogger(project='ASR Project')
checkpoint_callback = pl.callbacks.ModelCheckpoint(
     mode ='min',
     monitor='val_per',
     dirpath='/content/drive/MyDrive/self-supervised-speech/models',
    filename='asr_model_wav2vec_fr-{epoch:02d}-{val_f1_score:.2f}'
)
trainer = pl.Trainer(
    gpus = 1,
    progress_bar_refresh_rate =20,
    logger = wandb_logger,
    callbacks=[checkpoint_callback])    

In [None]:
#trainer.tune(model)
trainer.fit(model, data_module) 

## Look what the model is producing

In [None]:
audio,label=next(iter(data_module.train_dataloader()))
model.eval()
pred=model(audio)
print('pred shape: {}'.format(pred.shape))
print('Groundtruth Phonemes:  {}'.format(data_module.tokenizer.decode(np.array(label[0]))))
print('Phonemes produce from the model:{}'.format(data_module.tokenizer.decode(np.array(pred.argmax(dim=2)[:,0]))))
Audio(np.array(audio[0]), rate=freq)


In [None]:
len_logits(pred)

In [None]:
len_phoneme(label)

## Debugging

In [None]:
#definition of CTCloss
ctc_loss=torch.nn.CTCLoss(zero_infinity=True) 

#arguments
logits=torch.ones((10,4,48)).log_softmax(2)   # logits:Log_probs      : Tensor of size (max input length  ,  batch size  ,  number of classes)
phoneme=torch.ones((4,7))                     # phoneme : Targets        : Tensor of size (    batch size     ,          max target length  )
lenlog=len_logits(logits)                     # lenlog: Input_lengths  : Tensor of size               (batch size )    --> indicates the input length of each sequence of the batch
lenpho= len_phoneme(phoneme)                  # Target_lengths :Tensor of size                (batch size )    --> indicates the target length of each sequence of the batch

#compute the loss
lossDebug=ctc_loss(logits,phoneme,lenlog,lenpho)
print(lossDebug)

## Git push

In [None]:
# Git Ignore setup
!echo 'lightning_logs' >> .gitignore
!echo 'wandb' >> .gitignore

In [None]:
# Save to git
!git config --global user.email "nicolas.dufourn@gmail.com"
!git config --global user.name "Nicolas DUFOUR"
!git add --all
!git commit -m "Added logging and checkpointing"
!git push 