In [2]:
import torch, torchaudio
import torch.nn as nn
from torch.nn import functional as F
import torch.optim as optim
import wandb
from torchsummary import summary

from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [3]:
class LibriSpeechOne(Dataset):
    def __init__(self, audio, label):
        self.audio = audio
        self.label = label

    def __len__(self):
        return 1

    def __getitem__(self, idx):
        return self.audio, None, self.label, None, None, None


class LibriSpeechDataset(Dataset):

    def __init__(self, dataset_type, data=None):

        self.audio_transform = nn.Sequential(
            torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
            torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
            torchaudio.transforms.TimeMasking(time_mask_param=100)
        )

        self.dataset_dir = "/home/asblab2/sinarasi/mie1517/new_code/data"
        if dataset_type == "train":
            self.dataset = torchaudio.datasets.LIBRISPEECH(self.dataset_dir, url="train-clean-100", download=True)
        elif dataset_type == "valid":
            self.dataset = torchaudio.datasets.LIBRISPEECH(self.dataset_dir, url="test-clean", download=True)
        elif dataset_type == "one":
            self.dataset = LibriSpeechOne(*data)
        else:
            raise Exception("Invalid dataset type!")


        self.text_to_int = {"'": 0, " ": 1, "a": 2, "b": 3, "c": 4,
                            "d": 5, "e": 6, "f": 7, "g": 8, "h": 9,
                            "i": 10, "j": 11, "k": 12, "l": 13, "m": 14,
                            "n": 15, "o": 16, "p": 17, "q": 18, "r": 19,
                            "s": 20, "t": 21, "u": 22, "v": 23, "w": 24,
                            "x": 25, "y": 26, "z": 27}
        self.int_to_text = {v: k for k, v in self.text_to_int.items()}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        data = self.dataset[index]
        audio, _, sentence, _, _, _ = data
        spectogram = self.audio_transform(audio).squeeze(0).transpose(0, 1)
        label = [self.text_to_int[s] for s in sentence.lower()]
        spectogram_length = spectogram.shape[0] // 2
        label_length = len(label)
        return spectogram, label, spectogram_length, label_length


def collate(data):
    """
    Pad spectograms and labels within the batch to the same length.
    """
    spectograms, labels, spectogram_lengths, label_lengths = [], [], [], []
    for spectogram, label, spectogram_length, label_length in data:
        spectograms += [torch.Tensor(spectogram)]
        labels += [torch.Tensor(label)]
        spectogram_lengths += [spectogram_length]
        label_lengths += [label_length]
    spectograms = nn.utils.rnn.pad_sequence(spectograms, batch_first=True).transpose(1, 2)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)

    return spectograms, labels, torch.tensor(spectogram_lengths), torch.tensor(label_lengths)

In [4]:

class CNNLayerNorm(nn.Module):

    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time)


