In [1]:
import os
import random
from io import open
import unicodedata
import string
import re

import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from pathlib import Path
import kaldiio
import sys
import gc
import json
import time
from lm_utils import load_dataset, ParallelSentenceIterator


%matplotlib inline

print_use = False

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="3"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

In [3]:
train_label="/home1/meichaoyang/workspace/git/espnet/egs/aishell2/asr2/data/local/lm_train/train.txt"
valid_label="/home1/meichaoyang/workspace/git/espnet/egs/aishell2/asr2/data/local/lm_train/valid.txt"
dict_path="/home1/meichaoyang/workspace/git/espnet/egs/aishell2/asr2/data/lang_1char/train_sp_units.txt"

### 读取词典

In [4]:
with open(dict_path, "rb") as f:
    dictionary = f.readlines()
char_list = [entry.decode("utf-8").split(" ")[0] for entry in dictionary]
char_list.insert(0, "<blank>")
char_list.append("<eos>")
char_list_dict = {x: i for i, x in enumerate(char_list)}
n_vocab = len(char_list)

In [5]:
train, n_train_tokens, n_train_oovs = load_dataset(train_label, char_list_dict)

933907it [00:07, 128065.04it/s]


In [6]:
val, n_val_tokens, n_val_oovs = load_dataset(valid_label, char_list_dict)

50492it [00:00, 134685.67it/s]


In [7]:
maxlen=100
batch_size=32
unk = char_list_dict["<unk>"]
eos = char_list_dict["<eos>"]

In [8]:
train_iter = ParallelSentenceIterator(
        train,
        batch_size,
        max_length=maxlen,
        sos=eos,
        eos=eos,
        shuffle=True,
    )

In [9]:
val_iter = ParallelSentenceIterator(
        train,
        batch_size,
        max_length=maxlen,
        sos=eos,
        eos=eos,
        shuffle=True,
    )

In [10]:
class RNN_LM(nn.Module):
    def __init__(self, n_vocab, n_layers, n_units, n_embed=None, dropout_rate=0.5):
        nn.Module.__init__(self)
        if n_embed is None:
            n_embed = n_units
        self.embed = nn.Embedding(n_vocab, n_embed)
        self.rnn = nn.ModuleList(
                [nn.LSTMCell(n_embed, n_units)]
                + [nn.LSTMCell(n_units, n_units) for _ in range(n_layers - 1)]
            )
        self.lo = nn.Linear(n_units, n_vocab)
        self.n_layers = n_layers
        self.n_units = n_units
        self.dropout = nn.ModuleList(
            [nn.Dropout(dropout_rate) for _ in range(n_layers + 1)]
        )

    def zero_state(self, batchsize):
        """Initialize state."""
        p = next(self.parameters())
        return torch.zeros(batchsize, self.n_units).to(device=p.device, dtype=p.dtype)
    
    
    def forward(self, state, x):
        """Forward neural networks."""
        if state is None:
            h = [
                self.zero_state(x.size(0)).to(device)
                for n in range(self.n_layers)
            ]
            state = {"h": h}
            c = [
                self.zero_state(x.size(0)).to(device)
                for n in range(self.n_layers)
            ]
            state = {"c": c, "h": h}

        h = [None] * self.n_layers
        emb = self.embed(x)
        c = [None] * self.n_layers
        h[0], c[0] = self.rnn[0](
            self.dropout[0](emb), (state["h"][0], state["c"][0])
        )
        for n in range(1, self.n_layers):
            h[n], c[n] = self.rnn[n](
                self.dropout[n](h[n - 1]), (state["h"][n], state["c"][n])
            )
        state = {"c": c, "h": h}
        
        y = self.lo(self.dropout[-1](h[-1]))
        return state, y

In [11]:
class ClassifierWithState(nn.Module):
    """A wrapper for pytorch RNNLM."""

    def __init__(
        self, rnnlm, lossfun=nn.CrossEntropyLoss(reduction="none"), label_key=-1
    ):
        """Initialize class.

        :param torch.nn.Module predictor : The RNNLM
        :param function lossfun : The loss function to use
        :param int/str label_key :

        """
        
        super(ClassifierWithState, self).__init__()
        self.lossfun = lossfun
        self.y = None
        self.loss = None
        self.label_key = int(label_key)
        self.predictor = rnnlm

    def forward(self, state, *args, **kwargs):
        """Compute the loss value for an input and label pair.

        Notes:
            It also computes accuracy and stores it to the attribute.
            When ``label_key`` is ``int``, the corresponding element in ``args``
            is treated as ground truth labels. And when it is ``str``, the
            element in ``kwargs`` is used.
            The all elements of ``args`` and ``kwargs`` except the groundtruth
            labels are features.
            It feeds features to the predictor and compare the result
            with ground truth labels.

        :param torch.Tensor state : the LM state
        :param list[torch.Tensor] args : Input minibatch
        :param dict[torch.Tensor] kwargs : Input minibatch
        :return loss value
        :rtype torch.Tensor

        """

        if not (-len(args) <= self.label_key < len(args)):
            msg = "Label key %d is out of bounds" % self.label_key
            raise ValueError(msg)
        t = args[self.label_key]
        if self.label_key == -1:
            args = args[:-1]
        else:
            args = args[: self.label_key] + args[self.label_key + 1 :]


        self.y = None
        self.loss = None
        state, self.y = self.predictor(state, *args, **kwargs)
        self.loss = self.lossfun(self.y, t)
        return state, self.loss

    def predict(self, state, x):
        """Predict log probabilities for given state and input x using the predictor.

        :param torch.Tensor state : The current state
        :param torch.Tensor x : The input
        :return a tuple (new state, log prob vector)
        :rtype (torch.Tensor, torch.Tensor)
        """
        if hasattr(self.predictor, "normalized") and self.predictor.normalized:
            return self.predictor(state, x)
        else:
            state, z = self.predictor(state, x)
            return state, F.log_softmax(z, dim=1)



