In [1]:
import torch
import torch.nn as nn
from torch.nn import functional 
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torchaudio

import string
import itertools
from collections import Counter
from tqdm import tqdm
import unidecode
import speechbrain
from speechbrain.nnet.loss.transducer_loss import TransducerLoss

The torchaudio backend is switched to 'soundfile'. Note that 'sox_io' is not supported on Windows.
The torchaudio backend is switched to 'soundfile'. Note that 'sox_io' is not supported on Windows.


In [2]:
#Librispeech的数据可以通过torchaudio去简单的导入
#SpeechData_Train = torchaudio.datasets.LIBRISPEECH('./data', 'train-clean-100', 'LibriSpeech', False)
#(tensor([[-0.0033,  0.0000,  0.0003,  ...,  0.0056,  0.0054,  0.0055]]), 16000, 'PETER HAD ASKED HIM OF COURSE FOR MATTHEW CUTHBERT HAD NEVER BEEN KNOWN TO VOLUNTEER INFORMATION ABOUT ANYTHING IN HIS WHOLE LIFE AND YET HERE WAS MATTHEW CUTHBERT AT HALF PAST THREE ON THE AFTERNOON OF A BUSY DAY PLACIDLY DRIVING OVER THE HOLLOW AND UP THE HILL', 103, 1240, 10)

In [3]:
#label ='#abcdefghijklmnopqrstuvwxyz\' '，其中‘#’在这里也解释为NULL（空）
#spectrogram = torchaudio.transforms.MFCC()
#train_sets = []
#test_sets = []
#for _ ,a in enumerate(SpeechData_Train):
#    speech_signal = spectrogram(a[0]) #speech tensor
#    speech_sets.append(speech_signal)
#    text_signal = a[2].lower()
#    text_tensor = []
#    for i in text_signal:
#        if i not in label:
#            print(i)
#        text_tensor.append(label.index(i))
#    text_tensor = torch.Tensor(text_tensor).long()
#    text_sets.append(text_tensor)
#torch.save(speech_sets,'./speech_sets.pth')
#torch.save(text_sets,'./text_sets.pth')

In [4]:
#预先处理好的音频特征和文本数据
speech_sets = torch.load('./speech_sets.pth')
text_sets = torch.load('./text_sets.pth')
NULL_INDEX = 0 #起始字符、空字符

