In [None]:
%%capture
!pip install speechbrain
import speechbrain 
# here we download the material needed for this tutorial: images and an example based on mini-librispeech
!wget https://www.dropbox.com/s/b61lo6gkpuplanq/MiniLibriSpeechTutorial.tar.gz?dl=0
!tar -xvzf MiniLibriSpeechTutorial.tar.gz?dl=0
# downloading mini_librispeech dev data
!wget https://www.openslr.org/resources/31/dev-clean-2.tar.gz
!tar -xvzf dev-clean-2.tar.gz

In [None]:
from speechbrain.lobes.features import Fbank
import torch 
import speechbrain as sb
# Define fine-tuning procedure 
class EncDecFineTune(sb.Brain):

    def on_stage_start(self, stage, epoch):
        # enable grad for all modules we want to fine-tune
        if stage == sb.Stage.TRAIN:
            for module in [self.modules.enc, self.modules.emb, self.modules.dec, self.modules.seq_lin]:
                for p in module.parameters():
                    p.requires_grad = True
     
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches to the output probabilities."""
        batch = batch.to(self.device)
        wavs, wav_lens = batch.signal
        tokens_bos, _ = batch.tokens_bos
        wavs, wav_lens = wavs.to(self.device), wav_lens.to(self.device)

        # Forward pass
        feats = self.modules.compute_features(wavs)
        feats = self.modules.normalize(feats, wav_lens)
        #feats.requires_grad = True
        x = self.modules.enc(feats)
        
        e_in = self.modules.emb(tokens_bos)  # y_in bos + tokens
        h, _ = self.modules.dec(e_in, x, wav_lens)

        # Output layer for seq2seq log-probabilities
        logits = self.modules.seq_lin(h)
        p_seq = self.hparams.log_softmax(logits)

        return p_seq, wav_lens

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss (CTC+NLL) given predictions and targets."""

        
        p_seq, wav_lens = predictions

        ids = batch.id
        tokens_eos, tokens_eos_lens = batch.tokens_eos
        tokens, tokens_lens = batch.tokens

        loss = self.hparams.seq_cost(
            p_seq, tokens_eos, tokens_eos_lens)
        

        return loss

    def fit_batch(self, batch):
        """Train the parameters given a single batch in input"""
        predictions = self.compute_forward(batch, sb.Stage.TRAIN)
        loss = self.compute_objectives(predictions, batch, sb.Stage.TRAIN)
        loss.backward()
        if self.check_gradients(loss):
            self.optimizer.step()
        self.optimizer.zero_grad()
        return loss.detach()


In [None]:
from speechbrain.pretrained import EncoderDecoderASR
asr_model = EncoderDecoderASR.from_hparams(source="speechbrain/asr-crdnn-rnnlm-librispeech", savedir="./pretrained_ASR")
modules = {"enc": asr_model.mods.encoder.model, 
           "emb": asr_model.hparams.emb,
           "dec": asr_model.hparams.dec,
           "compute_features": asr_model.mods.encoder.compute_features, # we use the same features 
           "normalize": asr_model.mods.encoder.normalize,
           "seq_lin": asr_model.hparams.seq_lin, 
           
          }

hparams = {"seq_cost": lambda x, y, z: speechbrain.nnet.losses.nll_loss(x, y, z, label_smoothing = 0.1),
            "log_softmax": speechbrain.nnet.activations.Softmax(apply_log=True)}

brain = EncDecFineTune(modules, hparams=hparams, opt_class=lambda x: torch.optim.SGD(x, 1e-5))
brain.tokenizer = asr_model.tokenizer

Downloading:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.41k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/480M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/212M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/253k [00:00<?, ?B/s]

In [None]:
import speechbrain as sb
import torch

In [None]:
from parse_data import parse_to_json # parse_data is a local library downloaded before (see Installing Dependencies step) 
parse_to_json("./LibriSpeech/dev-clean-2")

In [None]:
from speechbrain.dataio.dataset import DynamicItemDataset
dataset = DynamicItemDataset.from_json("data.json")

In [None]:
dataset = dataset.filtered_sorted(sort_key="length", select_n=100)
# we limit the dataset to 100 utterances to keep the trainin short in this Colab example

In [None]:
dataset.add_dynamic_item(sb.dataio.dataio.read_audio, takes="file_path", provides="signal")

In [None]:
# 3. Define text pipeline:
@sb.utils.data_pipeline.takes("words")
@sb.utils.data_pipeline.provides(
        "words", "tokens_list", "tokens_bos", "tokens_eos", "tokens")
def text_pipeline(words):
      yield words
      tokens_list = asr_model.tokenizer.encode_as_ids(words)
      yield tokens_list
      tokens_bos = torch.LongTensor([asr_model.hparams.bos_index] + (tokens_list))
      yield tokens_bos
      tokens_eos = torch.LongTensor(tokens_list + [asr_model.hparams.eos_index]) # we use same eos and bos indexes as in pretrained model
      yield tokens_eos
      tokens = torch.LongTensor(tokens_list)
      yield tokens

In [None]:
dataset.add_dynamic_item(text_pipeline)

In [None]:
dataset.set_output_keys(["id", "signal", "words", "tokens_list", "tokens_bos", "tokens_eos", "tokens"])
dataset[0]

{'id': '1272-141231-0013',
 'signal': tensor([ 3.0518e-05, -1.2207e-04, -1.2207e-04,  ..., -8.5449e-04,
         -7.6294e-04, -6.4087e-04]),
 'tokens': tensor([  2, 100,  59,  99, 191]),
 'tokens_bos': tensor([  0,   2, 100,  59,  99, 191]),
 'tokens_eos': tensor([  2, 100,  59,  99, 191,   0]),
 'tokens_list': [2, 100, 59, 99, 191],
 'words': 'THE TWENTIES'}

In [None]:
h = brain.fit(range(2), train_set=dataset, 
    train_loader_kwargs={"batch_size": 8, "drop_last":True, "shuffle": False})

100%|██████████| 12/12 [07:46<00:00, 38.91s/it, train_loss=1.24]
100%|██████████| 12/12 [07:45<00:00, 38.78s/it, train_loss=1.27]
