In [0]:
dbutils.widgets.text("BASE_MODEL_NAME", "", "Base model")
dbutils.widgets.text("NEW_MODEL_NAME", "", "Train name")
dbutils.widgets.text("EPOCHS", "", "Train epochs")
dbutils.widgets.text("LEARNING_RATE", "", "Train learning rate")
dbutils.widgets.text("DATASET_PATH", "", "Dataset path")

In [0]:
!pip install -q torch transformers datasets peft trl bitsandbytes accelerate tqdm torchvision

%restart_python


In [0]:
import os
import multiprocessing


# Get the number of available CPU threads
num_threads = os.cpu_count() or multiprocessing.cpu_count()
print(f"Number of available threads: {num_threads}")

# Set the number of threads for OpenMP and MKL
os.environ["OMP_NUM_THREADS"] = f"{num_threads}"
os.environ["MKL_NUM_THREADS"] = f"{num_threads}"
os.environ["NUMEXPR_NUM_THREADS"] = f"{num_threads}"
os.environ["BLAS_NUM_THREADS"] = f"{num_threads}"
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Set cache directory
CACHE_DIR = "/tmp/hf_cache"
# os.environ["HF_HOME"] = CACHE_DIR
# os.environ["HF_DATASETS_CACHE"] = CACHE_DIR

# Model names
BASE_MODEL_NAME = dbutils.widgets.get("BASE_MODEL_NAME")
NEW_MODEL_NAME = dbutils.widgets.get("NEW_MODEL_NAME")
EPOCHS = int(dbutils.widgets.get("EPOCHS"))
LEARNING_RATE = float(dbutils.widgets.get("LEARNING_RATE"))

DATASET_PATH = dbutils.widgets.get("DATASET_PATH")
SAVE_MODEL_PATH = f"/Volumes/mlops/default/mlops_volume/{NEW_MODEL_NAME}"


In [0]:
from datasets import load_dataset


VOLUME = "/Volumes/mlops/default/mlops_volume"
dbutils.fs.cp(DATASET_PATH, VOLUME)

LOCAL_PATH = f"{VOLUME}/{os.path.basename(DATASET_PATH)}"

print(f"Loading dataset: {LOCAL_PATH}")
dataset = load_dataset(
    "json",
    data_files=LOCAL_PATH,
    split="train",
    cache_dir=CACHE_DIR
)

print(f"Using {len(dataset):,} records for training")

In [0]:
import mlflow


print(f"Loading base model: {BASE_MODEL_NAME}")

mlflow.set_registry_uri("databricks")
loaded_model_pipeline = mlflow.transformers.load_model(
    BASE_MODEL_NAME,
    device_map="cpu",
    low_cpu_mem_usage=True
)

model = loaded_model_pipeline.model
tokenizer = loaded_model_pipeline.tokenizer

In [0]:
def format_instruction(example):
    messages = [
        {
            "role": "system",
            "content": "You are a coding assistant. Your task is to write a Python function that matches the given documentation.",
        },
        {
            "role": "user",
            "content": f"Write a Python function for the following documentation:\n\n```python\n{example['docstring']}\n```"
        },
        {
            "role": "assistant", 
            "content": example['code']
        },
    ]

    return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)


In [0]:
from peft import LoraConfig
from trl import SFTConfig, SFTTrainer


peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=4,
    bias="none",
    task_type="CAUSAL_LM",
)

training_arguments = SFTConfig(
    output_dir="./trainer_progress",
    
    # max_steps=5,
    per_device_train_batch_size=16,
    
    optim="adamw_torch",
    logging_strategy="steps",
    logging_steps=1,
    save_steps=20,
    learning_rate=LEARNING_RATE,
    
    num_train_epochs=EPOCHS,
    gradient_accumulation_steps=1,
    weight_decay=0.0,
    fp16=False,
    bf16=False,
    max_grad_norm=1.0,
    warmup_ratio=0.0,
    lr_scheduler_type="constant",
    max_seq_length=128,
    packing=False,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    report_to="mlflow",

    ddp_find_unused_parameters=False
)

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=tokenizer,
    args=training_arguments,
    formatting_func=format_instruction
)

In [0]:
print("Launching training...")
trainer.train(resume_from_checkpoint=False)

In [0]:
print("Merging adapter weights with the base model...")
merged_model = trainer.model.merge_and_unload()

print("Saving model...")
merged_model.save_pretrained(SAVE_MODEL_PATH)
tokenizer.save_pretrained(SAVE_MODEL_PATH)

# Log the model
with mlflow.start_run() as run:
    model_info = mlflow.transformers.log_model(
        transformers_model=SAVE_MODEL_PATH,
        artifact_path="model",
        registered_model_name=NEW_MODEL_NAME,
        task="text-generation"
    )

model_version = model_info.registered_model_version
print(f"Model successfully registered as '{NEW_MODEL_NAME}' with version: {model_version}")

# Return the model version
dbutils.notebook.exit({"model": NEW_MODEL_NAME, "version": model_version})