In [12]:
class DefaultRNNLM(nn.Module):
  

    def __init__(self, n_vocab, layer, unit, embed_unit, dropout_rate=0.5):
        """Initialize class.

        Args:
            n_vocab (int): The size of the vocabulary

        """
        nn.Module.__init__(self)

        self.model = ClassifierWithState(
            RNN_LM(n_vocab, layer, unit, embed_unit, dropout_rate)
        )



    def forward(self, x, t):
        """Compute LM loss value from buffer sequences.

        Args:
            x (torch.Tensor): Input ids. (batch, len)
            t (torch.Tensor): Target ids. (batch, len)

        Returns:
            tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
                loss to backward (scalar),
                negative log-likelihood of t: -log p(t) (scalar) and
                the number of elements in x (scalar)

        Notes:
            The last two return values are used
            in perplexity: p(t)^{-n} = exp(-log p(t) / n)

        """
        loss = 0
        logp = 0
        count = torch.tensor(0).long()
        state = None
        batch_size, sequence_length = x.shape
        for i in range(sequence_length):
            # Compute the loss at this time step and accumulate it
            state, loss_batch = self.model(state, x[:, i], t[:, i])
            non_zeros = torch.sum(x[:, i] != 0, dtype=loss_batch.dtype)
            loss += loss_batch.mean() * non_zeros
            logp += torch.sum(loss_batch * non_zeros)
            count += int(non_zeros)
        return loss / batch_size, loss, count.to(loss.device)

    def score(self, y, state, x):
        """Score new token.

        Args:
            y (torch.Tensor): 1D torch.int64 prefix tokens.
            state: Scorer state for prefix tokens
            x (torch.Tensor): 2D encoder feature that generates ys.

        Returns:
            tuple[torch.Tensor, Any]: Tuple of
                torch.float32 scores for next token (n_vocab)
                and next state for ys

        """
        new_state, scores = self.model.predict(state, y[-1].unsqueeze(0))
        return scores.squeeze(0), new_state


In [13]:
import time
import math


def asMinutes(s):
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s / (percent)
    rs = es - s
    return '%s (- %s)' % (asMinutes(s), asMinutes(rs))

In [None]:
layer = 3
unit = 256
embed_unit = 200
dropout_rate = 0.5

model = DefaultRNNLM(n_vocab, layer, unit, embed_unit, dropout_rate)
model.to(device) # 589MB

optimizier = torch.optim.Adam(model.parameters(),
                                     lr=1e-3,
#                                      momentum=momentum,
                                     weight_decay=1e-5)

In [25]:
train_iter.iteration = 0
train_iter.epoch = 0

In [26]:
print_every = 100
plot_every  = 100
epoch=20

start = time.time()
n_iters = len(train)
plot_losses = []
print_loss_total = 0  # Reset every print_every
plot_loss_total = 0  # Reset every plot_every

while train_iter.epoch < epoch:
    e=train_iter.epoch
    i=train_iter.iteration
    
    data = train_iter.__next__()
    x = torch.tensor([i[0] for i in data]).to(device)
    t = torch.tensor([i[1] for i in data]).to(device)
    
    loss, _, _ = model(x, t)
    
    print_loss_total += float(loss)
    plot_loss_total += float(loss)
    optimizier.zero_grad()
    loss.backward()

    optimizier.step()
    
    if (i+1) % print_every == 0:
        print_loss_avg = print_loss_total / print_every
        print_loss_total = 0
        txt = 'Epoch %d | Iter %d | %s (%d %d%%) %.4f' % (e+1, i+1, timeSince(start, (e *n_iters +i+1) / (n_iters * epoch)),
                                             (i+1), (train_iter.epoch *n_iters +i+1) / (n_iters*epoch) * 100, print_loss_avg)
        print(txt)

    if (i+1) % plot_every == 0:
        plot_loss_avg = plot_loss_total / plot_every
        plot_losses.append(plot_loss_avg)
        plot_loss_total = 0

Epoch 1 | Iter 100 | 0m 5s (- 16868m 5s) (100 0%) 80.6223
Epoch 1 | Iter 200 | 0m 10s (- 16930m 52s) (200 0%) 75.6375
Epoch 1 | Iter 300 | 0m 15s (- 16237m 9s) (300 0%) 67.7000
Epoch 1 | Iter 400 | 0m 21s (- 16555m 15s) (400 0%) 77.3295
Epoch 1 | Iter 500 | 0m 26s (- 16595m 5s) (500 0%) 73.6129
Epoch 1 | Iter 600 | 0m 31s (- 16575m 53s) (600 0%) 74.7071
Epoch 1 | Iter 700 | 0m 37s (- 16600m 8s) (700 0%) 69.3573
Epoch 1 | Iter 800 | 0m 42s (- 16587m 21s) (800 0%) 68.2111
Epoch 1 | Iter 900 | 0m 48s (- 16726m 55s) (900 0%) 74.0991
Epoch 1 | Iter 1000 | 0m 54s (- 16887m 29s) (1000 0%) 78.6547
Epoch 1 | Iter 1100 | 0m 59s (- 16836m 8s) (1100 0%) 66.8761


KeyboardInterrupt: 

In [191]:
train[0]

array([3288, 4235, 1661,  952, 1661], dtype=int32)