Model

In [None]:
import os
import torch
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
import torch.nn.functional as F
import torchaudio
import numpy as np


class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""

    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous()  # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous()  # (batch, channel, feature, time)


class ResidualCNN(nn.Module):
    """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf
        except with layer norm instead of batch norm
    """

    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x  # (batch, channel, feature, time)


class BidirectionalGRU(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x


class SpeechRecognitionModel(nn.Module):

    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats // 2 + 1  # todo even vrsos odd
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i == 0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim * 2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2)  # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x


Train and test

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import wandb
from jiwer import wer
import Model
import Utils
from ASR import WB


def train(model, device, batch_iterator, criterion, optimizer, scheduler, epoch):
    model.train()
    data_len = len(batch_iterator)
    for batch_idx, _data in enumerate(batch_iterator):
        spectrograms, labels, input_lengths, label_lengths = _data
        spectrograms, labels = spectrograms.to(device), labels.to(device)

        optimizer.zero_grad()

        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)  # using log_softmax instead of softmax for numerical stability
        output = output.transpose(0, 1)  # (time, batch, n_class)

        loss = criterion(output, labels, torch.from_numpy(input_lengths), torch.from_numpy(label_lengths))
        loss.backward()

        optimizer.step()
        scheduler.step()

        if batch_idx % 50 == 0 or batch_idx == data_len:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(spectrograms), data_len,
                       100. * batch_idx / data_len, loss.item()))
            if WB:
                wandb.log({"train_loss": loss.item()})
                decoded_preds, decoded_targets = Utils.greedy_decoder(output.transpose(0, 1), labels,
                                                                           label_lengths)
                wer_sum = 0
                for j in range(len(decoded_preds)):
                    wer_sum += wer(decoded_targets[j], decoded_preds[j])

                wandb.log({"train_wer": wer_sum / len(decoded_preds)})


def validation(model, device, val_loader, criterion, epoch):
    print('\nevaluating...')
    model.eval()
    val_loss = 0
    val_wer = []
    with torch.no_grad():
        for i, _data in enumerate(val_loader):
            spectrograms, labels, input_lengths, label_lengths = _data
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            output = model(spectrograms)  # (batch, time, n_class)
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1)  # (time, batch, n_class)

            loss = criterion(output, labels, torch.from_numpy(input_lengths), torch.from_numpy(label_lengths))
            val_loss += loss.item() / len(val_loader)

            decoded_preds, decoded_targets = Utils.greedy_decoder(output.transpose(0, 1), labels, label_lengths)
            for j in range(len(decoded_preds)):
                val_wer.append(wer(decoded_targets[j], decoded_preds[j]))

    avg_wer = sum(val_wer) / len(val_wer)
    # experiment.log_metric('val_loss', val_loss, step=iter_meter.get())
    # experiment.log_metric('wer', avg_wer, step=iter_meter.get())

    print(
        'val set: Average loss: {:.4f}, Average WER: {:.4f}\n'.format(val_loss, avg_wer))

    # print a sample of the val data and decoded predictions against the true labels
    if epoch % 10 == 0:
        print('Ground Truth -> Decoded Prediction')
        for i in range(10):
            print('{} -> {}'.format(decoded_targets[i], decoded_preds[i]))

    if WB:
        wandb.log({"val_loss": val_loss})
        wandb.log({"val_wer": avg_wer})


def train_and_validation(hparams, batch_iterators):
    train_loader = batch_iterators[0]
    val_loader = batch_iterators[1]
    epochs = hparams['epochs']

    torch.manual_seed(7)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = Model.SpeechRecognitionModel(
        hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
    ).to(device)

    optimizer = optim.AdamW(model.parameters(), hparams['learning_rate'])
    criterion = nn.CTCLoss(blank=28).to(device)
    scheduler = optim.lr_scheduler.OneCycleLR(optimizer, max_lr=hparams['learning_rate'],
                                              steps_per_epoch=len(train_loader),  # todo: check if this is correct
                                              epochs=hparams['epochs'],
                                              anneal_strategy='linear')

    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, criterion, optimizer, scheduler, epoch)
        validation(model, device, val_loader, criterion, epoch)


