## Load Dataset

In [1]:
pip install datasets transformers

Defaulting to user installation because normal site-packages is not writeable
Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting transformers
  Downloading transformers-4.47.1-py3-none-any.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.1/44.1 kB[0m [31m1.6 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-18.1.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (3.3 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting requests>=2.32.2 (from datasets)
  Downloading requests-2.32.3-py3-none-any.whl.metadata (4.6 kB)
Collecting tqdm>=4.66.3 (from datasets)
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.7/57.7 kB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting xxhash (from datasets)
  Downloading xxhash-3.5.

In [6]:
from datasets import load_dataset, IterableDataset, IterableDatasetDict, load_from_disk

ds = load_from_disk("../data/hifi_tts/")

Loading dataset from disk:   0%|          | 0/35 [00:00<?, ?it/s]

In [7]:
ds

Dataset({
    features: ['speaker', 'file', 'duration', 'text', 'text_no_preprocessing', 'text_normalized', 'audio'],
    num_rows: 125989
})

## Model Firefly

In [1]:
import math
from functools import partial
from math import prod
from typing import Callable

import torch
import torch.nn.functional as F
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from torch.nn.utils.parametrize import remove_parametrizations
from torch.utils.checkpoint import checkpoint

In [2]:
def sequence_mask(length, max_length=None):
    if max_length is None:
        max_length = length.max()
    x = torch.arange(max_length, dtype=length.dtype, device=length.device)
    return x.unsqueeze(0) < length.unsqueeze(1)

In [3]:
def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv1D") != -1:
        m.weight.data.normal_(mean, std)


def get_padding(kernel_size, dilation=1):
    return (kernel_size * dilation - dilation) // 2

In [4]:
def unpad1d(x: torch.Tensor, paddings: tuple[int, int]):
    """Remove padding from x, handling properly zero padding. Only for 1d!"""
    padding_left, padding_right = paddings
    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
    assert (padding_left + padding_right) <= x.shape[-1]
    end = x.shape[-1] - padding_right
    return x[..., padding_left:end]

In [5]:
def get_extra_padding_for_conv1d(
    x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
) -> int:
    """See `pad_for_conv1d`."""
    length = x.shape[-1]
    n_frames = (length - kernel_size + padding_total) / stride + 1
    ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
    return ideal_length - length

In [6]:
def pad1d(
    x: torch.Tensor,
    paddings: tuple[int, int],
    mode: str = "zeros",
    value: float = 0.0,
):
    """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
    If this is the case, we insert extra 0 padding to the right
    before the reflection happen.
    """
    length = x.shape[-1]
    padding_left, padding_right = paddings
    assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
    if mode == "reflect":
        max_pad = max(padding_left, padding_right)
        extra_pad = 0
        if length <= max_pad:
            extra_pad = max_pad - length + 1
            x = F.pad(x, (0, extra_pad))
        padded = F.pad(x, paddings, mode, value)
        end = padded.shape[-1] - extra_pad
        return padded[..., :end]
    else:
        return F.pad(x, paddings, mode, value)

In [7]:
class ConvolutionalNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1, groups=1):
        super(ConvolutionalNet, self).__init__()
        self.conv = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            dilation=dilation,
            groups=groups,
        )
        self.stride = stride
        self.kernel_size = (kernel_size - 1) * dilation + 1
        self.dilation = dilation

    def forward(self, x):
        pad = self.kernel_size - self.stride
        extra_padding = get_extra_padding_for_conv1d(
            x, self.kernel_size, self.stride, pad
        )
        x = pad1d(x, (pad, extra_padding), mode="constant", value = 0) # Cari tahu value ini apa
        return self.conv(x).contiguous()

    def weight_norm(self, name="weight", dim = 0):
        self.conv = weight_norm(self.conv, name=name, dim=dim)
        return self

    def remove_parametrizations(self, name="weight"):
        self.conv = remove_parametrizations(self.conv, name)
        return self

In [8]:
class TransConvNet(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, stride=1):
        super(TransConvNet, self).__init__()
        self.conv = nn.ConvTranspose1d(
            in_channels, out_channels, kernel_size, stride=stride, dilation=dilation
        )
        self.stride = stride
        self.kernel_size = kernel_size

    def forward(self, x):
        x = self.conv(x)
        pad = self.kernel_size - self.stride
        padding_right = math.ceil(pad)
        padding_left = pad - padding_right
        x = unpad1d(x, (padding_left, padding_right))
        return x.contiguous()

    def weight_norm(self, name="weight", dim=0):
        self.conv = weight_norm(self.conv, name=name, dim=dim)
        return self

    def remove_parametrizations(self, name="weight"):
        self.conv = remove_parametrizations(self.conv, name)
        return self

In [9]:
class ResBlock1(torch.nn.Module):
    def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
        super().__init__()

        self.convs1 = nn.ModuleList(
            [
                ConvolutionalNet(
                    channels, channels, kernel_size, stride=1, dilation=dilation[0]
                ).weight_norm(),
                ConvolutionalNet(
                    channels, channels, kernel_size, stride=1, dilation=dilation[1]
                ).weight_norm(),
                ConvolutionalNet(
                    channels, channels, kernel_size, stride=1, dilation=dilation[2]
                ).weight_norm(),
            ]
        )
        self.convs1.apply(init_weights)

        self.convs2 = nn.ModuleList(
            [
                ConvolutionalNet(
                    channels, channels, kernel_size, stride=1, dilation=dilation[0]
                ).weight_norm(),
                ConvolutionalNet(
                    channels, channels, kernel_size, stride=1, dilation=dilation[1]
                ).weight_norm(),
                ConvolutionalNet(
                    channels, channels, kernel_size, stride=1, dilation=dilation[2]
                ).weight_norm(),
            ]
        )
        self.convs2.apply(init_weights)

    def forward(self, x):
        for c1, c2 in zip(self.convs1, self.convs2):
            xt = F.silu(x)
            xt = c1(xt)
            xt = F.silu(xt)
            xt = c2(xt)
            x = xt + x
        return x

    def remove_parametrizations(self):
        for conv in self.convs1:
            conv.remove_parametrizations()
        for conv in self.convs2:
            conv.remove_parametrizations()

In [10]:
class ParallelBlock(nn.Module):
    def __init__(
        self,
        channels: int,
        kernel_sizes: tuple[int] = (3, 7, 11),
        dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
    ):
        super().__init__()

        assert len(kernel_sizes) == len(dilation_sizes)

        self.blocks = nn.ModuleList()
        for k, d in zip(kernel_sizes, dilation_sizes):
            self.blocks.append(ResBlock1(channels, k, d))

    def forward(self, x):
        return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)

    def remove_parametrizations(self):
        for block in self.blocks:
            block.remove_parametrizations()

