In [None]:
import torch

# Load the model
model = torch.load("/workspace/pretrained_models/fcpe_c_v001.pt", map_location="cpu")
model.keys()

In [None]:
from modules.cfm.dit import DiT
import torch

dit = DiT(
    in_channels=128 * 2,
    hidden_channels=192,
    out_channels=128,
    filter_channels=192 * 4,
    dropout=0.05,
    n_layers=8,
    n_heads=4,
    dim_head=64,
    kernel_size=3,
    utt_emb_dim=384,
    use_skip_connections=False
)

# print dit parameter count
print(f"DiT parameter count: {sum(p.numel() for p in dit.parameters())}")

x = torch.randn(1, 128, 384)
x_mask = torch.ones(1, 1, 384)
mu = torch.randn(1, 128, 384)
t = torch.Tensor([0.2])
spks = torch.randn(1, 384)
cond = torch.randn(1, 192, 32)
cond_mask = torch.ones(1, 1, 32)

dit(x, x_mask, mu, t, spks).shape

In [None]:
import torch
from models import SynthesizerTrn

vc_model = SynthesizerTrn(
    spec_channels=128,
    hidden_channels=192,
    filter_channels=768,
    n_heads=2,
    n_layers=6,
    kernel_size=3,
    p_dropout=0.1,
    speaker_embedding=384,
    n_speakers=10,
    ssl_dim=768,
    ppgs_dim=40,
)

c = torch.randn(1, 768, 565)
c_lengths = torch.Tensor([565])
ppgs = torch.randn(1, 40, 565)
spec = torch.randn(1, 128, 565)
f0 = torch.randn(1, 1, 565)
uv = torch.ones(1, 565)
g = torch.randn(1, 384)


# print dit parameter count
print(f"VC model parameter count: {sum(p.numel() for p in vc_model.parameters())}")

# (prior_loss, diff_loss, f0_pred, lf0)
vc_model(c=c, f0=f0, uv=uv, spec=spec, ppgs=ppgs, c_lengths=c_lengths)

In [None]:
out, _ = vc_model.infer(c=c, spec=spec, f0=f0, uv=uv, ppgs=ppgs, c_lengths=c_lengths)
out.shape

In [None]:
g, cond, cond_mask = vc_model.compute_conditional_latent([spec], [c_lengths])
g.shape, cond.shape, cond_mask.shape

In [None]:
from modules.reference_encoder import MelStyleEncoder
import torch


mel_encoder = MelStyleEncoder(
            in_channels=128,
            hidden_channels=256,
            utt_channels=512,
            kernel_size=5,
            p_dropout=0.1,
            n_heads=4,
            dim_head=64,
        )

# print dit parameter count
print(f"VC model parameter count: {sum(p.numel() for p in mel_encoder.parameters())}")

spec = torch.randn(1, 128, 56)
spec_mask = torch.ones(1, 1, 56)


g, cond, cond_mask = mel_encoder(spec, spec_mask)
print(g, cond.shape, cond_mask.shape)

In [None]:
import random
from pathlib import Path

train_all = [
    "/workspace/metadata/filelists/xphoneBERT/en_borderlands2_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_baldursgate3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_worldofwarcraft_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_mario_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_gametts_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_archolos_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_borderlands2_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_warcraft_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_sqnarrator_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_emotional_train_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_emotional_train_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_witcher3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_witcher3_skyrim_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_fallout4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_naruto_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_kcd_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_witcher3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_diablo4_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/fr_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_skyrim_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/jp_one_piece_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/jp_skyrim_xphone.csv",
    # "/workspace/dataset/fr/Fallout4/fr_fallout4_xphone.csv",
    "/workspace/dataset/de/Fallout4/de_fallout4_xphone.csv",
    "/workspace/dataset/en/Fallout4/en_fallout4_xphone.csv",
]

all_lines = []

for file in train_all:
    with open(file, "r") as f:
        lines = f.readlines()
        all_lines.extend(lines)

random.shuffle(all_lines)

files_max_per_speaker = 50
min_audio_length = 0.3 * 22050
max_audio_length = 12.0 * 22050

speaker_files_dict = {}

with open("/workspace/tts_train_slim.csv", "w") as wf:
    for line in all_lines:
        cols = line.split("|")
        filename = cols[0]
        speaker = cols[1]
        language = cols[2]
        text_orig = cols[3]

        filename = filename.replace("/mnt/datasets/TTS_Data", "/workspace/dataset")

        if any(
            v in text_orig
            for v in ["v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10"]
        ):
            continue

        if not Path(filename).exists():
            continue

        if max_audio_length < Path(filename).stat().st_size // 2 < min_audio_length:
            continue

        if any(char in "#[]{}*" for char in text_orig):
            continue


        if speaker not in speaker_files_dict:
            speaker_files_dict[speaker] = []
            speaker_files_dict[speaker].append(line)
            wf.write(f"{filename}|{speaker}|{language}|{text_orig}")
        else:
            if len(speaker_files_dict[speaker]) < files_max_per_speaker:
                speaker_files_dict[speaker].append(line)
                wf.write(f"{filename}|{speaker}|{language}|{text_orig}")

In [None]:
import ppgs

# Load speech audio at correct sample rate
audio = ppgs.load.audio("/workspace/dataset/ru/Witcher3/wavs/0x00100118.wav")

