In [1]:
%cd /home/jupyter-yuki090626/yamashita/ErnieASR/scripts/ernie_asr

/home/jupyter-yuki090626/yamashita/ErnieASR/scripts/ernie_asr


In [2]:
import os
import re
import sys
import time
import glob
import json
import pickle
import kaldiio
import numpy as np
import pandas as pd
import subprocess
from itertools import chain
from pathlib import Path
from pathlib import PurePath
from tqdm.auto import tqdm
from scipy.io import wavfile
from sklearn.preprocessing import LabelEncoder

import torch
from torch import Tensor
import torch.utils.data
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

### データ読み込み(CSJ/core)

In [3]:
# load trainset
data = []
with open('./data/train/csj_core_train.json', 'r') as f:
    data = [json.loads(d) for d in f.readlines()]

In [4]:
phones_list = ["", "BOS", "EOS", "PAD"]
for d in data:
    for phone in d["phones"]:
        if phone not in phones_list:
            phones_list.append(phone)

In [5]:
len(phones_list)

41

In [6]:
phones_list

['',
 'BOS',
 'EOS',
 'PAD',
 'e',
 ':',
 'q',
 't',
 'o',
 'ky',
 'r',
 'i',
 'ts',
 'u',
 'd',
 'a',
 'g',
 'k',
 'n',
 'm',
 'sh',
 's',
 'z',
 'w',
 'b',
 'N',
 'h',
 'py',
 'j',
 'ch',
 'ry',
 'y',
 'hy',
 'p',
 'f',
 'my',
 'ny',
 'by',
 'gy',
 'dy',
 'ty']

In [7]:
labelEncoder = LabelEncoder()
labelEncoder.fit(list(phones_list))

LabelEncoder()

### wavデータ取得

In [8]:
# !rm -rf data/train
!mkdir data/train

mkdir: cannot create directory ‘data/train’: File exists


In [9]:
core_data = glob.glob("data/csj-data/core/*")

In [10]:
sym2path = {}

for dir_path in core_data:
    if dir_path == 'data/csj-data/core/$MYVIMRC':
        continue
    wav_symbol = dir_path.split('/')[-1]
    with open(dir_path+"/"+wav_symbol+"-wav.list", "r") as f:
        wavlist = f.readlines()
    f.close()
    wav_path = wavlist[0].strip()
    sym2path[wav_symbol] = wav_path

In [11]:
wavscpf = open("data/train/wav.scp", "w")
utt2spkf = open("data/train/utt2spk", "w")

for d in data:
    start, _ , duration = d["utt2info"]
    wavscp = d['uttname'] + ' sox ' + sym2path[d['uttname'].split('-')[0]] + ' -t wav -r 16000 -c 1 - trim {} {}|'.format(start, round(duration,3))
    utt2spk = d['uttname'] + " " + d['uttname'].split('-')[0]
    
    wavscpf.write(wavscp + "\n")
    utt2spkf.write(utt2spk + "\n")

wavscpf.close()
utt2spkf.close()    

In [12]:
!tail data/train/wav.scp

A11M0846-000443 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 855.126 1.321|
A11M0846-000444 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 856.747 0.937|
A11M0846-000445 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 857.961 2.719|
A11M0846-000446 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 861.759 0.892|
A11M0846-000447 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 863.092 0.428|
A11M0846-000448 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 863.782 1.202|
A11M0846-000449 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 865.335 0.676|
A11M0846-000450 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 868.068 1.97|
A11M0846-000451 sox /disk107/DATA/CSJ_RAW/WAV/core/A11M0846.wav -t wav -r 16000 -c 1 - trim 870.57 3.323|
A11M0846-000452 sox /disk107/DATA/CSJ_R

### MFCC関連

In [15]:
!bash run_mfcc.sh data/train data/mfcc data/log conf/mfcc.conf 32

