# <center> <font size="+4"> Fine-Tuning Hibiki - 2B </font> </center>

This notebook is an example of how to LoRA finetune Hibiki 2B. Recommended GPU is **A100 GPU**.

Check out `moshi-finetune` Github repo to learn more: https://github.com/kyutai-labs/moshi-finetune/

To replace the LLM with Gemma/LLaMa and still get streaming capabilities, check https://chatgpt.com/share/68a33f04-f260-8011-bc35-2b71120ffb9d as guideline


# <br> 1. Setup

Clone the `hibiki-finetune` repo from my profile `imrnh`:


In [1]:
# # Clone the repository
# !git clone https://github.com/imrnh/hibiki-finetune.git

# # Copy files to current directory (./)
# !cp -r hibiki-finetune/* ./
# !rm -rf hibiki-finetune

# Install deps
# %pip install -e ./

## Imports

In [2]:
from contextlib import ExitStack
import fire
import torch.cuda
import torch.distributed as dist
from torch.optim import AdamW, lr_scheduler
import os
import logging
import shutil
import pprint
import dataclasses
import copy
import yaml
from pathlib import Path
import json

# from torch.profiler import ProfilerActivity, profile

from finetune.args import TrainArgs
from finetune.checkpointing import Checkpointer
from finetune.data.data_loader import build_data_loader
from finetune.data.interleaver import InterleavedTokenizer, Interleaver
from finetune.distributed import (
    BACKEND,
    avg_aggregate,
    get_rank,
    get_world_size,
    is_torchrun,
    set_device,
)
from finetune.eval import evaluate
from finetune.loss import compute_loss_with_mask
from finetune.mixed_precision import (
    downcast_mixed_precision,
    prepare_mixed_precision,
    upcast_mixed_precision,
)
from finetune.monitoring.metrics_logger import (
    MetricsLogger,
    eval_log_msg,
    get_eval_logs,
    get_train_logs,
    train_log_msg,
)
from finetune.monitoring.utils import set_logger
from finetune.utils import TrainState, logged_closing, set_random_seed
from finetune.wrapped_model import get_fsdp_model
from moshi.models import loaders
from moshi.conditioners import ConditionAttributes
from moshi.modules.lora import LoRALinear


logger = logging.getLogger("train")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


# Set environment variables to simulate `torchrun` for a single process.
# This is necessary for `dist.init_process_group` to work correctly.
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355" # A random free port
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"

FRENCH_DATA_PATH = "/kaggle/input/hibiki-stereo-annotated-en-fr/french_stereo.jsonl"

2025-08-22 13:29:09.580289: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755869349.602988    2809 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755869349.609913    2809 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## `yaml` File Generation

In [3]:
config = """
# data
data:
  train_data: '/kaggle/input/hibiki-stereo-annotated-en-fr/english_stereo.jsonl'
  eval_data: ''
  shuffle: true

# model
moshi_paths:
  hf_repo_id: "kyutai/hibiki-2b-pytorch-bf16"

full_finetuning: false # Activate lora.enable if partial finetuning
lora:
  enable: true
  rank: 128
  scaling: 2.
  ft_embed: false

# training hyperparameters
first_codebook_weight_multiplier: 100.
text_padding_weight: .5


# tokens per training steps = batch_size x num_GPUs x duration_sec
# we recommend a sequence duration of 300 seconds
# If you run into memory error, you can try reduce the sequence length
duration_sec: 12
batch_size: 2
max_steps: 15000

gradient_checkpointing: true # Activate checkpointing of layers

# optim
optim:
  lr: 5.e-7
  weight_decay: 0.1
  pct_start: 0.05

# other
seed: 0
eval_freq: 10000
do_eval: False


# Checkpointing
ckpt_freq: 150
log_freq: 150

save_adapters: True

run_dir: "save_dir"
"""

In [4]:
# save the same file locally into the example.yaml file
with open("trainer.yaml", "w") as file:
    yaml.dump(yaml.safe_load(config), file)

# <br> 2. Fine-tune <br>

In [5]:
def main_logger_info(message: str) -> None:
    if get_rank() == 0:
        logger.info(message)

