# Getting Started with Fine-Tuning Moshi 7B

This notebook shows you a simple example of how to LoRA finetune Moshi 7B. You can run this notebook in Google Colab using a A100 GPU.

<a target="_blank" href="https://colab.research.google.com/github//kyutai-labs/moshi-finetune/blob/main/tutorials/moshi_finetune.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Check out `moshi-finetune` Github repo to learn more: https://github.com/kyutai-labs/moshi-finetune/


## Installation

Clone the `moshi-finetune` repo:


In [1]:
%cd /content/
!git clone https://github.com/kyutai-labs/moshi-finetune.git

/content
Cloning into 'moshi-finetune'...
remote: Enumerating objects: 245, done.[K
remote: Counting objects: 100% (55/55), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 245 (delta 36), reused 35 (delta 29), pack-reused 190 (from 1)[K
Receiving objects: 100% (245/245), 638.97 KiB | 13.31 MiB/s, done.
Resolving deltas: 100% (135/135), done.


Install all required dependencies:


In [None]:
%pip install -e /content/moshi-finetune

## Prepare dataset


In [None]:
from pathlib import Path

from huggingface_hub import snapshot_download

Path("/content/data/daily-talk-contiguous").mkdir(parents=True, exist_ok=True)

# Download the dataset
local_dir = snapshot_download(
    "kyutai/DailyTalkContiguous",
    repo_type="dataset",
    local_dir="/content/data/daily-talk-contiguous",
)

## Start training


In [4]:
# these info is needed for training
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [6]:
# define training configuration
# for your own use cases, you might want to change the data paths, model path, run_dir, and other hyperparameters
import yaml

config = """
# data
data:
  train_data: '/content/data/daily-talk-contiguous/dailytalk.jsonl' # Fill
  eval_data: '' # Optionally Fill
  shuffle: true

# model
moshi_paths:
  hf_repo_id: "kyutai/moshiko-pytorch-bf16"


full_finetuning: false # Activate lora.enable if partial finetuning
lora:
  enable: true
  rank: 128
  scaling: 2.
  ft_embed: false

# training hyperparameters
first_codebook_weight_multiplier: 100.
text_padding_weight: .5


# tokens per training steps = batch_size x num_GPUs x duration_sec
# we recommend a sequence duration of 300 seconds
# If you run into memory error, you can try reduce the sequence length
duration_sec: 50
batch_size: 1
max_steps: 30

gradient_checkpointing: true # Activate checkpointing of layers

# optim
optim:
  lr: 2.e-6
  weight_decay: 0.1
  pct_start: 0.05

# other
seed: 0
log_freq: 10
eval_freq: 1
do_eval: False
ckpt_freq: 10

save_adapters: True

run_dir: "/content/test"  # Fill
"""

# save the same file locally into the example.yaml file
with open("/content/example.yaml", "w") as file:
    yaml.dump(yaml.safe_load(config), file)

In [None]:
# make sure the run_dir has not been created before
# only run this when you ran torchrun previously and created the /content/test_ultra file
# ! rm -r /content/test

# Start Training

In [14]:
# Cell 2 - Imports
import dataclasses
import logging
import os
import pprint
import shutil
from contextlib import ExitStack
from pathlib import Path

import torch
import torch.cuda
import torch.distributed as dist
from torch.optim import AdamW, lr_scheduler

from finetune.args import TrainArgs
from finetune.checkpointing import Checkpointer
from finetune.data.data_loader import build_data_loader
from finetune.data.interleaver import InterleavedTokenizer, Interleaver
from finetune.distributed import (
    BACKEND,
    avg_aggregate,
    get_rank,
    get_world_size,
    is_torchrun,
    set_device,
)
from finetune.eval import evaluate
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
    downcast_mixed_precision,
    prepare_mixed_precision,
    upcast_mixed_precision,
)
from finetune.monitoring.metrics_logger import (
    MetricsLogger,
    eval_log_msg,
    get_eval_logs,
    get_train_logs,
    train_log_msg,
)
from finetune.monitoring.utils import set_logger
from finetune.utils import TrainState, logged_closing, set_random_seed
from finetune.wrapped_model import get_fsdp_model
from moshi.models import loaders

logger = logging.getLogger("train")

In [15]:
# Cell 3 - Utility logging function
def main_logger_info(message: str) -> None:
    if get_rank() == 0:
        logger.info(message)

In [16]:
# Cell 4 - Entry point
def train(config: str):
    args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False)
    set_logger(logging.INFO)

    with ExitStack() as exit_stack:
        _train(args, exit_stack)
    logger.info("Closed everything!")