Utils

In [None]:
import torch


class TextTransform:
    """Maps characters to integers and vice versa"""

    def __init__(self):
        char_map_str = """
        ' 0
        <SPACE> 1
        a 2
        b 3
        c 4
        d 5
        e 6
        f 7
        g 8
        h 9
        i 10
        j 11
        k 12
        l 13
        m 14
        n 15
        o 16
        p 17
        q 18
        r 19
        s 20
        t 21
        u 22
        v 23
        w 24
        x 25
        y 26
        z 27
        """
        self.char_map = {}
        self.index_map = {}
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch
        self.index_map[1] = ' '

    def text_to_int(self, text):
        """ Use a character map and convert text to an integer sequence """
        int_sequence = []
        for c in text:
            if c == ' ':
                ch = self.char_map['<SPACE>']
            else:
                ch = self.char_map[c]
            int_sequence.append(ch)
        return int_sequence

    def int_to_text(self, labels):
        """ Use a character map and convert integer labels to an text sequence """
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return ''.join(string).replace('<SPACE>', ' ')


text_transform = TextTransform()


def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
    """
    Greedy Decoder. Decodes the output of the network by picking the label with the highest probability at each time step.
    """
    arg_maxes = torch.argmax(output, dim=2)
    decodes = []
    targets = []
    for i, args in enumerate(arg_maxes):
        decode = []
        targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j - 1]:
                    continue
                decode.append(index.item())
        decodes.append(text_transform.int_to_text(decode))
    return decodes, targets



In [None]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn as nn
from Utils import TextTransform
from HyperParameters import hparams


class BatchIterator:
    def __init__(self, x, y, input_lengths, label_lengths, batch_size, shuffle=True):
        self.x = x
        self.y = y
        self.input_lengths = input_lengths
        self.label_lengths = label_lengths
        self.batch_size = batch_size
        self.num_samples = len(x)
        self.num_batches = (self.num_samples + batch_size - 1) // batch_size
        self.shuffle = shuffle
        self.current_batch = 0

    def __iter__(self):
        if self.shuffle:
            indices = np.random.permutation(self.num_samples)
            self.x = self.x[indices]
            self.y = self.y[indices]
            self.input_lengths = np.asarray(self.input_lengths)[indices]
            self.label_lengths = np.asarray(self.label_lengths)[indices]
        self.current_batch = 0
        return self

    def __len__(self):
        return self.num_batches

    def __next__(self):
        if self.current_batch < self.num_batches:
            start_idx = self.current_batch * self.batch_size
            end_idx = min((self.current_batch + 1) * self.batch_size, self.num_samples)

            batch_x = self.x[start_idx:end_idx]
            batch_y = self.y[start_idx:end_idx]
            batch_input_lengths = self.input_lengths[start_idx:end_idx]
            batch_label_lengths = self.label_lengths[start_idx:end_idx]

            self.current_batch += 1

            return batch_x, batch_y, batch_input_lengths, batch_label_lengths
        else:
            raise StopIteration


