diff --git a/TTS/tts/configs/matcha_tts.py b/TTS/tts/configs/matcha_tts.py deleted file mode 100644 index 15bb91b829..0000000000 --- a/TTS/tts/configs/matcha_tts.py +++ /dev/null @@ -1,9 +0,0 @@ -from dataclasses import dataclass, field - -from TTS.tts.configs.shared_configs import BaseTTSConfig - - -@dataclass -class MatchaTTSConfig(BaseTTSConfig): - model: str = "matcha_tts" - num_chars: int = None diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py deleted file mode 100644 index b47db51cf5..0000000000 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ /dev/null @@ -1,299 +0,0 @@ -import math -from einops import pack, rearrange -import torch -from torch import nn -import conformer - - -class PositionalEncoding(torch.nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - - def forward(self, x, scale=1000): - if x.ndim < 1: - x = x.unsqueeze(0) - emb = math.log(10000) / (self.channels // 2 - 1) - emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - -class ConvBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, num_groups=8): - super().__init__() - self.block = nn.Sequential( - nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), - nn.GroupNorm(num_groups, out_channels), - nn.Mish() - ) - - def forward(self, x, mask=None): - if mask is not None: - x = x * mask - output = self.block(x) - if mask is not None: - output = output * mask - return output - - -class ResNetBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8): - super().__init__() - self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) - self.mlp = nn.Sequential( - nn.Mish(), - nn.Linear(time_embed_channels, out_channels) - ) - self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups) - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) - - def forward(self, x, mask, t): - h = self.block_1(x, mask) - h += self.mlp(t).unsqueeze(-1) - h = self.block_2(h, mask) - output = h + self.conv(x * mask) - return output - - -class Downsample1D(nn.Module): - def __init__(self, channels): - super().__init__() - self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1) - - def forward(self, x): - return self.conv(x) - - -class Upsample1D(nn.Module): - def __init__(self, channels): - super().__init__() - self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1) - - def forward(self, x): - return self.conv(x) - - -class ConformerBlock(conformer.ConformerBlock): - def __init__( - self, - dim: int, - dim_head: int = 64, - heads: int = 8, - ff_mult: int = 4, - conv_expansion_factor: int = 2, - conv_kernel_size: int = 31, - attn_dropout: float = 0., - ff_dropout: float = 0., - conv_dropout: float = 0., - conv_causal: bool = False, - ): - super().__init__( - dim=dim, - dim_head=dim_head, - heads=heads, - ff_mult=ff_mult, - conv_expansion_factor=conv_expansion_factor, - conv_kernel_size=conv_kernel_size, - attn_dropout=attn_dropout, - ff_dropout=ff_dropout, - conv_dropout=conv_dropout, - conv_causal=conv_causal, - ) - - def forward(self, x, mask,): - x = rearrange(x, "b c t -> b t c") - mask = rearrange(mask, "b 1 t -> b t") - output = super().forward(x=x, mask=mask.bool()) - return rearrange(output, "b t c -> b c t") - - -class UNet(nn.Module): - def __init__( - self, - in_channels: int, - model_channels: int, - out_channels: int, - num_blocks: int, - transformer_num_heads: int = 4, - transformer_dim_head: int = 64, - transformer_ff_mult: int = 1, - transformer_conv_expansion_factor: int = 2, - transformer_conv_kernel_size: int = 31, - transformer_dropout: float = 0.05, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.time_encoder = PositionalEncoding(in_channels) - time_embed_channels = model_channels * 4 - self.time_embed = nn.Sequential( - nn.Linear(in_channels, time_embed_channels), - nn.SiLU(), - nn.Linear(time_embed_channels, time_embed_channels), - ) - - self.input_blocks = nn.ModuleList([]) - block_in_channels = in_channels * 2 - block_out_channels = model_channels - for level in range(num_blocks): - block = nn.ModuleList([]) - - block.append( - ResNetBlock1D( - in_channels=block_in_channels, - out_channels=block_out_channels, - time_embed_channels=time_embed_channels - ) - ) - - block.append( - self._create_transformer_block( - block_out_channels, - dim_head=transformer_dim_head, - num_heads=transformer_num_heads, - ff_mult=transformer_ff_mult, - conv_expansion_factor=transformer_conv_expansion_factor, - conv_kernel_size=transformer_conv_kernel_size, - dropout=transformer_dropout, - ) - ) - - if level != num_blocks - 1: - block.append(Downsample1D(block_out_channels)) - else: - block.append(None) - - block_in_channels = block_out_channels - self.input_blocks.append(block) - - self.middle_blocks = nn.ModuleList([]) - for i in range(2): - block = nn.ModuleList([]) - - block.append( - ResNetBlock1D( - in_channels=block_out_channels, - out_channels=block_out_channels, - time_embed_channels=time_embed_channels - ) - ) - - block.append( - self._create_transformer_block( - block_out_channels, - dim_head=transformer_dim_head, - num_heads=transformer_num_heads, - ff_mult=transformer_ff_mult, - conv_expansion_factor=transformer_conv_expansion_factor, - conv_kernel_size=transformer_conv_kernel_size, - dropout=transformer_dropout, - ) - ) - - self.middle_blocks.append(block) - - self.output_blocks = nn.ModuleList([]) - block_in_channels = block_out_channels * 2 - block_out_channels = model_channels - for level in range(num_blocks): - block = nn.ModuleList([]) - - block.append( - ResNetBlock1D( - in_channels=block_in_channels, - out_channels=block_out_channels, - time_embed_channels=time_embed_channels - ) - ) - - block.append( - self._create_transformer_block( - block_out_channels, - dim_head=transformer_dim_head, - num_heads=transformer_num_heads, - ff_mult=transformer_ff_mult, - conv_expansion_factor=transformer_conv_expansion_factor, - conv_kernel_size=transformer_conv_kernel_size, - dropout=transformer_dropout, - ) - ) - - if level != num_blocks - 1: - block.append(Upsample1D(block_out_channels)) - else: - block.append(None) - - block_in_channels = block_out_channels * 2 - self.output_blocks.append(block) - - self.conv_block = ConvBlock1D(model_channels, model_channels) - self.conv = nn.Conv1d(model_channels, self.out_channels, 1) - - def _create_transformer_block( - self, - dim, - dim_head: int = 64, - num_heads: int = 4, - ff_mult: int = 1, - conv_expansion_factor: int = 2, - conv_kernel_size: int = 31, - dropout: float = 0.05, - ): - return ConformerBlock( - dim=dim, - dim_head=dim_head, - heads=num_heads, - ff_mult=ff_mult, - conv_expansion_factor=conv_expansion_factor, - conv_kernel_size=conv_kernel_size, - attn_dropout=dropout, - ff_dropout=dropout, - conv_dropout=dropout, - conv_causal=False, - ) - - def forward(self, x_t, mean, mask, t): - t = self.time_encoder(t) - t = self.time_embed(t) - - x_t = pack([x_t, mean], "b * t")[0] - - hidden_states = [] - mask_states = [mask] - - for block in self.input_blocks: - res_net_block, transformer, downsample = block - - x_t = res_net_block(x_t, mask, t) - x_t = transformer(x_t, mask) - - hidden_states.append(x_t) - - if downsample is not None: - x_t = downsample(x_t * mask) - mask = mask[:, :, ::2] - mask_states.append(mask) - - for block in self.middle_blocks: - res_net_block, transformer = block - mask = mask_states[-1] - x_t = res_net_block(x_t, mask, t) - x_t = transformer(x_t, mask) - - for block in self.output_blocks: - res_net_block, transformer, upsample = block - - x_t = pack([x_t, hidden_states.pop()], "b * t")[0] - mask = mask_states.pop() - x_t = res_net_block(x_t, mask, t) - x_t = transformer(x_t, mask) - - if upsample is not None: - x_t = upsample(x_t * mask) - - output = self.conv_block(x_t) - output = self.conv(x_t) - - return output * mask \ No newline at end of file diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py deleted file mode 100644 index c87da9d559..0000000000 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - -from TTS.tts.layers.matcha_tts.UNet import UNet - - -class Decoder(nn.Module): - def __init__(self): - super().__init__() - self.sigma_min = 1e-5 - self.predictor = UNet( - in_channels=80, - model_channels=256, - out_channels=80, - num_blocks=2 - ) - - def forward(self, x_1, mean, mask): - """ - Shapes: - - x_1: :math:`[B, C, T]` - - mean: :math:`[B, C ,T]` - - mask: :math:`[B, 1, T]` - """ - t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype) - x_0 = torch.randn_like(x_1) - x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 - u_t = x_1 - (1 - self.sigma_min) * x_0 - v_t = self.predictor(x_t, mean, mask, t.squeeze()) - loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) - return loss diff --git a/TTS/tts/models/matcha_tts.py b/TTS/tts/models/matcha_tts.py deleted file mode 100644 index 9bc3e0ffc4..0000000000 --- a/TTS/tts/models/matcha_tts.py +++ /dev/null @@ -1,85 +0,0 @@ -from dataclasses import field -import math -import torch - -from TTS.tts.configs.matcha_tts import MatchaTTSConfig -from TTS.tts.layers.glow_tts.encoder import Encoder -from TTS.tts.layers.matcha_tts.decoder import Decoder -from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import maximum_path, sequence_mask -from TTS.tts.utils.text.tokenizer import TTSTokenizer - - -class MatchaTTS(BaseTTS): - - def __init__( - self, - config: MatchaTTSConfig, - ap: "AudioProcessor" = None, - tokenizer: "TTSTokenizer" = None, - ): - super().__init__(config, ap, tokenizer) - self.encoder = Encoder( - self.config.num_chars, - out_channels=80, - hidden_channels=192, - hidden_channels_dp=256, - encoder_type='rel_pos_transformer', - encoder_params={ - "kernel_size": 3, - "dropout_p": 0.1, - "num_layers": 6, - "num_heads": 2, - "hidden_channels_ffn": 768, - } - ) - - self.decoder = Decoder() - - def forward(self, x, x_lengths, y, y_lengths): - """ - Args: - x (torch.Tensor): - Input text sequence ids. :math:`[B, T_en]` - - x_lengths (torch.Tensor): - Lengths of input text sequences. :math:`[B]` - - y (torch.Tensor): - Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` - - y_lengths (torch.Tensor): - Lengths of target mel-spectrogram frames. :math:`[B]` - """ - y = y.transpose(1, 2) - y_max_length = y.size(2) - - o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None) - - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype) - attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2) - - with torch.no_grad(): - o_scale = torch.exp(-2 * o_log_scale) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2)) - logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y) - logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) - logp = logp1 + logp2 + logp3 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - - # Align encoded text with mel-spectrogram and get mu_y segment - c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2) - - _ = self.decoder(x_1=y, mean=c_mean, mask=y_mask) - - @torch.no_grad() - def inference(self): - pass - - @staticmethod - def init_from_config(config: "MatchaTTSConfig"): - pass - - def load_checkpoint(self, checkpoint_path): - pass diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py deleted file mode 100644 index 5fbe95377f..0000000000 --- a/tests/tts_tests2/test_matcha_tts.py +++ /dev/null @@ -1,36 +0,0 @@ -import unittest - -import torch - -from TTS.tts.configs.matcha_tts import MatchaTTSConfig -from TTS.tts.models.matcha_tts import MatchaTTS - -torch.manual_seed(1) -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -c = MatchaTTSConfig() - - -class TestMatchTTS(unittest.TestCase): - @staticmethod - def _create_inputs(batch_size=8): - input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) - input_lengths[-1] = 128 - mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) - speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) - return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids - - def _test_forward(self, batch_size): - input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) - config = MatchaTTSConfig(num_chars=32) - model = MatchaTTS(config).to(device) - - model.train() - - model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) - - def test_forward(self): - self._test_forward(1) - self._test_forward(3)