utils/fix_data_dir.sh: file data/train/utt2spk is not in sorted order or not unique, sorting it
utils/fix_data_dir.sh: file data/train/spk2utt is not in sorted order or not unique, sorting it
utils/fix_data_dir.sh: file data/train/wav.scp is not in sorted order or not unique, sorting it
fix_data_dir.sh: kept all 51675 utterances.
fix_data_dir.sh: old files are kept in data/train/.backup
steps/make_mfcc.sh --write-utt2num-frames true --mfcc-config conf/mfcc.conf --nj 32 --cmd utils/run.pl data/train data/log data/mfcc
utils/validate_data_dir.sh: Successfully validated data-directory data/train
steps/make_mfcc.sh: [info]: no segments file exists: assuming wav.scp indexed by utterance.
steps/make_mfcc.sh: Succeeded creating MFCC features for train
steps/compute_cmvn_stats.sh data/train data/log/cmvnlog data/mfcc
Succeeded creating CMVN stats for train
fix_data_dir.sh: kept all 51675 utterances.
fix_data_dir.sh: old files are kept in data/train/.backup


In [13]:
with open('data/train/feats.scp') as f:
    lines = f.readlines()
f.close()
utt2feats = {}
for i in lines:
    i = i.strip()
    utt, feat_path = i.split(' ')
    utt2feats[utt] = feat_path

### Dataset

In [14]:
class Dataset(torch.utils.data.Dataset):

    def __init__(self, data, utt2feats, phones_list, transform=None):
        self.transform = transform
        self.data = data
        self.data_num = len(data)
        self.utt2feats = utt2feats
        self.phones_list = phones_list
        self.phonel = len(self.phones_list)
        self.seq_length = 256
        
    def __len__(self):
        return self.data_num

    def __getitem__(self, idx):
        
        utt = self.data[idx]['uttname']
        phones_label = self.data[idx]["phones"]
        
        feats = torch.tensor(kaldiio.load_mat(utt2feats[utt]))
        label = torch.tensor(labelEncoder.transform(phones_label))
        
        return feats, label

In [15]:
dataset = Dataset(data, utt2feats, phones_list)

In [16]:
input1, test1 = dataset.__getitem__(0)

In [17]:
input1.shape

torch.Size([359, 30])

### model定義

In [18]:
class RNN(torch.nn.Module):
    """RNN module
    :param int idim: dimension of inputs
    :param int elayers: number of encoder layers
    :param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
    :param int hdim: number of final projection units
    :param float dropout: dropout rate
    :param str typ: The RNN type
    """

    def __init__(self, idim, elayers, cdim, hdim, dropout, typ="blstm"):
        super(RNN, self).__init__()
        bidir = typ[0] == "b"
        self.nbrnn = (
            torch.nn.LSTM(
                idim,
                cdim,
                elayers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidir,
            )
            if "lstm" in typ
            else torch.nn.GRU(
                idim,
                cdim,
                elayers,
                batch_first=True,
                dropout=dropout,
                bidirectional=bidir,
            )
        )
        if bidir:
            self.l_last = torch.nn.Linear(cdim * 2, hdim)
        else:
            self.l_last = torch.nn.Linear(cdim, hdim)
        self.typ = typ

        self.logsoftmax = torch.nn.LogSoftmax(dim=2)
        
    def forward(self, xs_pad, ilens, prev_state=None):
        """RNN forward
        :param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
        :param torch.Tensor ilens: batch of lengths of input sequences (B)
        :param torch.Tensor prev_state: batch of previous RNN states
        :return: batch of hidden state sequences (B, Tmax, eprojs)
        :rtype: torch.Tensor
        """
        xs_pack = torch.nn.utils.rnn.pack_padded_sequence(xs_pad, ilens, batch_first=True, enforce_sorted=False)
        
        self.nbrnn.flatten_parameters()
        
        if prev_state is not None and self.nbrnn.bidirectional:
            # We assume that when previous state is passed,
            # it means that we're streaming the input
            # and therefore cannot propagate backward BRNN state
            # (otherwise it goes in the wrong direction)
            prev_state = reset_backward_rnn_state(prev_state)
            
        ys, states = self.nbrnn(xs_pack, hx=prev_state)
        
        # ys: utt list of frame x cdim x 2 (2: means bidirectional)
        ys_pad, ilens = torch.nn.utils.rnn.pad_packed_sequence(ys, batch_first=True)
        
        ys_pad = self.l_last(ys_pad)

        xs_pad = self.logsoftmax(ys_pad)
        return xs_pad, ilens, states  # x: utt list of frame x dim