In [11]:
class HiFiGANGenerator(nn.Module):
    def __init__(self,
                 *,
                 hop_length: int = 512,
                 upsample_rates: tuple[int] = (8, 8, 2, 2, 2),
                 upsample_kernel_sizes: tuple[int] = (16, 16, 8, 2, 2),
                 resblock_kernel_sizes: tuple[int] = (3, 7, 11),
                 resblock_dilation_sizes: tuple[tuple[int]] = ((1, 3, 5), (1, 3, 5), (1, 3, 5)),
                 num_mels: int = 128,
                 upsample_initial_channel: int = 512,
                 pre_conv_kernel_size: int = 7,
                 post_conv_kernel_size: int = 7,
                 post_activation: Callable = partial(nn.SiLU, inplace=True)
    ):
        super().__init__()

        assert (
            prod(upsample_rates) == hop_length
        ), f"hop_length must be {prod(upsample_rates)}"

        self.conv_pre = ConvolutionalNet(
            num_mels,
            upsample_initial_channel,
            pre_conv_kernel_size,
            stride = 1
        ).weight_norm()

        self.num_upsamples = len(upsample_rates)
        self.num_kernels = len(resblock_kernel_sizes)

        self.noise_convs = nn.ModuleList()
        self.ups = nn.ModuleList()

        for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
            self.ups.append(
                TransConvNet(
                    upsample_initial_channel // (2**i),
                    upsample_initial_channel // (2**(i + 1)),
                    k,
                    stride=u
                ).weight_norm()
            )

        self.resblocks = nn.ModuleList()

        for i in range(len(self.ups)):
            ch = upsample_initial_channel // (2 ** (i + 1))
            self.resblocks.append(
                ParallelBlock(ch, resblock_kernel_sizes, resblock_dilation_sizes)
            )

        self.activation_post = post_activation()

        self.conv_post = ConvolutionalNet(
            ch, 1, post_conv_kernel_size, stride=1
        ).weight_norm()

        self.ups.apply(init_weights)
        self.conv_post.apply(init_weights)

    def forward(self, x):
        x = self.conv_pre(x)

        for i in range(self.num_upsamples):
            x = F.silu(x, inplace=True)
            x = self.ups[i](x)

            # if self.training and self.checkpointing:
            #     x = checkpoint(
            #         self.resblocks[i],
            #         x,
            #         use_reentrant=False,
            #     )
            # else:
            #     x = self.resblocks[i](x)

            x = self.resblocks[i](x)

        x = self.activation_post(x)
        x = self.conv_post(x)
        x = torch.tanh(x)

        return x
        
    def remove_parametrizations(self):
        for up in self.ups:
            up.remove_parametrizations()
        for block in self.resblocks:
            block.remove_parametrizations()
        self.conv_pre.remove_parametrizations()
        self.conv_post.remove_parametrizations()

