# 😻 Kani TTS - Fast and Expressive Speech Generation Model

[![](https://dcbadge.limes.pink/api/server/https://discord.gg/4fZ4mjD3)](https://discord.gg/4fZ4mjD3)

### Welcome to Kani TTS Fine-Tuning! Here you can adapt our breakthrough neural text-to-speech model to your own speaker, creating a personalized voice with the same speed and quality of generation.

<img src="https://www.nineninesix.ai/kitty.png" width="300">

In [None]:
#@title Installing dependencies

logo = """
===============================================
          N I N E N I N E S I X  😼
===============================================

          /\\_/\\
         ( -.- )───┐
          > ^ <    │
===============================================

"""
print(logo)

!pip install transformers==4.54.0 trl>=0.18.2 peft>=0.15.2

import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import SFTConfig, SFTTrainer
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import trl
import os
from dataclasses import dataclass, field
from typing import Optional, List

os.environ["WANDB_DISABLED"] = "true"

print(f"📦 PyTorch version: {torch.__version__}")
print(f"🤗 Transformers version: {transformers.__version__}")
print(f"📊 TRL version: {trl.__version__}")


          N I N E N I N E S I X  😼

          /\_/\
         ( -.- )───┐
          > ^ <    │


📦 PyTorch version: 2.8.0+cu126
🤗 Transformers version: 4.54.0
📊 TRL version: 0.20.0


# Login to HuggingFace

In [None]:
!git config --global credential.helper store
!hf auth login

## YOUR BASE MODEL

In [None]:
MODEL_ID = "nineninesix/kani-tts-450m-0.2-pt"

# Nano Dataset

The following section provides functionality for preprocessing your dataset, which is stored in the form of audio tokens produced by the NVIDIA Nano codec.



---

‼️ This [repository](https://github.com/nineninesix-ai/nano-codec-dataset-pipeline) will help you prepare your own dataset by tokenizing it using Nano Codec.

---



In [None]:
#@title Make Dataset Config

"""
Multi-Speaker TTS Dataset Configuration
========================================

This configuration system enables flexible dataset construction for training Text-to-Speech
models with optional multi-speaker support. It allows you to combine multiple Hugging Face
datasets, where each dataset can represent a different speaker or voice characteristic.

Core Concepts
-------------

1. **Multi-Speaker Training**:
   Each HFDataset can be configured with a unique `speaker_id` that gets prepended to every
   text example before tokenization. For example, if speaker_id="alice", the training sample
   "Hello world" becomes "alice: Hello world". This conditions the model to generate speech
   in that speaker's voice.

2. **Single-Speaker Training**:
   Simply omit the `speaker_id` parameter in all dataset definitions. The model will train
   without speaker conditioning, learning a single voice from all combined data.

3. **Dataset Merging**:
   Multiple datasets (HFDataset1, HFDataset2, HFDataset3, etc.) are automatically merged
   into one unified training dataset. This enables combining different speakers, languages,
   or recording conditions in a single model.

Configuration Components
------------------------

**CategoricalFilter** (Optional):
    Filters a dataset to include only specific samples based on a column value.
    Common use case: Extracting one speaker from a multi-speaker dataset.

    Example:
        CategoricalFilter(column_name="speaker", value="ex02")
        # Keeps only samples where speaker=="ex02"

**HFDatasetN** (Dataset Definition):
    Defines a single dataset source with the following key parameters:

    - reponame: Hugging Face dataset repository identifier
    - split: Dataset split to use ("train", "test", "validation")
    - text_col_name: Column containing the text transcriptions
    - nano_layer_1/2/3/4: Name of columns with audio codec token sequences (4 layers).
      If these columns will be named differently in your dataset, please indicate their
      names so that the names are brought to a common standard.
    - encoded_len: Name of column containing audio duration metadata
    - speaker_id: (Optional) Voice identity prefix added to all text samples
    - max_len: (Optional) Random sampling limit to prevent data imbalance
    - categorical_filter: (Optional) Filter to select subset of samples

**Config** (Top-Level):
    - max_duration_sec: Global filter - excludes all samples longer than this (in seconds)
    - hf_datasets: List of HFDataset definitions to merge

Key Features
------------

1. **Speaker Identity Control**:
   The `speaker_id` parameter is NOT a column name - it's the actual speaker identifier
   that will be injected into training examples:

   ```python
   # Configuration
   speaker_id = "simon"

   # Original text
   "How are you today?"

   # Becomes (before tokenization)
   "simon: How are you today?"
   ```

2. **Dataset Balancing with max_len**:
   When datasets have vastly different sizes, use `max_len` to randomly sample a fixed
   number of examples. This prevents larger datasets from dominating training.

   ```python
   # Dataset with 50,000 samples - limit to 2,000
   HFDataset2(
       reponame="...",
       speaker_id="puck",
       max_len=2000  # Randomly selects 2000 samples
   )
   ```

   Omit `max_len` to include all available samples.

3. **Speaker Filtering**:
   Use `categorical_filter` to extract a single speaker from multi-speaker datasets:

   ```python
   HFDataset1(
       reponame="multi_speaker_dataset",
       speaker_id="simon",
       categorical_filter=CategoricalFilter(
           column_name="speaker",
           value="ex02"  # Keep only this speaker from source dataset
       )
   )
   ```

4. **Duration Filtering**:
   The global `max_duration_sec` parameter ensures no training sample exceeds a certain
   audio length. This is crucial for:
   - Memory management (long sequences need more GPU RAM)
   - Training stability (very long sequences can cause gradient issues)
   - Consistent batch processing

Workflow
--------

1. **Define Datasets**: Create HFDataset1, HFDataset2, etc. with your sources
2. **Configure Speakers**: Add `speaker_id` for multi-speaker training (or omit for single)
3. **Balance Data**: Use `max_len` if needed to equalize dataset contributions
4. **Filter Samples**: Apply `categorical_filter` to select specific speakers
5. **Set Duration Limit**: Configure `max_duration_sec` for your GPU memory
6. **Build Config**: Collect all datasets in Config.hf_datasets list
7. **Run Merge Script**: The system combines everything into one training dataset

Example Configurations
----------------------

**Multi-Speaker Setup**:
```python
Config(
    max_duration_sec=12,
    hf_datasets=[
        HFDataset1(reponame="...", speaker_id="simon", max_len=2000),
        HFDataset2(reponame="...", speaker_id="puck", max_len=2000),
        HFDataset3(reponame="...", speaker_id="kore", max_len=2000),
    ]
)
# Model learns 3 distinct voices with balanced data
```

**Single-Speaker Setup**:
```python
Config(
    max_duration_sec=15,
    hf_datasets=[
        HFDataset1(reponame="...", categorical_filter=CategoricalFilter(...)),
        HFDataset2(reponame="..."),
        # No speaker_id specified - single voice training
    ]
)
```

**Filtered Multi-Speaker**:
```python
Config(
    max_duration_sec=10,
    hf_datasets=[
        HFDataset1(
            reponame="large_multi_speaker_corpus",
            speaker_id="alice",
            categorical_filter=CategoricalFilter(column_name="speaker", value="spk_001"),
            max_len=3000
        ),
        HFDataset2(
            reponame="small_single_speaker",
            speaker_id="bob"
            # No max_len - use all samples
        ),
    ]
)
```

Technical Notes
---------------

- All audio must be pre-encoded as codec tokens (nano_layer_1 through nano_layer_4)
- The `encoded_len` column should contain frame counts or duration metadata
- Duration filtering happens AFTER merging but BEFORE training
- Speaker IDs are converted to lowercase automatically during training
- Random sampling (max_len) uses a fixed seed for reproducibility
"""


@dataclass
class CategoricalFilter:
    column_name: str = "speaker"
    value: str = "ex02"


@dataclass
class HFDataset1:
    reponame: str = "nineninesix/expresso-conversational-en-nano-codec-dataset"
    name: Optional[str] = None
    split: str = "train"
    text_col_name: str = "text"
    nano_layer_1: str = "nano_layer_1"
    nano_layer_2: str = "nano_layer_2"
    nano_layer_3: str = "nano_layer_3"
    nano_layer_4: str = "nano_layer_4"
    encoded_len: str = "encoded_len"
    speaker_id: str = "simon"
    categorical_filter: CategoricalFilter = field(default_factory=CategoricalFilter) # OR None

@dataclass
class HFDataset2:
    reponame: str = "nineninesix/puck-gemini-flash-en-nano-codec-dataset"
    name: Optional[str] = None
    split: str = "train"
    text_col_name: str = "text"
    nano_layer_1: str = "nano_layer_1"
    nano_layer_2: str = "nano_layer_2"
    nano_layer_3: str = "nano_layer_3"
    nano_layer_4: str = "nano_layer_4"
    encoded_len: str = "encoded_len"
    speaker_id: str = "puck"
    max_len: int = 2000


@dataclass
class HFDataset3:
    reponame: str = "nineninesix/kore-gemini-flash-en-nano-codec-dataset"
    name: Optional[str] = None
    split: str = "train"
    text_col_name: str = "text"
    nano_layer_1: str = "nano_layer_1"
    nano_layer_2: str = "nano_layer_2"
    nano_layer_3: str = "nano_layer_3"
    nano_layer_4: str = "nano_layer_4"
    encoded_len: str = "encoded_len"
    speaker_id: str = "kore"
    max_len: int = 2000



@dataclass
class Config:
    max_duration_sec: Optional[int] = 12
    hf_datasets: List = field(default_factory=lambda: [HFDataset1(),
                                                       HFDataset2(),
                                                       HFDataset3()])


In [None]:
#@title Dataset Processor


"""
He who dares peer within shall renounce all understanding, yet in its stead shall find faith!
"""

from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
import torch
from datasets import load_dataset, concatenate_datasets
from omegaconf import OmegaConf
from transformers import AutoTokenizer
import locale
import os
import multiprocessing as mp
from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial
import math
import random
import numpy as np


class TrainDataPreProcessor:
    def __init__(self, tokenizer_name: str, max_dur: int, speaker_id: str= None) -> None:
        self.text_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        self.max_dur = max_dur
        self.speaker_id = speaker_id
        locale.getpreferredencoding = lambda: "UTF-8"

        self.tokeniser_length = 64400
        self.start_of_text = 1
        self.end_of_text = 2
        self.start_of_speech = self.tokeniser_length + 1
        self.end_of_speech = self.tokeniser_length + 2
        self.start_of_human = self.tokeniser_length + 3
        self.end_of_human = self.tokeniser_length + 4
        self.start_of_ai = self.tokeniser_length + 5
        self.end_of_ai = self.tokeniser_length + 6
        self.pad_token = self.tokeniser_length + 7
        self.audio_tokens_start = self.tokeniser_length + 10
        self.codebook_size = 4032

    def add_codes(self, example) -> list:
        snac_layers = ['nano_layer_1', 'nano_layer_2', 'nano_layer_3', 'nano_layer_4']
        codes = [example[i] for i in snac_layers]
        codes = np.array(codes).T
        all_codes = codes + np.array([self.codebook_size * i for i in range(4)])

        # remove duplicates
        all_codes = self.remove_consecutive_duplicates_np(all_codes)

        # flatten to sequence
        all_codes = all_codes + self.audio_tokens_start
        example["codes_list"] = all_codes.flatten().tolist()
        return example

    def remove_consecutive_duplicates_np(self, arr: np.ndarray)->np.ndarray:
        if arr.ndim != 2:
            raise ValueError("2D array expected [num_frames, frame_size]")

        mask = np.any(arr[1:] != arr[:-1], axis=1)
        keep = np.insert(mask, 0, True)
        return arr[keep]


    def create_input_ids(self, example):
        if self.speaker_id is not None:
            text_prompt = f"{self.speaker_id.lower()}: {example['text']}"
        else:
            text_prompt = example["text"]

        text_ids = self.text_tokenizer.encode(text_prompt, add_special_tokens=True)
        text_ids.append(self.end_of_text)

        example["text_tokens"] = text_ids
        input_ids = (
            [self.start_of_human]
            + example["text_tokens"]
            + [self.end_of_human]
            + [self.start_of_ai]
            + [self.start_of_speech]
            + example["codes_list"]
            + [self.end_of_speech]
            + [self.end_of_ai]
        )
        example["input_ids"] = input_ids
        example["labels"] = input_ids
        example["attention_mask"] = [1] * len(input_ids)
        return example

    def __call__(self, dataset: Dataset) -> Dataset:
        print(f'🔄 SHARD PROCESSING: Processing shard with {len(dataset)} samples...')

        if self.max_dur:
            print(f'📊 FILTER: max duration is -- {self.max_dur} sec --')
            dataset_len = len(dataset)
            dataset = dataset.filter(lambda i: i['encoded_len']/12.5 <= self.max_dur)
            filtred_len = len(dataset)
            print(f'✅ COMPLETE {filtred_len} rows from {dataset_len}')

        dataset = dataset.map(  self.add_codes,
                                remove_columns=['nano_layer_1', 'nano_layer_2', 'nano_layer_3', 'nano_layer_4'],
                                desc='Add Audio Codes: ')
        dataset = dataset.filter(lambda x: x["codes_list"] is not None, desc='Check codes list')
        dataset = dataset.filter(lambda x: len(x["codes_list"]) > 0, desc='Check Codes list lenght')
        dataset = dataset.map(self.create_input_ids, remove_columns=["text", "codes_list"],
                                desc='Create input ids: ')

        columns_to_keep = ["input_ids", "labels", "attention_mask"]
        columns_to_remove = [col for col in dataset.column_names if col not in columns_to_keep]
        dataset = dataset.remove_columns(columns_to_remove)

        print(f'✅ SHARD PROCESSING: Completed shard with {len(dataset)} samples')
        return dataset


def process_shard(shard_idx, shard_data, tokenizer_name, max_dur, speaker_id):
    print(f'🚀 WORKER {shard_idx}: Starting processing...')
    processor = TrainDataPreProcessor(tokenizer_name, max_dur, speaker_id)
    processed_shard = processor(shard_data)
    print(f'✅ WORKER {shard_idx}: Completed processing')
    return processed_shard


class ItemDataset:
    def __init__(self, item_cfg: OmegaConf, tokenizer_name: str, max_dur: int, n_shards: int = None):
        print(f'📦 DATASET: Loading dataset "{item_cfg.name}" from {item_cfg.reponame}...')
        self.item_cfg = item_cfg
        self.tokenizer_name = tokenizer_name
        self.max_dur = max_dur
        self.speaker_id = self.item_cfg.get('speaker_id')
        self.max_len = self.item_cfg.get('max_len')

        if n_shards is None:
            self.n_shards = min(mp.cpu_count(), 8)
        else:
            self.n_shards = n_shards

        self.dataset = load_dataset(
            self.item_cfg.reponame,
            self.item_cfg.name,
            split=self.item_cfg.split,
            num_proc=10
            )

        print(f'📊 DATASET: Loaded {len(self.dataset)} samples from {item_cfg.name}')
        print(f'🔧 DATASET: Will process with {self.n_shards} shards')

        if self.item_cfg.get('categorical_filter'):
            print(f'🔧 DATASET: Filtering by {self.item_cfg.categorical_filter.column_name} = {self.item_cfg.categorical_filter.value}')
            self.dataset = self.dataset.filter(lambda x: x[self.item_cfg.categorical_filter.column_name] == self.item_cfg.categorical_filter.value)
            print(f'✅ DATASET: Filtered {len(self.dataset)} samples')

        print(f'🔄 DATASET: Renaming columns...')
        rename_dict = {
            self.item_cfg.text_col_name: 'text',
            self.item_cfg.nano_layer_1: 'nano_layer_1',
            self.item_cfg.nano_layer_2: 'nano_layer_2',
            self.item_cfg.nano_layer_3: 'nano_layer_3',
            self.item_cfg.nano_layer_4: 'nano_layer_4',
            self.item_cfg.encoded_len: 'encoded_len',
        }
        self.dataset = self.dataset.rename_columns(rename_dict)
        print(f'✅ DATASET: Column renaming completed for {item_cfg.name}')


    def __call__(self):
        print(f'🔄 DATASET: Starting parallel processing of {self.item_cfg.name}...')

        shards = []
        for i in range(self.n_shards):
            shard = self.dataset.shard(num_shards=self.n_shards, index=i)
            shards.append((shard, i))
            print(f'📦 SHARD {i}: Created with {len(shard)} samples')

        processed_shards = []

        with ProcessPoolExecutor(max_workers=self.n_shards) as executor:

            future_to_shard = {
                executor.submit(process_shard, shard_idx, shard, self.tokenizer_name, self.max_dur, self.speaker_id): shard_idx
                for shard, shard_idx in shards
            }

            for future in as_completed(future_to_shard):
                shard_idx = future_to_shard[future]
                try:
                    processed_shard = future.result()
                    processed_shards.append((shard_idx, processed_shard))
                    print(f'✅ COMPLETED: Shard {shard_idx} processing finished')
                except Exception as exc:
                    print(f'❌ ERROR: Shard {shard_idx} generated an exception: {exc}')
                    raise exc

        processed_shards.sort(key=lambda x: x[0])
        final_shards = [shard for _, shard in processed_shards]

        print(f'🔗 DATASET: Concatenating {len(final_shards)} processed shards...')
        final_dataset = concatenate_datasets(final_shards)
        if self.max_len is not None:
            final_dataset = final_dataset.shuffle(seed=42).select(range(self.max_len))
        print(f'✅ DATASET: {self.item_cfg.name} processing completed! Final size: {len(final_dataset)} samples')

        return final_dataset


class DatasetProcessor:
    def __init__(self, dataset_config, tokenizer_name: str, n_shards_per_dataset: int = None):
        print(f'🚀 INIT: Initializing DatasetProcessor...')
        self.cfg = OmegaConf.structured(dataset_config)
        self.tokenizer_name = tokenizer_name
        self.n_shards_per_dataset = n_shards_per_dataset
        print(f'✅ INIT: DatasetProcessor initialized with {len(self.cfg.hf_datasets)} datasets to process')
        if n_shards_per_dataset:
            print(f'🔧 INIT: Each dataset will be processed with {n_shards_per_dataset} shards')

    def __call__(self):
        print(f'🔄 MASTER: Starting master dataset processing...')
        datasets = []

        for i, item_cfg in enumerate(self.cfg.hf_datasets, 1):
            print(f'📦 MASTER: Processing dataset {i}/{len(self.cfg.hf_datasets)}: {item_cfg.name}')
            item_ds_maker = ItemDataset(
                item_cfg=item_cfg,
                tokenizer_name=self.tokenizer_name,
                max_dur = self.cfg.max_duration_sec,
                n_shards=self.n_shards_per_dataset
            )
            datasets.append(item_ds_maker())

        print(f'🔗 MASTER: Concatenating all datasets...')
        final_dataset = concatenate_datasets(datasets)
        final_dataset = final_dataset.shuffle()
        print(f'🎉 MASTER: All datasets processed and concatenated! Final dataset size: {len(final_dataset)} samples')
        return final_dataset



In [None]:
dataset_config = Config()
dataset_ = DatasetProcessor(dataset_config, MODEL_ID, n_shards_per_dataset=4)
train_dataset = dataset_()

In [None]:
train_dataset

Dataset({
    features: ['input_ids', 'labels', 'attention_mask'],
    num_rows: 10604
})

# Model Fine-tuning

In [None]:
print("📚 Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

print("🧠 Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype="bfloat16",
)

In [None]:
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=['q_proj', 'k_proj', 'v_proj', 'out_proj', 'w1', 'w2', 'w3', 'in_proj', 'out_proj'],
    bias="none",
    modules_to_save=None,
    use_rslora = True
)

lora_model = get_peft_model(model, lora_config)

In [None]:
lora_sft_config = SFTConfig(
                            num_train_epochs = 1,
                            per_device_train_batch_size = 1,
                            gradient_accumulation_steps = 4,
                            learning_rate = 5e-5,
                            lr_scheduler_type = "cosine",
                            warmup_ratio = 0.1,
                            weight_decay = 0.02,
                            optim = "adamw_torch",

                            overwrite_output_dir=True,
                            output_dir=f"./checkpoints",
                            save_strategy="no",
                            remove_unused_columns=True,
                            )

In [None]:
print("🏗️  Creating LoRA SFT trainer...")
lora_sft_trainer = SFTTrainer(
    model=lora_model,
    args=lora_sft_config,
    train_dataset=train_dataset,
)

print("\n🚀 Starting LoRA + SFT training...")
lora_sft_trainer.train()

print("🎉 LoRA + SFT training completed!")

# Save merged model

Merge the extra weights learned with LoRA back into the model to obtain a "normal" model checkpoint.

In [None]:
print("\n🔄 Merging LoRA weights...")
merged_model = lora_model.merge_and_unload()
merged_model.save_pretrained("./checkpoints/lora_kani_model_ft_exp")
tokenizer.save_pretrained("./checkpoints/lora_kani_model_ft_exp")
print("💾 Merged model saved!")


🔄 Merging LoRA weights...
💾 Merged model saved!


## After training, you can publish a model to the Hub:
```bash
huggingface-cli upload <namespace/repo> ./checkpoints/<model_id> --private
```

# Inference

If you’d like to try out the model you’ve just fine-tuned, simply open the inference Colab using the [link](https://colab.research.google.com/drive/1mvzGs7jtAMSUz8wvNlL5uFmgFEyAPjDh?usp=sharing).