def reset_backward_rnn_state(states):
    """Sets backward BRNN states to zeroes
    Useful in processing of sliding windows over the inputs
    """
    if isinstance(states, (list, tuple)):
        for state in states:
            state[1::2] = 0.0
    else:
        states[1::2] = 0.0
    return states


In [19]:
net = RNN(30, 3, 512, len(phones_list), 0, typ="blstm")
net.double()

RNN(
  (nbrnn): LSTM(30, 512, num_layers=3, batch_first=True, bidirectional=True)
  (l_last): Linear(in_features=1024, out_features=41, bias=True)
  (logsoftmax): LogSoftmax()
)

In [20]:
ctcloss = torch.nn.CTCLoss(blank=0)

In [21]:
# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
optimizer = optim.Adam(net.parameters(), lr=0.001)

In [22]:
#並列で動作させるにはpadding
#https://qiita.com/iBotamon/items/acffef7852faadb420fd
#https://takoroy-ai.hatenadiary.jp/entry/2018/07/02/224216

In [23]:
BOS_ID = 1
EOS_ID = 2
PAD_ID = 3

In [24]:
def collate_fn(batch):
    fbanks = []
    tokens = []
    for feat, label in batch:
        fbanks.append(feat)
        tokens.append(label)
    ilens = torch.tensor([x.shape[0] for x in fbanks])
    olens = torch.tensor([x.shape[0] for x in tokens])
    return pad_sequence(fbanks, batch_first=True), ilens, pad_sequence(tokens, batch_first=True), olens

In [25]:
trainloader = torch.utils.data.DataLoader(
                       dataset,
                       batch_size=32,
                       shuffle=True,
                       collate_fn=collate_fn)

In [26]:
gpu_ids = [0]
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    net= net.to(device)
else:
    device = torch.device('cpu')

net = net.cuda()
#net = torch.nn.DataParallel(net, device_ids=gpu_ids).cuda()

In [27]:
# !mkdir model

In [28]:
exp_model = "model/exp_0717"

### 学習ループ

In [29]:
epocs = 50
epoc_start = 0

In [30]:
if epoc_start > 0:
    net.load_state_dict(torch.load(f'model/exp_0702_net_weight_{epoc_start}', map_location=device))

In [None]:
loss_ave = 0
prev_state = None

for j in tqdm(range(epoc_start, epocs)):
    for i, data in enumerate(trainloader, 0):
        a = time.time()
        # get the inputs; data is a list of [inputs, labels]
        inputs, inlen, label, oulen = data
        output, ilen_y, prev_state = net(inputs.double().cuda(), inlen, prev_state=None)
        net.zero_grad()
        # loss = ctcloss(output.cpu().transpose(0,1), torch.tensor(label), inlen, oulen)
        loss = ctcloss(output.cpu().transpose(0,1), torch.tensor(label), inlen, oulen) / len(output)
        loss_ave += loss
        # loss.backward(retain_graph=True)
        loss.backward()
        torch.nn.utils.clip_grad_norm(net.parameters(), 400)
        optimizer.step()
        b = time.time()
        print ("\r{}/{} loss: {} step took {}".format(i+1,len(trainloader),loss,b-a),end='')
        #print ("output:\n{} \n labels:\n{}".format(output,label))
    torch.save(net.state_dict(), exp_model + "_net_weight_" + str(j))
    print("\repoc : {}    average loss : {}".format(j, loss_ave/len(trainloader)))
    loss_ave = 0

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

  if sys.path[0] == '':
  app.launch_new_instance()


1162/1615 loss: 0.03723900425008437 step took 2.25383687019348141