In [2]:
import sys
sys.path.append("vits_torch/")
import utils
from models import SynthesizerTrn
from text.symbols import symbols

hps = utils.get_hparams_from_file("vits_torch/logs/vits-base-en-US-MadisonNeural/config.json")
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model)
_ = net_g.eval()

_ = utils.load_checkpoint("vits_torch/logs/vits-base-en-US-MadisonNeural/G_800000.pth", net_g, None)



In [21]:
import mlx.nn as nn

x = nn.Sequential(nn.Conv1d(2, 1, 1))

def init_weights(m, mean=0.0, std=0.01):
    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        # TODO: fix this
        return nn.init.normal(mean, std)(m)
    else:
        return m

In [22]:
model = nn.Sequential(nn.Linear(5, 10), nn.ReLU(), nn.Linear(10, 5))
init_fn = nn.init.uniform(low=-0.1, high=0.1)
model.apply(init_weights)

Sequential(
  (layers.0): Linear(input_dims=5, output_dims=10, bias=True)
  (layers.1): ReLU()
  (layers.2): Linear(input_dims=10, output_dims=5, bias=True)
)

In [16]:
model

Sequential(
  (layers.0): Linear(input_dims=5, output_dims=10, bias=True)
  (layers.1): ReLU()
  (layers.2): Linear(input_dims=10, output_dims=5, bias=True)
)

In [13]:
_ = x.apply(init_weights)

In [3]:
import torch

import mlx.core as mx
import mlx.nn as nn

In [4]:
def convert(model, rules=None):
    params = {}
    if rules is not None and type(model) in rules:
        out = rules[type(model)](model, rules)
        return out
    if isinstance(model, torch.Tensor):
        return mx.array(model.detach().numpy())
    if isinstance(model, torch.nn.ModuleList):
        return [convert(n, rules) for n in model.children()]
    # if isinstance(model, torch.nn.Conv1d):
    #     return {
    #         "weight": convert(model.weight).transpose(0, 2, 1),
    #         "bias": convert(model.bias),
    #     }
    for k, n in model.named_children():
        if rules is not None and k in rules:
            params.update(rules[k](n, rules))
        else:
            params[k] = convert(n, rules)
    for k, p in model.named_parameters(recurse=False):
        params[k] = convert(p)
    return params

In [5]:
params = convert(net_g)
params.keys()

dict_keys(['enc_p', 'dec', 'enc_q', 'flow', 'dp'])

In [6]:
params

{'enc_p': {'emb': {'weight': array([[-0.0110035, 0.00055253, -0.00105711, ..., -0.00946123, -0.0029717, -0.00927133],
          [0.0231356, -0.0140018, 5.54741e-05, ..., 0.00037821, -0.0537527, 0.0121214],
          [0.34511, -0.197778, -0.0869552, ..., 0.14594, -0.225829, 0.171807],
          ...,
          [0.0108739, 0.149518, -0.0654389, ..., 0.116465, -0.0402058, -0.0387925],
          [0.243, -0.0450219, -0.0510859, ..., 0.132216, 0.0824754, 0.0816978],
          [0.0386524, -0.0313871, 0.0122042, ..., -0.00492342, 0.0601239, 0.0159901]], dtype=float32)},
  'encoder': {'drop': {},
   'attn_layers': [{'conv_q': {'weight': array([[[0.00452969],
              [0.0360925],
              [-0.0546473],
              ...,
              [-0.0115284],
              [0.104641],
              [0.0720642]],
             [[0.0516898],
              [-0.0742255],
              [0.0171386],
              ...,
              [0.0547673],
              [-0.0681109],
              [0.0245506]],
   

In [None]:
class Encoder(nn.Module):
    def __init__(
        self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=4, **kwargs
    ):
        super().__init__()
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout
        self.window_size = window_size

        self.drop = nn.Dropout(p_dropout)
        self.attn_layers = nn.ModuleList()
        self.norm_layers_1 = nn.ModuleList()
        self.ffn_layers = nn.ModuleList()
        self.norm_layers_2 = nn.ModuleList()
        for i in range(self.n_layers):
            self.attn_layers.append(
                MultiHeadAttention(
                    hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size
                )
            )
            self.norm_layers_1.append(LayerNorm(hidden_channels))
            self.ffn_layers.append(
                FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout)
            )
            self.norm_layers_2.append(LayerNorm(hidden_channels))

    def forward(self, x, x_mask):
        attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
        x = x * x_mask
        for i in range(self.n_layers):
            y = self.attn_layers[i](x, x, attn_mask)
            y = self.drop(y)
            x = self.norm_layers_1[i](x + y)

            y = self.ffn_layers[i](x, x_mask)
            y = self.drop(y)
            x = self.norm_layers_2[i](x + y)
        x = x * x_mask
        return x

In [36]:
class TextEncoder(nn.Module):
    def __init__(
        self, n_vocab, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
    ):
        super().__init__()
        self.n_vocab = n_vocab
        self.out_channels = out_channels
        self.hidden_channels = hidden_channels
        self.filter_channels = filter_channels
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.kernel_size = kernel_size
        self.p_dropout = p_dropout

        self.emb = nn.Embedding(n_vocab, hidden_channels)
        nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)

        self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
        self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)


class VITS(nn.Module):
    def __init__(
        self,
        n_vocab,
        spec_channels,
        segment_size,
        inter_channels,
        hidden_channels,
        filter_channels,
        n_heads,
        n_layers,
        kernel_size,
        p_dropout,
        resblock,
        resblock_kernel_sizes,
        resblock_dilation_sizes,
        upsample_rates,
        upsample_initial_channel,
        upsample_kernel_sizes,
        n_speakers: int = 0,
        gin_channels: int = 0,
        dtype: mx.Dtype = mx.float16,
    ):
        super().__init__()
        self.enc_p = TextEncoder(
            n_vocab, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
        )

        if n_speakers > 1:
            self.emb_g = nn.Embedding(n_speakers, gin_channels)

In [37]:
vits = VITS()