In [21]:
from contextlib import ExitStack
import pprint, dataclasses, shutil, os, torch, torch.distributed as dist
from pathlib import Path

# 1. Load args from config file
args = TrainArgs.load("/content/example.yaml", drop_extra_fields=False)

# 2. Create exit stack
exit_stack = ExitStack()



In [27]:
from pathlib import Path
import os, pprint, dataclasses, shutil, torch
from contextlib import ExitStack

# Assuming `args` already exists
set_random_seed(args.seed)
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Init run dir
run_dir = Path(args.run_dir)
if run_dir.exists() and not args.overwrite_run_dir:
    raise RuntimeError(f"Run dir {run_dir} already exists.")
elif run_dir.exists():
    shutil.rmtree(run_dir)

if args.full_finetuning:
    assert not args.lora.enable, "LoRA should not be enabled for full finetuning."
else:
    assert args.lora.enable, "LoRA should be enabled for partial finetuning"

run_dir.mkdir(exist_ok=True, parents=True)

args_path = run_dir / "args.yaml"
if not args_path.exists():
    args.save(args_path)

print("TrainArgs:", pprint.pformat(dataclasses.asdict(args)))

TrainArgs: {'batch_size': 1,
 'ckpt_freq': 10,
 'data': {'eval_data': '',
          'shuffle': True,
          'train_data': '/content/data/daily-talk-contiguous/dailytalk.jsonl'},
 'do_ckpt': True,
 'do_eval': False,
 'duration_sec': 50.0,
 'eval_freq': 1,
 'first_codebook_weight_multiplier': 100.0,
 'full_finetuning': False,
 'gradient_checkpointing': True,
 'log_freq': 10,
 'lora': {'enable': True, 'ft_embed': False, 'rank': 128, 'scaling': 2.0},
 'max_norm': 1.0,
 'max_steps': 30,
 'moshi_paths': {'config_path': None,
                 'hf_repo_id': 'kyutai/moshiko-pytorch-bf16',
                 'mimi_path': None,
                 'moshi_path': None,
                 'tokenizer_path': None},
 'num_ckpt_keep': 3,
 'num_microbatches': 1,
 'optim': {'lr': 2e-06, 'pct_start': 0.05, 'weight_decay': 0.1},
 'overwrite_run_dir': False,
 'param_dtype': 'bfloat16',
 'run_dir': '/content/test',
 'save_adapters': True,
 'seed': 0,
 'text_padding_weight': 0.5,
 'wandb': {'key': None, 'offline':

## LOAD MODELs

In [28]:
# Load models
checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
    hf_repo=args.moshi_paths.hf_repo_id,
    moshi_weights=args.moshi_paths.moshi_path,
    mimi_weights=args.moshi_paths.mimi_path,
    tokenizer=args.moshi_paths.tokenizer_path,
    config_path=args.moshi_paths.config_path,
)

lm_config = (
    loaders._lm_kwargs if checkpoint_info.raw_config is None else checkpoint_info.raw_config
)
lm_config["lora"] = args.lora.enable
lm_config["lora_rank"] = args.lora.rank
lm_config["lora_scaling"] = args.lora.scaling



model.safetensors:   0%|          | 0.00/15.4G [00:00<?, ?B/s]

(…)nizer-e351c8d8-checkpoint125.safetensors:   0%|          | 0.00/385M [00:00<?, ?B/s]

tokenizer_spm_32k_3.model:   0%|          | 0.00/553k [00:00<?, ?B/s]

In [30]:
mimi = checkpoint_info.get_mimi(device="cuda")
mimi.eval()
for p in mimi.parameters():
    p.requires_grad = False

In [42]:
# Get pad, eos and zero token value.
%%capture
!wget https://huggingface.co/kyutai/moshiko-pytorch-bf16/resolve/main/tokenizer_spm_32k_3.model
import sentencepiece as spm_tknz

sp = spm_tknz.SentencePieceProcessor()
sp.load('tokenizer_spm_32k_3.model')

# To find token IDs by special token text, e.g. "<pad>", "<eos>", "<zero>"
text_padding_token_id = sp.piece_to_id("<pad>")
end_of_text_padding_id = sp.piece_to_id("<eos>")
zero_token_id = sp.piece_to_id("<zero>")

In [43]:
spm = checkpoint_info.get_text_tokenizer()

In [44]:
interleaver = Interleaver(
    spm,
    mimi.frame_rate,
    text_padding_token_id,
    end_of_text_padding_id,
    zero_token_id,
    keep_main_only=True,
)
interleaved_tokenizer = InterleavedTokenizer(
    mimi, interleaver, duration_sec=args.duration_sec
)

In [45]:
# Data loaders (single GPU rank=0, world_size=1)
data_loader = build_data_loader(
    instruct_tokenizer=interleaved_tokenizer,
    args=args.data,
    batch_size=args.batch_size,
    seed=args.seed,
    rank=0,
    world_size=1,
    is_eval=False,
)

In [46]:
for i in data_loader:
  print(i)
  break

Batch(codes=tensor([[[   0,  634,    0,  ...,    0,  705,  367],
         [1049,  958, 1784,  ...,  727,  142, 1586],
         [1515, 1597, 1523,  ...,  427,  861,  886],
         ...,
         [1443,  555, 1572,  ..., 1030, 1030, 1030],
         [1871,  666,  825,  ...,  976,  976,  976],
         [2008, 1648, 2008,  ..., 2008, 2008, 2008]]], device='cuda:0'), condition_attributes=None)


In [47]:
import safetensors

def get_model_fp8_no_fsdp(args: TrainArgs, checkpointer_info):
    # Load model on meta device first for memory efficiency
    with torch.device("meta"):
        model = checkpointer_info.get_moshi(
            device="meta",
            dtype=torch.float32,  # load full precision initially
            lm_kwargs_overrides={
                "gradient_checkpointing": args.gradient_checkpointing,
                "lora": args.lora.enable,
                "lora_rank": args.lora.rank,
                "lora_scaling": args.lora.scaling,
            },
            load_weight=False,
        )

    if get_rank() == 0:
        model_state_dict = safetensors.torch.load_file(checkpointer_info.moshi_weights)

        # Convert all weights to FP8-like (pseudo function, replace with actual FP8 method)
        def to_fp8(tensor):
            # Placeholder: real FP8 conversion needed from your FP8 library
            return tensor.to(torch.float16)  # or custom FP8 conversion

        for k, v in model_state_dict.items():
            model_state_dict[k] = to_fp8(v)

        model.load_state_dict(model_state_dict, strict=False, assign=True)

        if args.lora.enable and not args.full_finetuning:
            # Initialize LoRA layers if needed
            initialize_lora_parameters(model, torch.float16)

    torch.distributed.barrier()

    # Set requires_grad flags
    if args.lora.enable and not args.full_finetuning:
        for name, param in model.named_parameters():
            if "lora" in name:
                param.requires_grad = True
            elif args.lora.ft_embed and "emb" in name:
                param.requires_grad = True
            else:
                param.requires_grad = False
    else:
        for param in model.parameters():
            param.requires_grad = True

    # Move to CUDA device
    return model.cuda()

NameError: name 'CheckpointInfo' is not defined