In [1]:
!pip -q install transformers accelerate bitsandbytes trl mlflow boto3

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m80.4/80.4 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m60.1/60.1 MB[0m [31m14.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m423.1/423.1 kB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.8/8.8 MB[0m [31m102.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 MB[0m [31m88.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m55.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m139.3/139.3 kB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m14.1/14.1 MB[0m [31m121.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
import os
import torch
from datasets import load_dataset, Dataset, concatenate_datasets
from transformers import (
    AutoModelForCausalLM,          # AutoModel for language modeling tasks
    AutoTokenizer,                # AutoTokenizer for tokenization
    BitsAndBytesConfig,           # Configuration for BitsAndBytes
    TrainingArguments,            # Training arguments for model training
    TrainerCallback
)
from peft import LoraConfig, PeftModel,PeftConfig
from trl import SFTTrainer

import pandas as pd
import numpy as np
import os
import logging as log
from datetime import datetime
import matplotlib.pyplot as plt

import mlflow
from mlflow.tracking import MlflowClient

import warnings
warnings.filterwarnings('ignore')

from data_prep import get_dataset, tokenize_and_mask
from peft_lora_config import Peft_Config

def setup_logging():
    # Remove all handlers associated with the root logger object.
    for handler in log.root.handlers[:]:
        log.root.removeHandler(handler)

    log.basicConfig(
        level=log.WARNING,
        format="%(asctime)s [%(levelname)s] %(message)s",
        handlers=[log.StreamHandler()],
    )

logger = log.getLogger(__name__)
setup_logging()

# Dataset EDA

In [3]:
ds = load_dataset("tarudesu/ViHealthQA")
ds_train = ds["train"].to_pandas()
ds_test = ds["test"].to_pandas()

ds_train = pd.concat([ds_train, ds_test])

README.md: 0.00B [00:00, ?B/s]

train.csv: 0.00B [00:00, ?B/s]

val.csv: 0.00B [00:00, ?B/s]

test.csv: 0.00B [00:00, ?B/s]

Generating train split:   0%|          | 0/7009 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/993 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2013 [00:00<?, ? examples/s]

In [4]:
ds_train

Unnamed: 0,id,question,answer,link
0,1,Đang chích ngừa viêm gan B có chích ngừa Covid...,Nếu anh/chị đang tiêm ngừa vaccine phòng bệnh ...,https://vnexpress.net/tu-van-tiem-vaccine-covi...
1,2,"Đau đầu, căng thẳng do công việc, suy giảm trí...",Tình trạng đau đầu theo bạn mô tả thì chưa rõ....,https://www.vinmec.com/vi/suc-khoe-tong-quat/t...
2,3,Đặt lưu lượng khí hệ thống Jackson-Rees thấp h...,Hệ thống Jackson – Rees dùng khi gây mê để trá...,https://www.vinmec.com/vi/suc-khoe-tong-quat/t...
3,4,Bé 13 tháng tuổi uống thuốc Acyclovir có được ...,Acyclovir có thể sử dụng cho cả trẻ dưới 13 th...,https://www.vinmec.com/vi/suc-khoe-tong-quat/t...
4,5,Vừa qua ngày 4/6 tôi có bị con chó ở nhà cắn x...,Bệnh dại là bệnh nguy hiểm và nếu có chỉ định ...,https://vnexpress.net/tu-van-tiem-vaccine-covi...
...,...,...,...,...
2008,2009,Sốt kèm nhức đầu sau khi ngủ dậy là bệnh gì?,"Bạn có biểu hiện sốt, nhức đầu sau khi ngủ dậy...",https://www.vinmec.com/vi/suc-khoe-tong-quat/t...
2009,2010,Trước Tết em đã làm IVF một lần ở một bệnh việ...,"Theo các nghiên cứu, dự trữ buồng trứng của ng...",https://vnexpress.net/tu-van-vo-sinh-hiem-muon...
2010,2011,Bệnh nhân tiền sử tiểu đường tuýp 2 nóng rát t...,Anh đã xuất hiện biến chứng viêm đa thần kinh ...,https://www.vinmec.com/vi/tin-tuc/hoi-dap-bac-...
2011,2012,Cháu 34 tuổi có tiền sử bị dị ứng với đồ ăn nh...,"Với tiền sử như đã nêu, anh nên thực hiện tiêm...",https://vnexpress.net/tu-van-tiem-vaccine-covi...


# Set up PEFT and Lora Configs

In [5]:
config = Peft_Config()

In [6]:
# Step 2 :Load tokenizer and model with QLoRA configuration
compute_dtype = getattr(torch, config.bnb_4bit_compute_dtype)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=config.use_4bit,
    bnb_4bit_quant_type=config.bnb_4bit_quant_type,
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=config.use_nested_quant,
)

In [7]:
# Step 3 :Check GPU compatibility with bfloat16
if compute_dtype == torch.float16 and config.use_4bit:
    major, _ = torch.cuda.get_device_capability()
    if major >= 8:
        print("=" * 80)
        print("Your GPU supports bfloat16: accelerate training with bf16=True")
        print("=" * 80)

In [8]:
# Step 4 :Load base model
model = AutoModelForCausalLM.from_pretrained(
    config.model_name,
    quantization_config=bnb_config,
    device_map=config.device_map
)
model.config.use_cache = False
model.config.pretraining_tp = 1

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [9]:
# Step 5 :Load LLaMA tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code=True)

special_tokens = {"additional_special_tokens": ["<|user|>", "<|assistant|>"]}
tokenizer.add_special_tokens(special_tokens)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer.padding_side = "right"

tokenizer.chat_template = """{% for message in messages %}
{% if message['role'] == 'user' %}
<|user|>
{{ message['content'].strip() }}
{{ eos_token }}
{% elif message['role'] == 'assistant' %}
<|assistant|>
{{ message['content'].strip() }}
{{ eos_token }}
{% endif %}
{% endfor %}
{% if add_generation_prompt %}
<|assistant|>
{% endif %}"""

model.resize_token_embeddings(len(tokenizer))

model.config.pad_token_id = tokenizer.pad_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.eos_token_id = tokenizer.eos_token_id

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [10]:
# Step 6 :Load LoRA configuration
peft_config = LoraConfig(
    lora_alpha=config.lora_alpha,
    lora_dropout=config.lora_dropout,
    r=config.lora_r,
    bias="none",
    task_type="CAUSAL_LM",
)

## s3 Config

In [11]:
import boto3
from dotenv import load_dotenv
import os

load_dotenv()

S3_BUCKET = "mlflow-artifacts-monitor"
s3_client = boto3.client(
    "s3",
    aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
    aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
    region_name=os.environ["AWS_DEFAULT_REGION"]
)

## mlflow tracking

In [12]:
# mlflow set tracking
url = "https://victoria-communicable-sometimes.ngrok-free.dev"
mlflow.set_tracking_uri(url)
tracking_uri = mlflow.get_tracking_uri()
print(f"Current tracking uri: {tracking_uri}")

Current tracking uri: https://victoria-communicable-sometimes.ngrok-free.dev


In [13]:
mlflow.set_experiment("healthcarechatbot")

<Experiment: artifact_location='mlflow-artifacts:/1', creation_time=1760804990524, experiment_id='1', last_update_time=1760804990524, lifecycle_stage='active', name='healthcarechatbot', tags={}>

## training config

In [14]:
# Step 7 : Set training parameters
training_arguments = TrainingArguments(
    # --- Logging ---
    report_to="mlflow",
    run_name=f"{config.model_name_finetuned}-{datetime.now().strftime('%Y-%m-%d-%H-%M-%s')}",

    # --- Paths & Core training ---
    output_dir=config.output_dir,
    num_train_epochs=config.num_train_epochs,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    optim=config.optim,
    save_steps=config.save_steps,
    logging_steps=config.logging_steps,
    learning_rate=config.learning_rate,
    weight_decay=config.weight_decay,
    fp16=config.fp16,
    bf16=config.bf16,
    max_grad_norm=config.max_grad_norm,
    max_steps=config.max_steps,
    warmup_ratio=config.warmup_ratio,
    group_by_length=config.group_by_length,
    lr_scheduler_type=config.lr_scheduler_type,

    # --- val ---
    eval_steps=200,
    save_strategy="steps",
    eval_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss"
)

In [15]:
ds_train, ds_val = get_dataset()

message = ds_train['messages']

tokenized_chat = tokenizer.apply_chat_template(message, tokenize=True, add_generation_prompt=True, return_tensors="pt")
print(tokenizer.decode(tokenized_chat[0]))

Map:   0%|          | 0/7009 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

<|assistant|> 



In [16]:
ds_train

Dataset({
    features: ['messages'],
    num_rows: 7009
})

In [17]:
ds_train[0]["messages"]

[{'content': 'Đang chích ngừa viêm gan B có chích ngừa Covid-19 được không?',
  'role': 'user'},
 {'content': 'Nếu anh/chị đang tiêm ngừa vaccine phòng bệnh viêm gan B, anh/chị vẫn có thể tiêm phòng vaccine phòng Covid-19, tuy nhiên vaccine Covid-19 phải được tiêm cách trước và sau mũi vaccine viêm gan B tối thiểu là 14 ngày.',
  'role': 'assistant'}]

In [18]:
from mlflow.models import infer_signature

sample = ds_train[0]["messages"]

# MLflow infers schema from the provided sample input/output/params
signature = infer_signature(
  model_input=sample[0]["content"],
  model_output=sample[1]["content"],
  # Parameters are saved with default values if specified
  params={"max_new_tokens": 256, "repetition_penalty": 1.15, "return_full_text": False},
)

signature

inputs: 
  [string (required)]
outputs: 
  [string (required)]
params: 
  ['max_new_tokens': long (default: 256), 'repetition_penalty': double (default: 1.15), 'return_full_text': boolean (default: False)]

In [19]:
max_length = 1024

# Apply mapping (non-batched for simplicity; batched mapping can be used for speed)
tokenized = ds_train.map(
    lambda example: tokenize_and_mask(example, tokenizer, max_length), remove_columns=["messages"])
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
train_dataset = tokenized

tokenized = ds_val.map(
    lambda example: tokenize_and_mask(example, tokenizer, max_length), remove_columns=["messages"])
tokenized.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
val_dataset = tokenized

Map:   0%|          | 0/7009 [00:00<?, ? examples/s]

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [20]:
train_dataset[0]

{'input_ids': tensor([    1, 32000, 29871,  ...,     2,     2,     2]),
 'attention_mask': tensor([1, 1, 1,  ..., 0, 0, 0]),
 'labels': tensor([-100, -100, -100,  ...,    2,    2,    2])}

## Training

In [21]:
from transformers import EarlyStoppingCallback

# Step 8 :Set supervised fine-tuning parameters
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    peft_config=peft_config,
    args=training_arguments,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=1)]
)

Truncating train dataset:   0%|          | 0/7009 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/200 [00:00<?, ? examples/s]

In [22]:
class MLflowLossCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        if logs is not None:
            for k, v in logs.items():
                if isinstance(v, (int, float)):
                    mlflow.log_metric(k, v, step=state.global_step)

trainer.add_callback(MLflowLossCallback)

In [23]:
with mlflow.start_run() as run:
    trainer.train()

    # --- Step 9: Log training loss curve ---
    history = trainer.state.log_history
    steps = [h["step"] for h in history if "loss" in h]
    losses = [h["loss"] for h in history if "loss" in h]

    for step, loss in zip(steps, losses):
        mlflow.log_metric("train_loss", loss, step=step)

    # --- Step 10: Save trained adapter model (checkpoint) ---
    checkpoint_dir = "checkpoint_model"
    trainer.model.save_pretrained(checkpoint_dir)
    tokenizer.save_pretrained(checkpoint_dir)

    # --- Step 11: Merge and save final model with adapter ---
    model = trainer.model
    if hasattr(model, "merge_and_unload"):
        base_model = AutoModelForCausalLM.from_pretrained(
            config.model_name,
            dtype="float32",
            device_map=config.device_map
        )
        base_model.resize_token_embeddings(len(tokenizer))

        model = PeftModel.from_pretrained(base_model, checkpoint_dir)
        model = model.merge_and_unload()

        model.config.pad_token_id = tokenizer.pad_token_id
        model.config.eos_token_id = tokenizer.eos_token_id
        model.config.bos_token_id = tokenizer.bos_token_id

    merged_dir = "merged_model"
    model.save_pretrained(merged_dir)
    tokenizer.save_pretrained(merged_dir)

    # --- Step 12: Upload merged model to S3 ---
    s3_client = boto3.client("s3")
    bucket = "mlflow-artifacts-monitor"
    s3_prefix = f"models/health-llm/{run.info.run_id}"

    for root, _, files in os.walk(merged_dir):
        for f in files:
            path = os.path.join(root, f)
            key = f"{s3_prefix}/{os.path.relpath(path, merged_dir)}"
            s3_client.upload_file(path, bucket, key)

    model_uri = f"s3://{bucket}/{s3_prefix}"

    # --- Step 13: Register the model metadata in MLflow ---
    REGISTERED_MODEL_NAME = "health-llm"

    result = mlflow.register_model(
        model_uri=model_uri,
        name=REGISTERED_MODEL_NAME
    )

    # --- Step 14: Update metadata and tags ---
    client = MlflowClient()

    client.set_registered_model_tag(
        name=REGISTERED_MODEL_NAME, key="use_case", value="patient_service"
    )

    client.update_registered_model(
        name=REGISTERED_MODEL_NAME,
        description="A health-specific chatbot about daily Vietnamese sickness questions"
    )

    client.set_model_version_tag(
        name=REGISTERED_MODEL_NAME,
        version=result.version,
        key="validation_status",
        value="testing",
    )

    # --- Step 15: Create alias for easier reference ---
    client.set_registered_model_alias(
        name=REGISTERED_MODEL_NAME,
        alias="champion",
        version=result.version,
    )

    print(f"Model registered successfully: version {result.version}")
    print(f"S3 path: {model_uri}")
    print(f"MLflow tracking: {run.info.run_id}")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
200,0.6146,0.584878,0.697266,1638400.0,0.860943
400,0.5294,0.541673,0.682713,3276800.0,0.867548
600,0.5505,0.530657,0.680998,4915200.0,0.869088
800,0.5328,0.528639,0.680757,6553600.0,0.86934


Registered model 'health-llm' already exists. Creating a new version of this model...
2025/10/19 17:06:59 INFO mlflow.store.model_registry.abstract_store: Waiting up to 300 seconds for model version to finish creation. Model name: health-llm, version 3
Created version '3' of model 'health-llm'.


Model registered successfully: version 3
S3 path: s3://mlflow-artifacts-monitor/models/health-llm/cee9a10db9eb4f1e99bbf0848ddf86fa
MLflow tracking: cee9a10db9eb4f1e99bbf0848ddf86fa
🏃 View run unruly-smelt-423 at: https://victoria-communicable-sometimes.ngrok-free.dev/#/experiments/1/runs/cee9a10db9eb4f1e99bbf0848ddf86fa
🧪 View experiment at: https://victoria-communicable-sometimes.ngrok-free.dev/#/experiments/1
