In [1]:
import base64
import io

from IPython.core.display import HTML
from IPython.display import Audio
import librosa
import matplotlib.pyplot as plt
import numpy as np
import soundfile as sf
import torch

from models import DelightfulHiFi
from models.config import PreprocessingConfig
from models.tts.delightful_tts.delightful_tts_refined import DelightfulTTS
from models.vocoder.hifigan import HifiGan
from training.datasets.hifi_libri_dataset import HifiLibriDataset
from training.preprocess import TacotronSTFT

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sample_rate = 44100

from nemo.collections.tts.models import HifiGanModel


### Load the hifi-gan weights

In [2]:
hifigan_model = HifiGanModel.from_pretrained(model_name="nvidia/tts_hifigan")
hifigan_model

[NeMo W 2024-04-22 08:41:07 nemo_logging:393] If you intend to do training or fine-tuning, please call the ModelPT.setup_training_data() method and provide a valid configuration file to setup the train data loader.
    Train config : 
    dataset:
      _target_: nemo.collections.tts.data.datalayers.MelAudioDataset
      manifest_filepath: /home/fkreuk/data/train_finetune.txt
      min_duration: 0.75
      n_segments: 8192
    dataloader_params:
      drop_last: false
      shuffle: true
      batch_size: 64
      num_workers: 4
    
[NeMo W 2024-04-22 08:41:07 nemo_logging:393] If you intend to do validation, please call the ModelPT.setup_validation_data() or ModelPT.setup_multiple_validation_data() method and provide a valid configuration file to setup the validation data loader(s). 
    Validation config : 
    dataset:
      _target_: nemo.collections.tts.data.datalayers.MelAudioDataset
      manifest_filepath: /home/fkreuk/data/val_finetune.txt
      min_duration: 3
      n_segmen

[NeMo I 2024-04-22 08:41:07 nemo_logging:381] PADDING: 0


[NeMo W 2024-04-22 08:41:07 nemo_logging:393] Using torch_stft is deprecated and has been removed. The values have been forcibly set to False for FilterbankFeatures and AudioToMelSpectrogramPreprocessor. Please set exact_pad to True as needed.


[NeMo I 2024-04-22 08:41:07 nemo_logging:381] PADDING: 0


    


[NeMo I 2024-04-22 08:41:09 nemo_logging:381] Model HifiGanModel was successfully restored from /home/you/.cache/huggingface/hub/models--nvidia--tts_hifigan/snapshots/3ba1fed954276287015654bf4c78060ffc9a4772/tts_hifigan.nemo.


