Skip to content

Commit

Permalink
Merge remote-tracking branch 'subuday/matcha_tts' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmartinrius committed Mar 5, 2024
2 parents 275229a + f15230b commit f6a23c1
Show file tree
Hide file tree
Showing 5 changed files with 461 additions and 0 deletions.
9 changes: 9 additions & 0 deletions TTS/tts/configs/matcha_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from dataclasses import dataclass, field

from TTS.tts.configs.shared_configs import BaseTTSConfig


@dataclass
class MatchaTTSConfig(BaseTTSConfig):
model: str = "matcha_tts"
num_chars: int = None
299 changes: 299 additions & 0 deletions TTS/tts/layers/matcha_tts/UNet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,299 @@
import math
from einops import pack, rearrange
import torch
from torch import nn
import conformer


class PositionalEncoding(torch.nn.Module):
def __init__(self, channels):
super().__init__()
self.channels = channels

def forward(self, x, scale=1000):
if x.ndim < 1:
x = x.unsqueeze(0)
emb = math.log(10000) / (self.channels // 2 - 1)
emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb

class ConvBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, num_groups=8):
super().__init__()
self.block = nn.Sequential(
nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
nn.GroupNorm(num_groups, out_channels),
nn.Mish()
)

def forward(self, x, mask=None):
if mask is not None:
x = x * mask
output = self.block(x)
if mask is not None:
output = output * mask
return output


class ResNetBlock1D(nn.Module):
def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8):
super().__init__()
self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups)
self.mlp = nn.Sequential(
nn.Mish(),
nn.Linear(time_embed_channels, out_channels)
)
self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)

def forward(self, x, mask, t):
h = self.block_1(x, mask)
h += self.mlp(t).unsqueeze(-1)
h = self.block_2(h, mask)
output = h + self.conv(x * mask)
return output


class Downsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1)

def forward(self, x):
return self.conv(x)


class Upsample1D(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1)

def forward(self, x):
return self.conv(x)


class ConformerBlock(conformer.ConformerBlock):
def __init__(
self,
dim: int,
dim_head: int = 64,
heads: int = 8,
ff_mult: int = 4,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
attn_dropout: float = 0.,
ff_dropout: float = 0.,
conv_dropout: float = 0.,
conv_causal: bool = False,
):
super().__init__(
dim=dim,
dim_head=dim_head,
heads=heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=attn_dropout,
ff_dropout=ff_dropout,
conv_dropout=conv_dropout,
conv_causal=conv_causal,
)

def forward(self, x, mask,):
x = rearrange(x, "b c t -> b t c")
mask = rearrange(mask, "b 1 t -> b t")
output = super().forward(x=x, mask=mask.bool())
return rearrange(output, "b t c -> b c t")


class UNet(nn.Module):
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
num_blocks: int,
transformer_num_heads: int = 4,
transformer_dim_head: int = 64,
transformer_ff_mult: int = 1,
transformer_conv_expansion_factor: int = 2,
transformer_conv_kernel_size: int = 31,
transformer_dropout: float = 0.05,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels

self.time_encoder = PositionalEncoding(in_channels)
time_embed_channels = model_channels * 4
self.time_embed = nn.Sequential(
nn.Linear(in_channels, time_embed_channels),
nn.SiLU(),
nn.Linear(time_embed_channels, time_embed_channels),
)

self.input_blocks = nn.ModuleList([])
block_in_channels = in_channels * 2
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])

block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)

block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)

if level != num_blocks - 1:
block.append(Downsample1D(block_out_channels))
else:
block.append(None)

block_in_channels = block_out_channels
self.input_blocks.append(block)

self.middle_blocks = nn.ModuleList([])
for i in range(2):
block = nn.ModuleList([])

block.append(
ResNetBlock1D(
in_channels=block_out_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)

block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)

self.middle_blocks.append(block)

self.output_blocks = nn.ModuleList([])
block_in_channels = block_out_channels * 2
block_out_channels = model_channels
for level in range(num_blocks):
block = nn.ModuleList([])

block.append(
ResNetBlock1D(
in_channels=block_in_channels,
out_channels=block_out_channels,
time_embed_channels=time_embed_channels
)
)

block.append(
self._create_transformer_block(
block_out_channels,
dim_head=transformer_dim_head,
num_heads=transformer_num_heads,
ff_mult=transformer_ff_mult,
conv_expansion_factor=transformer_conv_expansion_factor,
conv_kernel_size=transformer_conv_kernel_size,
dropout=transformer_dropout,
)
)

if level != num_blocks - 1:
block.append(Upsample1D(block_out_channels))
else:
block.append(None)

block_in_channels = block_out_channels * 2
self.output_blocks.append(block)

self.conv_block = ConvBlock1D(model_channels, model_channels)
self.conv = nn.Conv1d(model_channels, self.out_channels, 1)

def _create_transformer_block(
self,
dim,
dim_head: int = 64,
num_heads: int = 4,
ff_mult: int = 1,
conv_expansion_factor: int = 2,
conv_kernel_size: int = 31,
dropout: float = 0.05,
):
return ConformerBlock(
dim=dim,
dim_head=dim_head,
heads=num_heads,
ff_mult=ff_mult,
conv_expansion_factor=conv_expansion_factor,
conv_kernel_size=conv_kernel_size,
attn_dropout=dropout,
ff_dropout=dropout,
conv_dropout=dropout,
conv_causal=False,
)

def forward(self, x_t, mean, mask, t):
t = self.time_encoder(t)
t = self.time_embed(t)

x_t = pack([x_t, mean], "b * t")[0]

hidden_states = []
mask_states = [mask]

for block in self.input_blocks:
res_net_block, transformer, downsample = block

x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)

hidden_states.append(x_t)

if downsample is not None:
x_t = downsample(x_t * mask)
mask = mask[:, :, ::2]
mask_states.append(mask)

for block in self.middle_blocks:
res_net_block, transformer = block
mask = mask_states[-1]
x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)

for block in self.output_blocks:
res_net_block, transformer, upsample = block

x_t = pack([x_t, hidden_states.pop()], "b * t")[0]
mask = mask_states.pop()
x_t = res_net_block(x_t, mask, t)
x_t = transformer(x_t, mask)

if upsample is not None:
x_t = upsample(x_t * mask)

output = self.conv_block(x_t)
output = self.conv(x_t)

return output * mask
32 changes: 32 additions & 0 deletions TTS/tts/layers/matcha_tts/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from torch import nn
import torch.nn.functional as F

from TTS.tts.layers.matcha_tts.UNet import UNet


class Decoder(nn.Module):
def __init__(self):
super().__init__()
self.sigma_min = 1e-5
self.predictor = UNet(
in_channels=80,
model_channels=256,
out_channels=80,
num_blocks=2
)

def forward(self, x_1, mean, mask):
"""
Shapes:
- x_1: :math:`[B, C, T]`
- mean: :math:`[B, C ,T]`
- mask: :math:`[B, 1, T]`
"""
t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype)
x_0 = torch.randn_like(x_1)
x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1
u_t = x_1 - (1 - self.sigma_min) * x_0
v_t = self.predictor(x_t, mean, mask, t.squeeze())
loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1])
return loss
Loading

0 comments on commit f6a23c1

Please sign in to comment.