# 🤙 Sesame-CSM (1B) on NVIDIA Brev

<div style="background: linear-gradient(90deg, #00ff87 0%, #60efff 100%); padding: 1px; border-radius: 8px; margin: 20px 0;">
    <div style="background: #0a0a0a; padding: 20px; border-radius: 7px;">
        <p style="color: #60efff; margin: 0;"><strong>⚡ Powered by Brev</strong> | Converted from <a href="https://github.com/unslothai/notebooks/blob/main/nb/Sesame-CSM_(1B).ipynb" style="color: #00ff87;">Unsloth Notebook</a></p>
    </div>
</div>

## 📋 Configuration

<table style="width: auto; margin-left: 0; border-collapse: collapse; border: 2px solid #808080;">
    <thead>
        <tr style="border-bottom: 2px solid #808080;">
            <th style="text-align: left; padding: 8px 12px; border-right: 2px solid #808080; font-weight: bold;">Parameter</th>
            <th style="text-align: left; padding: 8px 12px; font-weight: bold;">Value</th>
        </tr>
    </thead>
    <tbody>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Model</strong></td>
            <td style="text-align: left; padding: 8px 12px;">Sesame-CSM (1B)</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Recommended GPU</strong></td>
            <td style="text-align: left; padding: 8px 12px;">T4</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Min VRAM</strong></td>
            <td style="text-align: left; padding: 8px 12px;">12 GB</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Batch Size</strong></td>
            <td style="text-align: left; padding: 8px 12px;">4</td>
        </tr>
        <tr>
            <td style="text-align: left; padding: 8px 12px; border-right: 1px solid #808080;"><strong>Categories</strong></td>
            <td style="text-align: left; padding: 8px 12px;">audio, text-to-speech, fine-tuning</td>
        </tr>
    </tbody>
</table>

## 🔧 Key Adaptations for Brev

- ✅ Replaced Colab-specific installation with conda-based Unsloth
- ✅ Converted magic commands to subprocess calls
- ✅ Removed Google Drive dependencies
- ✅ Updated paths from `/workspace/` to `/workspace/`
- ✅ Added `device_map="auto"` for multi-GPU support
- ✅ Optimized batch sizes for NVIDIA GPUs

## 📚 Resources

