In [None]:
import torch

# Load the model
model = torch.load("/workspace/pretrained_models/fcpe_c_v001.pt", map_location="cpu")
model.keys()

In [None]:
from modules.cfm.dit import DiT
import torch

dit = DiT(
    in_channels=128 * 2,
    hidden_channels=192,
    out_channels=128,
    filter_channels=192 * 4,
    dropout=0.05,
    n_layers=8,
    n_heads=4,
    dim_head=64,
    kernel_size=3,
    utt_emb_dim=384,
    use_skip_connections=False,
)

# print dit parameter count
print(f"DiT parameter count: {sum(p.numel() for p in dit.parameters())}")

x = torch.randn(1, 128, 384)
x_mask = torch.ones(1, 1, 384)
mu = torch.randn(1, 128, 384)
t = torch.Tensor([0.2])
spks = torch.randn(1, 384)
cond = torch.randn(1, 192, 32)
cond_mask = torch.ones(1, 1, 32)

dit(x, x_mask, mu, t, spks).shape

In [None]:
import torch
from models import SynthesizerTrn

vc_model = SynthesizerTrn(
    spec_channels=128,
    hidden_channels=192,
    filter_channels=768,
    n_heads=4,
    dim_head=64,
    n_layers=6,
    kernel_size=3,
    p_dropout=0.1,
    speaker_embedding=384,
    n_speakers=10,
    ssl_dim=768,
)

c = torch.randn(1, 768, 565)
c_lengths = torch.LongTensor([565])
energy = torch.randn(1, 1, 565)
spec = torch.randn(1, 128, 565)
f0 = torch.randn(1, 1, 565)
uv = torch.ones(1, 565)
g = torch.randn(1, 384)
ppg = torch.randint(0, 40, (1, 124))
ppg_lengths = torch.LongTensor([124])
ppg_dur = torch.randint(0, 10, (1, 124)).float()


# print dit parameter count
print(f"VC model parameter count: {sum(p.numel() for p in vc_model.parameters())}")

# (prior_loss, diff_loss, f0_pred, lf0)
vc_model(
    c=c,
    f0=f0,
    uv=uv,
    spec=spec,
    energy=energy,
    c_lengths=c_lengths,
    ppg=ppg,
    ppg_lengths=ppg_lengths,
    ppg_dur=ppg_dur,
)

In [None]:
out, _ = vc_model.infer(
    c=c,
    spec=spec,
    f0=f0,
    uv=uv,
    energy=energy,
    ppg=ppg,
    ppg_lengths=ppg_lengths,
    ppg_dur=ppg_dur,
)
out.shape

In [None]:
from torch.utils.data import DataLoader
from data_utils import TextAudioCollate, TextAudioSpeakerLoader
import utils

hps = utils.get_hparams_from_file("/home/cfm-vc/configs/config_ppgs.json")

collate_fn = TextAudioCollate()
eval_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps, all_in_mem=False)
eval_loader = DataLoader(
    eval_dataset,
    num_workers=1,
    shuffle=False,
    batch_size=1,
    pin_memory=False,
    drop_last=False,
    collate_fn=collate_fn,
)

for idx, batch in enumerate(eval_loader):
    c, f0, spec, lengths, uv, energy, sid, ppg, ppg_lengths, ppg_dur = batch

    (
        prior_loss,
        diff_loss,
        loss_dur,
        f0_pred,
        lf0,
        energy_pred,
        speaker_logits,
    ) = vc_model(
        c=c,
        f0=f0,
        uv=uv,
        spec=spec,
        energy=energy,
        c_lengths=c_lengths,
        ppg=ppg,
        ppg_lengths=ppg_lengths,
        ppg_dur=ppg_dur,
    )
    

    if idx == 2:
        break

In [None]:
g, cond, cond_mask = vc_model.compute_conditional_latent([spec], [c_lengths])
g.shape, cond.shape, cond_mask.shape

In [None]:
from modules.reference_encoder import MelStyleEncoder
import torch


mel_encoder = MelStyleEncoder(
    in_channels=128,
    hidden_channels=256,
    utt_channels=512,
    kernel_size=5,
    p_dropout=0.1,
    n_heads=4,
    dim_head=64,
)

