/
dprnn_separator.py
131 lines (107 loc) · 4.31 KB
/
dprnn_separator.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from collections import OrderedDict
from typing import Dict, List, Optional, Tuple, Union
import torch
from packaging.version import parse as V
from torch_complex.tensor import ComplexTensor
from espnet2.enh.layers.complex_utils import is_complex
from espnet2.enh.layers.dprnn import DPRNN, merge_feature, split_feature
from espnet2.enh.separator.abs_separator import AbsSeparator
is_torch_1_9_plus = V(torch.__version__) >= V("1.9.0")
class DPRNNSeparator(AbsSeparator):
def __init__(
self,
input_dim: int,
rnn_type: str = "lstm",
bidirectional: bool = True,
num_spk: int = 2,
predict_noise: bool = False,
nonlinear: str = "relu",
layer: int = 3,
unit: int = 512,
segment_size: int = 20,
dropout: float = 0.0,
):
"""Dual-Path RNN (DPRNN) Separator
Args:
input_dim: input feature dimension
rnn_type: string, select from 'RNN', 'LSTM' and 'GRU'.
bidirectional: bool, whether the inter-chunk RNN layers are bidirectional.
num_spk: number of speakers
predict_noise: whether to output the estimated noise signal
nonlinear: the nonlinear function for mask estimation,
select from 'relu', 'tanh', 'sigmoid'
layer: int, number of stacked RNN layers. Default is 3.
unit: int, dimension of the hidden state.
segment_size: dual-path segment size
dropout: float, dropout ratio. Default is 0.
"""
super().__init__()
self._num_spk = num_spk
self.predict_noise = predict_noise
self.segment_size = segment_size
self.num_outputs = self.num_spk + 1 if self.predict_noise else self.num_spk
self.dprnn = DPRNN(
rnn_type=rnn_type,
input_size=input_dim,
hidden_size=unit,
output_size=input_dim * self.num_outputs,
dropout=dropout,
num_layers=layer,
bidirectional=bidirectional,
)
if nonlinear not in ("sigmoid", "relu", "tanh"):
raise ValueError("Not supporting nonlinear={}".format(nonlinear))
self.nonlinear = {
"sigmoid": torch.nn.Sigmoid(),
"relu": torch.nn.ReLU(),
"tanh": torch.nn.Tanh(),
}[nonlinear]
def forward(
self,
input: Union[torch.Tensor, ComplexTensor],
ilens: torch.Tensor,
additional: Optional[Dict] = None,
) -> Tuple[List[Union[torch.Tensor, ComplexTensor]], torch.Tensor, OrderedDict]:
"""Forward.
Args:
input (torch.Tensor or ComplexTensor): Encoded feature [B, T, N]
ilens (torch.Tensor): input lengths [Batch]
additional (Dict or None): other data included in model
NOTE: not used in this model
Returns:
masked (List[Union(torch.Tensor, ComplexTensor)]): [(B, T, N), ...]
ilens (torch.Tensor): (B,)
others predicted data, e.g. masks: OrderedDict[
'mask_spk1': torch.Tensor(Batch, Frames, Freq),
'mask_spk2': torch.Tensor(Batch, Frames, Freq),
...
'mask_spkn': torch.Tensor(Batch, Frames, Freq),
]
"""
# if complex spectrum,
if is_complex(input):
feature = abs(input)
else:
feature = input
B, T, N = feature.shape
feature = feature.transpose(1, 2) # B, N, T
segmented, rest = split_feature(
feature, segment_size=self.segment_size
) # B, N, L, K
processed = self.dprnn(segmented) # B, N*num_spk, L, K
processed = merge_feature(processed, rest) # B, N*num_spk, T
processed = processed.transpose(1, 2) # B, T, N*num_spk
processed = processed.view(B, T, N, self.num_outputs)
masks = self.nonlinear(processed).unbind(dim=3)
if self.predict_noise:
*masks, mask_noise = masks
masked = [input * m for m in masks]
others = OrderedDict(
zip(["mask_spk{}".format(i + 1) for i in range(len(masks))], masks)
)
if self.predict_noise:
others["noise1"] = input * mask_noise
return masked, ilens, others
@property
def num_spk(self):
return self._num_spk