In [2]:
from pathlib import Path
import time
import json

import IPython.display as ipd
import torch
from transformers import AutoTokenizer
from transformers.modeling_outputs import BaseModelOutput

from parler_tts import (
    ParlerTTSConfig,
    ParlerTTSForConditionalGeneration,
)

  from .autonotebook import tqdm as notebook_tqdm


### Load args

In [3]:
model_args_path = "/shared/production_tts/jordana_tts_args_v2.json"
model_args = json.load(open(model_args_path))

### Load model and text tokenizer

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the text tokenizer
prompt_tokenizer = AutoTokenizer.from_pretrained(
    model_args["prompt_tokenizer_name"],
    cache_dir=None,
    use_fast=True,
    padding_side="left",
)

# Load the model config
config = ParlerTTSConfig.from_pretrained(
    model_args["model_name_or_path"],
    cache_dir=None,
)

# Load the main model
model = ParlerTTSForConditionalGeneration.from_pretrained(
    model_args["model_name_or_path"],
    cache_dir=None,
    config=config,
)

model.to(device)
model.eval()



ParlerTTSForConditionalGeneration(
  (text_encoder): T5EncoderModel(
    (shared): Embedding(32128, 768)
    (encoder): T5Stack(
      (embed_tokens): Embedding(32128, 768)
      (block): ModuleList(
        (0): T5Block(
          (layer): ModuleList(
            (0): T5LayerSelfAttention(
              (SelfAttention): T5Attention(
                (q): Linear(in_features=768, out_features=768, bias=False)
                (k): Linear(in_features=768, out_features=768, bias=False)
                (v): Linear(in_features=768, out_features=768, bias=False)
                (o): Linear(in_features=768, out_features=768, bias=False)
                (relative_attention_bias): Embedding(32, 12)
              )
              (layer_norm): T5LayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (1): T5LayerFF(
              (DenseReluDense): T5DenseGatedActDense(
                (wi_0): Linear(in_features=768, out_features=2048, bias=False)
               

In [5]:
gen_kwargs = {
    "do_sample": model_args["do_sample"],
    "temperature": model_args["temperature"],
    "max_length": model_args["max_length"],
    "min_new_tokens": model_args["num_codebooks"] + 1,
}

### Load audio reference and specify sentence

In [14]:
audio_ref_embedding_root = Path("/shared/production_tts/style_embeddings/")

style = "default-default"

audio_ref_embedding_path = audio_ref_embedding_root / f"{style}.pt"

test_sentence = "I guess if you've truly hit rock bottom, the only place is up."

### Specify batch and audio normalization

In [15]:
generate_batch = True
normalize = True

### Generate

In [16]:
encoder_outputs = torch.load(audio_ref_embedding_path).unsqueeze(0).to(device)
attention_mask = torch.ones((1, 1), dtype=torch.long).to(
    device
)  # Encoder outputs is a single non-padded vector
encoder_outputs = BaseModelOutput(encoder_outputs)

prompt = prompt_tokenizer(test_sentence, return_tensors="pt")
prompt_input_ids = prompt["input_ids"].to(device)
prompt_attention_mask = prompt["attention_mask"].to(device)

# Pad prompt_input_ids and prompt_attention_mask to data_args.max_prompt_token_length, with leading zeros
zero_padding = torch.zeros(
    (1, model_args["max_prompt_token_length"] - prompt_input_ids.shape[1]),
    dtype=torch.long,
).to(device)
prompt_input_ids = torch.cat((zero_padding, prompt_input_ids), dim=1)
prompt_attention_mask = torch.cat((zero_padding, prompt_attention_mask), dim=1)

if generate_batch:
    batch = {}
    # Create 8 copies of the encoder_outputs, attention_mask, prompt_input_ids, and prompt_attention_mask
    encoder_outputs = BaseModelOutput(
        last_hidden_state=encoder_outputs.last_hidden_state.repeat(8, 1, 1),
        hidden_states=encoder_outputs.hidden_states,
    )
    attention_mask = attention_mask.repeat(8, 1)
    prompt_input_ids = prompt_input_ids.repeat(8, 1)
    prompt_attention_mask = prompt_attention_mask.repeat(8, 1)
    batch["encoder_outputs"] = encoder_outputs
    batch["attention_mask"] = attention_mask
    batch["prompt_input_ids"] = prompt_input_ids
    batch["prompt_attention_mask"] = prompt_attention_mask

    start_time = time.time()
    output_audios = model.generate(
        encoder_outputs=encoder_outputs,
        attention_mask=attention_mask,
        prompt_input_ids=prompt_input_ids,
        prompt_attention_mask=prompt_attention_mask,
        **gen_kwargs,
    )
    end_time = time.time()
    print(f"Time taken for batch: {end_time - start_time}")

    for i, audio in enumerate(output_audios):
        print(f"Audio {i+1}")
        ipd.display(
            ipd.Audio(
                audio.cpu(),
                rate=model_args["discrete_audio_feature_sample_rate"],
                normalize=normalize,
            )
        )

else:
    start_time = time.time()
    output_audios = model.generate(
        encoder_outputs=encoder_outputs,
        attention_mask=attention_mask,
        prompt_input_ids=prompt_input_ids,
        prompt_attention_mask=prompt_attention_mask,
        **gen_kwargs,
    )
    end_time = time.time()
    print(f"Time taken for batch: {end_time - start_time}")

    ipd.Audio(
        output_audios[0].cpu(),
        rate=model_args["discrete_audio_feature_sample_rate"],
        normalize=normalize,
    )

Time taken for batch: 4.135527610778809
Audio 1


Audio 2


Audio 3


Audio 4


Audio 5


Audio 6


Audio 7


Audio 8