# print dit parameter count
print(f"VC model parameter count: {sum(p.numel() for p in mel_encoder.parameters())}")

spec = torch.randn(1, 128, 56)
spec_mask = torch.ones(1, 1, 56)


g, cond, cond_mask = mel_encoder(spec, spec_mask)
print(g, cond.shape, cond_mask.shape)

In [None]:
import random
from pathlib import Path

train_all = [
    "/workspace/metadata/filelists/xphoneBERT/en_borderlands2_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_baldursgate3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_worldofwarcraft_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_mario_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_gametts_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_archolos_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_borderlands2_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_warcraft_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_sqnarrator_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_emotional_train_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_emotional_train_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_witcher3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_witcher3_skyrim_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_fallout4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_naruto_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_kcd_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_witcher3_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/de_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/en_diablo4_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/fr_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/pl_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_diablo4_xphone.csv",
    "/workspace/metadata/filelists/xphoneBERT/ru_skyrim_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/jp_one_piece_xphone.csv",
    # "/workspace/metadata/filelists/xphoneBERT/jp_skyrim_xphone.csv",
    # "/workspace/dataset/fr/Fallout4/fr_fallout4_xphone.csv",
    "/workspace/dataset/de/Fallout4/de_fallout4_xphone.csv",
    "/workspace/dataset/en/Fallout4/en_fallout4_xphone.csv",
]

all_lines = []

for file in train_all:
    with open(file, "r") as f:
        lines = f.readlines()
        all_lines.extend(lines)

random.shuffle(all_lines)

files_max_per_speaker = 50
min_audio_length = 0.3 * 22050
max_audio_length = 12.0 * 22050

speaker_files_dict = {}

with open("/workspace/tts_train_slim.csv", "w") as wf:
    for line in all_lines:
        cols = line.split("|")
        filename = cols[0]
        speaker = cols[1]
        language = cols[2]
        text_orig = cols[3]

        filename = filename.replace("/mnt/datasets/TTS_Data", "/workspace/dataset")

        if any(
            v in text_orig
            for v in ["v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10"]
        ):
            continue

        if not Path(filename).exists():
            continue

        if max_audio_length < Path(filename).stat().st_size // 2 < min_audio_length:
            continue

        if any(char in "#[]{}*" for char in text_orig):
            continue

        if speaker not in speaker_files_dict:
            speaker_files_dict[speaker] = []
            speaker_files_dict[speaker].append(line)
            wf.write(f"{filename}|{speaker}|{language}|{text_orig}")
        else:
            if len(speaker_files_dict[speaker]) < files_max_per_speaker:
                speaker_files_dict[speaker].append(line)
                wf.write(f"{filename}|{speaker}|{language}|{text_orig}")

In [None]:
import ppgs
import random
from glob import glob
import torch
import utils
from modules.mel_processing import mel_spectrogram_torch
import torchaudio
import utils
from modules.commons import dedup_seq

random_wav = random.choice(
    glob("/workspace/dataset/de/GameTTS/**/*.wav", recursive=True)
)
# Load speech audio at correct sample rate
audio = ppgs.load.audio(random_wav)

# Choose a gpu index to use for inference. Set to None to use cpu.
gpu = None

# Infer PPGs
ppg = ppgs.from_audio(audio, ppgs.SAMPLE_RATE, gpu=gpu).float()

# get mel spectrogram
audio, sr = torchaudio.load(random_wav)
mel = mel_spectrogram_torch(audio, 1024, 128, 22050, 256, 1024, 0, 8000)

print(ppg.shape, mel.shape)

ppg = utils.repeat_expand_2d(ppg.squeeze(0), mel.shape[-1], mode="nearest")

print(ppg.shape)

sparse_ppg = ppgs.sparsify(ppg=ppg, method="percentile", threshold=torch.Tensor([0.85]))
most_probable_ppg = torch.argmax(sparse_ppg, dim=1)
features, feature_dur = dedup_seq(most_probable_ppg)
print(features, feature_dur)

In [None]:
mel

In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from modules.mel_processing import mel_spectrogram_torch
import torchaudio
import utils

