## Import Functions

In [None]:
import whisper
from Scripts.functions import InferenceResult, length_to_mask
from functions import StyleTTS2_Helper
import soundfile as sf
import torch
import torch.nn.functional as F
import os
import torch # Deep Learning Framework

import soundfile as sf
from nltk.tokenize import word_tokenize # Tokenizers divide strings into lists of substrings
import time # Used for timing operations
import yaml

import torch.nn.functional as F
import whisper

from dataclasses import dataclass

from models import *
from utils import *
from text_utils import TextCleaner

import phonemizer

from Utils.PLBERT.util import load_plbert

import IPython.display as ipd

from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"  # <-- adjust if different
torch.manual_seed(0) # Fixes starting point of random seed for torch
torch.backends.cudnn.benchmark = False # Fix convolution algorithm
torch.backends.cudnn.deterministic = True # Only use deterministic algorithms

%cd ..

## Create Helper Functions


In [None]:
def length_to_mask(lengths):
    mask = torch.arange(lengths.max())  # Creates a Vector [0,1,2,3,...,x], where x = biggest value in lengths
    mask = mask.unsqueeze(0)  # Creates a Matrix [1,x] from Vector [x]
    mask = mask.expand(lengths.shape[0],
                       -1)  # Expands the matrix from [1,x] to [y,x], where y = number of elements in lengths
    mask = mask.type_as(lengths)  # Assign mask the same type as lengths
    mask = torch.gt(mask + 1, lengths.unsqueeze(
        1))  # gt = greater than, compares each value from lengths to a row of values in mask; unsqueeze = splits vector lengths into vectors of size 1
    return mask  # returns a mask of shape (batch_size, max_length) where mask[i, j] = 1 if j < lengths[i] and mask[i, j] = 0 otherwise.

In [None]:
def analyzeAudio(wav):
    model = whisper.load_model("tiny")
    result = model.transcribe(audio=wav)
    print(result["text"])

    # let's inspect segments
    for seg in result["segments"]:
        print(
            f"[{seg['start']:.2f} -> {seg['end']:.2f}] "
            f"text='{seg['text']}' "
            f"avg_logprob={seg['avg_logprob']:.3f} "
            f"no_speech_prob={seg['no_speech_prob']:.3f}"
        )

In [None]:
def generateAudio(pipe, name, text):

    inferenceResult = pipe.inference(text, noise=torch.randn(1,1,256).to(pipe.device))
    inferenceResult.save(name)
    audio = pipe.synthesizeSpeech(inferenceResult)
    sf.write("outputs/audio/" + name + ".wav", audio, samplerate=24000)
    return inferenceResult

In [None]:
def interpolateAllLatents(to_change, reference, interpolation_percentage):

    interpolation_result = {}

    for name in reference.__dataclass_fields__:

        latent_ground_truth = getattr(reference, name)
        latent_target = getattr(to_change, name)

        if name != "h_text":
            interpolation_result[name] = latent_ground_truth
            continue

        print("Starting interpolation for " + name)
        if latent_ground_truth.shape != latent_target.shape:
            print(f"Shape mismatch with ground_truth={latent_ground_truth.shape}, target={latent_target.shape}")

            if (latent_ground_truth.dim() < 3) and (latent_target.dim() < 3):
                latent_ground_truth = latent_ground_truth.unsqueeze(1)
                latent_target = latent_target.unsqueeze(1)

            latent_target = F.interpolate(
                input=latent_target,
                size=latent_ground_truth.shape[-1],
                mode="linear",
                align_corners=False
            ).squeeze(0)
        interpolation_result[name] = latent_ground_truth * (1 - interpolation_percentage) + latent_target * interpolation_percentage

    return InferenceResult(**interpolation_result)

In [None]:
def interpolateAttribute(to_change, reference, interpolation_percentage: float):

    if reference.shape != to_change.shape:
        print(f"Shape mismatch with ground_truth={reference.shape}, target={to_change.shape}")

        if (reference.dim() < 3) and (to_change.dim() < 3):
            reference = reference.unsqueeze(1)
            to_change = to_change.unsqueeze(1)

        to_change = F.interpolate(
            input=to_change,
            size=reference.shape[-1],
            mode="linear",
            align_corners=False
        ).squeeze(0)

    return reference * (1 - interpolation_percentage) + to_change * interpolation_percentage

