-
Notifications
You must be signed in to change notification settings - Fork 1
/
acoustic_system.py
53 lines (40 loc) · 1.7 KB
/
acoustic_system.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
class AcousticSystem(torch.nn.Module):
def __init__(self,
classifier: torch.nn.Module,
transform,
defender: torch.nn.Module=None,
defense_type: str='wave'):
super().__init__()
'''
the whole audio system: audio -> prediction probability distribution
*defender: audio -> audio or spectrogram -> spectrogram
*transform: audio -> spectrogram
*classifier: spectrogram -> prediction probability distribution or
audio -> prediction probability distribution
'''
self.classifier = classifier
self.transform = transform
self.defender = defender
self.defense_type = defense_type
if self.defense_type not in ['wave', 'spec']:
raise NotImplementedError('argument defense_type should be \'wave\' or \'spec\'!')
def forward(self, x, defend=True):
# if 0.9 * x.max() > 1 and 0.9 * x.min() < -1:
# x = x / (2**15)
# defense on waveform
if defend == True and self.defender is not None and self.defense_type == 'wave':
output = self.defender(x)
else:
output = x
# convert waveform to spectrogram
if self.transform is not None:
output = self.transform(output)
# defense on spectrogram
if defend == True and self.defender is not None and self.defense_type == 'spec':
output = self.defender(output)
else:
output = output
# give prediction of spectrogram
output = self.classifier(output)
return output