class ResidualCNN(nn.Module):

    def __init__(self, channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.layers = nn.Sequential(
            CNNLayerNorm(n_feats),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(channels, channels, kernel, stride, padding=kernel//2),
            CNNLayerNorm(n_feats),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Conv2d(channels, channels, kernel, stride, padding=kernel//2)
        )

    def forward(self, x):
        residual = x    # (batch, channel, feature, time)
        x = self.layers(x)
        x += residual
        return x        # (batch, channel, feature, time)

class BiGRU(nn.Module):

    def __init__(self, rnn_dim, hidden_size, dropout):
        super(BiGRU, self).__init__()
        self.bigru = nn.Sequential(
            nn.LayerNorm(rnn_dim),
            nn.GELU(),
            nn.GRU(input_size=rnn_dim, hidden_size=hidden_size, num_layers=1, batch_first=True, bidirectional=True),
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x, _ = self.bigru(x)
        x = self.dropout(x)
        return x


class ASR(nn.Module):

    def __init__(self, dropout, hidden_size, rnn_layers, rescnn_layers, n_mels):

        super(ASR, self).__init__()

        self.dropout = dropout
        self.n_mels = n_mels // 2
        self.lin_start = 128
        self.lin_end = 29
        self.hidden_size = hidden_size
        self.gru_layers = rnn_layers
        self.rescnn_layers = rescnn_layers

        # Process Mel Spectogram via Residual Conv2D Layers
        self.rescnn_layers = nn.Sequential(
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            *[ResidualCNN(32, kernel=3, stride=1, dropout=dropout, n_feats=self.n_mels) for _ in range(rescnn_layers)]
        )

        # Linear layers
        self.fc1 = nn.Sequential(
            nn.LayerNorm(self.n_mels * 32),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(self.n_mels * 32, self.hidden_size),
            nn.LayerNorm(self.hidden_size),
            nn.GELU(),
            nn.Dropout(self.dropout)
        )

        # GRU architecture
        self.gru = nn.Sequential(*[
                    BiGRU(rnn_dim=hidden_size if i==0 else hidden_size*2,
                                    hidden_size=hidden_size, dropout=dropout)
                    for i in range(self.gru_layers)
                ])

        # Linear Layers
        self.fc2 = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.lin_end),
            nn.LayerNorm(self.lin_end),
            nn.GELU(),
            nn.Dropout(self.dropout)
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.rescnn_layers(x)
        x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
        x = x.transpose(1,2) # Since linear layers require input of shape (batch, time, channels=n_mels)
        x = self.fc1(x)
        x = self.gru(x)
        x = self.fc2(x)
        return x, None

In [5]:

hp = {"batch_size": 10,
      "learning_rate": 5e-4,
      "lr_factor": 0.75,
      "lr_patience": 10,
      "epochs": 100,
      "n_mels": 128,
      "sample_rate": 16000,
      "dropout": 0.1,
      "hidden_size": 512,
      "rnn_layers": 5,
      "cnn_layers": 3,
      "architecture": 1}


def compute_validation_loss(net, criterion, dataloader):
    net.eval()
    losses = []
    for data, label, data_len, label_len in tqdm(dataloader):
        data, label, data_len, label_len = data.cuda(), label.cuda(), data_len.cuda(), label_len.cuda()
        out, _ = net(data)
        out = F.log_softmax(out, dim=2)
        out = out.transpose(0, 1)
        loss = criterion(out, label, data_len, label_len)
        losses += [loss.item()]

    return sum(losses) / len(losses)


def train():

    # Create wandb logger
    wandb.login()

    # Initialize model
    asr_model = ASR(hp["dropout"], hp["hidden_size"], hp["rnn_layers"], hp["cnn_layers"], hp["n_mels"])
    asr_model = asr_model.cuda()

    # Datasets
    train_dataset = LibriSpeechDataset("train")
    valid_dataset = LibriSpeechDataset("valid")

    train_loader = DataLoader(dataset=train_dataset,
                                batch_size=hp["batch_size"],
                                shuffle=True,
                                collate_fn=collate,
                                num_workers=3,
                                pin_memory=False)
    valid_loader = DataLoader(dataset=valid_dataset,
                                batch_size=hp["batch_size"],
                                shuffle=False,
                                collate_fn=collate,
                                num_workers=3,
                                pin_memory=False)

    # Train
    optimizer = optim.Adam(asr_model.parameters(), lr=hp["learning_rate"])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=hp["lr_factor"], patience=hp["lr_patience"])
    criterion = nn.CTCLoss(blank=28, zero_infinity=True)

    min_valid_loss = 1e10
    min_train_loss = 1e10

    with wandb.init(project="MIE1517", config=hp):
        wandb.watch(asr_model, log="all")

        for epoch in range(hp["epochs"]):

            asr_model = asr_model.train(True)

            train_losses = []
            for data, label, data_len, label_len in tqdm(train_loader, desc="Epoch {0} / {1}".format(epoch, hp["epochs"])):
                data, label, data_len, label_len = data.cuda(), label.cuda(), data_len.cuda(), label_len.cuda()
                out, _ = asr_model(data)
                out = F.log_softmax(out, dim=2)
                out = out.transpose(0, 1) # CTCLoss takes batch_size as second dim
                loss = criterion(out, label, data_len, label_len)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()
                train_losses += [loss.item()]

            train_loss = sum(train_losses) / len(train_losses)
            avg_valid_loss = compute_validation_loss(asr_model, criterion, valid_loader)
            scheduler.step(avg_valid_loss)

            # Save checkpoint if valid loss is at minimum
            if avg_valid_loss < min_valid_loss:
                print("Saved valid checkpoint!")
                torch.save(asr_model.state_dict(), "/home/asblab2/sinarasi/mie1517/MIE1517-Project/Speech Recog/best_model.pth")
                min_valid_loss = avg_valid_loss
            if train_loss < min_train_loss:
                print("Saved train checkpoint!")
                torch.save(asr_model.state_dict(), "/home/asblab2/sinarasi/mie1517/MIE1517-Project/Speech Recog/mid_model.pth")
                min_train_loss = train_loss
            print("Losses:", train_loss, avg_valid_loss, optimizer.param_groups[0]['lr'])

            wandb.log({"train_loss": train_loss,
                        "valid_loss": avg_valid_loss}, step=epoch)


In [6]:
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True

# train()