In [None]:
def addNoise(reference: InferenceResult, target: InferenceResult, interpolation_percentage: float, attribute: str):

    latent_reference = getattr(reference, attribute)
    latent_target = getattr(target, attribute)
    diff = latent_target.size(-1) - latent_reference.size(-1)

    print("Adding Noise to " + attribute + " for ground truth")
    if latent_reference.shape != latent_target.shape:
        print(f"Shape mismatch with ground_truth={latent_reference.shape}, target={latent_target.shape}")

    else:
        if (latent_reference.dim() < 3) and (latent_target.dim() < 3):
            latent_reference = latent_reference.unsqueeze(1)
            latent_target = latent_target.unsqueeze(1)

        noise = torch.randn(*latent_reference.shape[:-1], diff, device=latent_reference.device, dtype=latent_reference.dtype)

        latent_reference = torch.cat([latent_reference, noise], dim=-1)

    return latent_reference * (1 - interpolation_percentage) + latent_target * interpolation_percentage

In [None]:
def addNumber(to_change: torch.Tensor, reference: torch.Tensor, number: float):

    diff = reference.size(-1) - to_change.size(-1)

    if diff < 0:
        print("Reference is smaller then whats to be changed")
        return to_change

    zeros = torch.full((*to_change.shape[:-1], diff), number, device=to_change.device, dtype=to_change.dtype)

    to_change = torch.cat([to_change, zeros], dim=-1)

    return to_change

In [None]:
def resizeAttribute(to_change, reference):

    if reference.dim() == 2 and to_change.dim() == 2:
        reference = reference.unsqueeze(0)
        to_change = to_change.unsqueeze(0)

    resized = F.interpolate(
        to_change,
        size=reference.shape[-1],
        mode='linear',
        align_corners=False
    )

    resized = resized.squeeze(0)

    return resized

## Classes

In [None]:
@dataclass
class InferenceResult:
    h_text: torch.Tensor
    h_aligned: torch.Tensor
    f0_pred: torch.Tensor
    a_pred: torch.Tensor
    n_pred: torch.Tensor
    style_vector_prosodic: torch.Tensor

    def save(self, folder: str):

        os.makedirs("outputs/latent/"+folder, exist_ok=True)

        # Iterate through all fields of the dataclass
        for name, value in self.__dict__.items():
            if isinstance(value, torch.Tensor):
                path = os.path.join("outputs/latent/"+folder, f"{name}.pt")
                torch.save(value, path)
                print(f"✅ Saved {name} -> {path}")
            else:
                print(f"⚠️ Skipping {name} (not a tensor)")