def get_condition_tensors(lm, batch_size: int, cfg_coef: float):
    condition_tensors = {}
    if lm.condition_provider is not None and lm.condition_provider.conditioners:
        conditions: list[ConditionAttributes] | None = None
        conditions = [
            ConditionAttributes(text={"description": "very_good"}, tensor={})
            for _ in range(batch_size)
        ]
        if cfg_coef != 1.0:
            # Extending the conditions with the negatives for the CFG.
            conditions += [
                ConditionAttributes(text={"description": "very_bad"}, tensor={})
                for _ in range(batch_size)
            ]

        assert conditions is not None
        prepared = lm.condition_provider.prepare(conditions)
        condition_tensors = lm.condition_provider(prepared)
    return condition_tensors


def save_training_state(run_dir: Path, state: TrainState, optimizer, scheduler, step: int):
    """Save optimizer, scheduler, and training state for resuming."""
    training_state_dir = run_dir / "checkpoints" / f"checkpoint_{step:06d}" / "consolidated" / "training_state"
    training_state_dir.mkdir(parents=True, exist_ok=True)
    
    if get_rank() == 0:  # Only save on rank 0
        # Save optimizer state
        torch.save(optimizer.state_dict(), training_state_dir / "optimizer.pt")
        
        # Save scheduler state
        torch.save(scheduler.state_dict(), training_state_dir / "scheduler.pt")
        
        # Save training state (step, random states, etc.)
        training_state = {
            'step': state.step,
            'torch_rng_state': torch.get_rng_state(),
            'cuda_rng_state': torch.cuda.get_rng_state() if torch.cuda.is_available() else None,
        }
        torch.save(training_state, training_state_dir / "training_state.pt")
        
        main_logger_info(f"Saved training state at step {step}")


def load_training_state(checkpoint_path: str, state: TrainState, optimizer, scheduler, param_dtype):
    """Load optimizer, scheduler, and training state for resuming."""
    training_state_dir = Path(checkpoint_path) / "training_state"

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    if not training_state_dir.exists():
        main_logger_info(f"No training state found at {training_state_dir}")
        return False
    
    try:
        # Load training state
        training_state_path = training_state_dir / "training_state.pt"
        if training_state_path.exists():
            training_state = torch.load(training_state_path, map_location='cpu')
            state.step = training_state['step']
            torch.set_rng_state(training_state['torch_rng_state'])
            if training_state['cuda_rng_state'] is not None and torch.cuda.is_available():
                torch.cuda.set_rng_state(training_state['cuda_rng_state'])
            main_logger_info(f"Resumed from step {state.step}")
        
        # Load optimizer state
        optimizer_path = training_state_dir / "optimizer.pt"
        if optimizer_path.exists():
            optimizer_state = torch.load(optimizer_path, map_location='cpu')
            
            # # Move optimizer state tensors to correct device and dtype
            # for state in optimizer_state['state'].values():
            #     for key, value in state.items():
            #         if torch.is_tensor(value):
            #             state[key] = value.cuda().to(param_dtype) if torch.cuda.is_available() else value.to(param_dtype)

            # for state in optimizer_state['state'].values():
            #     for k, v in state.items():
            #         if torch.is_tensor(v):
            #             v = v.to(torch.float32)   # keep optimizer state in fp32
            #             if torch.cuda.is_available():
            #                 v = v.cuda()
            #             state[k] = v
            # Force all optimizer state tensors to fp32 on correct device
            for s in optimizer_state['state'].values():
                for k, v in s.items():
                    if torch.is_tensor(v):
                        v = v.to(torch.float32)
                        if torch.cuda.is_available():
                            v = v.to(device)
                        s[k] = v
            
            optimizer.load_state_dict(optimizer_state)
            main_logger_info("Loaded optimizer state")
                
        # Load scheduler state
        scheduler_path = training_state_dir / "scheduler.pt"
        if scheduler_path.exists():
            device = 'cuda' if torch.cuda.is_available() else 'cpu'
            scheduler_state = torch.load(scheduler_path, map_location=device)
            scheduler.load_state_dict(scheduler_state)
            main_logger_info("Loaded scheduler state")
        
        return True
    except Exception as e:
        main_logger_info(f"Failed to load training state: {e}")
        return False


def find_latest_checkpoint(run_dir: Path) -> str | None:
    """Find the latest checkpoint in the run directory."""
    checkpoints_dir = run_dir / "checkpoints"
    if not checkpoints_dir.exists():
        return None
    
    checkpoint_dirs = [d for d in checkpoints_dir.iterdir() if d.is_dir() and d.name.startswith("checkpoint_")]
    if not checkpoint_dirs:
        return None
    
    # Sort by step number
    checkpoint_dirs.sort(key=lambda x: int(x.name.split("_")[1]))
    latest_checkpoint = checkpoint_dirs[-1] / "consolidated" # Want to load training state from root.
    
    if latest_checkpoint.exists():
        return str(latest_checkpoint)
    return None


