diff --git a/modules/commons/common_layers.py b/modules/commons/common_layers.py index ae27427c..0012b99c 100644 --- a/modules/commons/common_layers.py +++ b/modules/commons/common_layers.py @@ -154,7 +154,10 @@ def forward(self, x): # out, gate = x.chunk(2, dim=self.dim) # Using torch.split instead of chunk for ONNX export compatibility. out, gate = torch.split(x, x.size(self.dim) // 2, dim=self.dim) - return ATanGLUFunction.apply(out, gate) + if self.training: + return ATanGLUFunction.apply(out, gate) + else: + return out * torch.atan(gate) class KaimingNormalConv1d(torch.nn.Conv1d):