HifiGanModel(
  (audio_to_melspec_precessor): FilterbankFeatures()
  (trg_melspec_fn): FilterbankFeatures()
  (generator): Generator(
    (conv_pre): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))
    (ups): ModuleList(
      (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
      (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
      (2): ConvTranspose1d(128, 64, kernel_size=(4,), stride=(2,), padding=(1,))
      (3): ConvTranspose1d(64, 32, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (resblocks): ModuleList(
      (0): ModuleList(
        (0): ResBlock1(
          (convs1): ModuleList(
            (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
            (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
            (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
          )
          (convs2): ModuleList(
            (0

In [2]:
delightful_checkpoint_path = "checkpoints/epoch=555-step=66164.ckpt"
hifi_gan_checkpoint_path = "checkpoints/logs_44100_vocoder_Mel44100_WAV44100_epoch=19-step=44480.ckpt"

In [4]:
hifi_gan = HifiGan.load_from_checkpoint(hifi_gan_checkpoint_path)
hifi_gan

    


HifiGan(
  (generator): Generator(
    (conv_pre): Conv1d(80, 512, kernel_size=(7,), stride=(1,), padding=(3,))
    (ups): ModuleList(
      (0): ConvTranspose1d(512, 256, kernel_size=(16,), stride=(8,), padding=(4,))
      (1): ConvTranspose1d(256, 128, kernel_size=(16,), stride=(8,), padding=(4,))
      (2): ConvTranspose1d(128, 64, kernel_size=(4,), stride=(4,))
      (3): ConvTranspose1d(64, 32, kernel_size=(4,), stride=(2,), padding=(1,))
    )
    (resblocks): ModuleList(
      (0): ModuleList(
        (0): ResBlock1(
          (convs1): ModuleList(
            (0): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
            (1): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(3,), dilation=(3,))
            (2): Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(5,), dilation=(5,))
          )
          (convs2): ModuleList(
            (0-2): 3 x Conv1d(256, 256, kernel_size=(3,), stride=(1,), padding=(1,))
          )
        )
        (1): ResBloc

### Load the DelightfulTTS from the checkpoint

In [3]:
delightful_tts = DelightfulTTS.load_from_checkpoint(delightful_checkpoint_path)
delightful_tts


 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars.


DelightfulTTS(
  (acoustic_model): AcousticModel(
    (encoder): Conformer(
      (layer_stack): ModuleList(
        (0-5): 6 x ConformerBlock(
          (conditioning): Conv1dGLU(
            (bsconv1d): BSConv1d(
              (pointwise): PointwiseConv1d(
                (conv): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))
              )
              (depthwise): DepthWiseConv1d(
                (conv): Conv1d(1024, 1024, kernel_size=(7,), stride=(1,), padding=(3,), groups=1024)
              )
            )
            (embedding_proj): Linear(in_features=1280, out_features=512, bias=True)
            (softsign): Softsign()
          )
          (ff): FeedForward(
            (dropout): Dropout(p=0.1, inplace=False)
            (ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (conv_1): Conv1d(512, 2048, kernel_size=(3,), stride=(1,), padding=(1,))
            (act): LeakyReLU(negative_slope=0.3)
            (conv_2): Conv1d(2048, 512, kernel_size=(1,), 

In [3]:
# Load the model with the checkpoints

model = DelightfulHiFi(
    delightful_checkpoint_path=delightful_checkpoint_path,
    hifi_checkpoint_path=hifi_gan_checkpoint_path,
)
model


 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars.
 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars.
    


DelightfulHiFi(
  (delightful_tts): DelightfulTTS(
    (acoustic_model): AcousticModel(
      (encoder): Conformer(
        (layer_stack): ModuleList(
          (0-5): 6 x ConformerBlock(
            (conditioning): Conv1dGLU(
              (bsconv1d): BSConv1d(
                (pointwise): PointwiseConv1d(
                  (conv): Conv1d(512, 1024, kernel_size=(1,), stride=(1,))
                )
                (depthwise): DepthWiseConv1d(
                  (conv): Conv1d(1024, 1024, kernel_size=(7,), stride=(1,), padding=(3,), groups=1024)
                )
              )
              (embedding_proj): Linear(in_features=1280, out_features=512, bias=True)
              (softsign): Softsign()
            )
            (ff): FeedForward(
              (dropout): Dropout(p=0.1, inplace=False)
              (ln): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
              (conv_1): Conv1d(512, 2048, kernel_size=(3,), stride=(1,), padding=(1,))
              (act): LeakyReLU(

In [10]:
# Load the hifi-gan descriminator checkpoint

disc_checkpoint_path = "checkpoints/do_02500000"
disc_checkpoint = torch.load(disc_checkpoint_path, map_location='cuda')
disc_checkpoint.keys()

{'mpd': OrderedDict([('discriminators.0.convs.0.bias',
               tensor([ 0.0226,  0.1185, -0.1512,  0.0198,  0.0208,  0.0752,  0.0573,  0.0813,
                       -0.0063,  0.0535, -0.1612,  0.0348,  0.0291, -0.1596,  0.0112, -0.0021,
                        0.0135, -0.1671,  0.0219,  0.0354,  0.0921,  0.0782, -0.1423, -0.1023,
                       -0.1484,  0.0767,  0.0113, -0.0996, -0.0010, -0.0013,  0.0528,  0.0160],
                      device='cuda:0')),
              ('discriminators.0.convs.0.weight_g',
               tensor([[[[ 1.2726e+01]]],
               
               
                       [[[ 1.2378e+01]]],
               
               
                       [[[ 1.6236e+00]]],
               
               
                       [[[ 6.0151e+00]]],
               
               
                       [[[ 1.2137e+00]]],
               
               
                       [[[ 9.8345e+00]]],
               
               
                       [[[ 

In [11]:
# Load the discriminator weights into the hifi-gan model

hifi_gan.discriminator.MPD.load_state_dict(disc_checkpoint['mpd'])
hifi_gan.discriminator.MSD.load_state_dict(disc_checkpoint['msd'])

<All keys matched successfully>

In [2]:
# Same for generator

gen_checkpoint_path = "checkpoints/generator_v1"
gen_checkpoint = torch.load(gen_checkpoint_path, map_location='cuda')

len(gen_checkpoint)

1

In [3]:
hifi_gan.generator.load_state_dict(gen_checkpoint['generator'])

<All keys matched successfully>

In [3]:
# Path to the checkpoints
delightful_checkpoint_path = "checkpoints/logs_22050_tts_epoch=114-step=6900.ckpt"

# epoch=24-step=14200.ckpt # epoch=61-step=35216.ckpt
hifi_checkpoint = "epoch=129-step=14300"
hifi_checkpoint_path = f"checkpoints/{hifi_checkpoint}.ckpt"

In [4]:
# Load the model
# model = DelightfulHiFi(
#     delightful_checkpoint_path=delightful_checkpoint_path,
#     hifi_checkpoint_path=hifi_checkpoint_path,
# )

# delightful_tts = DelightfulTTS.load_from_checkpoint(
#     delightful_checkpoint_path,
# )

preprocess_config = PreprocessingConfig("english_only", 44100)
tacotronSTFT = TacotronSTFT(
    filter_length=preprocess_config.stft.filter_length,
    hop_length=preprocess_config.stft.hop_length,
    win_length=preprocess_config.stft.win_length,
    n_mel_channels=preprocess_config.stft.n_mel_channels,
    sampling_rate=preprocess_config.sampling_rate,
    mel_fmin=preprocess_config.stft.mel_fmin,
    mel_fmax=preprocess_config.stft.mel_fmax,
    center=False,
)
tacotronSTFT = tacotronSTFT.to(device)

In [5]:
dataset = HifiLibriDataset()
dataset[0].mel


 NeMo-text-processing :: INFO     :: Creating ClassifyFst grammars.


tensor([[ -3.5546,  -3.7388,  -3.5969,  ...,  -7.4779,  -7.7555,  -7.4785],
        [ -3.1470,  -3.0982,  -3.0717,  ...,  -9.2397,  -8.6195,  -8.1689],
        [ -3.4514,  -2.3811,  -0.8866,  ...,  -9.3172,  -9.3316,  -8.7459],
        ...,
        [ -5.1811,  -4.5938,  -4.3260,  ..., -10.0529, -10.1092, -10.1201],
        [ -4.9618,  -4.9986,  -4.6231,  ...,  -9.8174,  -9.7945,  -9.9479],
        [ -4.9749,  -4.9377,  -4.8909,  ..., -10.0607,  -9.9447, -10.1064]])

In [11]:
mel = dataset[15].mel

with torch.no_grad():
    wav = model.hifi_gan.forward(mel.to(device))

# Save the audio to a file
sf.write(f"results/hifi-{hifi_checkpoint}.wav", wav.squeeze().cpu().numpy(), sample_rate)

Audio(wav.squeeze().cpu().numpy(), rate=sample_rate)


In [12]:
# Libri-speaker
mel = dataset[10111].mel

with torch.no_grad():
    wav = model.hifi_gan.forward(mel.to(device))

# Save the audio to a file
# sf.write("results/hifi-step_46008.wav", wav.squeeze().cpu().numpy(), sample_rate)

Audio(wav.squeeze().cpu().numpy(), rate=sample_rate)


In [12]:
mel.shape

torch.Size([160, 315])

In [4]:
def plot_spectrogram(mel: np.ndarray):
    r"""Plots the mel spectrogram."""
    plt.figure(dpi=80, figsize=(10, 3))

    img = librosa.display.specshow(mel, x_axis="time", y_axis="mel", sr=sample_rate)
    plt.title("Spectrogram")
    plt.colorbar(img, format="%+2.0f dB")

    # Save the plot to a BytesIO object
    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)

    # Convert the BytesIO object to a base64 string
    img_str = base64.b64encode(buf.read()).decode("utf-8")
    plt.close()

    return img_str

In [27]:
model.delightful_tts.acoustic_model.emb_g(torch.tensor(99, device=device))

tensor([-0.3965,  0.3189, -1.8614,  ..., -0.8290, -0.2550, -0.2367],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

In [28]:
from training.datasets.hifi_libri_dataset import selected_speakers_ids

selected_speakers_ids

{'Cori Samuel': 0,
 'Tony Oliva': 1,
 'John Van Stan': 2,
 'Helen Taylor': 3,
 '40': 4,
 '1088': 5,
 '3307': 6,
 '5935': 7,
 '215': 8,
 '6594': 9,
 '3867': 10,
 '5733': 11,
 '5181': 12}

In [32]:
len(dataset)

145531

In [8]:
text_tts = """As the snake shook its head, a deafening shout behind Harry made both of them jump.
‘DUDLEY! MR DURSLEY! COME AND LOOK AT THIS SNAKE! YOU WON’T BELIEVE WHAT IT’S DOING!’
"How did you know it was me?" she asked.
"My dear Professor, I’ve never seen a cat sit so stiffly."
"You’d be stiff if you’d been sitting on a brick wall all day," said Professor McGonagall.
"""

normalized_text = model.normilize_text(text_tts)
normalized_text

"As the snake shook its head, a deafening shout behind Harry made both of them jump.;; 'DUDLEY! MR DURSLEY! COME AND LOOK AT THIS SNAKE! YOU WON'T BELIEVE WHAT IT'S DOING!';; 'How did you know it was me?' she asked.;; 'My dear Professor, I've never seen a cat sit so stiffly.';; 'You'd be stiff if you'd been sitting on a brick wall all day,' said Professor McGonagall.;; "

In [14]:
phon, tokens = model.tokenizer(normalized_text)
phon

'æz ðə snˈeɪk ʃˈʊk ɪts hˈɛd, ɐ dˈɛfənɪŋ ʃˈaʊt bᵻhˌaɪnd hˈæɹi mˌeɪd bˈoʊθ ʌv ðˌɛm dʒˈʌmp.;; dˈʌdli! mˈɪstɚ dˈɜːsli! kˈʌm ænd lˈʊk æt ðɪs snˈeɪk! juː woʊnt bᵻlˈiːv wʌt ˈɪts dˈuːɪŋ!;; hˌaʊ dˈɪd juː nˈoʊ ɪt wʌz mˌiː?ʃiː ˈæskt.;; maɪ dˈɪɹ pɹəfˈɛsɚ, aɪv nˈɛvɚ sˈiːn ɐ kˈæt sˈɪt sˌoʊ stˈɪfli.;; juːd biː stˈɪf ɪf juːd bˌɪn sˈɪɾɪŋ ˌɔn ɐ bɹˈɪk wˈɔːl ˈɔːl dˈeɪ,sˈɛd pɹəfˈɛsɚ mə ɡˈɑːneɪɡˌɔːl.;; '

In [13]:
"".join(model.tokenizer.tokenizer.decode(tokens))

'<en_us>æz ðə snˈeɪk ʃˈʊk ɪts hˈɛd, ɐ dˈɛfənɪŋ ʃˈaʊt bᵻhˌaɪnd hˈæɹi mˌeɪd bˈoʊθ ʌv ðˌɛm dʒˈʌmp.;; dˈʌdli! mˈɪstɚ dˈɜːsli! kˈʌm ænd lˈʊk æt ðɪs snˈeɪk! juː woʊnt bᵻlˈiːv wʌt ˈɪts dˈuːɪŋ!;; hˌaʊ dˈɪd juː nˈoʊ ɪt wʌz mˌiː?ʃiː ˈæskt.;; maɪ dˈɪɹ pɹəfˈɛsɚ, aɪv nˈɛvɚ sˈiːn ɐ kˈæt sˈɪt sˌoʊ stˈɪfli.;; juːd biː stˈɪf ɪf juːd bˌɪn sˈɪɾɪŋ ˌɔn ɐ bɹˈɪk wˈɔːl ˈɔːl dˈeɪ,sˈɛd pɹəfˈɛsɚ mə ɡˈɑːneɪɡˌɔːl.;; <end>'

In [5]:
# text_tts = """As the snake shook its head, a deafening shout behind Harry made both of them jump.
# ‘DUDLEY! MR DURSLEY! COME AND LOOK AT THIS SNAKE! YOU WON’T BELIEVE WHAT IT’S DOING!’
# "How did you know it was me?" she asked.
# "My dear Professor, I’ve never seen a cat sit so stiffly."
# "You’d be stiff if you’d been sitting on a brick wall all day," said Professor McGonagall.
# """
#
# html = f"""<table border='1'>
# <h4>TTS: </h4> {text_tts}
# <h4>Speakers: </h4>
# <tr>
#     <th>SpeakerID</th>
#     <th>Audio</th>
#     <th>Mel</th>
# </tr>
# """

text_tts = """As the snake shook its head, a deafening shout behind Harry made both of them jump."""

# text_tts = """‘DUDLEY! MR DURSLEY! COME AND LOOK AT THIS SNAKE! YOU WON’T BELIEVE WHAT IT’S DOING!’"""


# text_tts = """As the snake shook its head, a deafening shout behind Harry made both of them jump.
# ‘DUDLEY! MR DURSLEY! COME AND LOOK AT THIS SNAKE! YOU WON’T BELIEVE WHAT IT’S DOING!’
# """

html = f"""<table border='1'>
<h4>TTS: </h4> {text_tts}
<h4>Speakers: </h4>
<tr>
    <th>SpeakerID</th>
    <th>Audio</th>
</tr>
"""


for speaker_id in [0, 1, 2, 3, 4]:
    with torch.no_grad():
        speaker_id_ = torch.tensor([int(speaker_id)], device=device)
        wav = model.forward(text_tts, speaker_id_)
        # mel = delightful_tts.forward(text_tts, speaker_id_)
        # wav = hifigan_model.convert_spectrogram_to_audio(spec=mel)

        mel = tacotronSTFT.get_mel_from_wav(wav.squeeze())
        mel_base64 = plot_spectrogram(mel.detach().cpu().numpy())

        # with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as wav_file:
        #     sf.write(
        #         wav_file.name,
        #         wav.squeeze().cpu().numpy(),
        #         sample_rate,
        #     )
            # Convert wav to mp3
            # audio = AudioSegment.from_wav(wav_file.name)
            # with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as mp3_file:
            #     audio.export(mp3_file.name, format="mp3")
            #     mp3_base64 = base64.b64encode(mp3_file.read()).decode('utf-8')

            # Add a row to the HTML table
            # html += f"""<tr>
            #     <td>{speaker_id}</td>
            #     <td><audio controls><source src="data:audio/mp3;base64,{mp3_base64}"></audio></td>
            #     <td><img src='data:image/png;base64,{mel_base64}' /></td>
            # </tr>"""

        # Add a row to the HTML table
        html += f"""<tr>
            <td>{speaker_id}</td>
            <td>{Audio(wav.squeeze().cpu().numpy(), rate=44100)._repr_html_()}</td>
            <td><img src='data:image/png;base64,{mel_base64}' /></td>
        </tr>"""

# Save result as HTML
with open("logs/demo.html", "w") as f:
    f.write(html)

HTML(html)

NameError: name 'tacotronSTFT' is not defined