diff --git a/README.md b/README.md index 06440c8..2bb006b 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,14 @@ # EEND (End-to-End Neural Diarization) EEND (End-to-End Neural Diarization) is a neural-network-based speaker diarization method. -- https://www.isca-speech.org/archive/Interspeech_2019/abstracts/2899.html -- https://arxiv.org/abs/1909.06247 (to appear at ASRU 2019) +- BLSTM EEND (INTERSPEECH 2019) + - https://www.isca-speech.org/archive/Interspeech_2019/abstracts/2899.html +- Self-attentive EEND (ASRU 2019) + - https://ieeexplore.ieee.org/abstract/document/9003959/ + +The EEND extension for various number of speakers is also provided in this repository. +- Self-attentive EEND with encoder-decoder based attractors + - https://arxiv.org/abs/2005.09921 ## Install tools ### Requirements @@ -48,6 +54,7 @@ cd egs/mini_librispeech/v1 ```bash ./run.sh ``` +- If you use encoder-decoder based attractors [3], modify `run.sh` to use `config/eda/{train,infer}.yaml` - See `RESULT.md` and compare with your result. ## CALLHOME two-speaker experiment @@ -57,28 +64,39 @@ If you use your local machine, use "run.pl". If you use Grid Engine, use "queue.pl" If you use SLURM, use "slurm.pl". For more information about cmd.sh see http://kaldi-asr.org/doc/queue.html. -- Modify `egs/callhome/v1/run_prepare_shared.sh` according to storage paths of your copora. +- Modify `egs/callhome/v1/run_prepare_shared.sh` according to storage paths of your corpora. ### Data preparation ```bash cd egs/callhome/v1 ./run_prepare_shared.sh +# If you want to conduct 1-4 speaker experiments, run below. +# You also have to set paths to your corpora properly. +./run_prepare_shared_eda.sh ``` -### Self-attention-based model (latest configuration) +### Self-attention-based model using 2-speaker mixtures ```bash ./run.sh ``` -### BLSTM-based model (old configuration) +### BLSTM-based model using 2-speaker mixtures ```bash local/run_blstm.sh ``` +### Self-attention-based model with EDA using 1-4-speaker mixtures +```bash +./run_eda.sh +``` ## References [1] Yusuke Fujita, Naoyuki Kanda, Shota Horiguchi, Kenji Nagamatsu, Shinji Watanabe, " End-to-End Neural Speaker Diarization with Permutation-free Objectives," Proc. Interspeech, pp. 4300-4304, 2019 [2] Yusuke Fujita, Naoyuki Kanda, Shota Horiguchi, Yawen Xue, Kenji Nagamatsu, Shinji Watanabe, " -End-to-End Neural Speaker Diarization with Self-attention," arXiv preprints arXiv:1909.06247, 2019 +End-to-End Neural Speaker Diarization with Self-attention," Proc. ASRU, pp. 296-303, 2019 + +[3] Shota Horiguchi, Yusuke Fujita, Shinji Watanabe, Yawen Xue, Kenji Nagamatsu, " +End-to-End Speaker Diarization for an Unknown Number of Speakers with Encoder-Decoder Based Attractors," Proc. INTERSPEECH, 2020 + ## Citation diff --git a/eend/bin/infer.py b/eend/bin/infer.py index 09bf9cc..4c90378 100755 --- a/eend/bin/infer.py +++ b/eend/bin/infer.py @@ -45,6 +45,17 @@ parser.add_argument('--transformer-encoder-n-heads', default=4, type=int) parser.add_argument('--transformer-encoder-n-layers', default=2, type=int) parser.add_argument('--save-attention-weight', default=0, type=int) + +attractor_args = parser.add_argument_group('attractor') +attractor_args.add_argument('--use-attractor', action='store_true', + help='Enable encoder-decoder attractor mode') +attractor_args.add_argument('--shuffle', action='store_true', + help='Shuffle the order in time-axis before input to the network') +attractor_args.add_argument('--attractor-loss-ratio', default=1.0, type=float, + help='weighting parameter') +attractor_args.add_argument('--attractor-encoder-dropout', default=0.1, type=float) +attractor_args.add_argument('--attractor-decoder-dropout', default=0.1, type=float) +attractor_args.add_argument('--attractor-threshold', default=0.5, type=float) args = parser.parse_args() system_info.print_system_info() diff --git a/eend/bin/train.py b/eend/bin/train.py index dcacb35..b7743a4 100755 --- a/eend/bin/train.py +++ b/eend/bin/train.py @@ -34,7 +34,7 @@ help='input transform') parser.add_argument('--lr', default=0.001, type=float) parser.add_argument('--optimizer', default='adam', type=str) -parser.add_argument('--num-speakers', default=2, type=int) +parser.add_argument('--num-speakers', type=int) parser.add_argument('--gradclip', default=-1, type=int, help='gradient clipping. if < 0, no clipping') parser.add_argument('--num-frames', default=2000, type=int, @@ -63,6 +63,16 @@ parser.add_argument('--transformer-encoder-dropout', default=0.1, type=float) parser.add_argument('--gradient-accumulation-steps', default=1, type=int) parser.add_argument('--seed', default=777, type=int) + +attractor_args = parser.add_argument_group('attractor') +attractor_args.add_argument('--use-attractor', action='store_true', + help='Enable encoder-decoder attractor mode') +attractor_args.add_argument('--shuffle', action='store_true', + help='Shuffle the order in time-axis before input to the network') +attractor_args.add_argument('--attractor-loss-ratio', default=1.0, type=float, + help='weighting parameter') +attractor_args.add_argument('--attractor-encoder-dropout', default=0.1, type=float) +attractor_args.add_argument('--attractor-decoder-dropout', default=0.1, type=float) args = parser.parse_args() system_info.print_system_info() diff --git a/eend/chainer_backend/diarization_dataset.py b/eend/chainer_backend/diarization_dataset.py index 5bee614..ba46bb8 100644 --- a/eend/chainer_backend/diarization_dataset.py +++ b/eend/chainer_backend/diarization_dataset.py @@ -26,6 +26,7 @@ def _gen_frame_indices( class KaldiDiarizationDataset(chainer.dataset.DatasetMixin): + def __init__( self, data_dir, @@ -40,7 +41,8 @@ def __init__( use_last_samples=False, label_delay=0, n_speakers=None, - ): + shuffle=False, + ): self.data_dir = data_dir self.dtype = dtype self.chunk_size = chunk_size @@ -64,9 +66,11 @@ def __init__( label_delay=self.label_delay, subsampling=self.subsampling): self.chunk_indices.append( - (rec, st * self.subsampling, ed * self.subsampling)) + (rec, st * self.subsampling, ed * self.subsampling)) print(len(self.chunk_indices), " chunks") + self.shuffle = shuffle + def __len__(self): return len(self.chunk_indices) @@ -83,4 +87,19 @@ def get_example(self, i): Y = feature.transform(Y, self.input_transform) Y_spliced = feature.splice(Y, self.context_size) Y_ss, T_ss = feature.subsample(Y_spliced, T, self.subsampling) + + # If the sample contains more than "self.n_speakers" speakers, + # extract top-(self.n_speakers) speakers + if self.n_speakers and T_ss.shape[1] > self.n_speakers: + selected_speakers = np.argsort(T_ss.sum(axis=0))[::-1][:self.n_speakers] + T_ss = T_ss[:, selected_speakers] + + # If self.shuffle is True, shuffle the order in time-axis + # This operation improves the performance of EEND-EDA + if self.shuffle: + order = np.arange(Y_ss.shape[0]) + np.random.shuffle(order) + Y_ss = Y_ss[order] + T_ss = T_ss[order] + return Y_ss, T_ss diff --git a/eend/chainer_backend/encoder_decoder_attractor.py b/eend/chainer_backend/encoder_decoder_attractor.py new file mode 100644 index 0000000..9eb4e92 --- /dev/null +++ b/eend/chainer_backend/encoder_decoder_attractor.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +# Copyright 2020 Hitachi, Ltd. (author: Shota Horiguchi) +# Licensed under the MIT license. + +from chainer import Chain, cuda +import chainer.functions as F +import chainer.links as L + + +class EncoderDecoderAttractor(Chain): + + def __init__(self, n_units, encoder_dropout=0.1, decoder_dropout=0.1): + super(EncoderDecoderAttractor, self).__init__() + with self.init_scope(): + self.encoder = L.NStepLSTM(1, n_units, n_units, encoder_dropout) + self.decoder = L.NStepLSTM(1, n_units, n_units, decoder_dropout) + self.counter = L.Linear(n_units, 1) + self.n_units = n_units + + def forward(self, xs, zeros): + hx, cx, _ = self.encoder(None, None, xs) + _, _, attractors = self.decoder(hx, cx, zeros) + return attractors + + def estimate(self, xs, max_n_speakers=15): + """ + Calculate attractors from embedding sequences + without prior knowledge of the number of speakers + + Args: + xs: List of (T,D)-shaped embeddings + max_n_speakers (int) + Returns: + attractors: List of (N,D)-shaped attractors + probs: List of attractor existence probabilities + """ + + xp = cuda.get_array_module(xs[0]) + zeros = [xp.zeros((max_n_speakers, self.n_units), dtype=xp.float32) for _ in xs] + attractors = self.forward(xs, zeros) + probs = [F.sigmoid(F.flatten(self.counter(att))) for att in attractors] + return attractors, probs + + def __call__(self, xs, n_speakers): + """ + Calculate attractors from embedding sequences with given number of speakers + + Args: + xs: List of (T,D)-shaped embeddings + n_speakers: List of number of speakers, or None if the number of speakers is unknown (ex. test phase) + Returns: + loss: Attractor existence loss + attractors: List of (N,D)-shaped attractors + """ + + xp = cuda.get_array_module(xs[0]) + zeros = [xp.zeros((n_spk + 1, self.n_units), dtype=xp.float32) for n_spk in n_speakers] + attractors = self.forward(xs, zeros) + labels = F.concat([xp.array([[1] * n_spk + [0]], xp.int32) for n_spk in n_speakers], axis=1) + logit = F.concat([F.reshape(self.counter(att), (-1, n_spk + 1)) for att, n_spk in zip(attractors, n_speakers)], axis=1) + loss = F.sigmoid_cross_entropy(logit, labels) + + # The final attractor does not correspond to a speaker so remove it + # attractors = [att[:-1] for att in attractors] + attractors = [att[slice(0, att.shape[0] - 1)] for att in attractors] + return loss, attractors diff --git a/eend/chainer_backend/infer.py b/eend/chainer_backend/infer.py index 8507444..0254e52 100755 --- a/eend/chainer_backend/infer.py +++ b/eend/chainer_backend/infer.py @@ -11,7 +11,7 @@ from chainer import serializers from scipy.ndimage import shift from eend.chainer_backend.models import BLSTMDiarization -from eend.chainer_backend.models import TransformerDiarization +from eend.chainer_backend.models import TransformerDiarization, TransformerEDADiarization from eend.chainer_backend.utils import use_single_gpu from eend import feature from eend import kaldi_data @@ -32,26 +32,39 @@ def infer(args): # Prepare model in_size = feature.get_input_dim( - args.frame_size, - args.context_size, - args.input_transform) + args.frame_size, + args.context_size, + args.input_transform) if args.model_type == "BLSTM": model = BLSTMDiarization( - in_size=in_size, - n_speakers=args.num_speakers, - hidden_size=args.hidden_size, - n_layers=args.num_lstm_layers, - embedding_layers=args.embedding_layers, - embedding_size=args.embedding_size) + in_size=in_size, + n_speakers=args.num_speakers, + hidden_size=args.hidden_size, + n_layers=args.num_lstm_layers, + embedding_layers=args.embedding_layers, + embedding_size=args.embedding_size + ) elif args.model_type == 'Transformer': - model = TransformerDiarization( + if args.use_attractor: + model = TransformerEDADiarization( + in_size, + n_units=args.hidden_size, + n_heads=args.transformer_encoder_n_heads, + n_layers=args.transformer_encoder_n_layers, + dropout=0, + attractor_encoder_dropout=args.attractor_encoder_dropout, + attractor_decoder_dropout=args.attractor_decoder_dropout, + ) + else: + model = TransformerDiarization( args.num_speakers, in_size, n_units=args.hidden_size, n_heads=args.transformer_encoder_n_heads, n_layers=args.transformer_encoder_n_layers, - dropout=0) + dropout=0 + ) else: raise ValueError('Unknown model type.') @@ -75,7 +88,12 @@ def infer(args): Y_chunked = Variable(Y[start:end]) if args.gpu >= 0: Y_chunked.to_gpu(gpuid) - hs, ys = model.estimate_sequential(hs, [Y_chunked]) + hs, ys = model.estimate_sequential( + hs, [Y_chunked], + n_speakers=args.num_speakers, + th=args.attractor_threshold, + shuffle=args.shuffle + ) if args.gpu >= 0: ys[0].to_cpu() out_chunks.append(ys[0].data) @@ -88,6 +106,8 @@ def infer(args): if hasattr(model, 'label_delay'): outdata = shift(np.vstack(out_chunks), (-model.label_delay, 0)) else: + max_n_speakers = max([o.shape[1] for o in out_chunks]) + out_chunks = [np.insert(o, o.shape[1], np.zeros((max_n_speakers - o.shape[1], o.shape[0])), axis=1) for o in out_chunks] outdata = np.vstack(out_chunks) with h5py.File(outpath, 'w') as wf: wf.create_dataset('T_hat', data=outdata) diff --git a/eend/chainer_backend/models.py b/eend/chainer_backend/models.py index 013a15c..ed95756 100644 --- a/eend/chainer_backend/models.py +++ b/eend/chainer_backend/models.py @@ -8,7 +8,9 @@ from itertools import permutations from chainer import cuda from chainer import reporter +from chainer import configuration from eend.chainer_backend.transformer import TransformerEncoder +from eend.chainer_backend.encoder_decoder_attractor import EncoderDecoderAttractor """ T: number of frames @@ -69,6 +71,196 @@ def batch_pit_loss(ys, ts, label_delay=0): return loss, labels +def batch_pit_loss_faster(ys, ts, label_delay=0): + """ + PIT loss over mini-batch. + Args: + ys: B-length list of predictions + ts: B-length list of labels + Returns: + loss: (1,)-shape mean cross entropy over mini-batch + labels: B-length list of permuted labels + """ + + n_speakers = ts[0].shape[1] + xp = chainer.backend.get_array_module(ys[0]) + # (B, T, C) + ys = F.pad_sequence(ys, padding=-1) + + losses = [] + for shift in range(n_speakers): + # rolled along with speaker-axis + ts_roll = [xp.roll(t, -shift, axis=1) for t in ts] + ts_roll = F.pad_sequence(ts_roll, padding=-1) + # loss: (B, T, C) + loss = F.sigmoid_cross_entropy(ys, ts_roll, reduce='no') + # sum over time: (B, C) + loss = F.sum(loss, axis=1) + losses.append(loss) + # losses: (B, C, C) + losses = F.stack(losses, axis=2) + # losses[b, i, j] is a loss between + # `i`-th speaker in y and `(i+j)%C`-th speaker in t + + perms = xp.array( + list(permutations(range(n_speakers))), + dtype='i', + ) + # y_inds: [0,1,2,3] + y_ind = xp.arange(n_speakers, dtype='i') + # perms -> relation to t_inds -> t_inds + # 0,1,2,3 -> 0+j=0,1+j=1,2+j=2,3+j=3 -> 0,0,0,0 + # 0,1,3,2 -> 0+j=0,1+j=1,2+j=3,3+j=2 -> 0,0,1,3 + t_inds = xp.mod(perms - y_ind, n_speakers) + + losses_perm = [] + for t_ind in t_inds: + losses_perm.append( + F.mean(losses[:, y_ind, t_ind], axis=1)) + # losses_perm: (B, Perm) + losses_perm = F.stack(losses_perm, axis=1) + + min_loss = F.sum(F.min(losses_perm, axis=1)) + + min_loss = F.sum(F.min(losses_perm, axis=1)) + n_frames = np.sum([t.shape[0] for t in ts]) + min_loss = min_loss / n_frames + + min_indices = xp.argmin(losses_perm.array, axis=1) + labels_perm = [t[:, perms[idx]] for t, idx in zip(ts, min_indices)] + + return min_loss, labels_perm + + +def standard_loss(ys, ts, label_delay=0): + losses = [F.sigmoid_cross_entropy(y, t) * len(y) for y, t in zip(ys, ts)] + loss = F.sum(F.stack(losses)) + n_frames = np.sum([t.shape[0] for t in ts]) + loss = loss / n_frames + return loss + + +def batch_pit_n_speaker_loss(ys, ts, n_speakers_list): + """ + PIT loss over mini-batch. + Args: + ys: B-length list of predictions (pre-activations) + ts: B-length list of labels + n_speakers_list: list of n_speakers in batch + Returns: + loss: (1,)-shape mean cross entropy over mini-batch + labels: B-length list of permuted labels + """ + max_n_speakers = ts[0].shape[1] + xp = chainer.backend.get_array_module(ys[0]) + # (B, T, C) + ys = F.pad_sequence(ys, padding=-1) + + losses = [] + for shift in range(max_n_speakers): + # rolled along with speaker-axis + ts_roll = [xp.roll(t, -shift, axis=1) for t in ts] + ts_roll = F.pad_sequence(ts_roll, padding=-1) + # loss: (B, T, C) + loss = F.sigmoid_cross_entropy(ys, ts_roll, reduce='no') + # sum over time: (B, C) + loss = F.sum(loss, axis=1) + losses.append(loss) + # losses: (B, C, C) + losses = F.stack(losses, axis=2) + # losses[b, i, j] is a loss between + # `i`-th speaker in y and `(i+j)%C`-th speaker in t + + perms = xp.array( + list(permutations(range(max_n_speakers))), + dtype='i', + ) + # y_ind: [0,1,2,3] + y_ind = xp.arange(max_n_speakers, dtype='i') + # perms -> relation to t_inds -> t_inds + # 0,1,2,3 -> 0+j=0,1+j=1,2+j=2,3+j=3 -> 0,0,0,0 + # 0,1,3,2 -> 0+j=0,1+j=1,2+j=3,3+j=2 -> 0,0,1,3 + t_inds = xp.mod(perms - y_ind, max_n_speakers) + + losses_perm = [] + for t_ind in t_inds: + losses_perm.append( + F.mean(losses[:, y_ind, t_ind], axis=1)) + # losses_perm: (B, Perm) + losses_perm = F.stack(losses_perm, axis=1) + + # masks: (B, Perms) + def select_perm_indices(num, max_num): + perms = list(permutations(range(max_num))) + sub_perms = list(permutations(range(num))) + return [ + [x[:num] for x in perms].index(perm) + for perm in sub_perms] + masks = xp.full_like(losses_perm.array, xp.inf) + for i, t in enumerate(ts): + n_speakers = n_speakers_list[i] + indices = select_perm_indices(n_speakers, max_n_speakers) + masks[i, indices] = 0 + losses_perm += masks + + min_loss = F.sum(F.min(losses_perm, axis=1)) + n_frames = np.sum([t.shape[0] for t in ts]) + min_loss = min_loss / n_frames + + min_indices = xp.argmin(losses_perm.array, axis=1) + labels_perm = [t[:, perms[idx]] for t, idx in zip(ts, min_indices)] + labels_perm = [t[:, :n_speakers] for t, n_speakers in zip(labels_perm, n_speakers_list)] + + return min_loss, labels_perm + + +def add_silence_labels(ts): + xp = cuda.get_array_module(ts[0]) + # pad label's speaker-dim to be model's n_speakers + for i, t in enumerate(ts): + ts[i] = xp.pad( + t, + [(0, 0), (0, 1)], + mode='constant', + constant_values=0., + ) + return ts + + +def pad_labels(ts, out_size): + xp = cuda.get_array_module(ts[0]) + # pad label's speaker-dim to be model's n_speakers + for i, t in enumerate(ts): + if t.shape[1] < out_size: + # padding + ts[i] = xp.pad( + t, + [(0, 0), (0, out_size - t.shape[1])], + mode='constant', + constant_values=0., + ) + elif t.shape[1] > out_size: + # truncate + raise ValueError + return ts + + +def pad_results(ys, out_size): + xp = cuda.get_array_module(ys[0]) + # pad label's speaker-dim to be model's n_speakers + ys_padded = [] + for i, y in enumerate(ys): + if y.shape[1] < out_size: + # padding + ys_padded.append(F.concat([y, chainer.Variable(xp.zeros((y.shape[0], out_size - y.shape[1]), dtype=y.dtype))], axis=1)) + elif y.shape[1] > out_size: + # truncate + raise ValueError + else: + ys_padded.append(y) + return ys_padded + + def calc_diarization_error(pred, label, label_delay=0): """ Calculates diarization error stats for reporting. @@ -148,6 +340,7 @@ def dc_loss(embedding, label): class TransformerDiarization(chainer.Chain): + def __init__(self, n_speakers, in_size, @@ -188,13 +381,15 @@ def forward(self, xs, activation=None): ys = [F.get_item(y, slice(0, ilen)) for y, ilen in zip(ys, ilens)] return ys - def estimate_sequential(self, hx, xs): + def estimate_sequential(self, hx, xs, **kwargs): ys = self.forward(xs, activation=F.sigmoid) return None, ys def __call__(self, xs, ts): ys = self.forward(xs) - loss, labels = batch_pit_loss(ys, ts) + # loss, labels = batch_pit_loss_faster(ys, ts) + n_speakers = [t.shape[1] for t in ts] + loss, labels = batch_pit_n_speaker_loss(ys, ts, n_speakers) reporter.report({'loss': loss}, self) report_diarization_error(ys, labels, self) return loss @@ -211,7 +406,115 @@ def save_attention_weight(self, ofile, batch_index=0): np.save(ofile, np.array(att_weights)) +class TransformerEDADiarization(chainer.Chain): + + def __init__(self, in_size, n_units, n_heads, n_layers, dropout, + attractor_loss_ratio=1.0, + attractor_encoder_dropout=0.1, + attractor_decoder_dropout=0.1): + """ Self-attention-based diarization model. + + Args: + in_size (int): Dimension of input feature vector + n_units (int): Number of units in a self-attention block + n_heads (int): Number of attention heads + n_layers (int): Number of transformer-encoder layers + dropout (float): dropout ratio + attractor_loss_ratio (float) + attractor_encoder_dropout (float) + attractor_decoder_dropout (float) + """ + super(TransformerEDADiarization, self).__init__() + with self.init_scope(): + self.enc = TransformerEncoder( + in_size, n_layers, n_units, h=n_heads + ) + self.eda = EncoderDecoderAttractor( + n_units, + encoder_dropout=attractor_encoder_dropout, + decoder_dropout=attractor_decoder_dropout, + ) + self.attractor_loss_ratio = attractor_loss_ratio + + def forward(self, xs, n_speakers=None, activation=None): + ilens = [x.shape[0] for x in xs] + # xs: (B, T, F) + xs = F.pad_sequence(xs, padding=-1) + pad_shape = xs.shape + # emb: (B*T, E) + emb = self.enc(xs) + ys = emb + # emb: [(T, E), ...] + emb = F.separate(emb.reshape(pad_shape[0], pad_shape[1], -1), axis=0) + emb = [F.get_item(e, slice(0, ilen)) for e, ilen in zip(emb, ilens)] + + return emb + + def estimate_sequential(self, hx, xs, **kwargs): + emb = self.forward(xs) + ys_active = [] + n_speakers = kwargs.get('n_speakers') + th = kwargs.get('th') + shuffle = kwargs.get('shuffle') + if shuffle: + xp = cuda.get_array_module(emb[0]) + orders = [xp.arange(e.shape[0]) for e in emb] + for order in orders: + xp.random.shuffle(order) + attractors, probs = self.eda.estimate([e[order] for e, order in zip(emb, orders)]) + else: + attractors, probs = self.eda.estimate(emb) + ys = [F.matmul(e, att, transb=True) for e, att in zip(emb, attractors)] + ys = [F.sigmoid(y) for y in ys] + for p, y in zip(probs, ys): + if n_speakers is not None: + ys_active.append(y[:, :n_speakers]) + elif th is not None: + silence = np.where(cuda.to_cpu(p.data) < th)[0] + n_spk = silence[0] if silence.size else None + ys_active.append(y[:, :n_spk]) + else: + NotImplementedError('n_speakers or th has to be given.') + return None, ys_active + + def __call__(self, xs, ts): + n_speakers = [t.shape[1] for t in ts] + emb = self.forward(xs, n_speakers) + attractor_loss, attractors = self.eda(emb, n_speakers) + # ys: [(T, C), ...] + ys = [F.matmul(e, att, transb=True) for e, att in zip(emb, attractors)] + + max_n_speakers = max(n_speakers) + ts_padded = pad_labels(ts, max_n_speakers) + ys_padded = pad_results(ys, max_n_speakers) + + if configuration.config.train: + # with chainer.using_config('enable_backprop', False): + loss, labels = batch_pit_n_speaker_loss(ys_padded, ts_padded, n_speakers) + loss = standard_loss(ys, labels) + else: + loss, labels = batch_pit_n_speaker_loss(ys_padded, ts_padded, n_speakers) + loss = standard_loss(ys, labels) + + reporter.report({'loss': loss}, self) + reporter.report({'attractor_loss': attractor_loss}, self) + report_diarization_error(ys, labels, self) + return loss + attractor_loss * self.attractor_loss_ratio + + def save_attention_weight(self, ofile, batch_index=0): + att_weights = [] + for l in range(self.enc.n_layers): + att_layer = getattr(self.enc, f'self_att_{l}') + # att.shape is (B, h, T, T); pick the first sample in batch + att_w = att_layer.att[batch_index, ...] + att_w.to_cpu() + att_weights.append(att_w.data) + # save as (n_layers, h, T, T)-shaped arryay + np.save(ofile, np.array(att_weights)) + + class BLSTMDiarization(chainer.Chain): + def __init__(self, n_speakers=4, dropout=0.25, @@ -270,7 +573,7 @@ def forward(self, xs, hs=None, activation=None): ems = [ems] return [hy1, cy1, hy_emb, cy_emb], ys, ems - def estimate_sequential(self, hx, xs): + def estimate_sequential(self, hx, xs, **kwargs): hy, ys, ems = self.forward(xs, hx, activation=F.sigmoid) return hy, ys diff --git a/eend/chainer_backend/train.py b/eend/chainer_backend/train.py index 253fa41..2007a4a 100755 --- a/eend/chainer_backend/train.py +++ b/eend/chainer_backend/train.py @@ -12,7 +12,7 @@ from chainer import training from chainer.training import extensions from eend.chainer_backend.models import BLSTMDiarization -from eend.chainer_backend.models import TransformerDiarization +from eend.chainer_backend.models import TransformerDiarization, TransformerEDADiarization from eend.chainer_backend.transformer import NoamScheduler from eend.chainer_backend.updater import GradientAccumulationUpdater from eend.chainer_backend.diarization_dataset import KaldiDiarizationDataset @@ -24,13 +24,7 @@ def _convert(batch, device): def to_device_batch(batch): if device is None: return batch - src_xp = chainer.backend.get_array_module(*batch) - xp = device.xp - concat = src_xp.concatenate(batch, axis=0) - sections = list(np.cumsum( - [len(x) for x in batch[:-1]], dtype=np.int32)) - concat_dst = device.send(concat) - batch_dst = xp.split(concat_dst, sections) + batch_dst = [device.send(x) for x in batch] return batch_dst return {'xs': to_device_batch([x for x, _ in batch]), 'ts': to_device_batch([t for _, t in batch])} @@ -57,7 +51,8 @@ def train(args): use_last_samples=True, label_delay=args.label_delay, n_speakers=args.num_speakers, - ) + shuffle=args.shuffle, + ) dev_set = KaldiDiarizationDataset( args.valid_data_dir, chunk_size=args.num_frames, @@ -70,29 +65,45 @@ def train(args): use_last_samples=True, label_delay=args.label_delay, n_speakers=args.num_speakers, - ) + shuffle=args.shuffle, + ) # Prepare model Y, T = train_set.get_example(0) if args.model_type == 'BLSTM': + assert args.num_speakers is not None model = BLSTMDiarization( - in_size=Y.shape[1], - n_speakers=args.num_speakers, - hidden_size=args.hidden_size, - n_layers=args.num_lstm_layers, - embedding_layers=args.embedding_layers, - embedding_size=args.embedding_size, - dc_loss_ratio=args.dc_loss_ratio, - ) + in_size=Y.shape[1], + n_speakers=args.num_speakers, + hidden_size=args.hidden_size, + n_layers=args.num_lstm_layers, + embedding_layers=args.embedding_layers, + embedding_size=args.embedding_size, + dc_loss_ratio=args.dc_loss_ratio, + ) elif args.model_type == 'Transformer': - model = TransformerDiarization( + if args.use_attractor: + model = TransformerEDADiarization( + Y.shape[1], + n_units=args.hidden_size, + n_heads=args.transformer_encoder_n_heads, + n_layers=args.transformer_encoder_n_layers, + dropout=args.transformer_encoder_dropout, + attractor_loss_ratio=args.attractor_loss_ratio, + attractor_encoder_dropout=args.attractor_encoder_dropout, + attractor_decoder_dropout=args.attractor_decoder_dropout, + ) + else: + assert args.num_speakers is not None + model = TransformerDiarization( args.num_speakers, Y.shape[1], n_units=args.hidden_size, n_heads=args.transformer_encoder_n_heads, n_layers=args.transformer_encoder_n_layers, - dropout=args.transformer_encoder_dropout) + dropout=args.transformer_encoder_dropout + ) else: raise ValueError('Possible model_type are "Transformer" and "BLSTM"') @@ -125,20 +136,20 @@ def train(args): serializers.load_npz(args.initmodel, model) train_iter = iterators.MultiprocessIterator( - train_set, - batch_size=args.batchsize, - repeat=True, shuffle=True, - # shared_mem=64000000, - shared_mem=None, - n_processes=4, n_prefetch=2) + train_set, + batch_size=args.batchsize, + repeat=True, shuffle=True, + # shared_mem=64000000, + shared_mem=None, + n_processes=4, n_prefetch=2) dev_iter = iterators.MultiprocessIterator( - dev_set, - batch_size=args.batchsize, - repeat=False, shuffle=False, - # shared_mem=64000000, - shared_mem=None, - n_processes=4, n_prefetch=2) + dev_set, + batch_size=args.batchsize, + repeat=False, shuffle=False, + # shared_mem=64000000, + shared_mem=None, + n_processes=4, n_prefetch=2) if args.gradient_accumulation_steps > 1: updater = GradientAccumulationUpdater( @@ -148,12 +159,12 @@ def train(args): train_iter, optimizer, converter=_convert, device=gpuid) trainer = training.Trainer( - updater, - (args.max_epochs, 'epoch'), - out=os.path.join(args.model_save_dir)) + updater, + (args.max_epochs, 'epoch'), + out=os.path.join(args.model_save_dir)) evaluator = extensions.Evaluator( - dev_iter, model, converter=_convert, device=gpuid) + dev_iter, model, converter=_convert, device=gpuid) trainer.extend(evaluator) if args.optimizer == 'noam': @@ -168,13 +179,13 @@ def train(args): # MICRO AVERAGE metrics = [ - ('diarization_error', 'speaker_scored', 'DER'), - ('speech_miss', 'speech_scored', 'SAD_MR'), - ('speech_falarm', 'speech_scored', 'SAD_FR'), - ('speaker_miss', 'speaker_scored', 'MI'), - ('speaker_falarm', 'speaker_scored', 'FA'), - ('speaker_error', 'speaker_scored', 'CF'), - ('correct', 'frames', 'accuracy')] + ('diarization_error', 'speaker_scored', 'DER'), + ('speech_miss', 'speech_scored', 'SAD_MR'), + ('speech_falarm', 'speech_scored', 'SAD_FR'), + ('speaker_miss', 'speaker_scored', 'MI'), + ('speaker_falarm', 'speaker_scored', 'FA'), + ('speaker_error', 'speaker_scored', 'CF'), + ('correct', 'frames', 'accuracy')] for num, den, name in metrics: trainer.extend(extensions.MicroAverage( 'main/{}'.format(num), diff --git a/eend/feature.py b/eend/feature.py index 2863a91..a5d9714 100644 --- a/eend/feature.py +++ b/eend/feature.py @@ -5,17 +5,17 @@ import numpy as np import librosa -import scipy.signal + def get_input_dim( frame_size, context_size, transform_type, - ): +): if transform_type.startswith('logmel23'): frame_size = 23 else: - fft_size = 1 << (frame_size-1).bit_length() + fft_size = 1 << (frame_size - 1).bit_length() frame_size = int(fft_size / 2) + 1 input_dim = (2 * context_size + 1) * frame_size return input_dim @@ -72,15 +72,15 @@ def transform( mel_basis = librosa.filters.mel(sr, n_fft, n_mels) Y = np.dot(Y ** 2, mel_basis.T) Y = np.log10(np.maximum(Y, 1e-10)) - #b = np.ones(300)/300 - #mean = scipy.signal.convolve2d(Y, b[:, None], mode='same') - # - # simple 2-means based threshoding for mean calculation + # b = np.ones(300)/300 + # mean = scipy.signal.convolve2d(Y, b[:, None], mode='same') + + # simple 2-means based threshoding for mean calculation powers = np.sum(Y, axis=1) - th = (np.max(powers) + np.min(powers))/2.0 + th = (np.max(powers) + np.min(powers)) / 2.0 for i in range(10): th = (np.mean(powers[powers >= th]) + np.mean(powers[powers < th])) / 2 - mean = np.mean(Y[powers > th,:], axis=0) + mean = np.mean(Y[powers > th, :], axis=0) Y = Y - mean elif transform_type == 'logmel23_mvn': n_fft = 2 * (Y.shape[1] - 1) @@ -121,13 +121,13 @@ def splice(Y, context_size=0): (n_frames, n_featdim * (2 * context_size + 1))-shaped """ Y_pad = np.pad( - Y, - [(context_size, context_size), (0, 0)], - 'constant') + Y, + [(context_size, context_size), (0, 0)], + 'constant') Y_spliced = np.lib.stride_tricks.as_strided( - Y_pad, - (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), - (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) + np.ascontiguousarray(Y_pad), + (Y.shape[0], Y.shape[1] * (2 * context_size + 1)), + (Y.itemsize * Y.shape[1], Y.itemsize), writeable=False) return Y_spliced @@ -148,7 +148,7 @@ def stft( (n_frames, n_bins)-shaped np.complex64 array """ # round up to nearest power of 2 - fft_size = 1 << (frame_size-1).bit_length() + fft_size = 1 << (frame_size - 1).bit_length() # HACK: The last frame is ommited # as librosa.stft produces such an excessive frame if len(data) % frame_shift == 0: @@ -192,13 +192,12 @@ def get_frame_labels( """ filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] speakers = np.unique( - [kaldi_obj.utt2spk[seg['utt']] for seg - in filtered_segments]).tolist() + [kaldi_obj.utt2spk[seg['utt']] for seg + in filtered_segments]).tolist() if n_speakers is None: n_speakers = len(speakers) es = end * frame_shift if end is not None else None - data, rate = kaldi_obj.load_wav( - rec, start * frame_shift, es) + data, rate = kaldi_obj.load_wav(rec, start * frame_shift, es) n_frames = _count_frames(len(data), frame_size, frame_shift) T = np.zeros((n_frames, n_speakers), dtype=np.int32) if end is None: @@ -207,9 +206,9 @@ def get_frame_labels( for seg in filtered_segments: speaker_index = speakers.index(kaldi_obj.utt2spk[seg['utt']]) start_frame = np.rint( - seg['st'] * rate / frame_shift).astype(int) + seg['st'] * rate / frame_shift).astype(int) end_frame = np.rint( - seg['et'] * rate / frame_shift).astype(int) + seg['et'] * rate / frame_shift).astype(int) rel_start = rel_end = None if start <= start_frame and start_frame < end: rel_start = start_frame - start @@ -246,13 +245,13 @@ def get_labeledSTFT( (n_frmaes, n_speakers)-shaped np.int32 array. """ data, rate = kaldi_obj.load_wav( - rec, start * frame_shift, end * frame_shift) + rec, start * frame_shift, end * frame_shift) Y = stft(data, frame_size, frame_shift) filtered_segments = kaldi_obj.segments[rec] # filtered_segments = kaldi_obj.segments[kaldi_obj.segments['rec'] == rec] speakers = np.unique( - [kaldi_obj.utt2spk[seg['utt']] for seg - in filtered_segments]).tolist() + [kaldi_obj.utt2spk[seg['utt']] for seg + in filtered_segments]).tolist() if n_speakers is None: n_speakers = len(speakers) T = np.zeros((Y.shape[0], n_speakers), dtype=np.int32) @@ -266,9 +265,9 @@ def get_labeledSTFT( if use_speaker_id: all_speaker_index = all_speakers.index(kaldi_obj.utt2spk[seg['utt']]) start_frame = np.rint( - seg['st'] * rate / frame_shift).astype(int) + seg['st'] * rate / frame_shift).astype(int) end_frame = np.rint( - seg['et'] * rate / frame_shift).astype(int) + seg['et'] * rate / frame_shift).astype(int) rel_start = rel_end = None if start <= start_frame and start_frame < end: rel_start = start_frame - start diff --git a/egs/callhome/v1/conf/eda/adapt.yaml b/egs/callhome/v1/conf/eda/adapt.yaml new file mode 100644 index 0000000..e31975a --- /dev/null +++ b/egs/callhome/v1/conf/eda/adapt.yaml @@ -0,0 +1,27 @@ +# adapt options +sampling_rate: 8000 +frame_size: 200 +frame_shift: 80 +model_type: Transformer +max_epochs: 100 +gradclip: 5 +batchsize: 64 +hidden_size: 256 +num_frames: 500 +num_speakers: 2 +input_transform: logmel23_mn +optimizer: adam +lr: 1e-5 +context_size: 7 +subsampling: 10 +gradient_accumulation_steps: 1 +transformer_encoder_n_heads: 4 +transformer_encoder_n_layers: 4 +transformer_encoder_dropout: 0.1 +use_attractor: True +shuffle: True +attractor_loss_ratio: 0.01 +attractor_encoder_dropout: 0.1 +attractor_decoder_dropout: 0.1 +seed: 777 +gpu: 0 diff --git a/egs/callhome/v1/conf/eda/infer.yaml b/egs/callhome/v1/conf/eda/infer.yaml new file mode 100644 index 0000000..d4753cf --- /dev/null +++ b/egs/callhome/v1/conf/eda/infer.yaml @@ -0,0 +1,17 @@ +# inference options +sampling_rate: 8000 +frame_size: 200 +frame_shift: 80 +model_type: Transformer +hidden_size: 256 +input_transform: logmel23_mn +context_size: 7 +subsampling: 10 +chunk_size: 200000000 +transformer_encoder_n_heads: 4 +transformer_encoder_n_layers: 4 +use_attractor: True +shuffle: True +attractor_encoder_dropout: 0.1 +attractor_decoder_dropout: 0.1 +gpu: 0 diff --git a/egs/callhome/v1/conf/eda/train.yaml b/egs/callhome/v1/conf/eda/train.yaml new file mode 100644 index 0000000..1966e2e --- /dev/null +++ b/egs/callhome/v1/conf/eda/train.yaml @@ -0,0 +1,27 @@ +# training options +sampling_rate: 8000 +frame_size: 200 +frame_shift: 80 +model_type: Transformer +max_epochs: 25 +gradclip: 5 +batchsize: 64 +hidden_size: 256 +num_frames: 500 +input_transform: logmel23_mn +optimizer: noam +context_size: 7 +subsampling: 10 +noam_scale: 1.0 +gradient_accumulation_steps: 1 +transformer_encoder_n_heads: 4 +transformer_encoder_n_layers: 4 +transformer_encoder_dropout: 0.1 +noam_warmup_steps: 100000 +use_attractor: True +shuffle: True +attractor_loss_ratio: 1.0 +attractor_encoder_dropout: 0.1 +attractor_decoder_dropout: 0.1 +seed: 777 +gpu: 0 diff --git a/egs/callhome/v1/conf/eda/train_2spk.yaml b/egs/callhome/v1/conf/eda/train_2spk.yaml new file mode 100644 index 0000000..b494b17 --- /dev/null +++ b/egs/callhome/v1/conf/eda/train_2spk.yaml @@ -0,0 +1,28 @@ +# training options +sampling_rate: 8000 +frame_size: 200 +frame_shift: 80 +model_type: Transformer +max_epochs: 100 +gradclip: 5 +batchsize: 64 +hidden_size: 256 +num_frames: 500 +num_speakers: 2 +input_transform: logmel23_mn +optimizer: noam +context_size: 7 +subsampling: 10 +noam_scale: 1.0 +gradient_accumulation_steps: 1 +transformer_encoder_n_heads: 4 +transformer_encoder_n_layers: 4 +transformer_encoder_dropout: 0.1 +noam_warmup_steps: 100000 +use_attractor: True +shuffle: True +attractor_loss_ratio: 1.0 +attractor_encoder_dropout: 0.1 +attractor_decoder_dropout: 0.1 +seed: 777 +gpu: 0 diff --git a/egs/callhome/v1/run_eda.sh b/egs/callhome/v1/run_eda.sh new file mode 100755 index 0000000..aba91f5 --- /dev/null +++ b/egs/callhome/v1/run_eda.sh @@ -0,0 +1,249 @@ +#!/bin/bash + +# Copyright 2019-2020 Hitachi, Ltd. (author: Yusuke Fujita, Shota Horiguchi) +# Licensed under the MIT license. +# +stage=0 + +# The datasets for training must be formatted as kaldi data directory. +# Also, make sure the audio files in wav.scp are 'regular' wav files. +# Including piped commands in wav.scp makes training very slow +train_2spk_set=data/simu/data/swb_sre_tr_ns2_beta2_100000 +valid_2spk_set=data/simu/data/swb_sre_cv_ns2_beta2_500 +train_set=data/simu/data/swb_sre_tr_ns1n2n3n4_beta2n2n5n9_100000 +valid_set=data/simu/data/swb_sre_cv_ns1n2n3n4_beta2n2n5n9_500 +adapt_set=data/eval/callhome1_spkall +adapt_valid_set=data/eval/callhome2_spkall + +# Base config files for {train,infer}.py +train_2spk_config=conf/eda/train_2spk.yaml +train_config=conf/eda/train.yaml +infer_config=conf/eda/infer.yaml +adapt_config=conf/eda/adapt.yaml + +# Additional arguments passed to {train,infer}.py. +# You need not edit the base config files above +train_2spk_args= +train_args= +infer_args= +adapt_args= + +# 2-speaker model averaging options +average_2spk_start=91 +average_2spk_end=100 + +# Model averaging options +average_start=16 +average_end=25 + +# Adapted model averaging options +adapt_average_start=91 +adapt_average_end=100 + +# Resume training from snapshot at this epoch +# TODO: not tested +resume=-1 + +# Debug purpose +debug= + +. path.sh +. cmd.sh +. parse_options.sh || exit + +set -eu + +if [ "$debug" != "" ]; then + # debug mode + train_set=data/simu/data/swb_sre_tr_ns2_beta2_1000 + train_config=conf/debug/train.yaml + average_start=3 + average_end=5 + adapt_config=conf/debug/adapt.yaml + adapt_average_start=6 + adapt_average_end=10 +fi + +# Parse the config file to set bash variables like: $train_frame_shift, $infer_gpu +eval `yaml2bash.py --prefix train $train_config` +eval `yaml2bash.py --prefix infer $infer_config` + +# Append gpu reservation flag to the queuing command +if [ $train_gpu -le 0 ]; then + train_cmd+=" --gpu 1" +fi +if [ $infer_gpu -le 0 ]; then + infer_cmd+=" --gpu 1" +fi + +# Build directry names for an experiment +# - Training (2 speakers) +# exp/diarize/model/{train_2spk_id}.{valid_2spk_id}.{train_2spk_config_id} +# - Training (1-4 speakers, finetune from the 2-speaker model) +# exp/diarize/model/{train_id}.{valid_id}.{train_config_id} +# - Adapation from non-adapted averaged model +# exp/diarize/model/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id} +# - Decoding +# exp/diarize/infer/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id}.{infer_config_id} +# - Scoring +# exp/diarize/scoring/{train_id}.{valid_id}.{train_config_id}.{avgid}.{adapt_config_id}.{infer_config_id} +train_2spk_id=$(basename $train_2spk_set) +valid_2spk_id=$(basename $valid_2spk_set) +train_id=$(basename $train_set) +valid_id=$(basename $valid_set) +train_2spk_config_id=$(echo $train_2spk_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') +train_config_id=$(echo $train_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') +infer_config_id=$(echo $infer_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') +adapt_config_id=$(echo $adapt_config | sed -e 's%conf/%%' -e 's%/%_%' -e 's%\.yaml$%%') + +# Additional arguments are added to config_id +train_2spk_config_id+=$(echo $train_2spk_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') +train_config_id+=$(echo $train_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') +infer_config_id+=$(echo $infer_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') +adapt_config_id+=$(echo $adapt_args | sed -e 's/\-\-/_/g' -e 's/=//g' -e 's/ \+//g') + +model_2spk_id=$train_2spk_id.$valid_2spk_id.$train_2spk_config_id +model_2spk_dir=exp/diarize/model/$model_2spk_id +if [ $stage -le 1 ]; then + echo "training 2-speaker model at $model_2spk_dir." + if [ -d $model_2spk_dir ]; then + echo "$model_2spk_dir already exists. " + echo " if you want to retry, please remove it." + exit 1 + fi + work=$model_2spk_dir/.work + mkdir -p $work + $train_cmd $work/train.log \ + train.py \ + -c $train_2spk_config \ + $train_2spk_args \ + $train_2spk_set $valid_2spk_set $model_2spk_dir \ + || exit 1 +fi + +ave_id=avg${average_2spk_start}-${average_2spk_end} +if [ $stage -le 2 ]; then + echo "averaging model parameters into $model_2spk_dir/$ave_id.nnet.npz" + if [ -s $model_2spk_dir/$ave_id.nnet.npz ]; then + echo "$model_2spk_dir/$ave_id.nnet.npz already exists. " + echo " if you want to retry, please remove it." + exit 1 + fi + models=`eval echo $model_2spk_dir/snapshot_epoch-{$average_2spk_start..$average_2spk_end}` + model_averaging.py $model_2spk_dir/$ave_id.nnet.npz $models || exit 1 +fi + +model_id=$train_id.$valid_id.$train_config_id +model_dir=exp/diarize/model/$model_id +if [ $stage -le 3 ]; then + echo "training model at $model_dir." + if [ -d $model_dir ]; then + echo "$model_dir already exists. " + echo " if you want to retry, please remove it." + exit 1 + fi + work=$model_dir/.work + mkdir -p $work + $train_cmd $work/train.log \ + train.py \ + -c $train_config \ + $train_args \ + --initmodel $model_2spk_dir/$ave_id.nnet.npz \ + $train_set $valid_set $model_dir \ + || exit 1 +fi + +ave_id=avg${average_start}-${average_end} +if [ $stage -le 4 ]; then + echo "averaging model parameters into $model_dir/$ave_id.nnet.npz" + if [ -s $model_dir/$ave_id.nnet.npz ]; then + echo "$model_dir/$ave_id.nnet.npz already exists. " + echo " if you want to retry, please remove it." + exit 1 + fi + models=`eval echo $model_dir/snapshot_epoch-{$average_start..$average_end}` + model_averaging.py $model_dir/$ave_id.nnet.npz $models || exit 1 +fi + +adapt_model_dir=exp/diarize/model/$model_id.$ave_id.$adapt_config_id +if [ $stage -le 5 ]; then + echo "adapting model at $adapt_model_dir" + if [ -d $adapt_model_dir ]; then + echo "$adapt_model_dir already exists. " + echo " if you want to retry, please remove it." + exit 1 + fi + work=$adapt_model_dir/.work + mkdir -p $work + $train_cmd $work/train.log \ + train.py \ + -c $adapt_config \ + $adapt_args \ + --initmodel $model_dir/$ave_id.nnet.npz \ + $adapt_set $adapt_valid_set $adapt_model_dir \ + || exit 1 +fi + +adapt_ave_id=avg${adapt_average_start}-${adapt_average_end} +if [ $stage -le 6 ]; then + echo "averaging models into $adapt_model_dir/$adapt_ave_id.nnet.gz" + if [ -s $adapt_model_dir/$adapt_ave_id.nnet.npz ]; then + echo "$adapt_model_dir/$adapt_ave_id.nnet.npz already exists." + echo " if you want to retry, please remove it." + exit 1 + fi + models=`eval echo $adapt_model_dir/snapshot_epoch-{$adapt_average_start..$adapt_average_end}` + model_averaging.py $adapt_model_dir/$adapt_ave_id.nnet.npz $models || exit 1 +fi + +infer_dir=exp/diarize/infer/$model_id.$ave_id.$adapt_config_id.$adapt_ave_id.$infer_config_id +if [ $stage -le 7 ]; then + echo "inference at $infer_dir" + if [ -d $infer_dir ]; then + echo "$infer_dir already exists. " + echo " if you want to retry, please remove it." + exit 1 + fi + for dset in callhome2_spkall; do + work=$infer_dir/$dset/.work + mkdir -p $work + $train_cmd $work/infer.log \ + infer.py -c $infer_config \ + data/eval/${dset} \ + $adapt_model_dir/$adapt_ave_id.nnet.npz \ + $infer_dir/$dset \ + || exit 1 + done +fi + +scoring_dir=exp/diarize/scoring/$model_id.$ave_id.$adapt_config_id.$adapt_ave_id.$infer_config_id +if [ $stage -le 8 ]; then + echo "scoring at $scoring_dir" + if [ -d $scoring_dir ]; then + echo "$scoring_dir already exists. " + echo " if you want to retry, please remove it." + exit 1 + fi + for dset in callhome2_spkall; do + work=$scoring_dir/$dset/.work + mkdir -p $work + find $infer_dir/$dset -iname "*.h5" > $work/file_list_$dset + for med in 1 11; do + for th in 0.3 0.4 0.5 0.6 0.7; do + make_rttm.py --median=$med --threshold=$th \ + --frame_shift=$infer_frame_shift --subsampling=$infer_subsampling --sampling_rate=$infer_sampling_rate \ + $work/file_list_$dset $scoring_dir/$dset/hyp_${th}_$med.rttm + md-eval.pl -c 0.25 \ + -r data/eval/$dset/rttm \ + -s $scoring_dir/$dset/hyp_${th}_$med.rttm > $scoring_dir/$dset/result_th${th}_med${med}_collar0.25 2>/dev/null || exit + done + done + done +fi + +if [ $stage -le 9 ]; then + for dset in callhome2_spkall; do + best_score.sh $scoring_dir/$dset + done +fi +echo "Finished !" diff --git a/egs/callhome/v1/run_prepare_shared_eda.sh b/egs/callhome/v1/run_prepare_shared_eda.sh new file mode 100755 index 0000000..57617d8 --- /dev/null +++ b/egs/callhome/v1/run_prepare_shared_eda.sh @@ -0,0 +1,236 @@ +#!/bin/bash + +# Copyright 2019 Hitachi, Ltd. (author: Yusuke Fujita, Shota Horiguchi) +# Licensed under the MIT license. +# +# This script prepares kaldi-style data sets shared with different experiments +# - data/xxxx +# callhome, sre, swb2, and swb_cellular datasets +# - data/simu_${simu_outputs} +# simulation mixtures generated with various options + +stage=0 + +# Modify corpus directories +# - callhome_dir +# CALLHOME (LDC2001S97) +# - swb2_phase1_train +# Switchboard-2 Phase 1 (LDC98S75) +# - data_root +# LDC99S79, LDC2002S06, LDC2001S13, LDC2004S07, +# LDC2006S44, LDC2011S01, LDC2011S04, LDC2011S09, +# LDC2011S10, LDC2012S01, LDC2011S05, LDC2011S08 +# - musan_root +# MUSAN corpus (https://www.openslr.org/17/) +callhome_dir=/export/corpora/NIST/LDC2001S97 +swb2_phase1_train=/export/corpora/LDC/LDC98S75 +data_root=/export/corpora5/LDC +musan_root=/export/corpora/JHU/musan +# Modify simulated data storage area. +# This script distributes simulated data under these directories +simu_actual_dirs=( +/export/c05/$USER/diarization-data +/export/c08/$USER/diarization-data +/export/c09/$USER/diarization-data +) + +# data preparation options +max_jobs_run=4 +sad_num_jobs=30 +sad_opts="--extra-left-context 79 --extra-right-context 21 --frames-per-chunk 150 --extra-left-context-initial 0 --extra-right-context-final 0 --acwt 0.3" +sad_graph_opts="--min-silence-duration=0.03 --min-speech-duration=0.3 --max-speech-duration=10.0" +sad_priors_opts="--sil-scale=0.1" + +# simulation options +simu_opts_overlap=yes +simu_opts_num_speaker_array=(1 2 3 4) +simu_opts_sil_scale_array=(2 2 5 9) +simu_opts_rvb_prob=0.5 +simu_opts_num_train=100000 +simu_opts_min_utts=10 +simu_opts_max_utts=20 + +. path.sh +. cmd.sh +. parse_options.sh || exit + +if [ $stage -le 0 ]; then + echo "prepare kaldi-style datasets" + # Prepare CALLHOME dataset. This will be used to evaluation. + if ! validate_data_dir.sh --no-text --no-feats data/callhome1_spkall \ + || ! validate_data_dir.sh --no-text --no-feats data/callhome2_spkall; then + # imported from https://github.com/kaldi-asr/kaldi/blob/master/egs/callhome_diarization/v1 + local/make_callhome.sh $callhome_dir data + # Generate two-speaker subsets + for dset in callhome1 callhome2; do + # Extract two-speaker recordings in wav.scp + copy_data_dir.sh data/${dset} data/${dset}_spkall + # Regenerate segments file from fullref.rttm + # $2: recid, $4: start_time, $5: duration, $8: speakerid + awk '{printf "%s_%s_%07d_%07d %s %.2f %.2f\n", \ + $2, $8, $4*100, ($4+$5)*100, $2, $4, $4+$5}' \ + data/callhome/fullref.rttm | sort > data/${dset}_spkall/segments + utils/fix_data_dir.sh data/${dset}_spkall + # Speaker ID is '[recid]_[speakerid] + awk '{split($1,A,"_"); printf "%s %s_%s\n", $1, A[1], A[2]}' \ + data/${dset}_spkall/segments > data/${dset}_spkall/utt2spk + utils/fix_data_dir.sh data/${dset}_spkall + # Generate rttm files for scoring + steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + data/${dset}_spkall/utt2spk data/${dset}_spkall/segments \ + data/${dset}_spkall/rttm + utils/data/get_reco2dur.sh data/${dset}_spkall + done + fi + # Prepare a collection of NIST SRE and SWB data. This will be used to train, + if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_comb; then + local/make_sre.sh $data_root data + # Prepare SWB for x-vector DNN training. + local/make_swbd2_phase1.pl $swb2_phase1_train \ + data/swbd2_phase1_train + local/make_swbd2_phase2.pl $data_root/LDC99S79 \ + data/swbd2_phase2_train + local/make_swbd2_phase3.pl $data_root/LDC2002S06 \ + data/swbd2_phase3_train + local/make_swbd_cellular1.pl $data_root/LDC2001S13 \ + data/swbd_cellular1_train + local/make_swbd_cellular2.pl $data_root/LDC2004S07 \ + data/swbd_cellular2_train + # Combine swb and sre data + utils/combine_data.sh data/swb_sre_comb \ + data/swbd_cellular1_train data/swbd_cellular2_train \ + data/swbd2_phase1_train \ + data/swbd2_phase2_train data/swbd2_phase3_train data/sre + fi + # musan data. "back-ground + if ! validate_data_dir.sh --no-text --no-feats data/musan_noise_bg; then + local/make_musan.sh $musan_root data + utils/copy_data_dir.sh data/musan_noise data/musan_noise_bg + awk '{if(NR>1) print $1,$1}' $musan_root/noise/free-sound/ANNOTATIONS > data/musan_noise_bg/utt2spk + utils/fix_data_dir.sh data/musan_noise_bg + fi + # simu rirs 8k + if ! validate_data_dir.sh --no-text --no-feats data/simu_rirs_8k; then + mkdir -p data/simu_rirs_8k + if [ ! -e sim_rir_8k.zip ]; then + wget --no-check-certificate http://www.openslr.org/resources/26/sim_rir_8k.zip + fi + unzip sim_rir_8k.zip -d data/sim_rir_8k + find $PWD/data/sim_rir_8k -iname "*.wav" \ + | awk '{n=split($1,A,/[\/\.]/); print A[n-3]"_"A[n-1], $1}' \ + | sort > data/simu_rirs_8k/wav.scp + awk '{print $1, $1}' data/simu_rirs_8k/wav.scp > data/simu_rirs_8k/utt2spk + utils/fix_data_dir.sh data/simu_rirs_8k + fi + # Automatic segmentation using pretrained SAD model + # it will take one day using 30 CPU jobs: + # make_mfcc: 1 hour, compute_output: 18 hours, decode: 0.5 hours + sad_nnet_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a + sad_work_dir=exp/segmentation_1a/tdnn_stats_asr_sad_1a + if ! validate_data_dir.sh --no-text $sad_work_dir/swb_sre_comb_seg; then + if [ ! -d exp/segmentation_1a ]; then + wget http://kaldi-asr.org/models/4/0004_tdnn_stats_asr_sad_1a.tar.gz + tar zxf 0004_tdnn_stats_asr_sad_1a.tar.gz + fi + steps/segmentation/detect_speech_activity.sh \ + --nj $sad_num_jobs \ + --graph-opts "$sad_graph_opts" \ + --transform-probs-opts "$sad_priors_opts" $sad_opts \ + data/swb_sre_comb $sad_nnet_dir mfcc_hires $sad_work_dir \ + $sad_work_dir/swb_sre_comb || exit 1 + fi + # Extract >1.5 sec segments and split into train/valid sets + if ! validate_data_dir.sh --no-text --no-feats data/swb_sre_cv; then + copy_data_dir.sh data/swb_sre_comb data/swb_sre_comb_seg + awk '$4-$3>1.5{print;}' $sad_work_dir/swb_sre_comb_seg/segments > data/swb_sre_comb_seg/segments + cp $sad_work_dir/swb_sre_comb_seg/{utt2spk,spk2utt} data/swb_sre_comb_seg + fix_data_dir.sh data/swb_sre_comb_seg + utils/subset_data_dir_tr_cv.sh data/swb_sre_comb_seg data/swb_sre_tr data/swb_sre_cv + fi +fi + +simudir=data/simu +if [ $stage -le 1 ]; then + echo "simulation of mixture" + mkdir -p $simudir/.work + random_mixture_cmd=random_mixture_nooverlap.py + make_mixture_cmd=make_mixture_nooverlap.py + if [ "$simu_opts_overlap" == "yes" ]; then + random_mixture_cmd=random_mixture.py + make_mixture_cmd=make_mixture.py + fi + + for ((i=0; i<${#simu_opts_sil_scale_array[@]}; ++i)); do + simu_opts_num_speaker=${simu_opts_num_speaker_array[i]} + simu_opts_sil_scale=${simu_opts_sil_scale_array[i]} + for dset in swb_sre_tr swb_sre_cv; do + if [ "$dset" == "swb_sre_tr" ]; then + n_mixtures=${simu_opts_num_train} + else + n_mixtures=500 + fi + simuid=${dset}_ns${simu_opts_num_speaker}_beta${simu_opts_sil_scale}_${n_mixtures} + # check if you have the simulation + if ! validate_data_dir.sh --no-text --no-feats $simudir/data/$simuid; then + # random mixture generation + $train_cmd $simudir/.work/random_mixture_$simuid.log \ + $random_mixture_cmd --n_speakers $simu_opts_num_speaker --n_mixtures $n_mixtures \ + --speech_rvb_probability $simu_opts_rvb_prob \ + --sil_scale $simu_opts_sil_scale \ + data/$dset data/musan_noise_bg data/simu_rirs_8k \ + \> $simudir/.work/mixture_$simuid.scp + nj=100 + mkdir -p $simudir/wav/$simuid + # distribute simulated data to $simu_actual_dir + split_scps= + for n in $(seq $nj); do + split_scps="$split_scps $simudir/.work/mixture_$simuid.$n.scp" + mkdir -p $simudir/.work/data_$simuid.$n + actual=${simu_actual_dirs[($n-1)%${#simu_actual_dirs[@]}]}/$simudir/wav/$simuid/$n + mkdir -p $actual + ln -nfs $actual $simudir/wav/$simuid/$n + done + utils/split_scp.pl $simudir/.work/mixture_$simuid.scp $split_scps || exit 1 + + $simu_cmd --max-jobs-run 32 JOB=1:$nj $simudir/.work/make_mixture_$simuid.JOB.log \ + $make_mixture_cmd --rate=8000 \ + $simudir/.work/mixture_$simuid.JOB.scp \ + $simudir/.work/data_$simuid.JOB $simudir/wav/$simuid/JOB + utils/combine_data.sh $simudir/data/$simuid $simudir/.work/data_$simuid.* + steps/segmentation/convert_utt2spk_and_segments_to_rttm.py \ + $simudir/data/$simuid/utt2spk $simudir/data/$simuid/segments \ + $simudir/data/$simuid/rttm + utils/data/get_reco2dur.sh $simudir/data/$simuid + fi + simuid_concat=${dset}_ns"$(IFS="n"; echo "${simu_opts_num_speaker_array[*]}")"_beta"$(IFS="n"; echo "${simu_opts_sil_scale_array[*]}")"_${n_mixtures} + mkdir -p $simudir/data/$simuid_concat + for f in `ls -F $simudir/data/$simuid | grep -v "/"`; do + cat $simudir/data/$simuid/$f >> $simudir/data/$simuid_concat/$f + done + done + done +fi + +if [ $stage -le 3 ]; then + # compose eval/callhome2_spkall + eval_set=data/eval/callhome2_spkall + if ! validate_data_dir.sh --no-text --no-feats $eval_set; then + utils/copy_data_dir.sh data/callhome2_spkall $eval_set + cp data/callhome2_spkall/rttm $eval_set/rttm + awk -v dstdir=wav/eval/callhome2_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome2_spkall/wav.scp > $eval_set/wav.scp + mkdir -p wav/eval/callhome2_spkall + wav-copy scp:data/callhome2_spkall/wav.scp scp:$eval_set/wav.scp + utils/data/get_reco2dur.sh $eval_set + fi + + # compose eval/callhome1_spkall + adapt_set=data/eval/callhome1_spkall + if ! validate_data_dir.sh --no-text --no-feats $adapt_set; then + utils/copy_data_dir.sh data/callhome1_spkall $adapt_set + cp data/callhome1_spkall/rttm $adapt_set/rttm + awk -v dstdir=wav/eval/callhome1_spkall '{print $1, dstdir"/"$1".wav"}' data/callhome1_spkall/wav.scp > $adapt_set/wav.scp + mkdir -p wav/eval/callhome1_spkall + wav-copy scp:data/callhome1_spkall/wav.scp scp:$adapt_set/wav.scp + utils/data/get_reco2dur.sh $adapt_set + fi +fi diff --git a/egs/mini_librispeech/v1/conf/eda/infer.yaml b/egs/mini_librispeech/v1/conf/eda/infer.yaml new file mode 100644 index 0000000..97eb177 --- /dev/null +++ b/egs/mini_librispeech/v1/conf/eda/infer.yaml @@ -0,0 +1,18 @@ +# inference options +sampling_rate: 8000 +frame_size: 200 +frame_shift: 80 +model_type: Transformer +hidden_size: 256 +num_speakers: 2 +input_transform: logmel23_mn +context_size: 7 +subsampling: 10 +chunk_size: 2000 +transformer_encoder_n_heads: 4 +transformer_encoder_n_layers: 2 +use_attractor: True +shuffle: True +attractor_encoder_dropout: 0.1 +attractor_decoder_dropout: 0.1 +gpu: 0 diff --git a/egs/mini_librispeech/v1/conf/eda/train.yaml b/egs/mini_librispeech/v1/conf/eda/train.yaml new file mode 100644 index 0000000..4b0cc07 --- /dev/null +++ b/egs/mini_librispeech/v1/conf/eda/train.yaml @@ -0,0 +1,28 @@ +# training options +sampling_rate: 8000 +frame_size: 200 +frame_shift: 80 +model_type: Transformer +max_epochs: 10 +gradclip: 5 +batchsize: 64 +hidden_size: 256 +num_frames: 500 +num_speakers: 2 +input_transform: logmel23_mn +optimizer: noam +context_size: 7 +subsampling: 10 +noam_scale: 1.0 +gradient_accumulation_steps: 1 +transformer_encoder_n_heads: 4 +transformer_encoder_n_layers: 2 +transformer_encoder_dropout: 0.1 +noam_warmup_steps: 25000 +use_attractor: True +shuffle: True +attractor_loss_ratio: 1.0 +attractor_encoder_dropout: 0.1 +attractor_decoder_dropout: 0.1 +seed: 777 +gpu: 0 diff --git a/egs/mini_librispeech/v1/run.sh b/egs/mini_librispeech/v1/run.sh index bebeaab..4315e5b 100755 --- a/egs/mini_librispeech/v1/run.sh +++ b/egs/mini_librispeech/v1/run.sh @@ -14,6 +14,9 @@ valid_set=data/simu/data/dev_clean_2_ns2_beta2_500 # Base config files for {train,infer}.py train_config=conf/train.yaml infer_config=conf/infer.yaml +# If you want to use EDA-EEND, uncommend two lines below. +# train_config=conf/eda/train.yaml +# infer_config=conf/eda/infer.yaml # Additional arguments passed to {train,infer}.py. # You need not edit the base config files above