def load_lora_checkpoint(model, checkpoint_path: str):
    """Load LoRA weights from checkpoint."""
    lora_checkpoint_path = Path(checkpoint_path) / "lora.safetensors"
    consolidated_checkpoint_path = Path(checkpoint_path) / "consolidated.safetensors"
    
    checkpoint_to_load = None
    if lora_checkpoint_path.exists():
        checkpoint_to_load = lora_checkpoint_path
        main_logger_info(f"Loading LoRA checkpoint from {checkpoint_to_load}")
    elif consolidated_checkpoint_path.exists():
        checkpoint_to_load = consolidated_checkpoint_path
        main_logger_info(f"Loading consolidated checkpoint from {checkpoint_to_load}")
    else:
        main_logger_info(f"No checkpoint found at {checkpoint_path}")
        return False
    
    try:
        import safetensors.torch
        state_dict = safetensors.torch.load_file(checkpoint_to_load)
        
        # Load the state dict (this will load only the LoRA parameters if it's a LoRA checkpoint)
        missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
        
        if missing_keys:
            main_logger_info(f"Missing keys when loading checkpoint: {missing_keys[:10]}...")  # Show first 10
        if unexpected_keys:
            main_logger_info(f"Unexpected keys when loading checkpoint: {unexpected_keys[:10]}...")  
            
        main_logger_info("Successfully loaded model checkpoint | Log from function: load_lora_checkpoint")
        return True
    except Exception as e:
        main_logger_info(f"Failed to load checkpoint: {e}")
        return False


def train(config: str):
    args: TrainArgs = TrainArgs.load(config, drop_extra_fields=False)
    
    # Add resume_from_checkpoint attribute to args if it doesn't exist
    if not hasattr(args, 'resume_from_checkpoint'):
        args.resume_from_checkpoint = None
    
    set_logger(logging.INFO)

    with ExitStack() as exit_stack:
        _train(args, exit_stack)
    logger.info("Closed everything!")

In [6]:
the_model = None
the_optimizer = None

