Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A bug in HiFiGANScaleDiscriminator #4595

Closed
JoeyHeisenberg opened this issue Aug 23, 2022 · 2 comments
Closed

A bug in HiFiGANScaleDiscriminator #4595

JoeyHeisenberg opened this issue Aug 23, 2022 · 2 comments
Assignees
Labels
Bug bug should be fixed TTS Text-to-speech Wontfix Want to fix

Comments

@JoeyHeisenberg
Copy link

seems that weight norm doesn't work in HiFiGANScaleDiscriminator, in which we only get conv1d

class HiFiGANScaleDiscriminator(torch.nn.Module):
"""HiFi-GAN scale discriminator module."""
def __init__(
self,
in_channels: int = 1,
out_channels: int = 1,
kernel_sizes: List[int] = [15, 41, 5, 3],
channels: int = 128,
max_downsample_channels: int = 1024,
max_groups: int = 16,
bias: int = True,
downsample_scales: List[int] = [2, 2, 4, 4, 1],
nonlinear_activation: str = "LeakyReLU",
nonlinear_activation_params: Dict[str, Any] = {"negative_slope": 0.1},
use_weight_norm: bool = True,
use_spectral_norm: bool = False,
):
"""Initilize HiFiGAN scale discriminator module.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
kernel_sizes (List[int]): List of four kernel sizes. The first will be used
for the first conv layer, and the second is for downsampling part, and
the remaining two are for the last two output layers.
channels (int): Initial number of channels for conv layer.
max_downsample_channels (int): Maximum number of channels for downsampling
layers.
bias (bool): Whether to add bias parameter in convolution layers.
downsample_scales (List[int]): List of downsampling scales.
nonlinear_activation (str): Activation function module name.
nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation
function.
use_weight_norm (bool): Whether to use weight norm. If set to true, it will
be applied to all of the conv layers.
use_spectral_norm (bool): Whether to use spectral norm. If set to true, it
will be applied to all of the conv layers.
"""
super().__init__()
self.layers = torch.nn.ModuleList()
# check kernel size is valid
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
# add first layer
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_channels,
channels,
# NOTE(kan-bayashi): Use always the same kernel size
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
# add downsample layers
in_chs = channels
out_chs = channels
# NOTE(kan-bayashi): Remove hard coding?
groups = 4
for downsample_scale in downsample_scales:
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params
),
)
]
in_chs = out_chs
# NOTE(kan-bayashi): Remove hard coding?
out_chs = min(in_chs * 2, max_downsample_channels)
# NOTE(kan-bayashi): Remove hard coding?
groups = min(groups * 4, max_groups)
# add final layers
out_chs = min(in_chs * 2, max_downsample_channels)
self.layers += [
torch.nn.Sequential(
torch.nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
),
getattr(torch.nn, nonlinear_activation)(**nonlinear_activation_params),
)
]
self.layers += [
torch.nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
),
]
if use_weight_norm and use_spectral_norm:
raise ValueError("Either use use_weight_norm or use_spectral_norm.")
# apply weight norm
if use_weight_norm:
self.apply_weight_norm()
# apply spectral norm
if use_spectral_norm:
self.apply_spectral_norm()
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
"""Calculate forward propagation.
Args:
x (Tensor): Input noise signal (B, 1, T).
Returns:
List[Tensor]: List of output tensors of each layer.
"""
outs = []
for f in self.layers:
x = f(x)
outs += [x]
return outs
def apply_weight_norm(self):
"""Apply weight normalization module from all of the layers."""
def _apply_weight_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.weight_norm(m)
logging.debug(f"Weight norm is applied to {m}.")
self.apply(_apply_weight_norm)
def apply_spectral_norm(self):
"""Apply spectral normalization module from all of the layers."""
def _apply_spectral_norm(m: torch.nn.Module):
if isinstance(m, torch.nn.Conv2d):
torch.nn.utils.spectral_norm(m)
logging.debug(f"Spectral norm is applied to {m}.")
self.apply(_apply_spectral_norm)

@JoeyHeisenberg JoeyHeisenberg added the Bug bug should be fixed label Aug 23, 2022
@kan-bayashi kan-bayashi added Wontfix Want to fix TTS Text-to-speech labels Aug 23, 2022
@kan-bayashi
Copy link
Member

Thank you for your report, you are right.

@kan-bayashi
Copy link
Member

Related: kan-bayashi/ParallelWaveGAN#309

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Bug bug should be fixed TTS Text-to-speech Wontfix Want to fix
Projects
None yet
Development

No branches or pull requests

2 participants