In [1]:
import sys
sys.path.append('/home/ubuntu/asr/speechbrain')

In [2]:
import logging
import os
import sys
from pathlib import Path
import librosa
from taylor_series_linear_attention import TaylorSeriesLinearAttn

import torch
from torch import Tensor
from hyperpyyaml import load_hyperpyyaml
from inspect import signature 

import speechbrain as sb
from speechbrain.utils.distributed import if_main_process, run_on_main

logger = logging.getLogger(__name__)
from speechbrain.inference.ASR import EncoderDecoderASR

### Pretrained Conformer

In [3]:
pretrained_conformer = EncoderDecoderASR.from_hparams(source="speechbrain/asr-conformer-transformerlm-librispeech", savedir="pretrained_models/asr-transformer-transformerlm-librispeech")



In [4]:
# Define LinearAttention a wrapper for Taylor Series LinearAttention
class LinearAttn(torch.nn.Module):
    """Wrapper Class for Taylor Series LinearAttention"""
    def __init__(self, dim = 512, dim_head = 16, heads = 8):
        super(LinearAttn, self).__init__()
        self.attn = TaylorSeriesLinearAttn(dim = dim, dim_head = dim_head, heads = heads)

    def forward(self, query, key, value, attn_mask, key_padding_mask, pos_embs):
        # Ignoring key_padding_mask and pos_embs for TaylorSeriesLinearAttn
        out = self.attn(query, mask=attn_mask)
        return out, self.attn

In [5]:
audio_file = "/home/ubuntu/asr/datasets/LibriSpeech/test-clean/61/70970/61-70970-0000.flac"

