In [39]:
import torch
from parler_tts import ParlerTTSForConditionalGeneration  # type: ignore
from transformers import AutoTokenizer
from IPython.display import Audio, display

In [40]:
def set_seeds(seed: int) -> None:
    """Set the seeds for all RNGs in torch"""
    torch.manual_seed(seed)  # Set CPU seed
    # Set GPU seeds
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # Make torch algos deterministic
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    
if not torch.cuda.is_available():
    print("WARNING: CUDA is not available. Running on CPU.")

device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
# Load model and tokenizer
MODEL_PATH = "parler-tts/parler-tts-mini-v1"
TOKENIZER_PATH = "parler-tts/parler-tts-large-v1"
model = ParlerTTSForConditionalGeneration.from_pretrained(MODEL_PATH).to(device)
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_PATH)



In [82]:
prompt = "Hey, how are you doing today? Can you tell that my voice changes across the seeds?"
description = "Jon's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
input_ids = tokenizer(description, return_tensors="pt").to(device).input_ids
prompt_input_ids = tokenizer(prompt, return_tensors="pt").to(device).input_ids

In [83]:
def multiple_seeded_tts_generations(gen_kwargs: dict, start_seed: int = 42, iterations: int = 3):
    generations = []
    for i in range(iterations):
        start_seed += i
        set_seeds(start_seed)
        output = model.generate(**gen_kwargs)
        generations.append(output[0].cpu().numpy())
    return generations

In [84]:
gen_kwargs = dict(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
synths = multiple_seeded_tts_generations(gen_kwargs, start_seed=42)

In [85]:
for synth in synths:
  a = Audio(synth, rate=model.config.sampling_rate)
  ipd.display(a)

### Continue generation with Enrolment

In [86]:
# Now encode this to pass in as enrolment
idx = 0 # The best audio sample (note the index starts from 0)

audio = torch.Tensor(synths[idx]) # Convert to tensor
audio = audio.unsqueeze(0).unsqueeze(0) # Add batch and channel dimension

encodeds = model.audio_encoder.encode(audio.to(device))
encodeds = encodeds["audio_codes"].squeeze().long() # This is the encoded enrolment data
encodeds.shape

torch.Size([9, 392])

In [87]:
# Continue the prompt
additional_prompt = "Does my voice now match the voice you selected?"
new_prompt = prompt + " " + additional_prompt
prompt_input_ids = tokenizer(new_prompt, return_tensors="pt").to(device).input_ids

# Generate the new audio
gen_kwargs = dict(input_ids=input_ids, prompt_input_ids=prompt_input_ids, decoder_input_ids=encodeds)
# NOTE: we change the seed to to see if the voice is more consistent without having the same seed
synths = multiple_seeded_tts_generations(gen_kwargs, start_seed=4242)

In [88]:
for synth in synths:
  # Remove the enrolment prompt from the new prompt
  s = synth[audio.shape[-1]:]
  a = Audio(s, rate=model.config.sampling_rate)
  ipd.display(a)

## Comparison against no enrolment

In [89]:
# Continue the prompt
prompt_input_ids = tokenizer(additional_prompt, return_tensors="pt").to(device).input_ids
# Generate the new audio
gen_kwargs = dict(input_ids=input_ids, prompt_input_ids=prompt_input_ids)
# Keep the same seed
synths = multiple_seeded_tts_generations(gen_kwargs, start_seed=4242)

In [90]:
for synth in synths:
  a = Audio(synth, rate=model.config.sampling_rate)
  ipd.display(a)