In [None]:
class StyleTTS2_Helper:
    def __init__(self):

        # Splits words into phonemes (symbols that represent how words are pronounced)
        self.model = None
        self.params = None
        self.sampler = None

        self.global_phonemizer = phonemizer.backend.EspeakBackend(
            language='en-us',
            preserve_punctuation=True,  # Keeps Punctuation such as , . ? !
            with_stress=True  # Adds stress marks to vowels
        )

        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'

        self.textcleaner = TextCleaner()  # Lowercasing & trimming, expanding numbers & symbols, handling punctuation, phoneme conversion, tokenization

    def load_models(self, yml_path="Models/LJSpeech/config.yml"):
        config = yaml.safe_load(open(yml_path))  # YAML File with model settings and pretrained checkpoints (ASR, F0, PL-BERT)

        # load pretrained ASR (Automatic Speech Recognition) model
        ASR_config = config.get('ASR_config', False)  # YAML config that describes the model’s structure
        ASR_path = config.get('ASR_path', False)  # Checkpoint File
        text_aligner = load_ASR_models(ASR_path, ASR_config)  # Load PyTorch model

        # load pretrained F0 model (Extracts Pitch Features from Audio, How Pitch Changes over time)
        F0_path = config.get('F0_path', False)  # YAML config that describes the model’s structure
        pitch_extractor = load_F0_models(F0_path)

        # load BERT model (encodes input text with prosodic cues)
        BERT_path = config.get('PLBERT_dir', False)  # YAML config that describes the model’s structure
        plbert = load_plbert(BERT_path)

        self.model = build_model(
            recursive_munch(config['model_params']),  # Allows attribute-style access to keys of model_params,
            text_aligner,  # Automatic Speech Recognition model
            pitch_extractor,  # F0 model
            plbert  # BERT model
        )

        _ = [self.model[key].eval() for key in self.model]
        _ = [self.model[key].to(self.device) for key in self.model]

        params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu')
        self.params = params_whole['net']

    def load_checkpoints(self):
        for key in self.model:
            if key in self.params:
                try:
                    self.model[key].load_state_dict(self.params[key])
                except:
                    from collections import OrderedDict
                    state_dict = self.params[key]
                    new_state_dict = OrderedDict()
                    for k, v in state_dict.items():
                        name = k[7:]  # remove `module.`
                        new_state_dict[name] = v
                    # load params
                    self.model[key].load_state_dict(new_state_dict, strict=False)
        #             except:
        #                 _load(params[key], model[key])
        _ = [self.model[key].eval() for key in self.model]

    def sample_diffusion(self):
        self.sampler = DiffusionSampler(
            self.model.diffusion.diffusion,
            sampler=ADPM2Sampler(),
            sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0),  # empirical parameters
            clamp=False
        )

    # Turns text to tensor with token ID
    def preprocessText(self, text):
        # 1. Preprocessing Text
        text = text.strip()  # Removes whitespaces from beginning and end of string
        text = text.replace('"', '')  # removes " to prevent unpredictable behavior

        # 2. Text -> Phoneme
        phonemes = self.global_phonemizer.phonemize([text])  # text -> list of phoneme
        phonemes = word_tokenize(phonemes[0])  # Split into individual tokens
        phonemes = ' '.join(phonemes)  # Join tokens together, split by a empty space

        # 3. Phoneme -> Token ID
        tokens = self.textcleaner(phonemes)  # Look up numeric ID per phoneme
        tokens.insert(0, 0)  # Insert leading 0 to mark start

        # 4. Token ID -> PyTorch Tensor
        tokens = torch.LongTensor(tokens).to(self.device).unsqueeze(0)  # Converts numeric ID to PyTorch Tensor

        return tokens

    def predictDuration(self, bert_encoder_with_style, input_lengths):

        # Duration Predictor, frames per phoneme
        d_pred, _ = self.model.predictor.lstm(bert_encoder_with_style)  # Model temporal dependencies between phonemes, LSTM = RNN
        d_pred = self.model.predictor.duration_proj(d_pred)  # Predict how long each phoneme lasts
        d_pred = torch.sigmoid(d_pred).sum(axis=-1)  # Sum of duration prediction -> Result: Prediction of frame duration
        d_pred = torch.round(d_pred.squeeze()).clamp(min=1)  # Convert duration prediction into integers, add clamp to ensure that each phoneme has at least one frame
        d_pred[-1] += 5  # Makes last phoneme last 5 frames longer, to ensure it not being cut off too fast

        # Creates predicted alignment matrix between text (phonemes) and audio frames
        a_pred = torch.zeros(input_lengths, int(d_pred.sum().data))  # Initializes a matrix with sizes: [# of Phonemes (input_lengths)] x [Sum of total predicted frames]
        current_frame = 0
        for i in range(a_pred.size(0)):  # Iterates over phoneme
            a_pred[i, current_frame:current_frame + int(d_pred[i].data)] = 1  # Changes for row-i (the i-th phoneme) all the values from current_frame to current_frame + int(d_pred[i].data) to 1
            current_frame += int(d_pred[i].data)  # Move current_frame to new first start

        return a_pred

    def computeStyleVector(self, noise, h_bert, embedding_scale, diffusion_steps):

        style_vector = self.sampler(
            noise,
            embedding=h_bert[0].unsqueeze(0),
            embedding_scale=embedding_scale,
            num_steps=diffusion_steps
        ).squeeze(0)

        # Split Style Vector
        style_vector_acoustic = style_vector[:, 128:]  # Right Half = Acoustic Style Vector
        style_vector_prosodic = style_vector[:, :128]  # Left Half = Prosodic Style Vector

        return style_vector_acoustic, style_vector_prosodic

    def inference(self, text, noise, diffusion_steps=5, embedding_scale=1):

        # Ground Truth
        tokens = self.preprocessText(text)

        with torch.no_grad():
            input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)  # Number of phoneme / Length of tokens, shape[-1] = last element in list/array
            text_mask = length_to_mask(input_lengths).to(tokens.device)  # Creates a bitmask based on number of phonemes

            h_text = self.model.text_encoder(tokens, input_lengths, text_mask)  # Creates acoustic text encoder (phoneme -> feature vectors)
            h_bert = self.model.bert(tokens, attention_mask=(~text_mask).int())
            bert_encoder = self.model.bert_encoder(h_bert).transpose(-1, -2)  # Creates prosodic text encoder (phoneme -> feature vectors)

            ## Function Call
            style_vector_acoustic, style_vector_prosodic = self.computeStyleVector(noise, h_bert, embedding_scale, diffusion_steps)

            # AdaIN, Adding information of style vector to phoneme
            bert_encoder_with_style = self.model.predictor.text_encoder(bert_encoder, style_vector_acoustic, input_lengths, text_mask)

            ## Function Call
            a_pred = self.predictDuration(bert_encoder_with_style, input_lengths)

            # Multiply alignment matrix with h_text
            h_aligned = h_text @ a_pred.unsqueeze(0).to(self.device)  # (B, D_text, T_frames)

            # encode prosody
            bert_encoder_with_style_per_frame = (bert_encoder_with_style.transpose(-1, -2) @ a_pred.unsqueeze(0).to(self.device))  # Multiply per-phoneme embedding (bert_encoder_with_style) with frame-per-phoneme matrix -> per-frame text embedding
            f0_pred, n_pred = self.model.predictor.F0Ntrain(bert_encoder_with_style_per_frame, style_vector_acoustic)

        return InferenceResult(
            h_text=h_text,
            h_aligned=h_aligned,
            f0_pred=f0_pred,
            a_pred=a_pred,
            n_pred=n_pred,
            style_vector_prosodic=style_vector_prosodic,
        )

    @torch.no_grad()
    def synthesizeSpeech(self, inferenceResult):

        with torch.no_grad():
            out = self.model.decoder(
                inferenceResult.h_aligned,
                inferenceResult.f0_pred,
                inferenceResult.n_pred,
                inferenceResult.style_vector_prosodic.squeeze().unsqueeze(0)
            )

        return out.squeeze().cpu().numpy()