audio, sr = torchaudio.load(random_wav)
mel = mel_spectrogram_torch(audio, 1024, 80, 22050, 256, 1024, 0, 8000)
print(mel.shape)
utils.plot_spectrogram_to_numpy(mel[0].cpu().numpy(), return_figure=True)[1]

In [None]:
from modules.mel_processing import mel_spectrogram_torch
import torchaudio
import utils
from glob import glob
import random
import torch
from modules.commons import average_over_durations

random_ppg_unit = random.choice(
    glob("/workspace/dataset/de/GameTTS/**/*.ppg_unit.pt", recursive=True)
)
ppg_unit = torch.load(random_ppg_unit)

wav_path = random_ppg_unit.replace(".ppg_unit.pt", ".wav")
audio, sr = torchaudio.load(wav_path)
mel = mel_spectrogram_torch(audio, 1024, 128, 22050, 256, 1024, 0, 8000)
mel.shape

In [None]:
unit_dur = ppg_unit["ppg_unit_dur"].unsqueeze(0)
average_over_durations(mel, unit_dur).shape

In [None]:
from collections import defaultdict
from glob import glob
import pydub
from tqdm import tqdm

all_ppgs = glob("/workspace/dataset/de/GameTTS/**/*.ppg_unit.pt", recursive=True)
all_ppgs = [ppg.replace(".ppg_unit.pt", ".wav") for ppg in all_ppgs]
files = "/workspace/vc_train.csv"

speaker_files = defaultdict(list)

with open(files, "r") as f:
    for line in f:
        cols = line.split("|")
        path = cols[0]
        if path in all_ppgs:
            audio = pydub.AudioSegment.from_file(path)
            duration = audio.duration_seconds

            if 10 > duration > 0.3:
                speaker = cols[1]
                speaker_files[speaker].append(path)
                all_ppgs.remove(path)

speaker_idx = 0

with open("/workspace/vc_train_ppg.csv", "w") as wf:
    for speaker, files in speaker_files.items():
        if len(files) > 5:
            for file in files:
                wf.write(f"{file}|{speaker_idx}\n")
            speaker_idx += 1

In [None]:
import pydub
from tqdm import tqdm

files = ["/workspace/tts_test_ph.csv", "/workspace/tts_train_ph_cleaned.csv"]

# /workspace/dataset/pl/Archolos/dia_captain_archolos_q305_amulet_03_06.wav|Captain|2|zaras, zaras! mɔʐɛ i ɲɛ dawɛm vam ʂalup, alɛ na arxɔlɔs vas dɔstart͡ʂɨwɛm. ɲɛ mamɨ dɔ t͡ʂɛɡɔ vrat͡sat͡ɕ.|Zaraz, zaraz! Może i nie dałem wam szalup, ale na Archolos was dostarczyłem. Nie mamy do czego wracać.


speaker_id = 0
speaker_id_map = {}

for f in files:
    with open(f, "r") as rf:
        lines = rf.readlines()

    with open(f.replace("_ph", "_dur_ph"), "w") as wf:
        for line in tqdm(lines):
            cols = line.split("|")
            filename = cols[0]
            speaker_name = cols[1]
            language = cols[2]
            phoneme_text = cols[3]
            text = cols[4].replace("\n", "")

            audio = pydub.AudioSegment.from_file(filename)
            duration = audio.duration_seconds
            if 11.0 >= duration >= 0.3 and len(text) > 2:
                cols.extend([str(duration)])
                wf.write(
                    f"{filename}|{speaker_name}|{language}|{phoneme_text}|{text}|{duration}\n"
                )

                if speaker_name not in speaker_id_map:
                    speaker_id_map[speaker_name] = speaker_id
                    speaker_id += 1

In [None]:
import random
from collections import defaultdict

files = "/workspace/vc_train.csv"


speaker_files = defaultdict(list)

with open(files, "r") as f:
    for line in f:
        cols = line.split("|")
        path = cols[0]
        speaker = cols[1]
        speaker_files[speaker].append(path)

speaker_idx = 0

with open("/workspace/vc_train_cleaned.csv", "w") as wf:
    for speaker, files in speaker_files.items():
        if len(files) > 5:
            for file in files:
                wf.write(f"{file}|{speaker_idx}\n")
            speaker_idx += 1