In [1]:
from transformers import AutoModelForCausalLM, Trainer, TrainingArguments, AutoTokenizer
import torch
from snac import SNAC
import librosa
import numpy as np
from scipy.io.wavfile import write

model_name = "unsloth/orpheus-3b-0.1-pretrained"
model_path = "models/orpheus"
print(f"Using Model: {model_name}")

print("Loading Tokenizer")
tokenizer = AutoTokenizer.from_pretrained(model_path)
print("Loaded Tokenizer from File")

print("Setting device to cuda")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device now set to {device}")

print("Loading snac")
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
print("Completed loading snac")

print(f"Loading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16)
model.eval()
if torch.cuda.is_available():
    model = model.to('cuda')
print("Completed Loading Model")

  from .autonotebook import tqdm as notebook_tqdm


Using Model: unsloth/orpheus-3b-0.1-pretrained
Loading Tokenizer
Loaded Tokenizer from File
Setting device to cuda
Device now set to cuda
Loading snac
Completed loading snac
Loading model: unsloth/orpheus-3b-0.1-pretrained


Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  2.42it/s]


Completed Loading Model


In [2]:
def tokenize_audio(audio_file_path, snac_model):
    audio_array, sample_rate = librosa.load(audio_file_path, sr=24000)
    waveform = torch.from_numpy(audio_array).unsqueeze(0)
    waveform = waveform.to(dtype=torch.float32)

    waveform = waveform.unsqueeze(0)

    with torch.inference_mode():
        codes = snac_model.encode(waveform)

    all_codes = []
    for i in range(codes[0].shape[1]):
        all_codes.append(codes[0][0][i].item() + 128266)
        all_codes.append(codes[1][0][2 * i].item() + 128266 + 4096)
        all_codes.append(codes[2][0][4 * i].item() + 128266 + (2 * 4096))
        all_codes.append(codes[2][0][(4 * i) + 1].item() + 128266 + (3 * 4096))
        all_codes.append(codes[1][0][(2 * i) + 1].item() + 128266 + (4 * 4096))
        all_codes.append(codes[2][0][(4 * i) + 2].item() + 128266 + (5 * 4096))
        all_codes.append(codes[2][0][(4 * i) + 3].item() + 128266 + (6 * 4096))

    return all_codes

In [3]:
def prepare_inputs(
    fpath_audio_ref,
    audio_ref_transcript: str,
    text_prompts: list[str],
    snac_model,
    tokenizer,
):
    audio_tokens = tokenize_audio(fpath_audio_ref, snac_model)

    start_tokens = torch.tensor([[128259]], dtype=torch.int64)
    end_tokens = torch.tensor([[128009, 128260, 128261, 128257]], dtype=torch.int64)
    final_tokens = torch.tensor([[128258, 128262]], dtype=torch.int64)

    transcript_tokens = tokenizer(audio_ref_transcript, return_tensors="pt")

    # REF PROMPT TOKENS could be precomputed
    input_ids = transcript_tokens['input_ids']
    zeroprompt_input_ids = torch.cat([start_tokens, input_ids, end_tokens, torch.tensor([audio_tokens]), final_tokens],
                                     dim=1)  # SOH SOT Text EOT EOH

    # PROMPT TOKENS (what to say)
    all_modified_input_ids = []
    for prompt in text_prompts:
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        second_input_ids = torch.cat([zeroprompt_input_ids, start_tokens, input_ids, end_tokens], dim=1)
        all_modified_input_ids.append(second_input_ids)

    all_padded_tensors = []
    all_attention_masks = []
    max_length = max([modified_input_ids.shape[1] for modified_input_ids in all_modified_input_ids])

    for modified_input_ids in all_modified_input_ids:
        padding = max_length - modified_input_ids.shape[1]
        padded_tensor = torch.cat([torch.full((1, padding), 128263, dtype=torch.int64), modified_input_ids], dim=1)
        attention_mask = torch.cat([torch.zeros((1, padding), dtype=torch.int64),
                                    torch.ones((1, modified_input_ids.shape[1]), dtype=torch.int64)], dim=1)
        all_padded_tensors.append(padded_tensor)
        all_attention_masks.append(attention_mask)

    all_padded_tensors = torch.cat(all_padded_tensors, dim=0)
    all_attention_masks = torch.cat(all_attention_masks, dim=0)

    input_ids = all_padded_tensors.to("cuda")
    attention_mask = all_attention_masks.to("cuda")
    return input_ids, attention_mask