In [5]:
#构建编码器 LAS金字塔式encoder
#将语音信号编码成一个固定大小的向量
#embed_size:mfcc的维度，为40
#num_hiddens:上下文向量的维度
class Encoder(nn.Module):
    def __init__(self, embed_size, num_hiddens, num_layers, dropout = 0, **kwargs):
        super(Encoder,self).__init__(**kwargs)
        self.embed_size = embed_size
        self.num_hiddens = num_hiddens
        self.rnn1 = nn.GRU(embed_size, num_hiddens, num_layers, dropout=dropout)
        self.rnn2 = nn.GRU(2*num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.rnn3 = nn.GRU(2*num_hiddens, num_hiddens, num_layers, dropout=dropout)
        self.linear = nn.Linear(num_hiddens, num_hiddens)
        
    def forward(self, signal_mfcc):
        batch_size = signal_mfcc.shape[0]
        signal_mfcc = signal_mfcc.permute(1,0,2)
        output,state= self.rnn1(signal_mfcc)
        if len(output)%2 == 0:
            output = output.permute(1,0,2).reshape(batch_size,-1,2*self.num_hiddens)
        else:
            padding_matrix = torch.zeros(batch_size,self.num_hiddens).unsqueeze(dim = 1).to(signal_mfcc.device)
            output = torch.cat((output.permute(1,0,2),padding_matrix),dim = 1).reshape(batch_size,-1,2*self.num_hiddens)
        output = output.permute(1,0,2)
        output, state = self.rnn2(output)
        if len(output)%2 == 0:
            output = output.permute(1,0,2).reshape(batch_size,-1,2*self.num_hiddens)
        else:
            padding_matrix = torch.zeros(batch_size,self.num_hiddens).unsqueeze(dim = 1).to(signal_mfcc.device)
            output = torch.cat((output.permute(1,0,2),padding_matrix),dim = 1).reshape(batch_size,-1,2*self.num_hiddens)
        output = output.permute(1,0,2)
        output, state = self.rnn3(output)
        return (output,state)

In [6]:
#预测器
class Predictor(torch.nn.Module):
    def __init__(self, num_outputs, predictor_dim, joiner_dim, NULL_INDEX):
        super(Predictor, self).__init__()
        self.embed = torch.nn.Embedding(num_outputs, predictor_dim)
        self.rnn = torch.nn.GRUCell(input_size=predictor_dim, hidden_size=predictor_dim)
        self.linear = torch.nn.Linear(predictor_dim, joiner_dim)
    
        self.initial_state = torch.nn.Parameter(torch.randn(predictor_dim))
        self.start_symbol = NULL_INDEX # In the original paper, a vector of 0s is used; just using the null index instead is easier when using an Embedding layer.

    def forward_one_step(self, input, previous_state):
        embedding = self.embed(input)
        state = self.rnn.forward(embedding, previous_state)
        out = self.linear(state)
        return out, state

    def forward(self, y):
        batch_size = y.shape[0]
        U = y.shape[1]
        outs = []
        state = torch.stack([self.initial_state] * batch_size).to(y.device)
        for u in range(U+1): # need U+1 to get null output for final timestep 
            if u == 0:
                decoder_input = torch.tensor([self.start_symbol] * batch_size).to(y.device)
            else:
                decoder_input = y[:,u-1]
            out, state = self.forward_one_step(decoder_input, state)
            outs.append(out)
        out = torch.stack(outs, dim=1)
        return out

In [7]:
class Joiner(torch.nn.Module):
    def __init__(self, num_outputs, joiner_dim):
        super(Joiner, self).__init__()
        self.linear = torch.nn.Linear(joiner_dim, num_outputs)

    def forward(self, encoder_out, predictor_out):
        out = encoder_out + predictor_out
        out = torch.nn.functional.relu(out)
        out = self.linear(out)
        return out

In [8]:
#将编码器、预测器打包起来
class Transducer(torch.nn.Module):
    def __init__(self, embed_size, encoder_dim, num_layers, dropout, num_outputs, predictor_dim, joiner_dim, NULL_INDEX):
        super(Transducer, self).__init__()
        self.encoder = Encoder(embed_size, encoder_dim, num_layers, dropout = dropout)
        self.predictor = Predictor(num_outputs, predictor_dim, joiner_dim, NULL_INDEX)
        self.joiner = Joiner(num_outputs,joiner_dim)
        self.transducer_loss = TransducerLoss(0)

        if torch.cuda.is_available(): self.device = "cuda:0"
        else: self.device = "cpu"
        self.to(self.device)

    def compute_forward_prob(self, joiner_out, T, U, y, NULL_INDEX):
        """
        joiner_out: tensor of shape (B, T_max, U_max+1, #labels)
        T: list of input lengths
        U: list of output lengths 
        y: label tensor (B, U_max+1)
        """
        B = joiner_out.shape[0]
        T_max = joiner_out.shape[1]
        U_max = joiner_out.shape[2] - 1
        log_alpha = torch.zeros(B, T_max, U_max+1).to(self.device)
        #print(log_alpha.shape)
        for t in range(T_max):
            for u in range(U_max+1):
                if u == 0:
                    if t == 0:
                        log_alpha[:, t, u] = 0.

                    else: #t > 0
                        log_alpha[:, t, u] = log_alpha[:, t-1, u] + joiner_out[:, t-1, 0, NULL_INDEX] 
                  
                else: #u > 0
                    if t == 0:
                        log_alpha[:, t, u] = log_alpha[:, t,u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
            
                    else: #t > 0
                        log_alpha[:, t, u] = torch.logsumexp(torch.stack([
                            log_alpha[:, t-1, u] + joiner_out[:, t-1, u, NULL_INDEX],
                            log_alpha[:, t, u-1] + torch.gather(joiner_out[:, t, u-1], dim=1, index=y[:,u-1].view(-1,1) ).reshape(-1)
                        ]), dim=0)
    
        log_probs = []
        for b in range(B):
            log_prob = log_alpha[b, T[b]-1, U[b]] + joiner_out[b, T[b]-1, U[b], NULL_INDEX]
            log_probs.append(log_prob)
        log_probs = torch.stack(log_probs) 
        return log_prob

    def compute_loss(self, x, y, T, U,NULL_INDEX):
        encoder_out = self.encoder.forward(x)[0].permute(1,0,2)
        T = T.long()
        predictor_out = self.predictor.forward(y)
        U = U.long()
        joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
        loss = -self.compute_forward_prob(joiner_out, T, U, y, NULL_INDEX).mean()
        return loss
    def greedy_search(self, x, T):
        y_batch = []
        B = len(x)
        encoder_out = self.encoder.forward(x)[0]
        U_max = 300
        for b in range(B):
            t = 0; u = 0; y = [self.predictor.start_symbol]; predictor_state = self.predictor.initial_state.unsqueeze(0)
            while t < T[b] and u < U_max:
                predictor_input = torch.tensor([ y[-1] ]).to(x.device)
                g_u, predictor_state = self.predictor.forward_one_step(predictor_input, predictor_state)
                f_t = encoder_out[b, t]
                h_t_u = self.joiner.forward(f_t, g_u)
                argmax = h_t_u.max(-1)[1].item()
                if argmax == NULL_INDEX:
                    t += 1
                else: # argmax == a label
                    u += 1
                    y.append(argmax)
            y_batch.append(y[1:]) # 去掉首元素
        return y_batch
    def compute_loss1(self, x, y, T, U):
        encoder_out = self.encoder.forward(x)[0].permute(1,0,2)
        predictor_out = self.predictor.forward(y)
        joiner_out = self.joiner.forward(encoder_out.unsqueeze(2), predictor_out.unsqueeze(1)).log_softmax(3)
        #loss = -self.compute_forward_prob(joiner_out, T, U, y).mean()
        T = T.to(joiner_out.device)
        U = U.to(joiner_out.device)
        loss = self.transducer_loss(joiner_out, y, T, U) #, blank_index=NULL_INDEX, reduction="mean")
        return loss

In [9]:
#数据处理，构建数据的dataoader
class SpeechDataset(torch.utils.data.Dataset):
    def __init__(self, speech_sets, text_sets, batch_size):
        self.length = len(speech_sets)
        self.SpeechSets = speech_sets
        self.TextSets = text_sets
        collate = Collate()
        self.loader = torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=False, collate_fn=collate)
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        x = self.SpeechSets[idx]
        y = self.TextSets[idx]
        
        return (x.squeeze(dim = 0).permute(1,0), y)
    
class Collate:
    def __call__(self, batch):
        """
        batch: list of tuples (input string, output string)[T,Mel_dim]\[U]
        Returns a minibatch of strings, encoded as labels and padded to have the same length.
        """
        x = []; y = []
        batch_size = len(batch)
        for index in range(batch_size):
            x_,y_ = batch[index]
            x.append(x_)
            y.append(y_)

    # pad all sequences to have same length
        T = [len(x_) for x_ in x]
        U = [len(y_) for y_ in y]
        T_max = max(T)
        U_max = max(U)
        for index in range(batch_size):
            x[index] = torch.cat((torch.tensor([[NULL_INDEX] * 40 ] * (T_max - len(x[index]))),x[index]),dim = 0)
            y[index] = torch.cat((torch.tensor([NULL_INDEX] * (U_max - len(y[index]))), y[index]),dim = 0)

        # stack into single tensor
        x = torch.stack(x)
        y = torch.stack(y).long()
        T = torch.tensor(T)
        U = torch.tensor(U)

        return (x,y,T,U)

In [14]:
class Trainer:
    def __init__(self, model, lr):
        self.model = model
        self.lr = lr
        self.optimizer = torch.optim.Adam(model.parameters(), lr=self.lr)
  
    def train(self, dataset, print_interval = 20):
        train_loss = 0
        num_samples = 0 
        self.model.train()
        pbar = tqdm(dataset.loader)
        for idx, batch in enumerate(pbar):
            x,y,T,U = batch
            x = x.to(self.model.device)
            y = y.to(self.model.device)
            T = T.to(self.model.device)
            U = U.to(self.model.device)
            batch_size = len(x)
            num_samples += batch_size
            #loss = self.model.compute_loss1(x,y,torch.ceil(T/4).long(),U,0)
            loss = self.model.compute_loss1(x,y,torch.ceil(T/4).long(),U)
            self.optimizer.zero_grad()
            
            loss.backward()
            self.optimizer.step()
            with torch.no_grad():
                pbar.set_description("%.2f" % loss.item())
                train_loss += loss.item() * batch_size
                #if idx % print_interval == 0:
                #    #self.model.eval()
                #    guesses = self.model.greedy_search(x,T)
                #   #self.model.train()
                #    print("\n")
                #    for b in range(2):
                #        print("guess:", guesses[b])
                #        print("truth:", y[b,:U[b]])
                #        print("")
                train_loss /= num_samples
        return train_loss

    def test(self, dataset, print_interval=20):
        test_loss = 0
        num_samples = 0
        self.model.eval()
        pbar = tqdm(dataset.loader)
        for idx, batch in enumerate(pbar):
            x,y,T,U = batch
            x = x.to(self.model.device)
            y = y.to(self.model.device)
            T = T.to(self.model.device)
            U = U.to(self.model.device)
            batch_size = len(x)
            num_samples += batch_size
            loss = self.model.compute_loss(x,y,torch.ceil(T/4).long(),U,0)
            pbar.set_description("%.2f" % loss.item())
            test_loss += loss.item() * batch_size
            if idx % print_interval == 0:
                print("\n")
                print("guess:", self.model.greedy_search(x,T)[0])
                print("truth:", y[0,:U[0]])
                print("")
        test_loss /= num_samples
        return test_loss

In [15]:
end_length = int(0.1 * len(speech_sets))
train_sets = SpeechDataset(speech_sets[:end_length],text_sets[:end_length],batch_size = 10)
tranducer = Transducer(embed_size = 40, encoder_dim = 100, num_layers = 1, dropout  = 0, 
                       num_outputs = 29, predictor_dim = 100, joiner_dim = 100, NULL_INDEX = 0)
tranducer.load_state_dict(torch.load('./tranducer.pth'))
trainer = Trainer(model=tranducer, lr=0.003)
num_epochs = 10
train_losses=[]
test_losses=[]
for epoch in range(num_epochs):
    train_loss = trainer.train(train_sets)
    #test_loss = trainer.test(train_sets)
    train_losses.append(train_loss)
    #test_losses.append(test_loss)
    print("Epoch %d: train loss = %f, test loss = null" % (epoch, train_loss))

0.14: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.73it/s]
0.08: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.73it/s]
0.05: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.73it/s]
0.03: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.73it/s]
0.03: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.72it/s]
0.02: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.73it/s]
0.02: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.73it/s]
0.02: 100%|██████████████████████████████████████████████████████████████████████████| 286/286 [02:45<00:00,  1.73it/s]
0.01: 100%|█████████████████████████████

In [16]:
torch.save(tranducer.state_dict(),'./tranducer.pth')