def load_wavs_data(load_again=False, save=False,
                   path=r".\an4"):
    text_transform = TextTransform()

    if load_again:
        all_spectrogram = torch.load("data/all_spectrogram.pt")
        all_labels = torch.load("data/all_labels.pt")
        all_input_lengths = torch.load("data/all_input_lengths.pt")
        all_label_lengths = torch.load("data/all_label_lengths.pt")
        return all_spectrogram, all_labels, all_input_lengths, all_label_lengths

    valid_file = ["test", "train", "val"]
    all_spectrogram = {"test": [], "train": [], "val": []}
    all_labels = {"test": [], "train": [], "val": []}
    all_input_lengths = {"test": [], "train": [], "val": []}
    all_label_lengths = {"test": [], "train": [], "val": []}
    for dir in os.listdir(path):
        if dir in valid_file:
            spectrogram = []
            labels = []
            input_lengths = []
            label_lengths = []
            for root2, dirs2, files2 in os.walk(os.path.join(path, dir)):
                for file in files2:
                    if file.endswith(".txt"):
                        # change suffix to wav
                        wav = file.replace(".txt", ".wav")
                        # print(os.path.join(root2, file))
                        # change last dir to wav instead of txt
                        root2_wav = root2.replace("txt", "wav")

                        # load wav:
                        waveform, sample_rate = torchaudio.load(os.path.join(root2_wav, wav))
                        # mfcc:
                        mfcc = torchaudio.transforms.MFCC(sample_rate=sample_rate, n_mfcc=13)(waveform)
                        mfcc = mfcc.squeeze(0)
                        mfcc = mfcc.transpose(0, 1)
                        # add mfcc to y:
                        spectrogram.append(mfcc)

                        # load txt:
                        with open(os.path.join(root2, file), 'r') as f:
                            text = f.read()
                            int_text = text_transform.text_to_int(text.lower())
                        # add text to labels:
                        label = torch.Tensor(int_text)
                        labels.append(label)
                        input_lengths.append(mfcc.shape[0] // 2)  # todo why divide by 2?
                        label_lengths.append(len(label))
                    # print(file)

            # todo maybe the padding should be done for all the data together (train test and val)
            spectrogram = nn.utils.rnn.pad_sequence(spectrogram, batch_first=True).unsqueeze(1).transpose(2, 3)
            labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

            all_spectrogram[dir] = spectrogram
            all_labels[dir] = labels
            all_input_lengths[dir] = input_lengths
            all_label_lengths[dir] = label_lengths

    if save:
        torch.save(all_spectrogram, "data/all_spectrogram.pt")
        torch.save(all_labels, "data/all_labels.pt")
        torch.save(all_input_lengths, "data/all_input_lengths.pt")
        torch.save(all_label_lengths, "data/all_label_lengths.pt")
    return all_spectrogram, all_labels, all_input_lengths, all_label_lengths


def get_batch_iterator(data_type, batch_size=hparams["batch_size"]):
    if data_type not in ["test", "train", "val"]:
        raise ValueError("data_type must be one of [test, train, val]")
    all_spectrogram, all_labels, all_input_lengths, all_label_lengths = load_wavs_data(load_again=True, save=True)
    return BatchIterator(all_spectrogram[data_type], all_labels[data_type],
                         all_input_lengths[data_type], all_label_lengths[data_type], batch_size)


Main

In [None]:
import TrainAndEvaluation
import DataLoader
import wandb
from HyperParameters import hparams
from HyperParameters import WB

DESCRIPTION = 'initial work'
RUN = 'Complex Model'


def init_w_and_b():
    epochs = hparams['epochs']
    learning_rate = hparams['learning_rate']

    if WB:
        wandb.init(
            # Set the project where this run will be logged
            group="Complex Model initial work",
            project="ASR",
            # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
            name=f"{DESCRIPTION}{RUN}_{hparams}_epochs",
            notes='checking if log is work properly',
            # Track hyperparameters and run metadata
            config={
                "learning_rate": learning_rate,
                "architecture": "assembly",
                "dataset": "AN4",
                "epochs": epochs,

            })


def main():
    if WB:
        wandb.login()
        init_w_and_b()

    train_batch_iterator = DataLoader.get_batch_iterator("train", hparams["batch_size"])
    test_batch_iterator = DataLoader.get_batch_iterator("test", hparams["batch_size"])
    val_batch_iterator = DataLoader.get_batch_iterator("val", hparams["batch_size"])
    all_iterators = [train_batch_iterator, test_batch_iterator, val_batch_iterator]

    TrainAndEvaluation.train_and_validation(hparams, all_iterators)

    if WB:
        wandb.finish()


main()



evaluating...
val set: Average loss: 12.1804, Average WER: 1.0000


evaluating...
val set: Average loss: 3.5923, Average WER: 1.0000


evaluating...
val set: Average loss: 3.3025, Average WER: 1.0000


evaluating...
val set: Average loss: 3.1093, Average WER: 1.0000


evaluating...
val set: Average loss: 3.0277, Average WER: 1.0000


evaluating...
val set: Average loss: 2.9881, Average WER: 1.0000

