In [None]:
import utils

f0_predictor = utils.get_f0_predictor(
    "fcpe",
    sampling_rate=22050,
    hop_length=256,
    device="cpu",
    threshold=0.05,
)

In [None]:
import torch

# Load the model
model = torch.load("/workspace/pretrained_models/fcpe_c_v001.pt", map_location="cpu")
model.keys()

In [None]:
from modules.cfm.dit import DiT
import torch

dit = DiT(
    in_channels=80 * 2,
    hidden_channels=192,
    out_channels=80,
    filter_channels=192 * 4,
    dropout=0.05,
    n_layers=6,
    n_heads=2,
    kernel_size=3,
    utt_emb_dim=512,
)

# print dit parameter count
print(f"DiT parameter count: {sum(p.numel() for p in dit.parameters())}")

x = torch.randn(1, 80, 128)
x_mask = torch.ones(1, 1, 128)
mu = torch.randn(1, 80, 128)
t = torch.Tensor([0.2])
spks = torch.randn(1, 512)
cond = torch.randn(1, 192, 32)
cond_mask = torch.ones(1, 1, 32)

dit(x, x_mask, mu, t, spks, cond, cond_mask).shape

In [None]:
import torch
from models import SynthesizerTrn

vc_model = SynthesizerTrn(
    spec_channels=80,
    hidden_channels=192,
    filter_channels=768,
    n_heads=2,
    n_layers=6,
    kernel_size=3,
    p_dropout=0.1,
    speaker_embedding=512,
    n_speakers=10,
    ssl_dim=768,
    ppgs_dim=40,
)

c = torch.randn(1, 768, 56)
c_lengths = torch.Tensor([56])
ppgs = torch.randn(1, 40, 56)
spec = torch.randn(1, 80, 56)
f0 = torch.randn(1, 1, 56)
uv = torch.ones(1, 56)
g = torch.randn(1, 512)


# print dit parameter count
print(f"VC model parameter count: {sum(p.numel() for p in vc_model.parameters())}")

# (prior_loss, diff_loss, f0_pred, lf0)
vc_model(c=c, f0=f0, uv=uv, spec=spec, ppgs=ppgs, c_lengths=c_lengths)

In [None]:
vc_model.decoder.estimator.blocks

In [None]:
cond, cond_mask, g = vc_model.compute_conditional_latent([spec], [c_lengths])
g.shape, cond.shape, cond_mask.shape

In [None]:
deocder_output, _ = vc_model.infer(c=c, f0=f0, uv=uv, spec=spec, ppgs=ppgs)
deocder_output.shape

In [None]:
from modules.perceiver_encoder import PerceiverResampler
import torch


resampler = PerceiverResampler(
    hidden_channels=192,
    depth=2,
    num_latents=32,
    dim_head=64,
    heads=8,
    ff_mult=4,
)

spec = torch.randn(1, 192, 56)
spec_mask = torch.ones(1, 1, 56)


resampler(spec, spec_mask)

In [2]:
import random
from pathlib import Path

train_all = [
    "/workspace/metadata/filelists/xphoneBERT/en_borderlands2_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_baldursgate3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_worldofwarcraft_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_mario_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_gametts_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_archolos_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_borderlands2_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_warcraft_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_sqnarrator_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_emotional_train_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_emotional_train_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_witcher3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_witcher3_skyrim_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_fallout4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_naruto_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_kcd_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_witcher3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/fr_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_skyrim_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/jp_one_piece_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/jp_skyrim_xphone.csv",
    # "/workspace/dataset/fr/Fallout4/fr_fallout4_xphone.csv",
    # "/workspace/dataset/de/Fallout4/de_fallout4_xphone.csv",
    # "/workspace/dataset/en/Fallout4/en_fallout4_xphone.csv",
]

all_lines = []

for file in train_all:
    with open(file, "r") as f:
        lines = f.readlines()
        all_lines.extend(lines)

random.shuffle(all_lines)

files_max_per_speaker = 50
min_audio_length = 0.3 * 22050
max_audio_length = 12.0 * 22050

speaker_files_dict = {}

with open("/workspace/tts_train_slim.csv", "w") as wf:
    for line in all_lines:
        cols = line.split("|")
        filename = cols[0]
        speaker = cols[1]
        emotion = cols[2]
        language = cols[3]
        text = cols[-2]
        text_orig = cols[-1]

        filename = filename.replace("/mnt/datasets/TTS_Data", "/workspace/dataset")

        if any(
            v in text_orig
            for v in ["v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10"]
        ):
            continue

        if not Path(filename).exists():
            continue

        if max_audio_length < Path(filename).stat().st_size // 2 < min_audio_length:
            continue

        if any(char in "#[]{}*" for char in text_orig):
            continue

        # if len(text) < 4 or len(text) > 350:
        #     continue

        if speaker not in speaker_files_dict:
            speaker_files_dict[speaker] = []
            speaker_files_dict[speaker].append(line)
            wf.write(f"{filename}|{speaker}|{language}|{text_orig}")
        else:
            if len(speaker_files_dict[speaker]) < files_max_per_speaker:
                speaker_files_dict[speaker].append(line)
                wf.write(f"{filename}|{speaker}|{language}|{text_orig}")

In [None]:
import ppgs

print(ppgs.MAX_INFERENCE_FRAMES)