# Choose a gpu index to use for inference. Set to None to use cpu.
gpu = None

# Infer PPGs
mu = ppgs.from_audio(audio, ppgs.SAMPLE_RATE, gpu=gpu)
mu.shape

In [None]:
mu[0][0]

In [None]:
from modules.mel_processing import mel_spectrogram_torch
import torchaudio
import utils
from glob import glob
import random

random_ppg = random.choice(glob("/workspace/dataset/de/GameTTS/**/*.ppg_unit.pt", recursive=True))
feature = torch.load(random_ppg)

audio, sr = torchaudio.load("/workspace/dataset/ru/Witcher3/wavs/0x00100118.wav")
mel = mel_spectrogram_torch(audio, 1024, 80, 22050, 256, 1024, 0, 8000)
mel.shape

In [None]:
import math
import monotonic_align
import torch

ppg_embedding = torch.nn.Embedding(40, 80)

y_mask = torch.ones(1, 1, mel.shape[-1])

mu_x = ppg_embedding(mu).transpose(1, 2)
print(mu_x.shape)
x_mask= torch.ones(1, 1, mu.shape[-1])


# attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)

# Use MAS to find most likely alignment `attn` between ppg and mel-spectrogram
with torch.no_grad():
    const = -0.5 * math.log(2 * math.pi) * 80
    factor = -0.5 * torch.ones(mu_x.shape, dtype=mu_x.dtype, device=mu_x.device)
    y_square = torch.matmul(factor.transpose(1, 2), mel**2)
    y_mu_double = torch.matmul(2.0 * (factor * mu_x).transpose(1, 2), mel)
    mu_square = torch.sum(factor * (mu_x**2), 1).unsqueeze(-1)
    log_prior = y_square - y_mu_double + mu_square + const

    attn = (
        monotonic_align.maximum_path(
            log_prior,
            attn_mask.squeeze(1),
        )
        .unsqueeze(1)
        .detach()
    )
    
print(torch.sum(attn, -1))
logw_ = torch.log(1e-6 + torch.sum(attn, -1)) * x_mask
logw_

In [None]:
from itertools import groupby

def collate_1d_or_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
    if len(values[0].shape) == 1:
        return collate_1d(values, pad_idx, left_pad, shift_right, max_len, shift_id)
    else:
        return collate_2d(values, pad_idx, left_pad, shift_right, max_len)


def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
    """Convert a list of 1d tensors into a padded 2d tensor."""
    size = max(v.size(0) for v in values) if max_len is None else max_len
    res = values[0].new(len(values), size).fill_(pad_idx)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if shift_right:
            dst[1:] = src[:-1]
            dst[0] = shift_id
        else:
            dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
    return res


def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
    """Convert a list of 2d tensors into a padded 3d tensor."""
    size = max(v.size(0) for v in values) if max_len is None else max_len
    res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)

    def copy_tensor(src, dst):
        assert dst.numel() == src.numel()
        if shift_right:
            dst[1:] = src[:-1]
        else:
            dst.copy_(src)

    for i, v in enumerate(values):
        copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
    return res


def dedup_seq(seq):
    B, L = seq.shape
    vals, counts = [], []
    for i in range(B):
        val, count = zip(*[(k.item(), sum(1 for _ in g)) for k, g in groupby(seq[i])])
        vals.append(torch.LongTensor(val))
        counts.append(torch.LongTensor(count))
    vals = collate_1d_or_2d(vals, 0)
    counts = collate_1d_or_2d(counts, 0)
    return vals, counts

In [None]:
import torch
import numpy as np
from sklearn.cluster import MiniBatchKMeans, KMeans
import ppgs
import time
from glob import glob
import random
from modules.mel_processing import mel_spectrogram_torch
import torchaudio
import utils


random_wav = random.choice(glob("/workspace/dataset/de/GameTTS/**/*.ppg.pt", recursive=True))

audio, sr = torchaudio.load(random_wav.replace(".ppg.pt", ".wav"))
mel = mel_spectrogram_torch(audio, 1024, 80, 22050, 256, 1024, 0, 8000)
print(mel.shape)

np.random.seed(1234)
kmeans = KMeans(n_clusters=40, verbose=False)
feature = torch.load(random_wav)

start_time = time.time()

feature = feature.squeeze(0).numpy().T
feature = feature.astype(np.float32)

kmeans.fit(feature)
features = torch.from_numpy(feature)
features = kmeans.fit_predict(features)

torch_features = torch.LongTensor(features).unsqueeze(0)
torch_features, _ = dedup_seq(torch_features)
print(torch_features.shape)
print(f"Time taken: {time.time() - start_time}")

In [None]:
random_ppg = random.choice(glob("/workspace/dataset/de/GameTTS/**/*.ppg.pt", recursive=True))
feature = torch.load(random_ppg)

feature = feature.squeeze(0).numpy().T
feature = feature.astype(np.float32)
feature.shape

kmeans.fit(feature)
features = torch.from_numpy(feature)
features = kmeans.fit_predict(features)

torch_features = torch.LongTensor(features).unsqueeze(0)
torch_features, _ = dedup_seq(torch_features)
torch_features = torch_features.squeeze(0)
torch_features.shape