def _train(args: TrainArgs, exit_stack: ExitStack):
    # 1. Initial setup and checks
    global the_model
    global the_optimizer
    set_random_seed(args.seed)
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

   
    # Init NCCL
    if "LOCAL_RANK" in os.environ:
        set_device()
        logger.info("Going to init comms...")
        dist.init_process_group(backend=BACKEND)
    else:
        logger.error("PyTorch environment is not correctly initialized. This message should only be displayed when testing.")

    
    # 2. Init run dir
    main_logger_info(f"Run dir: {args.run_dir}")
    run_dir = Path(args.run_dir)

    # Check for resume_from_checkpoint in args or find latest checkpoint
    resume_from_checkpoint = find_latest_checkpoint(run_dir) # set to path of checkpoint.

    if hasattr(args, 'resume_from_checkpoint') and args.resume_from_checkpoint:
        resume_checkpoint = args.resume_from_checkpoint
        main_logger_info(f"\n****Resuming from specified checkpoint: {resume_checkpoint}****\n")
    else:
        # Try to find latest checkpoint
        resume_checkpoint = find_latest_checkpoint(run_dir)
        if resume_checkpoint:
            main_logger_info(f"Found latest checkpoint to resume from: {resume_checkpoint}")

    if is_torchrun() and not resume_checkpoint:
        if run_dir.exists():
            main_logger_info(f"Removing run dir {run_dir}...")
            shutil.rmtree(run_dir)

    if args.full_finetuning:
        assert not args.lora.enable, "LoRA should not be enabled for full finetuning."
    else:
        assert args.lora.enable, "LoRA should be enabled for partial finetuning"

    dist.barrier()
    run_dir.mkdir(exist_ok=True, parents=True)

    args_path = run_dir / "args.yaml"
    if not args_path.exists():
        args.save(args_path)

    main_logger_info(f"TrainArgs: {pprint.pformat(dataclasses.asdict(args))}")

    # 3. Get loggers
    metrics_logger: MetricsLogger = MetricsLogger(
        run_dir,
        tag="train",
        is_master=get_rank() == 0,
        wandb_args=args.wandb,
        config=dataclasses.asdict(args),
    )
    exit_stack.enter_context(logged_closing(metrics_logger, "metrics_logger"))

    eval_logger: MetricsLogger = MetricsLogger(
        run_dir,
        tag="eval",
        is_master=get_rank() == 0,
        wandb_args=args.wandb,
        config=dataclasses.asdict(args),
    )
    exit_stack.enter_context(logged_closing(eval_logger, "eval_logger"))

    # 4.1 Load function calling audio encoder and tokenizer
    main_logger_info("Loading Mimi and Moshi...")
    checkpoint_info = loaders.CheckpointInfo.from_hf_repo(
        hf_repo=args.moshi_paths.hf_repo_id,
        moshi_weights=args.moshi_paths.moshi_path,
        mimi_weights=args.moshi_paths.mimi_path,
        tokenizer=args.moshi_paths.tokenizer_path,
        config_path=args.moshi_paths.config_path,
    )

    lm_config = (
        loaders._lm_kwargs
        if checkpoint_info.raw_config is None
        else checkpoint_info.raw_config
    )
    lm_config["lora"] = args.lora.enable
    lm_config["lora_rank"] = args.lora.rank
    lm_config["lora_scaling"] = args.lora.scaling

    mimi = checkpoint_info.get_mimi(device="cuda")
    mimi.eval()
    for p in mimi.parameters():
        p.requires_grad = False

    # 4.2 Load and shard model, prepare interleaver for audio/text tokens.
    model = get_fsdp_model(args, checkpoint_info)

    # Load checkpoint if resuming
    if resume_checkpoint:
        success = load_lora_checkpoint(model, resume_checkpoint)
        if not success:
            main_logger_info("Failed to load checkpoint, starting from scratch")
            resume_checkpoint = None

    spm = checkpoint_info.get_text_tokenizer()

    interleaver = Interleaver(
        spm,
        mimi.frame_rate,
        model.text_padding_token_id,
        model.end_of_text_padding_id,
        model.zero_token_id,
        keep_main_only=True,
    )
    interleaved_tokenizer = InterleavedTokenizer(
        mimi, interleaver, duration_sec=args.duration_sec
    )

    # 5. Load data loaders
    data_loader = build_data_loader(
        instruct_tokenizer=interleaved_tokenizer,
        args=args.data,
        batch_size=args.batch_size,
        seed=args.seed,
        rank=get_rank(),  # DDP rank
        world_size=get_world_size(),  # DDP world_size
        is_eval=False,
    )

    french_data_args = copy.deepcopy(args.data)
    french_data_args.train_data = FRENCH_DATA_PATH
    
    print(f"French data args: {french_data_args}")

    french_data_loader = build_data_loader(
        instruct_tokenizer=interleaved_tokenizer,
        args=french_data_args,
        batch_size=args.batch_size,
        seed=args.seed,
        rank=get_rank(),  # DDP rank
        world_size=get_world_size(),  # DDP world_size
        is_eval=False,
    )

    saved_tokenizer = interleaved_tokenizer

    if args.do_eval:
        eval_data_loader = build_data_loader(
            instruct_tokenizer=interleaved_tokenizer,
            args=args.data,
            batch_size=args.batch_size,
            seed=None,
            rank=get_rank(),  # DDP rank
            world_size=get_world_size(),  # DDP world_size
            is_eval=True,
        )

    # 6. Load model
    # Define mixed precision
    param_dtype = getattr(torch, args.param_dtype)
    optim_dtype = torch.bfloat16

    # cast_lora_params(model, param_dtype)
    for m in model.modules():
        if isinstance(m, LoRALinear):
            m.lora_A.to(dtype=torch.bfloat16)
            m.lora_B.to(dtype=torch.bfloat16)

    assert args.lora is not None, "`args.lora` should be set to a valid value."
    
    # # Ensure LoRA trainable params are cast to the same dtype as the model
    # for name, p in model.named_parameters():
    #     if p.requires_grad:
    #         p.data = p.data.to(param_dtype)
    #         if p.grad is not None:
    #             p.grad = p.grad.to(param_dtype)
    # print(f"Changed Model parameters to {param_dtype}")
    # model = model.to(dtype=param_dtype)
    # print(f"Model to {param_dtype}")

    # 7. Load optimizer
    optimizer = AdamW(
        model.parameters(),
        lr=args.optim.lr,
        betas=(0.9, 0.95),
        eps=1e-08,
        weight_decay=args.optim.weight_decay,
    )

    scheduler = lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=args.optim.lr,
        total_steps=args.max_steps,
        pct_start=args.optim.pct_start,
    )

    state = TrainState(args.max_steps)

    # Load training state if resuming
    if resume_checkpoint:
        success = load_training_state(resume_checkpoint, state, optimizer, scheduler, param_dtype)
        if success:
            main_logger_info(f"Successfully resumed training from step {state.step}")
        else:
            main_logger_info("Failed to load training state, starting from scratch")

    # 8. Initialize checkpointer
    if args.do_ckpt:
        checkpointer = Checkpointer(
            model=model,
            state=state,
            config=lm_config,
            run_dir=run_dir,
            optimizer=optimizer,
            num_ckpt_keep=args.num_ckpt_keep,
            full_finetuning=args.full_finetuning,
        )
        
    # 9. Prepare mixed precision
    prepare_mixed_precision(
        model.parameters(), param_dtype=param_dtype, optim_dtype=optim_dtype
    )

    for gi, group in enumerate(optimizer.param_groups):
        for pi, p in enumerate(group['params']):
            group['params'][pi] = p.to(optim_dtype)  # re-assign in param group

    # 11. train!
    model.train()
    torch.cuda.empty_cache()

    while state.step < args.max_steps:
        state.start_step()
        is_last_step = state.step == args.max_steps

        optimizer.zero_grad()

        loss = torch.tensor([0.0], device="cuda")
        n_batch_tokens: int = 0
        n_real_tokens: int = 0


        for i in range(args.num_microbatches):
            batch = next(data_loader)
            codes = batch.codes
            
            fr_batch = next(french_data_loader)
            fr_codes = fr_batch.codes

            condition_tensors = None
            if batch.condition_attributes is not None:
                condition_tensors = model.condition_provider.prepare(
                    batch.condition_attributes
                )

            if condition_tensors is None:
                condition_tensors = get_condition_tensors(model, args.batch_size, 1.0)

                    
            the_model = model
            the_optimizer = optimizer
        

            # forward / backward
            output = model(codes=codes, condition_tensors=condition_tensors)

            text_loss_component = fr_codes[:, : model.audio_offset]
            audio_loss_component = fr_codes[:, model.audio_offset : model.audio_offset + model.dep_q]

            # Replace all -1 with -100
            text_loss_component = torch.where(
                text_loss_component == -1, 
                torch.tensor(-100, dtype=text_loss_component.dtype, device=text_loss_component.device), 
                text_loss_component
            )

            # Replace all -1 with -100
            audio_loss_component = torch.where(
                audio_loss_component == -1, 
                torch.tensor(-100, dtype=audio_loss_component.dtype, device=audio_loss_component.device), 
                audio_loss_component
            )
            
            
            text_loss = compute_loss_with_mask(
                output.text_logits,
                text_loss_component,
                output.text_mask,
                mode="text",
                text_padding_weight=args.text_padding_weight,
                text_padding_ids={
                    model.text_padding_token_id,
                    model.end_of_text_padding_id,
                },
            )
            
            audio_loss = compute_loss_with_mask(
                output.logits,
                audio_loss_component,
                output.mask,
                mode="audio",
                first_codebook_weight_multiplier=args.first_codebook_weight_multiplier,
            )

            mb_loss = text_loss + audio_loss
            mb_loss.backward()

            loss += mb_loss.detach()
            n_batch_tokens += output.text_mask.numel() + output.mask.numel()
            n_real_tokens += (
                torch.sum(output.text_mask).item() + torch.sum(output.mask).item()
            )

            if i < args.num_microbatches - 1:
                # synchronize CUDA to re-run backward
                assert args.num_microbatches > 1  # should not happen
                torch.cuda.synchronize()

        if args.num_microbatches > 1:
            loss /= args.num_microbatches
            for p in model.parameters():
                if p.requires_grad:
                    assert p.grad is not None
                    p.grad.div_(args.num_microbatches)

        # upcast params for optimizer update
        upcast_mixed_precision(model.parameters(), optim_dtype=optim_dtype)

        # clip grad norm
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)

        # optimizer step
        optimizer.step()

        # downcast params for forward & backward
        downcast_mixed_precision(model.parameters(), param_dtype=param_dtype)

        last_lr = scheduler.get_last_lr()[0]
        scheduler.step()

        # Host sync
        loss_item = loss.item()
        avg_loss = avg_aggregate(loss_item)

        if args.do_eval and ((args.eval_freq > 0 and state.step % args.eval_freq == 0) or is_last_step):
            # write perplexity to state
            evaluate(model, eval_data_loader, state, args)

            eval_logs = get_eval_logs(state.step, avg_loss, state.this_eval_perplexity, state.this_eval_loss,)

            main_logger_info(eval_log_msg(eval_logs))
            eval_logger.log(eval_logs, step=state.step)

        # Timing
        state.end_step(n_batch_tokens)

        if state.step % args.log_freq == 0:
            train_logs = get_train_logs(state, avg_loss, n_real_tokens, last_lr, torch.cuda.max_memory_allocated(), torch.cuda.memory_allocated(), args,)
            main_logger_info(train_log_msg(state, logs=train_logs, loss=avg_loss))
            metrics_logger.log(train_logs, step=state.step)

        if args.do_ckpt and ((args.ckpt_freq > 0 and state.step % args.ckpt_freq == 0) or is_last_step):
            checkpointer.save_checkpoint(save_only_lora=not args.full_finetuning and args.save_adapters, dtype=param_dtype,)
            
            # Save training state (optimizer, scheduler, random states)
            save_training_state(run_dir, state, optimizer, scheduler, state.step)

    main_logger_info("done!")