In [12]:
def drop_path(
    x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).

    This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
    the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
    See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
    changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
    'survival rate' as the argument.

    """  # noqa: E501

    if drop_prob == 0.0 or not training:
        return x
    keep_prob = 1 - drop_prob
    shape = (x.shape[0],) + (1,) * (
        x.ndim - 1
    )  # work with diff dim tensors, not just 2D ConvNets
    random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
    if keep_prob > 0.0 and scale_by_keep:
        random_tensor.div_(keep_prob)
    return x * random_tensor

In [13]:
class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of residual blocks)."""  # noqa: E501

    def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob
        self.scale_by_keep = scale_by_keep

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)

    def extra_repr(self):
        return f"drop_prob={round(self.drop_prob,3):0.3f}"

In [14]:
class LayerNorm(nn.Module):
    r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs
    with shape (batch_size, channels, height, width).
    """  # noqa: E501

    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError
        self.normalized_shape = (normalized_shape,)

    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(
                x, self.normalized_shape, self.weight, self.bias, self.eps
            )
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None] * x + self.bias[:, None]
            return x

In [15]:
class ConvNeXtBlock(nn.Module):
    r"""ConvNeXt Block. There are two equivalent implementations:
    (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
    (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
    We use (2) as we find it slightly faster in PyTorch

    Args:
        dim (int): Number of input channels.
        drop_path (float): Stochastic depth rate. Default: 0.0
        layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
        mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
        kernel_size (int): Kernel size for depthwise conv. Default: 7.
        dilation (int): Dilation for depthwise conv. Default: 1.
    """  # noqa: E501

    def __init__(
        self,
        dim: int,
        drop_path: float = 0.0,
        layer_scale_init_value: float = 1e-6,
        mlp_ratio: float = 4.0,
        kernel_size: int = 7,
        dilation: int = 1,
    ):
        super().__init__()

        self.dwconv = ConvolutionalNet(
            dim,
            dim,
            kernel_size=kernel_size,
            groups=dim,
        )  # depthwise conv
        self.norm = LayerNorm(dim, eps=1e-6)
        self.pwconv1 = nn.Linear(
            dim, int(mlp_ratio * dim)
        )  # pointwise/1x1 convs, implemented with linear layers
        self.act = nn.GELU()
        self.pwconv2 = nn.Linear(int(mlp_ratio * dim), dim)
        self.gamma = (
            nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
            if layer_scale_init_value > 0
            else None
        )
        self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()

    def forward(self, x, apply_residual: bool = True):
        input = x

        x = self.dwconv(x)
        x = x.permute(0, 2, 1)  # (N, C, L) -> (N, L, C)
        x = self.norm(x)
        x = self.pwconv1(x)
        x = self.act(x)
        x = self.pwconv2(x)

        if self.gamma is not None:
            x = self.gamma * x

        x = x.permute(0, 2, 1)  # (N, L, C) -> (N, C, L)
        x = self.drop_path(x)

        if apply_residual:
            x = input + x

        return x

In [27]:
class ConvNeXtEncoder(nn.Module):
    def __init__(
        self,
        input_channels: int = 3,
        depths: list[int] = [3, 3, 9, 3],
        dims: list[int] = [96, 192, 384, 768],
        drop_path_rate: float = 0.0,
        layer_scale_init_value: float = 1e-6,
        kernel_size: int = 7,
    ):
        super().__init__()
        assert len(depths) == len(dims)

        self.downsample_layers = nn.ModuleList()
        stem = nn.Sequential(
            ConvolutionalNet(
                input_channels,
                dims[0],
                kernel_size=7,
            ),
            LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
        )
        self.downsample_layers.append(stem)

        for i in range(len(depths) - 1):
            mid_layer = nn.Sequential(
                LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
                nn.Conv1d(dims[i], dims[i + 1], kernel_size=1),
            )
            self.downsample_layers.append(mid_layer)

        self.stages = nn.ModuleList()
        dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]

        cur = 0
        for i in range(len(depths)):
            stage = nn.Sequential(
                *[
                    ConvNeXtBlock(
                        dim=dims[i],
                        drop_path=dp_rates[cur + j],
                        layer_scale_init_value=layer_scale_init_value,
                        kernel_size=kernel_size,
                    )
                    for j in range(depths[i])
                ]
            )
            self.stages.append(stage)
            cur += depths[i]

        self.norm = LayerNorm(dims[-1], eps=1e-6, data_format="channels_first")
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, (nn.Conv1d, nn.Linear)):
            nn.init.trunc_normal_(m.weight, std=0.02)
            nn.init.constant_(m.bias, 0)

    def forward(
        self,
        x: torch.Tensor,
    ) -> torch.Tensor:
        for i in range(len(self.downsample_layers)):
            x = self.downsample_layers[i](x)
            x = self.stages[i](x)

        return self.norm(x)

In [28]:
class FireflyArchitecture(nn.Module):
    def __init__(
        self,
        backbone: nn.Module,
        head: nn.Module,
        quantizer: nn.Module,
        spec_transform: nn.Module,
    ):
        super().__init__()

        self.backbone = backbone
        self.head = head
        self.quantizer = quantizer
        self.spec_transform = spec_transform
        self.downsample_factor = math.prod(self.quantizer.downsample_factor)

    def forward(self, x: torch.Tensor, template=None, mask=None) -> torch.Tensor:
        if self.spec_transform is not None:
            x = self.spec_transform(x)

        x = self.backbone(x)
        if mask is not None:
            x = x * mask

        if self.quantizer is not None:
            vq_result = self.quantizer(x)
            x = vq_result.z

            if mask is not None:
                x = x * mask

        x = self.head(x)

        if x.ndim == 2:
            x = x[:, None, :]

        return x, vq_result

    def encode(self, audios, audio_lengths):
        audios = audios.float()

        mels = self.spec_transform(audios)
        mel_lengths = audio_lengths // self.spec_transform.hop_length
        mel_masks = sequence_mask(mel_lengths, mels.shape[2])
        mel_masks_float_conv = mel_masks[:, None, :].float()
        mels = mels * mel_masks_float_conv

        # Encode
        encoded_features = self.backbone(mels) * mel_masks_float_conv
        feature_lengths = mel_lengths // self.downsample_factor

        return self.quantizer.encode(encoded_features), feature_lengths

    def decode(self, indices, feature_lengths) -> torch.Tensor:
        mel_masks = sequence_mask(
            feature_lengths * self.downsample_factor,
            indices.shape[2] * self.downsample_factor,
        )
        mel_masks_float_conv = mel_masks[:, None, :].float()
        audio_lengths = (
            feature_lengths * self.downsample_factor * self.spec_transform.hop_length
        )

        audio_masks = sequence_mask(
            audio_lengths,
            indices.shape[2] * self.downsample_factor * self.spec_transform.hop_length,
        )
        audio_masks_float_conv = audio_masks[:, None, :].float()

        z = self.quantizer.decode(indices) * mel_masks_float_conv
        x = self.head(z) * audio_masks_float_conv

        return x, audio_lengths

    def remove_parametrizations(self):
        if hasattr(self.backbone, "remove_parametrizations"):
            self.backbone.remove_parametrizations()

        if hasattr(self.head, "remove_parametrizations"):
            self.head.remove_parametrizations()

    @property
    def device(self):
        return next(self.parameters()).device

## GFSQ

In [29]:
pip install vector_quantize_pytorch

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [30]:
## MODIFIKASI

from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from vector_quantize_pytorch import GroupedResidualFSQ

@dataclass
class FSQResult:
    # Represents quantization outputs
    z: torch.Tensor        
    codes: torch.Tensor    
    latents: torch.Tensor  

class DownsampleFSQ(nn.Module):
    """Performs downsampling and FSQ quantization with configurable architecture"""
    
    def __init__(
        self, 
        input_dim=512,
        n_codebooks=9,
        n_groups=1,
        levels=(8, 5, 5, 5),
        downsample_factor=(2, 2),
        downsample_dims=None
    ):
        super().__init__()
        
        # Handle default case for downsample dimensions
        self.downsample_dims = downsample_dims or [input_dim] * len(downsample_factor)
        self.dims_sequence = [input_dim] + list(self.downsample_dims)
        
        # Configure quantizer
        self._setup_quantizer(self.dims_sequence[-1], levels, n_codebooks, n_groups)
        
        # Build network components
        self._build_downsample_network(downsample_factor)
        self._build_upsample_network(downsample_factor)
        
        # Store configuration
        self.downsample_factor = downsample_factor
        
        # Initialize weights
        self._initialize_network_weights()

    def _setup_quantizer(self, dim, levels, n_codebooks, n_groups):
        """Configures the FSQ quantizer with specified parameters"""
        self.residual_fsq = GroupedResidualFSQ(
            dim=dim,
            levels=levels,
            num_quantizers=n_codebooks,
            groups=n_groups
        )

    def _build_downsample_network(self, factors):
        """Constructs the downsampling pathway"""
        layers = []
        for idx in range(len(factors)):
            conv_block = self._create_conv_block(
                self.dims_sequence[idx],
                self.dims_sequence[idx + 1],
                factors[idx]
            )
            layers.append(conv_block)
        self.downsample = nn.Sequential(*layers)

    def _build_upsample_network(self, factors):
        """Constructs the upsampling pathway"""
        layers = []
        for idx in reversed(range(len(factors))):
            conv_block = self._create_conv_block(
                self.dims_sequence[idx + 1],
                self.dims_sequence[idx],
                factors[idx]
            )
            layers.append(conv_block)
        self.upsample = nn.Sequential(*layers)

    def _create_conv_block(self, in_dim, out_dim, factor):
        """Creates a single convolutional block"""
        return nn.Sequential(
            ConvolutionalNet(in_dim, out_dim, kernel_size=factor, stride=factor),
            ConvNeXtBlock(dim=out_dim)
        )

    def _initialize_network_weights(self):
        """Initializes network weights using truncated normal distribution"""
        def init_weights(m):
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.trunc_normal_(m.weight, std=0.02)
                nn.init.constant_(m.bias, 0)
        self.apply(init_weights)

    def _adjust_output_size(self, tensor, target_size):
        """Adjusts tensor size through padding or trimming"""
        current_size = tensor.shape[-1]
        size_diff = target_size - current_size
        
        if size_diff == 0:
            return tensor
            
        if size_diff > 0:
            pad_left = size_diff // 2
            pad_right = size_diff - pad_left
            return F.pad(tensor, (pad_left, pad_right))
        else:
            trim_left = -size_diff // 2
            trim_right = current_size - (-size_diff - trim_left)
            return tensor[..., trim_left:trim_right]

    def forward(self, z) -> FSQResult:
        # Store original dimensions
        target_size = z.shape[-1]
        
        # Process through networks
        compressed = self.downsample(z)
        quantized, indices = self.residual_fsq(compressed.mT)
        
        # Create initial result
        result = FSQResult(
            z=quantized.mT,
            codes=indices.mT,
            latents=compressed
        )
        
        # Upsample and adjust dimensions
        result.z = self.upsample(result.z)
        result.z = self._adjust_output_size(result.z, target_size)
        
        return result

    def encode(self, z):
        compressed = self.downsample(z)
        _, indices = self.residual_fsq(compressed.mT)
        return rearrange(indices, "g b l r -> b (g r) l")

    def decode(self, indices):
        # Reshape indices for processing
        grouped_indices = rearrange(
            indices, 
            "b (g r) l -> g b l r", 
            g=self.residual_fsq.groups
        )
        
        # Reconstruct signal
        quantized = self.residual_fsq.get_output_from_indices(grouped_indices)
        return self.upsample(quantized.mT)

## U-Net untuk LLDM

Arsitektur U-Net yang diusulkan akan menggunakan 4 level network, dimana setiap level akan terdiri atas 2 residual network, dan 1 attention layer (pada middle block strukturnya R-A-R, R: Residual, A: Attention).

U-Net ini juga perlu memproses timestep ke dalam embedding. Attention mechanism yang digunakan akan memuat rotational embedding, ini berperan penting untuk menangkap dan mempelajari fitur temporal (acuan berasal dari LLaMA). Selain itu, sistem Attentionnya akan berupa Self-attention.

In [52]:
backbone = ConvNeXtEncoder(
    input_channels = 160,
    depths = [3, 3, 9, 3],
    dims = [128, 256, 288, 384],
    drop_path_rate = 0.2,
    kernel_size = 7
)
head = HiFiGANGenerator(
    hop_length=512,
    upsample_rates=[8, 8, 2, 2, 2],
    upsample_kernel_sizes=[16, 16, 4, 4, 4],
    resblock_kernel_sizes=[3, 7, 11],
    resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
    num_mels = 384, #default 512
    upsample_initial_channel=384, #default 512
    pre_conv_kernel_size=13,
    post_conv_kernel_size=13
)
quantizer = DownsampleFSQ(
    input_dim=384,
    n_groups=8,
    n_codebooks=1,
    levels=[8, 5, 5, 5],
    downsample_factor=[2, 2]
)

In [49]:
sum(p.numel() for p in ResBlock1(channels=512).parameters()) * 18

85045248

In [53]:
model = FireflyArchitecture(
    backbone=backbone,
    head=head,
    quantizer=quantizer,
    spec_transform=None
)

In [54]:
x, vq = model(torch.randn(1, 160, 256))

In [55]:
encoded = backbone(torch.randn(1, 160, 256))
quantized = quantizer(encoded)

In [56]:
head(quantized.z).shape

torch.Size([1, 1, 131072])

In [55]:
intermediate = backbone(torch.randn(1, 160, 124))

In [58]:
nn.Conv1d(in_channels=768, out_channels=768, kernel_size=3, stride=1, padding=1)(intermediate).shape

torch.Size([1, 768, 124])

intermediate.shape

In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))
    
    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self._norm(x.float()).type_as(x) * self.scale

class SinusoidalTimeEmbedding(nn.Module):
    """
    Time embedding using sinusoidal positions, similar to transformer positional encoding.
    This helps the model understand the diffusion timestep.
    """
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.SiLU(),
            nn.Linear(dim * 4, dim)
        )
        
    def forward(self, x):
        half_dim = self.dim // 2
        embeddings = math.log(10000) / (half_dim - 1)
        embeddings = torch.exp(torch.arange(half_dim, device=x.device) * -embeddings)
        embeddings = x[:, None] * embeddings[None, :]
        embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
        return self.mlp(embeddings)

class DownBlock(nn.Module):
    """
    Downsampling block with dual residual blocks and optional attention
    """
    def __init__(self, in_channels, out_channels, time_dim, use_attention=False):
        super().__init__()
        self.res1 = ResBlock1(in_channels)
        self.res2 = ResBlock1(in_channels)
        self.downsample = nn.Conv1d(in_channels, out_channels, 4, stride=2, padding=1)
        
        if use_attention:
            # Convert channels to sequence for attention
            self.norm = RMSNorm(out_channels)
            self.attention = RotaryAttention(
                ModelArgs(dim=out_channels, n_heads=8, n_kv_heads=4),
                dim=out_channels,
                n_heads=8,
                n_kv_heads=4,
                max_seq_len=2048
            )
        else:
            self.attention = None

    def forward(self, x):
        x = self.res1(x)
        x = self.res2(x)
        x = self.downsample(x)
        
        if self.attention is not None:
            # Reshape for attention: [B, C, L] -> [B, L, C]
            x = x.transpose(1, 2)
            x = self.norm(x)
            x = self.attention(x)
            x = x.transpose(1, 2)
            
        return x

class UpBlock(nn.Module):
    """
    Upsampling block with dual residual blocks and optional attention
    """
    def __init__(self, in_channels, out_channels, time_dim, use_attention=False):
        super().__init__()
        self.res1 = ResBlock1(in_channels)
        self.res2 = ResBlock1(in_channels)
        self.upsample = TransConvNet(in_channels, out_channels, 4, stride=2)
        
        if use_attention:
            self.norm = RMSNorm(out_channels)
            self.attention = RotaryAttention(
                ModelArgs(dim=out_channels, n_heads=8, n_kv_heads=4),
                dim=out_channels,
                n_heads=8,
                n_kv_heads=4,
                max_seq_len=2048
            )
        else:
            self.attention = None

    def forward(self, x, skip_x=None):
        x = self.res1(x)
        x = self.res2(x)
        
        x = self.upsample(x)
        if skip_x is not None:
            x = torch.cat([x, skip_x], dim=1)
            
        if self.attention is not None:
            x = x.transpose(1, 2)
            x = self.norm(x)
            x = self.attention(x)
            x = x.transpose(1, 2)
            
        return x

class DiffusionUNet(nn.Module):
    """
    Advanced U-Net architecture with rotary attention and residual blocks.
    Features 4 levels with dual residual blocks and attention integration.
    """
    def __init__(
        self,
        in_channels=4,
        model_channels=64,
        out_channels=4,
        time_dim=None,
        num_levels=4
    ):
        super().__init__()
        
        self.time_dim = time_dim or model_channels * 4
        time_dim = self.time_dim
        
        # Time embedding
        self.time_mlp = SinusoidalTimeEmbedding(model_channels)
        
        # Initial projection
        self.init_conv = ConvolutionalNet(in_channels, model_channels, kernel_size=3)
        
        # Down blocks - increasing channel dimensions
        ch = model_channels
        self.down_blocks = nn.ModuleList()
        channels = []
        
        for level in range(num_levels):
            use_attention = level > 0  # Use attention from second level onwards
            out_ch = ch * 2
            
            self.down_blocks.append(
                DownBlock(ch, out_ch, time_dim, use_attention)
            )
            channels.append(ch)
            ch = out_ch
        
        # Middle block with Residual-Attention-Residual structure
        self.mid_block1 = ResBlock1(ch)
        self.mid_attn = RotaryAttention(
            ModelArgs(dim=ch, n_heads=8, n_kv_heads=4),
            dim=ch,
            n_heads=8,
            n_kv_heads=4,
            max_seq_len=2048
        )
        self.mid_block2 = ResBlock1(ch)
        
        # Up blocks - decreasing channel dimensions
        self.up_blocks = nn.ModuleList()
        for level in reversed(range(num_levels)):
            use_attention = level > 0  # Match down block attention
            in_ch = ch + channels.pop()  # Add skip connection channels
            out_ch = ch // 2
            
            self.up_blocks.append(
                UpBlock(in_ch, out_ch, time_dim, use_attention)
            )
            ch = out_ch
            
        # Final blocks
        self.final_res = ResBlock1(ch)
        self.final_conv = ConvolutionalNet(ch, out_channels, kernel_size=3)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, m):
        if isinstance(m, (nn.Conv1d, nn.Linear)):
            torch.nn.init.xavier_uniform_(m.weight)
            if hasattr(m, 'bias') and m.bias is not None:
                torch.nn.init.zeros_(m.bias)
                
    def forward(self, x, time):
        # Time embedding
        t = self.time_mlp(time)
        
        # Initial projection
        x = self.init_conv(x)
        
        # Store skip connections
        skips = []
        
        # Down path
        for block in self.down_blocks:
            skips.append(x)
            x = block(x)
        
        # Middle block (R-A-R)
        x = self.mid_block1(x)
        # Prepare for attention
        x = x.transpose(1, 2)
        x = self.mid_attn(x)
        x = x.transpose(1, 2)
        x = self.mid_block2(x)
        
        # Up path with skip connections
        for block in self.up_blocks:
            skip_x = skips.pop()
            x = block(x, skip_x)
        
        # Final processing
        x = self.final_res(x)
        x = self.final_conv(x)
        
        return x

def create_diffusion_model(
    in_channels=4,
    base_channels=64,
    out_channels=4,
):
    """
    Creates a diffusion model with the specified configuration
    """
    return DiffusionUNet(
        in_channels=in_channels,
        model_channels=base_channels,
        out_channels=out_channels,
        num_levels=4  # 4 levels as specified
    )

In [34]:
model = create_diffusion_model()

NameError: name 'RotaryAttention' is not defined