# Transformers utils for experiments

In [None]:
#|default_exp hf.transformers.experiment

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export
from copy import deepcopy
import torch
from bellek.utils import NestedDict, generate_time_id
from bellek.logging import get_logger

log = get_logger(__name__)

In [None]:
#|export

def preprocess_config(config: NestedDict):
    config = deepcopy(config)

    # Set float precision
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        log.info("GPU supports bfloat16.")
        torch_dtype, bf16, fp16, bnb_4bit_compute_dtype = ("bfloat16", True, False, "bfloat16")
    else:
        log.info("GPU does not support bfloat16.")
        torch_dtype, bf16, fp16, bnb_4bit_compute_dtype = ("float16", False, True, "float16")

    if config.at("pretrained_model.torch_dtype"):
        config.set("pretrained_model.torch_dtype", torch_dtype)
    if config.at("pretrained_model.quantization_config.load_in_4bit"):
        config.set("pretrained_model.quantization_config.bnb_4bit_compute_dtype", bnb_4bit_compute_dtype)
    if config.at("trainer.training_args.bf16") or config.at("trainer.training_args.fp16"):
        config.set("trainer.training_args.bf16", bf16)
        config.set("trainer.training_args.fp16", fp16)

    # Generate unique model id
    model_id = config.at("hfhub.model_id")
    if config.at("trainer.lora"):
        model_id += "-peft"
    if "debug" not in model_id:
        model_id += f"-{generate_time_id()}"
    config.set("hfhub.model_id", model_id)

    return config


In [None]:
#|hide
import nbdev; nbdev.nbdev_export()