## Start Training

In [None]:
train(config='trainer.yaml')

In [None]:
# Clean up the distributed environment
if dist.is_initialized():
    dist.destroy_process_group()

In [None]:
# # Check first 2 model params
# for i, (name, p) in enumerate(the_model.named_parameters()):
#     if p.requires_grad:
#         print(f"Model param[{i}]: {name}, device={p.device}, dtype={p.dtype}")
#     if i >= 1:  # stop after 2
#         break

# # # Check first 2 optimizer params
# for gi, group in enumerate(the_optimizer.param_groups):
#     # for pi, p in enumerate(group['params']):
#     #     if p.grad is not None or p.requires_grad:
#     #         print(f"Optimizer param[{gi}:{pi}]: device={p.device}, dtype={p.dtype}")
#     #     if pi >= 1:  # stop after 2 per group
#     #         break
#     # if gi >= 0:  # only first group
#     #     break
#     for p in group['params']:
#         print(p.dtype, p.requires_grad)

In [None]:
# Update optimizer param type.

# for gi, group in enumerate(the_optimizer.param_groups):
#     for pi, p in enumerate(group['params']):
#         if p.requires_grad:
#             group['params'][pi] = p.to(torch.bfloat16)  # re-assign in param group

In [None]:
# for gi, group in enumerate(the_optimizer.param_groups):
#     # for pi, p in enumerate(group['params']):
#     #     if p.grad is not None or p.requires_grad:
#     #         print(f"Optimizer param[{gi}:{pi}]: device={p.device}, dtype={p.dtype}")
#     #     if pi >= 1:  # stop after 2 per group
#     #         break
#     # if gi >= 0:  # only first group
#     #     break
#     for p in group['params']:
#         print(p.dtype, p.requires_grad)
#         # p.to(torch.bfloat16)

### Debug Mode

In [None]:
# """
# dict_keys(['en_codes', 'fr_codes', 'text_logits', 'text_loss_component', 'text_mask', 'text_padding_weight', 
#     'text_padding_token_id', 'end_of_text_padding_id', 'logits', 'audio_loss_component', 'output_mask', 
#     'first_codebook_weight_multiplier'])
# """


# gcs = gen_codes[-2]

# with torch.enable_grad():
#     text_loss = compute_loss_with_mask(
#         gcs['text_logits'],
#         gcs['text_loss_component'],
#         gcs['text_mask'],
#         mode="text",
#         text_padding_weight=gcs['text_padding_weight'],
#         text_padding_ids={
#             gcs['text_padding_token_id'],
#             gcs['end_of_text_padding_id'],
#         },
#     )
#     audio_loss = compute_loss_with_mask(
#         gcs['logits'],
#         gcs['audio_loss_component'],
#         gcs['output_mask'],
#         mode="audio",
#         first_codebook_weight_multiplier=gcs['first_codebook_weight_multiplier'],
#     )
    
#     print(text_loss)
#     print(audio_loss)