-
Notifications
You must be signed in to change notification settings - Fork 48
/
model.py
968 lines (819 loc) · 37.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
#!/usr/bin/env python
"""
model.py for harmonic-plus-noise NSF with trainable sinc filter
version: 9
"""
from __future__ import absolute_import
from __future__ import print_function
import sys
import numpy as np
import torch
import torch.nn as torch_nn
import torch.nn.functional as torch_nn_func
import core_scripts.other_tools.debug as nii_debug
import config as prj_conf
__author__ = "Xin Wang"
__email__ = "wangxin@nii.ac.jp"
__copyright__ = "Copyright 2020, Xin Wang"
##############
# Building blocks (torch.nn modules + dimension operation)
#
# For blstm
class BLSTMLayer(torch_nn.Module):
""" Wrapper over lstm
Input tensor: (batchsize=1, length, dim_in)
Output tensor: (batchsize=1, length, dim_out)
Recurrency is conducted along "length"
"""
def __init__(self, input_dim, output_dim):
super(BLSTMLayer, self).__init__()
if output_dim % 2 != 0:
print("Output_dim of BLSTMLayer is {:d}".format(output_dim))
print("BLSTMLayer expects a layer size of even number")
sys.exit(1)
# bi-directional LSTM
self.l_blstm = torch_nn.LSTM(input_dim, output_dim // 2, \
bidirectional=True)
def forward(self, x):
# permute to (length, batchsize=1, dim)
blstm_data, _ = self.l_blstm(x.permute(1, 0, 2))
# permute it backt to (batchsize=1, length, dim)
return blstm_data.permute(1, 0, 2)
#
# 1D dilated convolution that keep the input/output length
class Conv1dKeepLength(torch_nn.Conv1d):
""" Wrapper for causal convolution
Input tensor: (batchsize=1, length, dim_in)
Output tensor: (batchsize=1, length, dim_out)
https://github.com/pytorch/pytorch/issues/1333
Note: Tanh is optional
"""
def __init__(self, input_dim, output_dim, dilation_s, kernel_s,
causal = False, stride = 1, groups=1, bias=True, \
tanh = True, pad_mode='constant'):
super(Conv1dKeepLength, self).__init__(
input_dim, output_dim, kernel_s, stride=stride,
padding = 0, dilation = dilation_s, groups=groups, bias=bias)
self.pad_mode = pad_mode
self.causal = causal
# input & output length will be the same
if self.causal:
# left pad to make the convolution causal
self.pad_le = dilation_s * (kernel_s - 1)
self.pad_ri = 0
else:
# pad on both sizes
self.pad_le = dilation_s * (kernel_s - 1) // 2
self.pad_ri = dilation_s * (kernel_s - 1) - self.pad_le
if tanh:
self.l_ac = torch_nn.Tanh()
else:
self.l_ac = torch_nn.Identity()
def forward(self, data):
# permute to (batchsize=1, dim, length)
# add one dimension (batchsize=1, dim, ADDED_DIM, length)
# pad to ADDED_DIM
# squeeze and return to (batchsize=1, dim, length)
# https://github.com/pytorch/pytorch/issues/1333
x = torch_nn_func.pad(data.permute(0, 2, 1).unsqueeze(2), \
(self.pad_le, self.pad_ri, 0, 0),
mode = self.pad_mode).squeeze(2)
# tanh(conv1())
# permmute back to (batchsize=1, length, dim)
output = self.l_ac(super(Conv1dKeepLength, self).forward(x))
return output.permute(0, 2, 1)
#
# Moving average
class MovingAverage(Conv1dKeepLength):
""" Wrapper to define a moving average smoothing layer
Note: MovingAverage can be implemented using TimeInvFIRFilter too.
Here we define another Module dicrectly on Conv1DKeepLength
"""
def __init__(self, feature_dim, window_len, causal=False, \
pad_mode='replicate'):
super(MovingAverage, self).__init__(
feature_dim, feature_dim, 1, window_len, causal,
groups=feature_dim, bias=False, tanh=False, \
pad_mode=pad_mode)
# set the weighting coefficients
torch_nn.init.constant_(self.weight, 1/window_len)
# turn off grad for this layer
for p in self.parameters():
p.requires_grad = False
def forward(self, data):
return super(MovingAverage, self).forward(data)
#
# FIR filter layer
class TimeInvFIRFilter(Conv1dKeepLength):
""" Wrapper to define a FIR filter over Conv1d
Note: FIR Filtering is conducted on each dimension (channel)
independently: groups=channel_num in conv1d
"""
def __init__(self, feature_dim, filter_coef,
causal=True, flag_train=False):
""" __init__(self, feature_dim, filter_coef,
causal=True, flag_train=False)
feature_dim: dimension of input data
filter_coef: 1-D tensor of filter coefficients
causal: FIR is causal or not (default: true)
flag_train: whether train the filter coefficients (default false)
Input data: (batchsize=1, length, feature_dim)
Output data: (batchsize=1, length, feature_dim)
"""
super(TimeInvFIRFilter, self).__init__(
feature_dim, feature_dim, 1, filter_coef.shape[0], causal,
groups=feature_dim, bias=False, tanh=False)
if filter_coef.ndim == 1:
# initialize weight using provided filter_coef
with torch.no_grad():
tmp_coef = torch.zeros([feature_dim, 1,
filter_coef.shape[0]])
tmp_coef[:, 0, :] = filter_coef
tmp_coef = torch.flip(tmp_coef, dims=[2])
self.weight = torch.nn.Parameter(tmp_coef,
requires_grad=flag_train)
else:
print("TimeInvFIRFilter expects filter_coef to be 1-D tensor")
print("Please implement the code in __init__ if necessary")
sys.exit(1)
def forward(self, data):
return super(TimeInvFIRFilter, self).forward(data)
class TimeVarFIRFilter(torch_nn.Module):
""" TimeVarFIRFilter
Given sequences of filter coefficients and a signal, do filtering
Filter coefs: (batchsize=1, signal_length, filter_order = K)
Signal: (batchsize=1, signal_length, 1)
For batch 0:
For n in [1, sequence_length):
output(0, n, 1) = \sum_{k=1}^{K} signal(0, n-k, 1)*coef(0, n, k)
Note: filter coef (0, n, :) is only used to compute the output
at (0, n, 1)
"""
def __init__(self):
super(TimeVarFIRFilter, self).__init__()
def forward(self, signal, f_coef):
"""
Filter coefs: (batchsize=1, signal_length, filter_order = K)
Signal: (batchsize=1, signal_length, 1)
Output: (batchsize=1, signal_length, 1)
For n in [1, sequence_length):
output(0, n, 1)= \sum_{k=1}^{K} signal(0, n-k, 1)*coef(0, n, k)
This method may be not efficient:
Suppose signal [x_1, ..., x_N], filter [a_1, ..., a_K]
output [y_1, y_2, y_3, ..., y_N, *, * ... *]
= a_1 * [x_1, x_2, x_3, ..., x_N, 0, ..., 0]
+ a_2 * [ 0, x_1, x_2, x_3, ..., x_N, 0, ..., 0]
+ a_3 * [ 0, 0, x_1, x_2, x_3, ..., x_N, 0, ..., 0]
"""
signal_l = signal.shape[1]
order_k = f_coef.shape[-1]
# pad to (batchsize=1, signal_length + filter_order-1, dim)
padded_signal = torch_nn_func.pad(signal, (0, 0, 0, order_k - 1))
y = torch.zeros_like(signal)
# roll and weighted sum, only take [0:signal_length]
for k in range(order_k):
y += torch.roll(padded_signal, k, dims=1)[:, 0:signal_l, :] \
* f_coef[:, :, k:k+1]
# done
return y
# Sinc filter generator
class SincFilter(torch_nn.Module):
""" SincFilter
Given the cut-off-frequency, produce the low-pass and high-pass
windowed-sinc-filters.
If input cut-off-frequency is (batchsize=1, signal_length, 1),
output filter coef is (batchsize=1, signal_length, filter_order).
For each time step in [1, signal_length), we calculate one
filter for low-pass sinc filter and another for high-pass filter.
Example:
import scipy
import scipy.signal
import numpy as np
filter_order = 31
cut_f = 0.2
sinc_layer = SincFilter(filter_order)
lp_coef, hp_coef = sinc_layer(torch.ones(1, 10, 1) * cut_f)
w, h1 = scipy.signal.freqz(lp_coef[0, 0, :].numpy(), [1])
w, h2 = scipy.signal.freqz(hp_coef[0, 0, :].numpy(), [1])
plt.plot(w, 20*np.log10(np.abs(h1)))
plt.plot(w, 20*np.log10(np.abs(h2)))
plt.plot([cut_f * np.pi, cut_f * np.pi], [-100, 0])
"""
def __init__(self, filter_order):
super(SincFilter, self).__init__()
# Make the filter oder an odd number
# [-(M-1)/2, ... 0, (M-1)/2]
#
self.half_k = (filter_order - 1) // 2
self.order = self.half_k * 2 +1
def hamming_w(self, n_index):
""" prepare hamming window for each time step
n_index (batchsize=1, signal_length, filter_order)
For each time step, n_index will be [-(M-1)/2, ... 0, (M-1)/2]
n_index[0, 0, :] = [-(M-1)/2, ... 0, (M-1)/2]
n_index[0, 1, :] = [-(M-1)/2, ... 0, (M-1)/2]
...
output (batchsize=1, signal_length, filter_order)
output[0, 0, :] = hamming_window
output[0, 1, :] = hamming_window
...
"""
# Hamming window
return 0.54 + 0.46 * torch.cos(2 * np.pi * n_index / self.order)
def sinc(self, x):
""" Normalized sinc-filter sin( pi * x) / pi * x
https://en.wikipedia.org/wiki/Sinc_function
Assume x (batchsize, signal_length, filter_order) and
x[0, 0, :] = [-half_order, - half_order+1, ... 0, ..., half_order]
x[:, :, self.half_order] -> time index = 0, sinc(0)=1
"""
y = torch.zeros_like(x)
y[:,:,0:self.half_k]=torch.sin(np.pi * x[:, :, 0:self.half_k]) \
/ (np.pi * x[:, :, 0:self.half_k])
y[:,:,self.half_k+1:]=torch.sin(np.pi * x[:, :, self.half_k+1:]) \
/ (np.pi * x[:, :, self.half_k+1:])
y[:,:,self.half_k] = 1
return y
def forward(self, cut_f):
""" lp_coef, hp_coef = forward(self, cut_f)
cut-off frequency cut_f (batchsize=1, length, dim = 1)
lp_coef: low-pass filter coefs (batchsize, length, filter_order)
hp_coef: high-pass filter coefs (batchsize, length, filter_order)
"""
# create the filter order index
with torch.no_grad():
# [- (M-1) / 2, ..., 0, ..., (M-1)/2]
lp_coef = torch.arange(-self.half_k, self.half_k + 1,
device=cut_f.device)
# [[[- (M-1) / 2, ..., 0, ..., (M-1)/2],
# [- (M-1) / 2, ..., 0, ..., (M-1)/2],
# ...
# ],
# [[- (M-1) / 2, ..., 0, ..., (M-1)/2],
# [- (M-1) / 2, ..., 0, ..., (M-1)/2],
# ...
# ]]
lp_coef = lp_coef.repeat(cut_f.shape[0], cut_f.shape[1], 1)
hp_coef = torch.arange(-self.half_k, self.half_k + 1,
device=cut_f.device)
hp_coef = hp_coef.repeat(cut_f.shape[0], cut_f.shape[1], 1)
# temporary buffer of [-1^n] for gain norm in hp_coef
tmp_one = torch.pow(-1, hp_coef)
# unnormalized filter coefs with hamming window
lp_coef = cut_f * self.sinc(cut_f * lp_coef) \
* self.hamming_w(lp_coef)
hp_coef = (self.sinc(hp_coef) \
- cut_f * self.sinc(cut_f * hp_coef)) \
* self.hamming_w(hp_coef)
# normalize the coef to make gain at 0/pi is 0 dB
# sum_n lp_coef[n]
lp_coef_norm = torch.sum(lp_coef, axis=2).unsqueeze(-1)
# sum_n hp_coef[n] * -1^n
hp_coef_norm = torch.sum(hp_coef * tmp_one, axis=2).unsqueeze(-1)
lp_coef = lp_coef / lp_coef_norm
hp_coef = hp_coef / hp_coef_norm
# return normed coef
return lp_coef, hp_coef
#
# Up sampling
class UpSampleLayer(torch_nn.Module):
""" Wrapper over up-sampling
Input tensor: (batchsize=1, length, dim)
Ouput tensor: (batchsize=1, length * up-sampling_factor, dim)
"""
def __init__(self, feature_dim, up_sampling_factor, smoothing=False):
super(UpSampleLayer, self).__init__()
# wrap a up_sampling layer
self.scale_factor = up_sampling_factor
self.l_upsamp = torch_nn.Upsample(scale_factor=self.scale_factor)
if smoothing:
self.l_ave1 = MovingAverage(feature_dim, self.scale_factor)
self.l_ave2 = MovingAverage(feature_dim, self.scale_factor)
else:
self.l_ave1 = torch_nn.Identity()
self.l_ave2 = torch_nn.Identity()
return
def forward(self, x):
# permute to (batchsize=1, dim, length)
up_sampled_data = self.l_upsamp(x.permute(0, 2, 1))
# permute it backt to (batchsize=1, length, dim)
# and do two moving average
return self.l_ave1(self.l_ave2(up_sampled_data.permute(0, 2, 1)))
# Neural filter block (1 block)
class NeuralFilterBlock(torch_nn.Module):
""" Wrapper over a single filter block
"""
def __init__(self, signal_size, hidden_size,\
kernel_size=3, conv_num=10):
super(NeuralFilterBlock, self).__init__()
self.signal_size = signal_size
self.hidden_size = hidden_size
self.kernel_size = kernel_size
self.conv_num = conv_num
self.dilation_s = [np.power(2, x) for x in np.arange(conv_num)]
# ff layer to expand dimension
self.l_ff_1 = torch_nn.Linear(signal_size, hidden_size, \
bias=False)
self.l_ff_1_tanh = torch_nn.Tanh()
# dilated conv layers
tmp = [Conv1dKeepLength(hidden_size, hidden_size, x, \
kernel_size, causal=True, bias=False) \
for x in self.dilation_s]
self.l_convs = torch_nn.ModuleList(tmp)
# ff layer to de-expand dimension
self.l_ff_2 = torch_nn.Linear(hidden_size, hidden_size//4, \
bias=False)
self.l_ff_2_tanh = torch_nn.Tanh()
self.l_ff_3 = torch_nn.Linear(hidden_size//4, signal_size, \
bias=False)
self.l_ff_3_tanh = torch_nn.Tanh()
# a simple scale
self.scale = torch_nn.Parameter(torch.tensor([0.1]),
requires_grad=False)
return
def forward(self, signal, context):
"""
Assume: signal (batchsize=1, length, signal_size)
context (batchsize=1, length, hidden_size)
Output: (batchsize=1, length, signal_size)
"""
# expand dimension
tmp_hidden = self.l_ff_1_tanh(self.l_ff_1(signal))
# loop over dilated convs
# output of a d-conv is input + context + d-conv(input)
for l_conv in self.l_convs:
tmp_hidden = tmp_hidden + l_conv(tmp_hidden) + context
# to be consistent with legacy configuration in CURRENNT
tmp_hidden = tmp_hidden * self.scale
# compress the dimesion and skip-add
tmp_hidden = self.l_ff_2_tanh(self.l_ff_2(tmp_hidden))
tmp_hidden = self.l_ff_3_tanh(self.l_ff_3(tmp_hidden))
output_signal = tmp_hidden + signal
return output_signal
#
# Sine waveform generator
#
# Sine waveform generator
class SineGen(torch_nn.Module):
""" Definition of sine generator
SineGen(samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False)
samp_rate: sampling rate in Hz
harmonic_num: number of harmonic overtones (default 0)
sine_amp: amplitude of sine-wavefrom (default 0.1)
noise_std: std of Gaussian noise (default 0.003)
voiced_thoreshold: F0 threshold for U/V classification (default 0)
flag_for_pulse: this SinGen is used inside PulseGen (default False)
Note: when flag_for_pulse is True, the first time step of a voiced
segment is always sin(np.pi) or cos(0)
"""
def __init__(self, samp_rate, harmonic_num = 0,
sine_amp = 0.1, noise_std = 0.003,
voiced_threshold = 0,
flag_for_pulse=False):
super(SineGen, self).__init__()
self.sine_amp = sine_amp
self.noise_std = noise_std
self.harmonic_num = harmonic_num
self.dim = self.harmonic_num + 1
self.sampling_rate = samp_rate
self.voiced_threshold = voiced_threshold
self.flag_for_pulse = flag_for_pulse
def _f02uv(self, f0):
# generate uv signal
uv = torch.ones_like(f0)
uv = uv * (f0 > self.voiced_threshold)
return uv
def _f02sine(self, f0_values):
""" f0_values: (batchsize, length, dim)
where dim indicates fundamental tone and overtones
"""
# convert to F0 in rad. The interger part n can be ignored
# because 2 * np.pi * n doesn't affect phase
rad_values = (f0_values / self.sampling_rate) % 1
# initial phase noise (no noise for fundamental component)
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2],\
device = f0_values.device)
rand_ini[:, 0] = 0
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
if not self.flag_for_pulse:
# for normal case
sines = torch.sin(torch.cumsum(rad_values, dim=1) *2*np.pi)
else:
# If necessary, make sure that the first time step of every
# voiced segments is sin(pi) or cos(0)
# This is used for pulse-train generation
# identify the last time step in unvoiced segments
uv = self._f02uv(f0_values)
uv_1 = torch.roll(uv, shifts=-1, dims=1)
uv_1[:, -1, :] = 1
u_loc = (uv < 1) * (uv_1 > 0)
# get the instantanouse phase
tmp_cumsum = torch.cumsum(rad_values, dim=1)
# different batch needs to be processed differently
for idx in range(f0_values.shape[0]):
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
# stores the accumulation of i.phase within
# each voiced segments
tmp_cumsum[idx, :, :] = 0
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
# rad_values - tmp_cumsum: remove the accumulation of i.phase
# within the previous voiced segment.
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
# get the sines
sines = torch.cos(i_phase * 2 * np.pi)
return sines
def forward(self, f0):
""" sine_tensor, uv = forward(f0)
input F0: tensor(batchsize=1, length, dim=1)
f0 for unvoiced steps should be 0
output sine_tensor: tensor(batchsize=1, length, dim)
output uv: tensor(batchsize=1, length, 1)
"""
with torch.no_grad():
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, \
device=f0.device)
# fundamental component
f0_buf[:, :, 0] = f0[:, :, 0]
for idx in np.arange(self.harmonic_num):
# idx + 2: the (idx+1)-th overtone, (idx+2)-th harmonic
f0_buf[:, :, idx+1] = f0_buf[:, :, 0] * (idx+2)
# generate sine waveforms
sine_waves = self._f02sine(f0_buf) * self.sine_amp
# generate uv signal
#uv = torch.ones(f0.shape)
#uv = uv * (f0 > self.voiced_threshold)
uv = self._f02uv(f0)
# noise: for unvoiced should be similar to sine_amp
# std = self.sine_amp/3 -> max value ~ self.sine_amp
#. for voiced regions is self.noise_std
noise_amp = uv * self.noise_std + (1-uv) * self.sine_amp / 3
noise = noise_amp * torch.randn_like(sine_waves)
# first: set the unvoiced part to 0 by uv
# then: additive noise
sine_waves = sine_waves * uv + noise
return sine_waves, uv, noise
#####
## Model definition
##
## For condition module only provide Spectral feature to Filter block
class CondModuleHnSincNSF(torch_nn.Module):
""" Condition module for hn-sinc-NSF
Upsample and transform input features
CondModuleHnSincNSF(input_dimension, output_dimension, up_sample_rate,
blstm_dimension = 64, cnn_kernel_size = 3)
Spec, F0, cut_off_freq = CondModuleHnSincNSF(features, F0)
Both input features should be frame-level features
If x doesn't contain F0, just ignore the returned F0
CondModuleHnSincNSF(input_dim, output_dim, up_sample,
blstm_s = 64, cnn_kernel_s = 3,
voiced_threshold = 0):
input_dim: sum of dimensions of input features
output_dim: dim of the feature Spec to be used by neural filter-block
up_sample: up sampling rate of input features
blstm_s: dimension of the features from blstm (default 64)
cnn_kernel_s: kernel size of CNN in condition module (default 3)
voiced_threshold: f0 > voiced_threshold is voiced, otherwise unvoiced
"""
def __init__(self, input_dim, output_dim, up_sample, \
blstm_s = 64, cnn_kernel_s = 3, voiced_threshold = 0):
super(CondModuleHnSincNSF, self).__init__()
# input feature dimension
self.input_dim = input_dim
self.output_dim = output_dim
self.up_sample = up_sample
self.blstm_s = blstm_s
self.cnn_kernel_s = cnn_kernel_s
self.cut_f_smooth = up_sample * 4
self.voiced_threshold = voiced_threshold
# the blstm layer
self.l_blstm = BLSTMLayer(input_dim, self.blstm_s)
# the CNN layer (+1 dim for cut_off_frequence of sinc filter)
self.l_conv1d = Conv1dKeepLength(self.blstm_s, \
self.output_dim, \
dilation_s = 1, \
kernel_s = self.cnn_kernel_s)
# Upsampling layer for hidden features
self.l_upsamp = UpSampleLayer(self.output_dim, \
self.up_sample, True)
# separate layer for up-sampling normalized F0 values
self.l_upsamp_f0_hi = UpSampleLayer(1, self.up_sample, True)
# Upsampling for F0: don't smooth up-sampled F0
self.l_upsamp_F0 = UpSampleLayer(1, self.up_sample, False)
# Another smoothing layer to smooth the cut-off frequency
# for sinc filters. Use a larger window to smooth
self.l_cut_f_smooth = MovingAverage(1, self.cut_f_smooth)
def get_cut_f(self, hidden_feat, f0):
""" cut_f = get_cut_f(self, feature, f0)
feature: (batchsize, length, dim=1)
f0: (batchsize, length, dim=1)
"""
# generate uv signal
uv = torch.ones_like(f0) * (f0 > self.voiced_threshold)
# hidden_feat is between (-1, 1) after conv1d with tanh
# (-0.2, 0.2) + 0.3 = (0.1, 0.5)
# voiced: (0.1, 0.5) + 0.4 = (0.5, 0.9)
# unvoiced: (0.1, 0.5) = (0.1, 0.5)
return hidden_feat * 0.2 + uv * 0.4 + 0.3
def forward(self, feature, f0):
""" spec, f0 = forward(self, feature, f0)
feature: (batchsize, length, dim)
f0: (batchsize, length, dim=1), which should be F0 at frame-level
spec: (batchsize, length, self.output_dim), at wave-level
f0: (batchsize, length, 1), at wave-level
"""
tmp = self.l_upsamp(self.l_conv1d(self.l_blstm(feature)))
# concatenat normed F0 with hidden spectral features
context = torch.cat((tmp[:, :, 0:self.output_dim-1], \
self.l_upsamp_f0_hi(feature[:, :, -1:])), \
dim=2)
# hidden feature for cut-off frequency
hidden_cut_f = tmp[:, :, self.output_dim-1:]
# directly up-sample F0 without smoothing
f0_upsamp = self.l_upsamp_F0(f0)
# get the cut-off-frequency from output of CNN
cut_f = self.get_cut_f(hidden_cut_f, f0_upsamp)
# smooth the cut-off-frequency using fixed average smoothing
cut_f_smoothed = self.l_cut_f_smooth(cut_f)
# return
return context, f0_upsamp, cut_f_smoothed, hidden_cut_f
# For source module
class SourceModuleHnNSF(torch_nn.Module):
""" SourceModule for hn-nsf
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0)
sampling_rate: sampling_rate in Hz
harmonic_num: number of harmonic above F0 (default: 0)
sine_amp: amplitude of sine source signal (default: 0.1)
add_noise_std: std of additive Gaussian noise (default: 0.003)
note that amplitude of noise in unvoiced is decided
by sine_amp
voiced_threshold: threhold to set U/V given F0 (default: 0)
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
uv (batchsize, length, 1)
"""
def __init__(self, sampling_rate, harmonic_num=0, sine_amp=0.1,
add_noise_std=0.003, voiced_threshod=0):
super(SourceModuleHnNSF, self).__init__()
self.sine_amp = sine_amp
self.noise_std = add_noise_std
# to produce sine waveforms
self.l_sin_gen = SineGen(sampling_rate, harmonic_num,
sine_amp, add_noise_std, voiced_threshod)
# to merge source harmonics into a single excitation
self.l_linear = torch_nn.Linear(harmonic_num+1, 1)
self.l_tanh = torch_nn.Tanh()
def forward(self, x):
"""
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
F0_sampled (batchsize, length, 1)
Sine_source (batchsize, length, 1)
noise_source (batchsize, length 1)
"""
# source for harmonic branch
sine_wavs, uv, _ = self.l_sin_gen(x)
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
# source for noise branch, in the same shape as uv
noise = torch.randn_like(uv) * self.sine_amp / 3
return sine_merge, noise, uv
# For Filter module
class FilterModuleHnSincNSF(torch_nn.Module):
""" Filter for Hn-sinc-NSF
FilterModuleHnSincNSF(signal_size, hidden_size, sinc_order = 31,
block_num = 5, kernel_size = 3,
conv_num_in_block = 10)
signal_size: signal dimension (should be 1)
hidden_size: dimension of hidden features inside neural filter block
sinc_order: order of the sinc filter
block_num: number of neural filter blocks in harmonic branch
kernel_size: kernel size in dilated CNN
conv_num_in_block: number of d-conv1d in one neural filter block
Usage:
output = FilterModuleHnSincNSF(har_source, noi_source, cut_f, context)
har_source: source for harmonic branch (batchsize, length, dim=1)
noi_source: source for noise branch (batchsize, length, dim=1)
cut_f: cut-off-frequency of sinc filters (batchsize, length, dim=1)
context: hidden features to be added (batchsize, length, dim)
output: (batchsize, length, dim=1)
"""
def __init__(self, signal_size, hidden_size, sinc_order = 31, \
block_num = 5, kernel_size = 3, conv_num_in_block = 10):
super(FilterModuleHnSincNSF, self).__init__()
self.signal_size = signal_size
self.hidden_size = hidden_size
self.kernel_size = kernel_size
self.block_num = block_num
self.conv_num_in_block = conv_num_in_block
self.sinc_order = sinc_order
# filter blocks for harmonic branch
tmp = [NeuralFilterBlock(signal_size, hidden_size, \
kernel_size, conv_num_in_block) \
for x in range(self.block_num)]
self.l_har_blocks = torch_nn.ModuleList(tmp)
# filter blocks for noise branch (only one block, 5 sub-blocks)
tmp = [NeuralFilterBlock(signal_size, hidden_size, \
kernel_size, conv_num_in_block // 2) \
for x in range(1)]
self.l_noi_blocks = torch_nn.ModuleList(tmp)
# sinc filter generators and time-variant filtering layer
self.l_sinc_coef = SincFilter(self.sinc_order)
self.l_tv_filtering = TimeVarFIRFilter()
# done
def forward(self, har_component, noi_component, cond_feat, cut_f):
"""
"""
# harmonic component
for l_har_block in self.l_har_blocks:
har_component = l_har_block(har_component, cond_feat)
# noise componebt
for l_noi_block in self.l_noi_blocks:
noi_component = l_noi_block(noi_component, cond_feat)
# get sinc filter coefficients
lp_coef, hp_coef = self.l_sinc_coef(cut_f)
# time-variant filtering
har_signal = self.l_tv_filtering(har_component, lp_coef)
noi_signal = self.l_tv_filtering(noi_component, hp_coef)
# get output
return har_signal + noi_signal
## FOR MODEL
class Model(torch_nn.Module):
""" Model definition
"""
def __init__(self, in_dim, out_dim, args, mean_std=None):
super(Model, self).__init__()
torch.manual_seed(1)
# mean std of input and output
in_m, in_s, out_m, out_s = self.prepare_mean_std(in_dim,out_dim,\
args, mean_std)
self.input_mean = torch_nn.Parameter(in_m, requires_grad=False)
self.input_std = torch_nn.Parameter(in_s, requires_grad=False)
self.output_mean = torch_nn.Parameter(out_m, requires_grad=False)
self.output_std = torch_nn.Parameter(out_s, requires_grad=False)
self.input_dim = in_dim
self.output_dim = out_dim
# configurations
# amplitude of sine waveform (for each harmonic)
self.sine_amp = 0.1
# standard deviation of Gaussian noise for additive noise
self.noise_std = 0.003
# dimension of hidden features in filter blocks
self.hidden_dim = 64
# upsampling rate on input acoustic features (16kHz * 5ms = 80)
# assume input_reso has the same value
self.upsamp_rate = prj_conf.input_reso[0]
# sampling rate (Hz)
self.sampling_rate = prj_conf.wav_samp_rate
# CNN kernel size in filter blocks
self.cnn_kernel_s = 3
# number of filter blocks (for harmonic branch)
# noise branch only uses 1 block
self.filter_block_num = 5
# number of dilated CNN in each filter block
self.cnn_num_in_block = 10
# number of harmonic overtones in source
self.harmonic_num = 7
# order of sinc-windowed-FIR-filter
self.sinc_order = 31
# the three modules
self.m_cond = CondModuleHnSincNSF(self.input_dim, \
self.hidden_dim, \
self.upsamp_rate, \
cnn_kernel_s=self.cnn_kernel_s)
self.m_source = SourceModuleHnNSF(self.sampling_rate,
self.harmonic_num,
self.sine_amp, self.noise_std)
self.m_filter = FilterModuleHnSincNSF(self.output_dim, \
self.hidden_dim, \
self.sinc_order, \
self.filter_block_num, \
self.cnn_kernel_s, \
self.cnn_num_in_block)
# done
return
def prepare_mean_std(self, in_dim, out_dim, args, data_mean_std=None):
"""
"""
if data_mean_std is not None:
in_m = torch.from_numpy(data_mean_std[0])
in_s = torch.from_numpy(data_mean_std[1])
out_m = torch.from_numpy(data_mean_std[2])
out_s = torch.from_numpy(data_mean_std[3])
if in_m.shape[0] != in_dim or in_s.shape[0] != in_dim:
print("Input dim: {:d}".format(in_dim))
print("Mean dim: {:d}".format(in_m.shape[0]))
print("Std dim: {:d}".format(in_s.shape[0]))
print("Input dimension incompatible")
sys.exit(1)
if out_m.shape[0] != out_dim or out_s.shape[0] != out_dim:
print("Output dim: {:d}".format(out_dim))
print("Mean dim: {:d}".format(out_m.shape[0]))
print("Std dim: {:d}".format(out_s.shape[0]))
print("Output dimension incompatible")
sys.exit(1)
else:
in_m = torch.zeros([in_dim])
in_s = torch.ones([in_dim])
out_m = torch.zeros([out_dim])
out_s = torch.ones([out_dim])
return in_m, in_s, out_m, out_s
def normalize_input(self, x):
""" normalizing the input data
"""
return (x - self.input_mean) / self.input_std
def normalize_target(self, y):
""" normalizing the target data
"""
return (y - self.output_mean) / self.output_std
def denormalize_output(self, y):
""" denormalizing the generated output from network
"""
return y * self.output_std + self.output_mean
def forward(self, x):
""" definition of forward method
Assume x (batchsize=1, length, dim)
Return output(batchsize=1, length)
"""
# assume x[:, :, -1] is F0, denormalize F0
f0 = x[:, :, -1:]
# normalize the input features data
feat = self.normalize_input(x)
# condition module
# feature-to-filter-block, f0-up-sampled, cut-off-f-for-sinc,
# hidden-feature-for-cut-off-f
cond_feat, f0_upsamped, cut_f, hid_cut_f = self.m_cond(feat, f0)
# source module
# harmonic-source, noise-source (for noise branch), uv
har_source, noi_source, uv = self.m_source(f0_upsamped)
# neural filter module (including sinc-based FIR filtering)
# output
output = self.m_filter(har_source, noi_source, cond_feat, cut_f)
if self.training:
# just in case we need to penalize the hidden feauture for
# cut-off-freq.
return [output.squeeze(-1), hid_cut_f]
else:
return output.squeeze(-1)
class Loss():
""" Wrapper to define loss function
"""
def __init__(self, args):
"""
"""
# frame shift (number of points)
self.frame_hops = [80, 40, 640]
# frame length
self.frame_lens = [320, 80, 1920]
# fft length
self.fft_n = [512, 128, 2048]
# window type in stft
self.win = torch.hann_window
# floor in log-spectrum-amplitude calculating
self.amp_floor = 0.00001
# loss function
self.loss = torch_nn.MSELoss()
# weight to penalize hidden features for cut-off-frequency
# for experiments on CMU-arctic, ATR-F009, VCTK, cutoff_w = 0.0
self.cutoff_w = 0.0
def compute(self, outputs, target):
""" Loss().compute(outputs, target) should return
the Loss in torch.tensor format
Assume output and target as (batchsize=1, length)
"""
# hidden-feature for cut-off-frequency
cut_f = outputs[1]
# generated signal
output = outputs[0]
# convert from (batchsize=1, length, dim=1) to (1, length)
if target.ndim == 3:
target.squeeze_(-1)
# compute loss
loss = 0
for frame_shift, frame_len, fft_p in \
zip(self.frame_hops, self.frame_lens, self.fft_n):
x_stft = torch.stft(output, fft_p, frame_shift, frame_len, \
window=self.win(frame_len), onesided=True,
pad_mode="constant")
y_stft = torch.stft(target, fft_p, frame_shift, frame_len, \
window=self.win(frame_len), onesided=True,
pad_mode="constant")
x_sp_amp = torch.log(torch.norm(x_stft, 2, -1).pow(2) + \
self.amp_floor)
y_sp_amp = torch.log(torch.norm(y_stft, 2, -1).pow(2) + \
self.amp_floor)
loss += self.loss(x_sp_amp, y_sp_amp)
# A norm on cut_f, which forces sinc-cut-off-frequency
# to be close to the U/V-decided value
# Experiments on CMU-arctic, ATR-F009, and VCTK don't use it
# by setting self.cutoff_w = 0.0
# However, just in case
loss += self.cutoff_w * self.loss(cut_f, torch.zeros_like(cut_f))
return loss
if __name__ == "__main__":
print("Definition of model")