In [6]:
# Define training procedure
class ASR(sb.core.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""
        batch = batch.to(self.device)
        wavs, wav_lens = batch.sig
        tokens_bos, _ = batch.tokens_bos

        # compute features
        feats = self.hparams.compute_features(wavs)
        current_epoch = self.hparams.epoch_counter.current
        feats = self.modules.normalize(feats, wav_lens, epoch=current_epoch)

        # Add feature augmentation if specified.
        if stage == sb.Stage.TRAIN and hasattr(self.hparams, "fea_augment"):
            feats, fea_lens = self.hparams.fea_augment(feats, wav_lens)
            tokens_bos = self.hparams.fea_augment.replicate_labels(tokens_bos)

        # forward modules
        src = self.modules.CNN(feats)

        enc_out, pred = self.modules.Transformer(
            src, tokens_bos, wav_lens, pad_idx=self.hparams.pad_index
        )

        # output layer for ctc log-probabilities
        logits = self.modules.ctc_lin(enc_out)
        p_ctc = self.hparams.log_softmax(logits)

        # output layer for seq2seq log-probabilities
        pred = self.modules.seq_lin(pred)
        p_seq = self.hparams.log_softmax(pred)

        # Compute outputs
        hyps = None
        current_epoch = self.hparams.epoch_counter.current
        is_valid_search = (
            stage == sb.Stage.VALID
            and current_epoch % self.hparams.valid_search_interval == 0
        )
        is_test_search = stage == sb.Stage.TEST

        if any([is_valid_search, is_test_search]):
            # Note: For valid_search, for the sake of efficiency, we only perform beamsearch with
            # limited capacity and no LM to give user some idea of how the AM is doing

            # Decide searcher for inference: valid or test search
            if stage == sb.Stage.VALID:
                hyps, _, _, _ = self.hparams.valid_search(
                    enc_out.detach(), wav_lens
                )
            else:
                hyps, _, _, _ = self.hparams.test_search(
                    enc_out.detach(), wav_lens
                )

        return p_ctc, p_seq, wav_lens, hyps

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC+NLL) given predictions and targets."""

        (p_ctc, p_seq, wav_lens, hyps) = predictions

        ids = batch.id
        tokens_eos, tokens_eos_lens = batch.tokens_eos
        tokens, tokens_lens = batch.tokens

        if stage == sb.Stage.TRAIN:
            # Labels must be extended if parallel augmentation or concatenated
            # augmentation was performed on the input (increasing the time dimension)
            if hasattr(self.hparams, "fea_augment"):
                (
                    tokens,
                    tokens_lens,
                    tokens_eos,
                    tokens_eos_lens,
                ) = self.hparams.fea_augment.replicate_multiple_labels(
                    tokens, tokens_lens, tokens_eos, tokens_eos_lens
                )

        loss_seq = self.hparams.seq_cost(
            p_seq, tokens_eos, length=tokens_eos_lens
        ).sum()

        loss_ctc = self.hparams.ctc_cost(
            p_ctc, tokens, wav_lens, tokens_lens
        ).sum()

        loss = (
            self.hparams.ctc_weight * loss_ctc
            + (1 - self.hparams.ctc_weight) * loss_seq
        )

        if stage != sb.Stage.TRAIN:
            current_epoch = self.hparams.epoch_counter.current
            valid_search_interval = self.hparams.valid_search_interval
            if current_epoch % valid_search_interval == 0 or (
                stage == sb.Stage.TEST
            ):
                # Decode token terms to words
                predicted_words = [
                    tokenizer.decode_ids(utt_seq).split(" ") for utt_seq in hyps
                ]
                target_words = [wrd.split(" ") for wrd in batch.wrd]
                self.wer_metric.append(ids, predicted_words, target_words)

            # compute the accuracy of the one-step-forward prediction
            self.acc_metric.append(p_seq, tokens_eos, tokens_eos_lens)
        return loss

    def on_evaluate_start(self, max_key=None, min_key=None):
        """perform checkpoint average if needed"""
        super().on_evaluate_start()

        ckpts = self.checkpointer.find_checkpoints(
            max_key=max_key, min_key=min_key
        )
        ckpt = sb.utils.checkpoints.average_checkpoints(
            ckpts, recoverable_name="model"
        )

        self.hparams.model.load_state_dict(ckpt, strict=True)
        self.hparams.model.eval()
        print("Loaded the average")

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch"""
        if stage != sb.Stage.TRAIN:
            self.acc_metric = self.hparams.acc_computer()
            self.wer_metric = self.hparams.error_rate_computer()

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of a epoch."""
        # Compute/store important stats
        stage_stats = {"loss": stage_loss}
        if stage == sb.Stage.TRAIN:
            self.train_stats = stage_stats
        else:
            stage_stats["ACC"] = self.acc_metric.summarize()
            current_epoch = self.hparams.epoch_counter.current
            valid_search_interval = self.hparams.valid_search_interval
            if (
                current_epoch % valid_search_interval == 0
                or stage == sb.Stage.TEST
            ):
                stage_stats["WER"] = self.wer_metric.summarize("error_rate")

        # log stats and save checkpoint at end-of-epoch
        if stage == sb.Stage.VALID:
            lr = self.hparams.noam_annealing.current_lr
            steps = self.optimizer_step
            optimizer = self.optimizer.__class__.__name__

            epoch_stats = {
                "epoch": epoch,
                "lr": lr,
                "steps": steps,
                "optimizer": optimizer,
            }
            self.hparams.train_logger.log_stats(
                stats_meta=epoch_stats,
                train_stats=self.train_stats,
                valid_stats=stage_stats,
            )
            self.checkpointer.save_and_keep_only(
                meta={"ACC": stage_stats["ACC"], "epoch": epoch},
                max_keys=["ACC"],
                num_to_keep=self.hparams.avg_checkpoints,
            )

        elif stage == sb.Stage.TEST:
            self.hparams.train_logger.log_stats(
                stats_meta={"Epoch loaded": self.hparams.epoch_counter.current},
                test_stats=stage_stats,
            )
            if if_main_process():
                with open(self.hparams.test_wer_file, "w") as w:
                    self.wer_metric.write_stats(w)

            # save the averaged checkpoint at the end of the evaluation stage
            # delete the rest of the intermediate checkpoints
            # ACC is set to 1.1 so checkpointer only keeps the averaged checkpoint
            self.checkpointer.save_and_keep_only(
                meta={"ACC": 1.1, "epoch": epoch},
                max_keys=["ACC"],
                num_to_keep=1,
            )

    def on_fit_batch_end(self, batch, outputs, loss, should_step):
        """At the end of the optimizer step, apply noam annealing."""
        if should_step:
            self.hparams.noam_annealing(self.optimizer)

### Finetuned Conformer

In [7]:
hparams_file = "/home/ubuntu/asr/hparam/conformer_large.yaml"
overrides = {"data_folder": "/home/ubuntu/asr/datasets/LibriSpeech"}

with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

In [8]:
asr_brain = ASR(
        modules=hparams["modules"],
        opt_class=hparams["Adam"],
        hparams=hparams,
        checkpointer=hparams["checkpointer"],
    )

In [9]:
asr_brain.hparams.model.load_state_dict(torch.load("/home/ubuntu/asr/speechbrain/recipes/LibriSpeech/ASR/transformer/results/conformer_large/3407/save/CKPT+2024-04-19+19-19-29+00/model.ckpt"), strict=True)

finetuned_conformer = EncoderDecoderASR.from_hparams(source="speechbrain/asr-conformer-transformerlm-librispeech", savedir="pretrained_models/asr-transformer-transformerlm-librispeech")




In [10]:
torch.save(asr_brain.modules.CNN.state_dict(), 'weights_file_cnn_1') 
finetuned_conformer.mods.asr_model[0].load_state_dict(torch.load('weights_file_cnn_1'))
print("Exported and loaded Weights of ConvolutionFrontEnd....")

torch.save(asr_brain.modules.Transformer.state_dict(), 'weights_file_transformer_1') #
finetuned_conformer.mods.asr_model[1].load_state_dict(torch.load('weights_file_transformer_1'))
print("Exported and loaded Weights of Transformer module....")

torch.save(asr_brain.modules.seq_lin.state_dict(), 'weights_file_seq_lin_1') #
finetuned_conformer.mods.asr_model[2].load_state_dict(torch.load('weights_file_seq_lin_1'))
print("Exported and loaded Weights of linear layer related to seq_lin module....")

torch.save(asr_brain.modules.ctc_lin.state_dict(), 'weights_file_ctc_lin_1') #
finetuned_conformer.mods.asr_model[3].load_state_dict(torch.load('weights_file_ctc_lin_1'))
print("Exported and loaded Weights of linear layer related to ctc_lin module....")

torch.save(asr_brain.modules.normalize.state_dict(), 'weights_file_normalizer_1') #
finetuned_conformer.mods.normalizer.load_state_dict(torch.load('weights_file_normalizer_1'))
print("Exported and loaded Weights of normalizer module....")

Exported and loaded Weights of ConvolutionFrontEnd....
Exported and loaded Weights of Transformer module....
Exported and loaded Weights of linear layer related to seq_lin module....
Exported and loaded Weights of linear layer related to ctc_lin module....
Exported and loaded Weights of normalizer module....


### Finetuned TSConformer

In [11]:
hparams_file = "/home/ubuntu/asr/hparam/conformer_large.yaml"
overrides = {"data_folder": "/home/ubuntu/asr/datasets/LibriSpeech"}

with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

asr_brain = ASR(
        modules=hparams["modules"],
        opt_class=hparams["Adam"],
        hparams=hparams,
        checkpointer=hparams["checkpointer"],
    )

finetuned_tsconformer = EncoderDecoderASR.from_hparams(source="speechbrain/asr-conformer-transformerlm-librispeech", savedir="pretrained_models/asr-transformer-transformerlm-librispeech")



for i in range(12):
    asr_brain.modules.Transformer.encoder.layers[i].mha_layer = LinearAttn(dim = 512, dim_head = 16, heads = 8)
    finetuned_tsconformer.mods.transformer.encoder.layers[i].mha_layer = LinearAttn(dim = 512, dim_head = 16, heads = 8)


asr_brain.hparams.model.load_state_dict(torch.load("/home/ubuntu/asr/speechbrain/recipes/LibriSpeech/ASR/transformer/results/tsconformer_large/3407/save/CKPT+2024-04-24+11-40-51+00/model.ckpt"), strict=True)



<All keys matched successfully>

In [12]:
torch.save(asr_brain.modules.CNN.state_dict(), 'weights_file_cnn_2') 
finetuned_tsconformer.mods.asr_model[0].load_state_dict(torch.load('weights_file_cnn_2'))
print("Exported and loaded Weights of ConvolutionFrontEnd....")

torch.save(asr_brain.modules.Transformer.state_dict(), 'weights_file_transformer_2') #
finetuned_tsconformer.mods.asr_model[1].load_state_dict(torch.load('weights_file_transformer_2'))
print("Exported and loaded Weights of Transformer module....")

torch.save(asr_brain.modules.seq_lin.state_dict(), 'weights_file_seq_lin_2') #
finetuned_tsconformer.mods.asr_model[2].load_state_dict(torch.load('weights_file_seq_lin_2'))
print("Exported and loaded Weights of linear layer related to seq_lin module....")

torch.save(asr_brain.modules.ctc_lin.state_dict(), 'weights_file_ctc_lin_2') #
finetuned_tsconformer.mods.asr_model[3].load_state_dict(torch.load('weights_file_ctc_lin_2'))
print("Exported and loaded Weights of linear layer related to ctc_lin module....")

torch.save(asr_brain.modules.normalize.state_dict(), 'weights_file_normalizer_2') #
finetuned_tsconformer.mods.normalizer.load_state_dict(torch.load('weights_file_normalizer_2'))
print("Exported and loaded Weights of normalizer module....")

Exported and loaded Weights of ConvolutionFrontEnd....
Exported and loaded Weights of Transformer module....
Exported and loaded Weights of linear layer related to seq_lin module....
Exported and loaded Weights of linear layer related to ctc_lin module....
Exported and loaded Weights of normalizer module....


In [23]:
def transcribe_audio(audio_file, cpretrained=pretrained_conformer, cfinetuned=finetuned_conformer, tscfinetuned=finetuned_tsconformer):
    cpretrained_text = cpretrained.transcribe_file(audio_file)
    cfinetuned_test = cfinetuned.transcribe_file(audio_file)
    tscfinetuned_text = tscfinetuned.transcribe_file(audio_file)
    return audio_file, cpretrained_text, cfinetuned_test, tscfinetuned_text

In [15]:
'''

Input Audio File : /home/ubuntu/asr/datasets/LibriSpeech/test-clean/908/157963/908-157963-0001.flac

'''

/home/ubuntu/asr/datasets/LibriSpeech/test-clean/61/70970/61-70970-0000.flac


"YOUNG FITZOOTH HAD BEEN COMMANDED TO HIS MOTHER'S CHAMBER SO SOON AS HE HAD COME OUT FROM HIS CONVERSE WITH THE SQUIRE"

In [None]:
import gradio as gr

demo = gr.Interface(
    title="ASR for Noisy Audio!",
    fn=transcribe_audio,
    inputs = ["text"],
    outputs=[gr.Audio(label="Audio Transcript"), gr.Textbox(label="Pretrained Conformer", lines=3), 
             gr.Textbox(label="Finetuned Conformer", lines=3), gr.Textbox(label="Finetuned TSConformer", lines=3)],
)
demo.launch(share=True, debug=True)



Running on local URL:  http://127.0.0.1:7860





Could not create share link. Please check your internet connection or our status page: https://status.gradio.app.
