# Getting Started with Fine-Tuning Moshi 7B

This notebook shows you a simple example of how to LoRA finetune Moshi 7B. You can run this notebook in Google Colab using a A100 GPU.

<a target="_blank" href="https://colab.research.google.com/github//kyutai-labs/moshi-finetune/blob/main/tutorials/moshi_finetune.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

Check out `moshi-finetune` Github repo to learn more: https://github.com/kyutai-labs/moshi-finetune/


## Installation

Clone the `moshi-finetune` repo:


In [1]:
%cd /content/
!git clone https://github.com/kyutai-labs/moshi-finetune.git

/content
Cloning into 'moshi-finetune'...
remote: Enumerating objects: 245, done.[K
remote: Counting objects: 100% (55/55), done.[K
remote: Compressing objects: 100% (26/26), done.[K
remote: Total 245 (delta 36), reused 35 (delta 29), pack-reused 190 (from 1)[K
Receiving objects: 100% (245/245), 638.97 KiB | 13.31 MiB/s, done.
Resolving deltas: 100% (135/135), done.


Install all required dependencies:


In [None]:
%pip install -e /content/moshi-finetune

## Prepare dataset


In [None]:
from pathlib import Path

from huggingface_hub import snapshot_download

Path("/content/data/daily-talk-contiguous").mkdir(parents=True, exist_ok=True)

# Download the dataset
local_dir = snapshot_download(
    "kyutai/DailyTalkContiguous",
    repo_type="dataset",
    local_dir="/content/data/daily-talk-contiguous",
)

## Start training


In [4]:
# these info is needed for training
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [6]:
# define training configuration
# for your own use cases, you might want to change the data paths, model path, run_dir, and other hyperparameters
import yaml

config = """
# data
data:
  train_data: '/content/data/daily-talk-contiguous/dailytalk.jsonl' # Fill
  eval_data: '' # Optionally Fill
  shuffle: true

# model
moshi_paths:
  hf_repo_id: "kyutai/moshiko-pytorch-bf16"


full_finetuning: false # Activate lora.enable if partial finetuning
lora:
  enable: true
  rank: 128
  scaling: 2.
  ft_embed: false

# training hyperparameters
first_codebook_weight_multiplier: 100.
text_padding_weight: .5


# tokens per training steps = batch_size x num_GPUs x duration_sec
# we recommend a sequence duration of 300 seconds
# If you run into memory error, you can try reduce the sequence length
duration_sec: 50
batch_size: 1
max_steps: 30

gradient_checkpointing: true # Activate checkpointing of layers

# optim
optim:
  lr: 2.e-6
  weight_decay: 0.1
  pct_start: 0.05

# other
seed: 0
log_freq: 10
eval_freq: 1
do_eval: False
ckpt_freq: 10

save_adapters: True

run_dir: "/content/test"  # Fill
"""

# save the same file locally into the example.yaml file
with open("/content/example.yaml", "w") as file:
    yaml.dump(yaml.safe_load(config), file)

In [None]:
# make sure the run_dir has not been created before
# only run this when you ran torchrun previously and created the /content/test_ultra file
# ! rm -r /content/test

# Start Training

In [14]:
# Cell 2 - Imports
import dataclasses
import logging
import os
import pprint
import shutil
from contextlib import ExitStack
from pathlib import Path

import torch
import torch.cuda
import torch.distributed as dist
from torch.optim import AdamW, lr_scheduler

from finetune.args import TrainArgs
from finetune.checkpointing import Checkpointer
from finetune.data.data_loader import build_data_loader
from finetune.data.interleaver import InterleavedTokenizer, Interleaver
from finetune.distributed import (
    BACKEND,
    avg_aggregate,
    get_rank,
    get_world_size,
    is_torchrun,
    set_device,
)
from finetune.eval import evaluate
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
    downcast_mixed_precision,
    prepare_mixed_precision,
    upcast_mixed_precision,
)
from finetune.monitoring.metrics_logger import (
    MetricsLogger,
    eval_log_msg,
    get_eval_logs,
    get_train_logs,
    train_log_msg,
)
from finetune.monitoring.utils import set_logger
from finetune.utils import TrainState, logged_closing, set_random_seed
from finetune.wrapped_model import get_fsdp_model
from moshi.models import loaders

logger = logging.getLogger("train")

In [15]:
# Cell 3 - Utility logging function
def main_logger_info(message: str) -> None:
    if get_rank() == 0:
        logger.info(message)

In [16]:
# Cell 4 - Entry point
def train(config: str):
    args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False)
    set_logger(logging.INFO)

    with ExitStack() as exit_stack:
        _train(args, exit_stack)
    logger.info("Closed everything!")

In [None]:
def _train(args: TrainArgs, exit_stack: ExitStack):
    set_random_seed(args.seed)
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

    # Init NCCL
    if "LOCAL_RANK" in os.environ:
        set_device()
        logger.info("Going to init comms...")
        dist.init_process_group(backend=BACKEND)
    else:
        logger.error("PyTorch environment is not correctly initialized.")

    # Init run dir
    main_logger_info(f"Run dir: {args.run_dir}")
    run_dir = Path(args.run_dir)

    if is_torchrun():
        if run_dir.exists() and not args.overwrite_run_dir:
            raise RuntimeError(f"Run dir {run_dir} already exists.")
        elif run_dir.exists():
            main_logger_info(f"Removing run dir {run_dir}...")
            shutil.rmtree(run_dir)

    if args.full_finetuning:
        assert not args.lora.enable, "LoRA should not be enabled for full finetuning."
    else:
        assert args.lora.enable, "LoRA should be enabled for partial finetuning"

    dist.barrier()
    run_dir.mkdir(exist_ok=True, parents=True)

    args_path = run_dir / "args.yaml"
    if not args_path.exists():
        args.save(args_path)

    main_logger_info(f"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}")

    # Loggers
    metrics_logger = MetricsLogger(
        run_dir, "train", get_rank() == 0, args.wandb, dataclasses.asdict(args)
    )
    exit_stack.enter_context(logged_closing(metrics_logger, "metrics_logger"))

    eval_logger = MetricsLogger(
        run_dir, "eval", get_rank() == 0, args.wandb, dataclasses.asdict(args)
    )
    exit_stack.enter_context(logged_closing(eval_logger, "eval_logger"))

    # Load models
    main_logger_info("Loading Mimi and Moshi...")
    checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
        hf_repo=args.moshi_paths.hf_repo_id,
        moshi_weights=args.moshi_paths.moshi_path,
        mimi_weights=args.moshi_paths.mimi_path,
        tokenizer=args.moshi_paths.tokenizer_path,
        config_path=args.moshi_paths.config_path,
    )

    lm_config = (
        loaders._lm_kwargs if checkpoint_info.raw_config is None else checkpoint_info.raw_config
    )
    lm_config["lora"] = args.lora.enable
    lm_config["lora_rank"] = args.lora.rank
    lm_config["lora_scaling"] = args.lora.scaling

    mimi = checkpoint_info.get_mimi(device="cuda")
    mimi.eval()
    for p in mimi.parameters():
        p.requires_grad = False

    model = get_fsdp_model(args, checkpoint_info)

    spm = checkpoint_info.get_text_tokenizer()

    interleaver = Interleaver(
        spm,
        mimi.frame_rate,
        model.text_padding_token_id,
        model.end_of_text_padding_id,
        model.zero_token_id,
        keep_main_only=True,
    )
    interleaved_tokenizer = InterleavedTokenizer(
        mimi, interleaver, duration_sec=args.duration_sec
    )

    # Data loaders
    data_loader = build_data_loader(
        instruct_tokenizer=interleaved_tokenizer,
        args=args.data,
        batch_size=args.batch_size,
        seed=args.seed,
        rank=get_rank(),
        world_size=get_world_size(),
        is_eval=False,
    )

    if args.do_eval:
        eval_data_loader = build_data_loader(
            instruct_tokenizer=interleaved_tokenizer,
            args=args.data,
            batch_size=args.batch_size,
            seed=None,
            rank=get_rank(),
            world_size=get_world_size(),
            is_eval=True,
        )

    # Optimizer / Scheduler
    param_dtype = getattr(torch, args.param_dtype)
    optim_dtype = torch.float32
    optimizer = AdamW(
        model.parameters(),
        lr=args.optim.lr,
        betas=(0.9, 0.95),
        eps=1e-08,
        weight_decay=args.optim.weight_decay,
    )
    scheduler = lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=args.optim.lr,
        total_steps=args.max_steps,
        pct_start=args.optim.pct_start,
    )

    state = TrainState(args.max_steps)

    # Checkpointer
    if args.do_ckpt:
        checkpointer = Checkpointer(
            model=model,
            state=state,
            config=lm_config,
            run_dir=run_dir,
            optimizer=optimizer,
            num_ckpt_keep=args.num_ckpt_keep,
            full_finetuning=args.full_finetuning,
        )

    # Mixed precision
    prepare_mixed_precision(
        model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
    )

    # Train loop
    model.train()
    torch.cuda.empty_cache()

    while state.step < args.max_steps:
        state.start_step()
        is_last_step = state.step == args.max_steps

        optimizer.zero_grad()
        loss = torch.tensor([0.0], device="cuda")
        n_batch_tokens = 0
        n_real_tokens = 0

        for i in range(args.num_microbatches):
            batch = next(data_loader)
            codes = batch.codes

            condition_tensors = None
            if batch.condition_attributes is not None:
                condition_tensors = model.condition_provider.prepare(
                    batch.condition_attributes
                )

            output = model(codes=codes, condition_tensors=condition_tensors)
            text_loss = compute_loss_with_mask(
                output.text_logits,
                codes[:, : model.audio_offset],
                output.text_mask,
                mode="text",
                text_padding_weight=args.text_padding_weight,
                text_padding_ids={
                    model.text_padding_token_id,
                    model.end_of_text_padding_id,
                },
            )
            audio_loss = compute_loss_with_mask(
                output.logits,
                codes[:, model.audio_offset : model.audio_offset + model.dep_q],
                output.mask,
                mode="audio",
                first_codebook_weight_multiplier=args.first_codebook_weight_multiplier,
            )

            mb_loss = text_loss + audio_loss
            mb_loss.backward()

            loss += mb_loss.detach()
            n_batch_tokens += output.text_mask.numel() + output.mask.numel()
            n_real_tokens += (
                torch.sum(output.text_mask).item() + torch.sum(output.mask).item()
            )

            if i < args.num_microbatches - 1:
                assert args.num_microbatches > 1
                torch.cuda.synchronize()

        if args.num_microbatches > 1:
            loss /= args.num_microbatches
            for p in model.parameters():
                if p.requires_grad:
                    assert p.grad is not None
                    p.grad.div_(args.num_microbatches)

        upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        optimizer.step()
        downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)

        last_lr = scheduler.get_last_lr()[0]
        scheduler.step()

        avg_loss = avg_aggregate(loss.item())

        if args.do_eval and (
            (args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step
        ):
            evaluate(model, eval_data_loader, state, args)
            eval_logs = get_eval_logs(
                state.step, avg_loss, state.this_eval_perplexity, state.this_eval_loss
            )
            main_logger_info(eval_log_msg(eval_logs))
            eval_logger.log(eval_logs, step=state.step)

        state.end_step(n_batch_tokens)

        if state.step % args.log_freq == 0:
            train_logs = get_train_logs(
                state,
                avg_loss,
                n_real_tokens,
                last_lr,
                torch.cuda.max_memory_allocated(),
                torch.cuda.memory_allocated(),
                args,
            )
            main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss))
            metrics_logger.log(train_logs, step=state.step)

        if args.do_ckpt and (
            (args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step
        ):
            checkpointer.save_checkpoint(
                save_only_lora=not args.full_finetuning and args.save_adapters,
                dtype=param_dtype,
            )

    main_logger_info("done!")

In [None]:
# start training

# !cd /content/moshi-finetune && torchrun --nproc-per-node 1 -m train /content/example.yaml

## Inference

Once the model has been trained, inference can be run on the colab GPU too, and gradio can be used to tunnel the audio data from a local client to the notebook.

More details on how to set this up can be found in the [moshi readme](https://github.com/kyutai-labs/moshi?tab=readme-ov-file#python-pytorch).


In [None]:
!pip install gradio

In [None]:
!python -m moshi.server --gradio-tunnel --lora-weight=/content/test/checkpoints/checkpoint_000300/consolidated/lora.safetensors --config-path=/content/test/checkpoints/checkpoint_000300/consolidated/config.json