## Main Function

### Load Models

In [None]:
pipe = StyleTTS2_Helper()
pipe.load_models()  # builds self.model and loads self.params
pipe.load_checkpoints()  # puts params into self.model
pipe.sample_diffusion()  # builds self.sampler

### Initializes Values

In [None]:
diffusion_steps = 5
embedding_scale = 1

interpolation_percentage = 0.45 # How much of Target to be used, small interpolation_percentage means more of ground_truth (Minimization)

name_gt = "ground_truth"
text_gt = "This is a medium-length sentence, maybe it will work out?"

name_target = "target"
text_target = "This is a longer sentence to see how the model copes with different lengths"

noise_gt = torch.randn(1, 1, 256).to(pipe.device)
noise_target = torch.randn(1, 1, 256).to(pipe.device)

tokens_gt = pipe.preprocessText(text_gt)
tokens_target = pipe.preprocessText(text_target)
print("tokens_gt.shape (before zeros):", tokens_gt.shape)
print("tokens_gt (before zeros):", tokens_gt)
tokens_gt = addNumber(tokens_gt, tokens_target, 16)
print("tokens_gt.shape (after zeros):", tokens_gt.shape)
print("tokens_gt (after zeros):", tokens_gt)
print("tokens_target.shape :", tokens_target.shape)
print("tokens_target:", tokens_target)

### Inference