In [4]:
# Please pass your input's `attention_mask` to obtain reliable results.
# Setting `pad_token_id` to `eos_token_id`:128258 for open-end generation.
def inference(model, input_ids, attention_mask):
    with torch.no_grad():
        generated_ids = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            max_new_tokens=990,
            do_sample=True,
            temperature=0.5,
            # top_k=40,
            top_p=0.9,
            repetition_penalty=1.1,
            num_return_sequences=1,
            eos_token_id=128258,
            # end_token_id=128009
        )

        # dunno
        # generated_ids = torch.cat([generated_ids, torch.tensor([[128262]]).to("cuda")], dim=1) # EOAI

        return generated_ids

In [5]:
def redistribute_codes(code_list, snac_model, debug=True):
    """
    Redistribute codes for SNAC model and generate audio.
    
    Parameters:
    code_list (list): List of codes to redistribute
    snac_model: SNAC model for decoding
    debug (bool): Whether to print debug information
    
    Returns:
    torch.Tensor: Generated audio waveform
    """
    import torch
    
    # Make a copy to avoid modifying the original
    code_list = code_list.copy()
    
    # Filter out invalid codes (negative values) and replace with zeros
    for i in range(len(code_list)):
        if code_list[i] < 0:
            code_list[i] = 0
    
    # SNAC expects codes in multiples of 7
    # If not divisible by 7, add padding
    if len(code_list) % 7 != 0:
        padding_needed = 7 - (len(code_list) % 7)
        code_list.extend([0] * padding_needed)
    
    # Verify we have enough codes to work with
    if len(code_list) < 7:
        if debug:
            print("Warning: Not enough codes. Padding with zeros.")
        code_list = code_list + [0] * (7 - len(code_list))
    
    # Define vocab size for each layer
    vocab_size = 4096  # Standard SNAC vocabulary size per layer
    
    # Initialize containers for each layer
    layer_1 = []
    layer_2 = []
    layer_3 = []
    
    # Carefully process each group of 7 codes
    for i in range(len(code_list) // 7):
        start_idx = 7 * i
        
        # Safety check for each access to avoid index errors
        if start_idx < len(code_list):
            # First code goes to layer 1
            idx_0 = min(max(0, code_list[start_idx]), vocab_size - 1)
            layer_1.append(idx_0)
            
            # Second code goes to layer 2
            if start_idx + 1 < len(code_list):
                idx_1 = min(max(0, code_list[start_idx + 1] - vocab_size), vocab_size - 1)
                layer_2.append(idx_1)
            else:
                layer_2.append(0)  # Pad with zero if index out of range
            
            # Third code goes to layer 3
            if start_idx + 2 < len(code_list):
                idx_2 = min(max(0, code_list[start_idx + 2] - (2 * vocab_size)), vocab_size - 1)
                layer_3.append(idx_2)
            else:
                layer_3.append(0)
            
            # Fourth code goes to layer 3
            if start_idx + 3 < len(code_list):
                idx_3 = min(max(0, code_list[start_idx + 3] - (3 * vocab_size)), vocab_size - 1)
                layer_3.append(idx_3)
            else:
                layer_3.append(0)
            
            # Fifth code goes to layer 2
            if start_idx + 4 < len(code_list):
                idx_4 = min(max(0, code_list[start_idx + 4] - (4 * vocab_size)), vocab_size - 1)
                layer_2.append(idx_4)
            else:
                layer_2.append(0)
            
            # Sixth code goes to layer 3
            if start_idx + 5 < len(code_list):
                idx_5 = min(max(0, code_list[start_idx + 5] - (5 * vocab_size)), vocab_size - 1)
                layer_3.append(idx_5)
            else:
                layer_3.append(0)
            
            # Seventh code goes to layer 3
            if start_idx + 6 < len(code_list):
                idx_6 = min(max(0, code_list[start_idx + 6] - (6 * vocab_size)), vocab_size - 1)
                layer_3.append(idx_6)
            else:
                layer_3.append(0)
    
    # Debug info
    if debug:
        print(f"Layer 1 length: {len(layer_1)}, range: {min(layer_1) if layer_1 else 0}-{max(layer_1) if layer_1 else 0}")
        print(f"Layer 2 length: {len(layer_2)}, range: {min(layer_2) if layer_2 else 0}-{max(layer_2) if layer_2 else 0}")
        print(f"Layer 3 length: {len(layer_3)}, range: {min(layer_3) if layer_3 else 0}-{max(layer_3) if layer_3 else 0}")
    
    # Create tensors with proper shapes and correct dtype
    codes = [
        torch.tensor(layer_1, dtype=torch.long).unsqueeze(0),
        torch.tensor(layer_2, dtype=torch.long).unsqueeze(0),
        torch.tensor(layer_3, dtype=torch.long).unsqueeze(0)
    ]
    
    # Attempt to decode with the model
    try:
        audio_hat = snac_model.decode(codes)
        if debug:
            print(f"Successfully generated audio with shape: {audio_hat.shape}")
        return audio_hat
    except Exception as e:
        print(f"Error during decoding: {e}")
        
        # Progressive fallback strategy with multiple attempts
        try:
            # First fallback: Try with minimal valid inputs
            fallback_codes = [
                torch.tensor([0], dtype=torch.long).unsqueeze(0),  # Layer 1
                torch.tensor([0, 0], dtype=torch.long).unsqueeze(0),  # Layer 2
                torch.tensor([0, 0, 0, 0], dtype=torch.long).unsqueeze(0)  # Layer 3
            ]
            return snac_model.decode(fallback_codes)
        except Exception as e2:
            print(f"First fallback failed: {e2}")
            
            # Second fallback: Try with model-specific defaults if available
            try:
                # Some SNAC models have default values that can be used
                silence = getattr(snac_model, 'get_silence', lambda: None)()
                if silence is not None:
                    return silence
                else:
                    # Last resort: Generate simple sine wave as placeholder
                    import torch
                    import math
                    sample_rate = 24000  # Common for SNAC models
                    duration = 1  # 1 second
                    t = torch.arange(0, duration, 1/sample_rate)
                    sine_wave = torch.sin(2 * math.pi * 440 * t)  # 440 Hz sine wave
                    return sine_wave.unsqueeze(0)
            except Exception as e3:
                print(f"All fallbacks failed: {e3}")
                return None

In [6]:
def convert_tokens_to_speech(generated_ids, snac_model):
    token_to_find = 128257
    token_to_remove = 128258
    token_indices = (generated_ids == token_to_find).nonzero(as_tuple=True)

    if len(token_indices[1]) > 0:
        last_occurrence_idx = token_indices[1][-1].item()
        cropped_tensor = generated_ids[:, last_occurrence_idx + 1:]
    else:
        cropped_tensor = generated_ids

    _mask = cropped_tensor != token_to_remove
    processed_rows = []
    for row in cropped_tensor:
        # Apply the mask to each row
        masked_row = row[row != token_to_remove]
        processed_rows.append(masked_row)

    code_lists = []
    for row in processed_rows:
        # row is a 1D tensor with its own length
        row_length = row.size(0)
        new_length = (row_length // 7) * 7  # largest multiple of 7 that fits in this row
        trimmed_row = row[:new_length]
        trimmed_row = [t - 128266 for t in trimmed_row]
        code_lists.append(trimmed_row)

    my_samples = []
    for code_list in code_lists:
        samples = redistribute_codes(code_list, snac_model)
        my_samples.append(samples)

    return my_samples

In [7]:
def to_wav_from(samples: list) -> list[np.ndarray]:
    """Converts a list of PyTorch tensors (or NumPy arrays) to NumPy arrays."""
    processed_samples = []

    for s in samples:
        if isinstance(s, torch.Tensor):  # Check if it's a tensor
            s = s.detach().squeeze().to('cpu').numpy()
        else:  # Assume it's already a NumPy array
            s = np.squeeze(s)

        processed_samples.append(s)

    return processed_samples

In [8]:
def zero_shot_tts(fpath_audio_ref, audio_ref_transcript, texts: list[str], model, snac_model, tokenizer):
    inp_ids, attn_mask = prepare_inputs(fpath_audio_ref, audio_ref_transcript, texts, snac_model, tokenizer)
    gen_ids = inference(model, inp_ids, attn_mask)
    samples = convert_tokens_to_speech(gen_ids, snac_model)
    wav_forms = to_wav_from(samples)
    return wav_forms

In [9]:
def save_wav(samples: list[np.array], sample_rate: int, filenames: list[str]):
    """ Saves a list of tensors as .wav files.

    Args:
        samples (list[torch.Tensor]): List of audio tensors.
        sample_rate (int): Sample rate in Hz.
        filenames (list[str]): List of filenames to save.
    """
    wav_data = to_wav_from(samples)

    for data, filename in zip(wav_data, filenames):
        write(filename, sample_rate, data.astype(np.float32))
        print(f"saved to {filename}")

In [None]:
texts = [
    "He questioned what'll be the result of the experiment.",
    "He would'nt have done it if he knew the consequences.",
    "If we were'nt focused, wed miss the opportunity.",
    "She would'nt even try to resolve the issue.",
    # "They were'nt always this careful with their work.",
    # "They were'nt participating in the survey.",
    # "They've attended the conference.",
    # "They've made their decision.",
    # "We were'nt able to find the solution.",
    # "We'd like to go hiking this weekend.",
    # "We'd worked all day on the project.",
    # "We're concerned about the upcoming deadline.",
]

prompt_pairs = [
    # ("pre_audio/01.wav", "This British etiquette is often passed down through many generations from family to family to family, which has led to a very strong cultural emphasis on politeness. Politeness even in everyday interactions and being mindful about not causing inconvenience to your fellow human beings."),
    ("pre_audio/02.wav", "This is another phrase that will come in very handy for you all when it comes to sounding polite."),
    # ("pre_audio/03.wav", "A classic example is when you're at the dinner table and you ask someone to pass you the salt.")
]

for fpath_audio, audio_transcript in prompt_pairs:
    print(f"zero shot: {fpath_audio} {audio_transcript}")
    wav_forms = zero_shot_tts(fpath_audio, audio_transcript, texts, model, snac_model, tokenizer)

    # import os
    from pathlib import Path


    out_dir = Path('out')
    out_dir.mkdir(parents=True, exist_ok=True)  # Correct method
    file_names = [f"{out_dir.as_posix()}/{Path(fpath_audio).stem}_{i}.wav" for i, t in enumerate(texts)]
    save_wav(wav_forms, 24000, file_names)

zero shot: pre_audio/02.wav This is another phrase that will come in very handy for you all when it comes to sounding polite.
Layer 1 length: 81, range: 0-4087
Layer 2 length: 162, range: 0-4085
Layer 3 length: 324, range: 0-4082
Successfully generated audio with shape: torch.Size([1, 1, 165888])
Layer 1 length: 81, range: 48-4091
Layer 2 length: 162, range: 19-4080
Layer 3 length: 324, range: 14-4079
Successfully generated audio with shape: torch.Size([1, 1, 165888])
Layer 1 length: 81, range: 0-4095
Layer 2 length: 162, range: 0-4060
Layer 3 length: 324, range: 0-4065
Successfully generated audio with shape: torch.Size([1, 1, 165888])
saved to out/02_0.wav
saved to out/02_1.wav
saved to out/02_2.wav
