# Transformers utils

In [None]:
#|default_exp ml.transformers

In [None]:
#|hide
from fastcore.test import *
from nbdev.showdoc import *

In [None]:
#|export
from copy import deepcopy
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from bellek.utils import NestedDict
from bellek.logging import get_logger

log = get_logger(__name__)

In [None]:
#|export

def merge_adapters_and_publish(
    model_id: str,
    torch_dtype=torch.float16,
    device_map={"": 0},
    merged_model_id: str=None,
):
    from peft import AutoPeftModelForCausalLM
    
    if isinstance(torch_dtype, str) and torch_dtype != "auto":
        torch_dtype = getattr(torch, torch_dtype) 

    log.info(f"Loading model and tokenizer for {model_id}")
    model = AutoPeftModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch_dtype,
        device_map=device_map,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    log.info("Merging adapters to model...")
    model = model.merge_and_unload()

    if merged_model_id is None:
        merged_model_id = f"{model_id}-merged"
    log.info(f"Pushing merged model to HF hub as {merged_model_id}")
    model.push_to_hub(merged_model_id)
    tokenizer.push_to_hub(merged_model_id)
    return merged_model_id

In [None]:
#|export

def load_tokenizer_model(
    model_name_or_path: str,
    *,
    auto_model_cls=None,
    device_map={"": 0},
    **model_kwargs,
):
    if auto_model_cls is None:
        if "-peft" in model_name_or_path:
            from peft import AutoPeftModelForCausalLM
            auto_model_cls = AutoPeftModelForCausalLM
        else:
            auto_model_cls = AutoModelForCausalLM

    # Setup quantization config
    if (quantization_config := model_kwargs.get("quantization_config")) and isinstance(quantization_config, dict):
        from transformers import BitsAndBytesConfig
        model_kwargs["quantization_config"] = BitsAndBytesConfig(**quantization_config)
    # Setup torch dtype
    if (torch_dtype := model_kwargs.get("torch_dtype")) and (torch_dtype != "auto"):
        model_kwargs["torch_dtype"] = getattr(torch, torch_dtype)
    # Load model
    model = auto_model_cls.from_pretrained(
        model_name_or_path,
        device_map=device_map,
        **model_kwargs,
    )
    # Load tokenizer
    if auto_model_cls == AutoModelForCausalLM:
        tokenizer_id = model_name_or_path
    else:
        from peft import AutoPeftModelForCausalLM
        if auto_model_cls == AutoPeftModelForCausalLM:
            tokenizer_id = model.active_peft_config.base_model_name_or_path
        else:
            raise ValueError(f"Unknown auto_model_cls: {auto_model_cls}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, trust_remote_code=True)
    return tokenizer, model

In [None]:
#|export

def preprocess_config(config: NestedDict):
    config = deepcopy(config)

    # Set float precision
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        log.info("GPU supports bfloat16.")
        torch_dtype, bf16, fp16, bnb_4bit_compute_dtype = ("bfloat16", True, False, "bfloat16")
    else:
        log.info("GPU does not support bfloat16.")
        torch_dtype, bf16, fp16, bnb_4bit_compute_dtype = ("float16", False, True, "float16")

    if config.at("pretrained_model.torch_dtype"):
        config.set("pretrained_model.torch_dtype", torch_dtype)
    if config.at("pretrained_model.quantization_config.load_in_4bit"):
        config.set("pretrained_model.quantization_config.bnb_4bit_compute_dtype", bnb_4bit_compute_dtype)
    if config.at("trainer.training_args.bf16") or config.at("trainer.training_args.fp16"):
        config.set("trainer.training_args.bf16", bf16)
        config.set("trainer.training_args.fp16", fp16)
    
    return config

In [None]:
#|hide
import nbdev; nbdev.nbdev_export()