In [None]:
with torch.no_grad():
    # Number of phoneme / Length of tokens, shape[-1] = last element in list/array
    input_lengths_gt = torch.LongTensor([tokens_gt.shape[-1]]).to(tokens_gt.device)
    input_lengths_target = torch.LongTensor([tokens_target.shape[-1]]).to(tokens_target.device)

    # Creates a bitmask based on number of phonemes
    text_mask_gt = length_to_mask(input_lengths_gt).to(tokens_gt.device)
    text_mask_target = length_to_mask(input_lengths_target).to(tokens_target.device)

    # Creates acoustic text encoder (phoneme -> feature vectors)
    print("\n h_text:")
    h_text_gt = pipe.model.text_encoder(tokens_gt, input_lengths_gt, text_mask_gt)
    print(h_text_gt)
    h_text_target = pipe.model.text_encoder(tokens_target, input_lengths_target, text_mask_target)
    print(h_text_target)

    # Creates prosodic text encoder (phoneme -> feature vectors)
    h_bert_gt = pipe.model.bert(tokens_gt, attention_mask=(~text_mask_gt).int())
    h_bert_target = pipe.model.bert(tokens_target, attention_mask=(~text_mask_target).int())
    bert_encoder_gt = pipe.model.bert_encoder(h_bert_gt).transpose(-1, -2)
    bert_encoder_target = pipe.model.bert_encoder(h_bert_target).transpose(-1, -2)

    ## Function Call
    style_vector_gt_acoustic, style_vector_gt_prosodic = pipe.computeStyleVector(noise_gt, h_bert_gt, embedding_scale, diffusion_steps)
    style_vector_target_acoustic, style_vector_target_prosodic = pipe.computeStyleVector(noise_target, h_bert_target, embedding_scale, diffusion_steps)

    # AdaIN, Adding information of style vector to phoneme
    bert_encoder_gt_with_style = pipe.model.predictor.text_encoder(bert_encoder_gt, style_vector_gt_acoustic, input_lengths_gt, text_mask_gt)
    bert_encoder_target_with_style = pipe.model.predictor.text_encoder(bert_encoder_target, style_vector_target_acoustic, input_lengths_target, text_mask_target)

    ## Function Call
    a_pred_gt = pipe.predictDuration(bert_encoder_gt_with_style, input_lengths_gt)
    a_pred_target = pipe.predictDuration(bert_encoder_target_with_style, input_lengths_target)
    # a_pred_mixed = resizeAttribute(a_pred_target, a_pred_gt)
    print("\na_pred_gt:", a_pred_gt.shape)
    print("a_pred_target:", a_pred_target.shape)
    print("a_pred_mixed:", a_pred_mixed.shape)

    # Multiply alignment matrix with h_text
    h_aligned_gt = h_text_gt @ a_pred_gt.unsqueeze(0).to(pipe.device)  # (B, D_text, T_frames)
    h_aligned_target = h_text_target @ a_pred_target.unsqueeze(0).to(pipe.device)
    # h_aligned_mixed = h_text_gt @ a_pred_mixed.unsqueeze(0).to(pipe.device)  # (B, D_text, T_frames)
    # h_aligned_mixed = (1 - interpolation_percentage) * h_aligned_mixed + interpolation_percentage * h_aligned_target
    print("\nh_aligned_gt:", h_aligned_gt.shape)
    print("h_aligned_target:", h_aligned_target.shape)
    # print("h_aligned_mixed:", h_aligned_mixed.shape)

    # Multiply per-phoneme embedding (bert_encoder_with_style) with frame-per-phoneme matrix -> per-frame text embedding
    bert_encoder_gt_with_style_per_frame = (bert_encoder_gt_with_style.transpose(-1, -2) @ a_pred_gt.unsqueeze(0).to(pipe.device))
    bert_encoder_target_with_style_per_frame = (bert_encoder_target_with_style.transpose(-1, -2) @ a_pred_target.unsqueeze(0).to(pipe.device))
    # bert_encoder_mixed_with_style_per_frame = (bert_encoder_gt_with_style.transpose(-1, -2) @ a_pred_mixed.unsqueeze(0).to(pipe.device))

    f0_pred_gt, n_pred_gt = pipe.model.predictor.F0Ntrain(bert_encoder_gt_with_style_per_frame, style_vector_gt_acoustic)
    f0_pred_target, n_pred_target = pipe.model.predictor.F0Ntrain(bert_encoder_target_with_style_per_frame, style_vector_target_acoustic)
    # f0_pred_mixed, n_pred_mixed = pipe.model.predictor.F0Ntrain(bert_encoder_mixed_with_style_per_frame, style_vector_gt_acoustic)

inferenceResult_gt = InferenceResult(
    h_text=h_text_gt,
    h_aligned=h_aligned_gt,
    f0_pred=f0_pred_gt,
    a_pred=a_pred_gt,
    n_pred=n_pred_gt,
    style_vector_prosodic=style_vector_gt_prosodic,
)

inferenceResult_target = InferenceResult(
    h_text=h_text_target,
    h_aligned=h_aligned_target,
    f0_pred=f0_pred_target,
    a_pred=a_pred_target,
    n_pred=n_pred_target,
    style_vector_prosodic=style_vector_target_prosodic,
)

"""
inferenceResult_mixed = InferenceResult(
    h_text=h_text_gt,
    h_aligned=h_aligned_mixed,
    f0_pred=f0_pred_mixed,
    a_pred=a_pred_gt,
    n_pred=n_pred_mixed,
    style_vector_prosodic=style_vector_gt_prosodic,
)
"""

audio_gt = pipe.synthesizeSpeech(inferenceResult_gt)
audio_target = pipe.synthesizeSpeech(inferenceResult_target)
# audio_mixed = pipe.synthesizeSpeech(inferenceResult_mixed)

print("\nground truth")
display(ipd.Audio(audio_gt, rate=24000))
print("target")
display(ipd.Audio(audio_target, rate=24000))
# print("mixed")
# display(ipd.Audio(audio_mixed, rate=24000))

sf.write("outputs/audio/padding_empty_phonemes_target.wav", audio_gt, samplerate=24000)
sf.write("outputs/audio/padding_empty_phonemes_reference.wav", audio_target, samplerate=24000)


In [None]:
analyzeAudio(audio_gt)
analyzeAudio(audio_target)
analyzeAudio(audio_mixed)