- [Unsloth Documentation](https://docs.unsloth.ai/)
- [Brev Documentation](https://docs.nvidia.com/brev)
- [Original Notebook](https://github.com/unslothai/notebooks/blob/main/nb/Sesame-CSM_(1B).ipynb)



<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth your local device, follow [our guide](https://docs.unsloth.ai/get-started/install-and-update). This notebook is licensed [LGPL-3.0](https://github.com/unslothai/notebooks?tab=LGPL-3.0-1-ov-file#readme).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News


Unsloth's [Docker image](https://hub.docker.com/r/unsloth/unsloth) is here! Start training with no setup & environment issues. [Read our Guide](https://docs.unsloth.ai/new/how-to-train-llms-with-unsloth-and-docker).

[gpt-oss RL](https://docs.unsloth.ai/new/gpt-oss-reinforcement-learning) is now supported with the fastest inference & lowest VRAM. Try our [new notebook](https://github.com/unslothai/notebooks/blob/main/nb/gpt-oss-(20B)-GRPO.ipynb) which creates kernels!

Introducing [Vision](https://docs.unsloth.ai/new/vision-reinforcement-learning-vlm-rl) and [Standby](https://docs.unsloth.ai/basics/memory-efficient-rl) for RL! Train Qwen, Gemma etc. VLMs with GSPO - even faster with less VRAM.

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [None]:
# Environment Check for Brev
import sys
import os
import shutil

print(f"Python executable: {sys.executable}")
print(f"Python version: {sys.version}")

# Configure PyTorch cache directories to avoid permission errors
# MUST be set before any torch imports
# Prefer /ephemeral for Brev instances (larger scratch space)

# Test if /ephemeral exists and is actually writable (not just readable)
use_ephemeral = False
if os.path.exists("/ephemeral"):
    try:
        test_file = "/ephemeral/.write_test"
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
        use_ephemeral = True
    except (PermissionError, OSError):
        pass

if use_ephemeral:
    cache_base = "/ephemeral/torch_cache"
    triton_cache = "/ephemeral/triton_cache"
    tmpdir = "/ephemeral/tmp"
    print("Using /ephemeral for cache (Brev scratch space)")
else:
    cache_base = os.path.expanduser("~/.cache/torch/inductor")
    triton_cache = os.path.expanduser("~/.cache/triton")
    tmpdir = os.path.expanduser("~/.cache/tmp")
    print("Using home directory for cache")

# Set ALL PyTorch/Triton cache and temp directories
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_base
os.environ["TORCH_COMPILE_DIR"] = cache_base
os.environ["TRITON_CACHE_DIR"] = triton_cache
os.environ["XDG_CACHE_HOME"] = os.path.expanduser("~/.cache")
os.environ["TMPDIR"] = tmpdir  # Override system /tmp
os.environ["TEMP"] = tmpdir
os.environ["TMP"] = tmpdir

# Create cache directories with proper permissions (777 to ensure writability)
for cache_dir in [cache_base, triton_cache, tmpdir, os.environ["XDG_CACHE_HOME"]]:
    os.makedirs(cache_dir, mode=0o777, exist_ok=True)

# Clean up any old compiled caches that point to /tmp
old_cache = os.path.join(os.getcwd(), "unsloth_compiled_cache")
if os.path.exists(old_cache):
    print(f"⚠️  Removing old compiled cache: {old_cache}")
    shutil.rmtree(old_cache, ignore_errors=True)

print(f"✅ PyTorch cache: {cache_base}")

try:
    from unsloth import FastLanguageModel
    import transformers
    print("\n✅ Unsloth already available")
    print(f"   Unsloth: {FastLanguageModel.__module__}")
    print(f"   Transformers: {transformers.__version__}")
    
    # Check if we need to upgrade/downgrade transformers
    import pkg_resources
    try:
        current_transformers = pkg_resources.get_distribution("transformers").version
        if current_transformers != "4.56.2":
            print(f"   ⚠️  Transformers {current_transformers} != 4.56.2, may need adjustment")
    except:
        pass
    
    print("   ✅ All packages OK, skipping installation")
except ImportError:
    print("\n⚠️  Unsloth not found - installing required packages...")
    import subprocess
    
    # Find uv in common locations
    uv_paths = [
        "uv",  # In PATH
        os.path.expanduser("~/.venv/bin/uv"),
        os.path.expanduser("~/.cargo/bin/uv"),
        "/usr/local/bin/uv"
    ]
    
    uv_cmd = None
    for path in uv_paths:
        try:
            result = subprocess.run([path, "--version"], capture_output=True, timeout=2)
            if result.returncode == 0:
                uv_cmd = path
                print(f"   Found uv at: {path}")
                break
        except (FileNotFoundError, subprocess.TimeoutExpired):
            continue
    
    print(f"\nInstalling packages into: {sys.executable}")
    
    if uv_cmd:
        print("Using uv package manager...\n")
        try:
            subprocess.check_call([uv_cmd, "pip", "install", "unsloth"])
            subprocess.check_call([uv_cmd, "pip", "install", "transformers==4.56.2"])
            subprocess.check_call([uv_cmd, "pip", "install", "--no-deps", "trl==0.22.2"])
            print("\n✅ Installation complete")
        except subprocess.CalledProcessError as e:
            print(f"⚠️  uv install failed: {e}")
            uv_cmd = None  # Fall back to pip
    
    if not uv_cmd:
        print("Using pip package manager...\n")
        try:
            # Ensure pip is available
            subprocess.run([sys.executable, "-m", "ensurepip", "--upgrade"], 
                         capture_output=True, timeout=30)
            # Install packages
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "unsloth"])
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "transformers==4.56.2"])
            subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--no-deps", "trl==0.22.2"])
            print("\n✅ Installation complete")
        except subprocess.CalledProcessError as e:
            print(f"❌ Installation failed: {e}")
            print("   This may be due to permission issues.")
            print("   Packages may already be installed - attempting to continue...")
    
    # Verify installation
    try:
        from unsloth import FastLanguageModel
        print("✅ Unsloth is now available")
    except ImportError as e:
        print(f"❌ Unsloth still not available: {e}")
        print("⚠️  Please check setup script ran successfully or restart instance")

### Unsloth

`FastModel` supports loading nearly any model now! This includes Vision and Text models!

In [None]:
from unsloth import FastModel
from transformers import CsmForConditionalGeneration
import torch

model, processor = FastModel.from_pretrained(
    model_name = "unsloth/csm-1b",
    max_seq_length= 2048, # Choose any for long context!
    dtype = None, # Leave as None for auto-detection
    auto_model = CsmForConditionalGeneration,
    load_in_4bit = False, # Select True for 4bit - reduces memory usage
)

We now add LoRA adapters so we only need to update 1 to 10% of all parameters!

In [None]:
model = FastModel.get_peft_model(
    model,
    r = 32, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 32,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

<a name="Data"></a>
### Data Prep  

We will use the `MrDragonFox/Elise`, which is designed for training TTS models. Ensure that your dataset follows the required format: **text, audio** for single-speaker models or **source, text, audio** for multi-speaker models. You can modify this section to accommodate your own dataset, but maintaining the correct structure is essential for optimal training.

In [None]:
#@title Dataset Prep functions
from datasets import load_dataset, Audio, Dataset
import os
from transformers import AutoProcessor
processor = AutoProcessor.from_pretrained("unsloth/csm-1b")

raw_ds = load_dataset("MrDragonFox/Elise", split="train")

# Getting the speaker id is important for multi-speaker models and speaker consistency
speaker_key = "source"
if "source" not in raw_ds.column_names and "speaker_id" not in raw_ds.column_names:
    print("Unsloth: No speaker found, adding default \"source\" of 0 for all examples")
    new_column = ["0"] * len(raw_ds)
    raw_ds = raw_ds.add_column("source", new_column)
elif "source" not in raw_ds.column_names and "speaker_id" in raw_ds.column_names:
    speaker_key = "speaker_id"

target_sampling_rate = 24000
raw_ds = raw_ds.cast_column("audio", Audio(sampling_rate=target_sampling_rate))

def preprocess_example(example):
    conversation = [
        {
            "role": str(example[speaker_key]),
            "content": [
                {"type": "text", "text": example["text"]},
                {"type": "audio", "path": example["audio"]["array"]},
            ],
        }
    ]

    try:
        model_inputs = processor.apply_chat_template(
            conversation,
            tokenize=True,
            return_dict=True,
            output_labels=True,
            text_kwargs = {
                "padding": "max_length", # pad to the max_length
                "max_length": 256, # this should be the max length of audio
                "pad_to_multiple_of": 8,
                "padding_side": "right",
            },
            audio_kwargs = {
                "sampling_rate": 24_000,
                "max_length": 240001, # max input_values length of the whole dataset
                "padding": "max_length",
            },
            common_kwargs = {"return_tensors": "pt"},
        )
    except Exception as e:
        print(f"Error processing example with text '{example['text'][:50]}...': {e}")
        return None

    required_keys = ["input_ids", "attention_mask", "labels", "input_values", "input_values_cutoffs"]
    processed_example = {}
    # print(model_inputs.keys())
    for key in required_keys:
        if key not in model_inputs:
            print(f"Warning: Required key '{key}' not found in processor output for example.")
            return None

        value = model_inputs[key][0]
        processed_example[key] = value


    # Final check (optional but good)
    if not all(isinstance(processed_example[key], torch.Tensor) for key in processed_example):
         print(f"Error: Not all required keys are tensors in final processed example. Keys: {list(processed_example.keys())}")
         return None

    return processed_example

processed_ds = raw_ds.map(
    preprocess_example,
    remove_columns=raw_ds.column_names,
    desc="Preprocessing dataset",
)

<a name="Train"></a>
### Train the model
Now let's use Huggingface  `Trainer`! More docs here: [Transformers docs](https://huggingface.co/docs/transformers/main_classes/trainer). We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.

In [None]:
from transformers import TrainingArguments, Trainer
from unsloth import is_bfloat16_supported

trainer = Trainer(
    model = model,
    train_dataset = processed_ds,
    args = TrainingArguments(
        per_device_train_batch_size=4,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01, # Turn this on if overfitting
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir="/workspace/outputs",
        report_to = "none", # Use TrackIO/WandB etc
    ),
)

In [None]:
import subprocess
import sys

# Enhanced GPU check for NVIDIA Brev
print("=" * 60)
print("GPU Information")
print("=" * 60)

# Run nvidia-smi
subprocess.run(['nvidia-smi'], check=False)

# PyTorch CUDA info
import torch
print(f"\nPyTorch CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        print(f"    Memory: {props.total_memory / 1024**3:.2f} GB")
print("=" * 60)


# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

In [None]:
trainer_stats = trainer.train()

In [None]:
import subprocess
import sys

# Enhanced GPU check for NVIDIA Brev
print("=" * 60)
print("GPU Information")
print("=" * 60)

# Run nvidia-smi
subprocess.run(['nvidia-smi'], check=False)

# PyTorch CUDA info
import torch
print(f"\nPyTorch CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Version: {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        props = torch.cuda.get_device_properties(i)
        print(f"    Memory: {props.total_memory / 1024**3:.2f} GB")
print("=" * 60)


# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

<a name="Inference"></a>
### Inference
Let's run the model! You can change the prompts

In [None]:
# Fix torch compilation cache permissions
import os
import shutil

# Test if /ephemeral is writable (not just readable)
use_ephemeral = False
if os.path.exists("/ephemeral"):
    try:
        test_file = "/ephemeral/.write_test"
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
        use_ephemeral = True
    except (PermissionError, OSError):
        pass

if use_ephemeral:
    cache_dir = "/ephemeral/torch_cache"
    triton_cache = "/ephemeral/triton_cache"
    tmpdir = "/ephemeral/tmp"
else:
    cache_dir = os.path.expanduser("~/.cache/torch/inductor")
    triton_cache = os.path.expanduser("~/.cache/triton")
    tmpdir = os.path.expanduser("~/.cache/tmp")

# Create directories with full write permissions
for d in [cache_dir, triton_cache, tmpdir]:
    os.makedirs(d, mode=0o777, exist_ok=True)

# Set ALL PyTorch/Triton cache and temp directories
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
os.environ["TORCH_COMPILE_DIR"] = cache_dir
os.environ["TRITON_CACHE_DIR"] = triton_cache
os.environ["TMPDIR"] = tmpdir  # Override system /tmp
os.environ["TEMP"] = tmpdir
os.environ["TMP"] = tmpdir

# Clean up any old compiled caches
old_cache = os.path.join(os.getcwd(), "unsloth_compiled_cache")
if os.path.exists(old_cache):
    shutil.rmtree(old_cache, ignore_errors=True)

print(f"✅ Torch cache: {cache_dir}")
print(f"✅ Temp dir: {tmpdir}")

from IPython.display import Audio, display
import soundfile as sf

text = "We just finished fine tuning a text to speech model... and it's pretty good!"
speaker_id = 0
inputs = processor(f"[{speaker_id}]{text}", add_special_tokens=True).to("cuda")
audio_values = model.generate(
    **inputs,
    max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
    # play with these parameters to tweak results
    # depth_decoder_top_k=0,
    # depth_decoder_top_p=0.9,
    # depth_decoder_do_sample=True,
    # depth_decoder_temperature=0.9,
    # top_k=0,
    # top_p=1.0,
    # temperature=0.9,
    # do_sample=True,
    #########################################################
    output_audio=True
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_without_context.wav", audio, 24000)
display(Audio(audio, rate=24000))

In [None]:
# Fix torch compilation cache permissions
import os
import shutil

# Test if /ephemeral is writable (not just readable)
use_ephemeral = False
if os.path.exists("/ephemeral"):
    try:
        test_file = "/ephemeral/.write_test"
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
        use_ephemeral = True
    except (PermissionError, OSError):
        pass

if use_ephemeral:
    cache_dir = "/ephemeral/torch_cache"
    triton_cache = "/ephemeral/triton_cache"
    tmpdir = "/ephemeral/tmp"
else:
    cache_dir = os.path.expanduser("~/.cache/torch/inductor")
    triton_cache = os.path.expanduser("~/.cache/triton")
    tmpdir = os.path.expanduser("~/.cache/tmp")

# Create directories with full write permissions
for d in [cache_dir, triton_cache, tmpdir]:
    os.makedirs(d, mode=0o777, exist_ok=True)

# Set ALL PyTorch/Triton cache and temp directories
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
os.environ["TORCH_COMPILE_DIR"] = cache_dir
os.environ["TRITON_CACHE_DIR"] = triton_cache
os.environ["TMPDIR"] = tmpdir  # Override system /tmp
os.environ["TEMP"] = tmpdir
os.environ["TMP"] = tmpdir

# Clean up any old compiled caches
old_cache = os.path.join(os.getcwd(), "unsloth_compiled_cache")
if os.path.exists(old_cache):
    shutil.rmtree(old_cache, ignore_errors=True)

print(f"✅ Torch cache: {cache_dir}")
print(f"✅ Temp dir: {tmpdir}")

text = "Sesame is a super cool TTS model which can be fine tuned with Unsloth."

speaker_id = 0
# Another equivalent way to prepare the inputs
conversation = [
    {"role": str(speaker_id), "content": [{"type": "text", "text": text}]},
]
audio_values = model.generate(
    **processor.apply_chat_template(
        conversation,
        tokenize=True,
        return_dict=True,
    ).to("cuda"),
    max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer speech increase this
    # play with these parameters to tweak results
    # depth_decoder_top_k=0,
    # depth_decoder_top_p=0.9,
    # depth_decoder_do_sample=True,
    # depth_decoder_temperature=0.9,
    # top_k=0,
    # top_p=1.0,
    # temperature=0.9,
    # do_sample=True,
    #########################################################
    output_audio=True
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_without_context.wav", audio, 24000)
display(Audio(audio, rate=24000))

#### Voice and style consistency

Sesame CSM's power comes from providing audio context for each speaker. Let's pass a sample utterance from our dataset to ground speaker identity and style.

In [None]:
# Fix torch compilation cache permissions
import os
import shutil

# Test if /ephemeral is writable (not just readable)
use_ephemeral = False
if os.path.exists("/ephemeral"):
    try:
        test_file = "/ephemeral/.write_test"
        with open(test_file, "w") as f:
            f.write("test")
        os.remove(test_file)
        use_ephemeral = True
    except (PermissionError, OSError):
        pass

if use_ephemeral:
    cache_dir = "/ephemeral/torch_cache"
    triton_cache = "/ephemeral/triton_cache"
    tmpdir = "/ephemeral/tmp"
else:
    cache_dir = os.path.expanduser("~/.cache/torch/inductor")
    triton_cache = os.path.expanduser("~/.cache/triton")
    tmpdir = os.path.expanduser("~/.cache/tmp")

# Create directories with full write permissions
for d in [cache_dir, triton_cache, tmpdir]:
    os.makedirs(d, mode=0o777, exist_ok=True)

# Set ALL PyTorch/Triton cache and temp directories
os.environ["TORCHINDUCTOR_CACHE_DIR"] = cache_dir
os.environ["TORCH_COMPILE_DIR"] = cache_dir
os.environ["TRITON_CACHE_DIR"] = triton_cache
os.environ["TMPDIR"] = tmpdir  # Override system /tmp
os.environ["TEMP"] = tmpdir
os.environ["TMP"] = tmpdir

# Clean up any old compiled caches
old_cache = os.path.join(os.getcwd(), "unsloth_compiled_cache")
if os.path.exists(old_cache):
    shutil.rmtree(old_cache, ignore_errors=True)

print(f"✅ Torch cache: {cache_dir}")
print(f"✅ Temp dir: {tmpdir}")

speaker_id = 0

utterance = raw_ds[3]["audio"]["array"]
utterance_text = raw_ds[3]["text"]
text = "Sesame is a super cool TTS model which can be fine tuned with Unsloth."

# CSM will fill in the audio for the last text.
# You can even provide a conversation history back in as you generate new audio

conversation = [
    {"role": str(speaker_id), "content": [{"type": "text", "text": utterance_text},{"type": "audio", "path": utterance}]},
    {"role": str(speaker_id), "content": [{"type": "text", "text": text}]},
]

inputs = processor.apply_chat_template(
        conversation,
        tokenize=True,
        return_dict=True,
    )
audio_values = model.generate(
    **inputs.to("cuda"),
    max_new_tokens=125, # 125 tokens is 10 seconds of audio, for longer text increase this
    # play with these parameters to tweak results
    # depth_decoder_top_k=0,
    # depth_decoder_top_p=0.9,
    # depth_decoder_do_sample=True,
    # depth_decoder_temperature=0.9,
    # top_k=0,
    # top_p=1.0,
    # temperature=0.9,
    # do_sample=True,
    #########################################################
    output_audio=True
)
audio = audio_values[0].to(torch.float32).cpu().numpy()
sf.write("example_with_context.wav", audio, 24000)
display(Audio(audio, rate=24000))

<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [None]:
model.save_pretrained("lora_model")  # Local saving
processor.save_pretrained("lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# processor.push_to_hub("your_name/lora_model", token = "...") # Online saving

### Saving to float16

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", processor, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", processor, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", processor, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", processor, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False:
    model.save_pretrained("model")
    processor.save_pretrained("model")
if False:
    model.push_to_hub("hf/model", token = "")
    processor.push_to_hub("hf/model", token = "")


And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

**Additional Resources:**

- 📚 [Unsloth Documentation](https://docs.unsloth.ai) - Complete guides and examples
- 💬 [Unsloth Discord](https://discord.gg/unsloth) - Community support
- 📖 [More Notebooks](https://github.com/unslothai/notebooks) - Full collection on GitHub
- 🚀 [Brev Documentation](https://docs.nvidia.com/brev) - Deploy and scale on NVIDIA GPUs