## 1. Import libraries

In [7]:
from datasets import load_dataset
from unsloth import FastLanguageModel
from huggingface_hub import login
from dotenv import load_dotenv
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported

## 2. Load dataset

In [4]:
ds = load_dataset("openlifescienceai/medmcqa")
del ds["test"]

### 2.1. Format data

In [None]:
data_prompt = """Choose the correct option for the following question.

### Question:
{}

### Choice:
{}

### Answer:
"""

# Mapping chỉ số sang nhãn
id2label = {
    0: 'A',
    1: 'B',
    2: 'C',
    3: 'D'
}

# Hàm xử lý dữ liệu và tạo prompt
def formatting_prompt(examples):
    questions = examples["question"]
    opas = examples["opa"]
    opbs = examples["opb"]
    opcs = examples["opc"]
    opds = examples["opd"]
    cops = examples["cop"]

    texts = []
    for idx in range(len(questions)):
        question = questions[idx]
        opa = opas[idx]
        opb = opbs[idx]
        opc = opcs[idx]
        opd = opds[idx]
        answer = id2label[cops[idx]]

        # Thêm đáp án đúng vào phần trả lời
        if answer == "A":
            answer += " " + opa
        elif answer == "B":
            answer += " " + opb
        elif answer == "C":
            answer += " " + opc
        elif answer == "D":
            answer += " " + opd

        # Gộp các lựa chọn thành một chuỗi
        choices = f"A. {opa} B. {opb} C. {opc} D. {opd}"
        text = data_prompt.format(question, choices)
        texts.append(text)

    return {"text": texts}

# Áp dụng hàm xử lý lên tập dữ liệu
process_ds = ds.map(formatting_prompt, batched=True)

## 3. Load pre-trained model

In [5]:
# Thiết lập độ dài chuỗi tối đa
max_seq_length = 2048

# Load mô hình đã nén 4-bit
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="unsloth/Llama-3.2-1B-bnb-4bit",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    dtype=None,
)

# Thiết lập PEFT với LoRA
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=16,
    lora_dropout=0,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "up_proj",
        "down_proj", "o_proj", "gate_proj"
    ],
    use_rslora=True,
    use_gradient_checkpointing="unsloth",
    random_state=42,
    loftq_config=None,
)

# In thông tin các tham số có thể huấn luyện
print(model.print_trainable_parameters())


==((====))==  Unsloth 2025.6.12: Fast Llama patching. Transformers: 4.53.1.
   \\   /|    NVIDIA GeForce RTX 4060 Ti. Num GPUs = 1. Max memory: 15.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.6.12 patched 16 layers with 16 QKV layers, 16 O layers and 16 MLP layers.


trainable params: 11,272,192 || all params: 1,247,086,592 || trainable%: 0.9039
None


## 4. Finetuning

In [None]:
# Thiết lập tham số huấn luyện
args = TrainingArguments(
    output_dir="med-mcqa-llama-3.2-3B-4bit-lora",
    logging_dir="logs",
    learning_rate=3e-4,
    lr_scheduler_type="linear",
    per_device_train_batch_size=32,
    gradient_accumulation_steps=8,
    num_train_epochs=2,
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    eval_steps=50,
    save_steps=50,
    logging_steps=50,
    save_total_limit=1,
    load_best_model_at_end=True,
    fp16=not is_bfloat16_supported(),
    bf16=is_bfloat16_supported(),
    optim="adamw_8bit",
    weight_decay=0.01,
    warmup_steps=10,
    seed=0,
)

# Khởi tạo trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=args,
    train_dataset=process_ds["train"],
    eval_dataset=process_ds["validation"],
    dataset_text_field="text",
)

# Bắt đầu huấn luyện
trainer.train()


## 5. Save model

In [None]:
# Đăng nhập Hugging Face Hub bằng access token cá nhân
login(token=load_dotenv('HF_TOKEN'))  # Thay bằng token thật

# Lưu mô hình local
model.save_pretrained("unsloth-llama-trained")

# Tên repo Hugging Face bạn muốn upload mô hình
PEFT_MODEL = "dainlieu/Llama-3.2-3B-bnb-4bit-MedMCQA"

# Đẩy mô hình lên Hugging Face
model.push_to_hub(PEFT_MODEL, use_auth_token=True)

## 6. Inference

In [8]:
def infer_from_hf(
    model_path="dainlieu/Llama-3.2-3B-bnb-4bit-MedMCQA",
    prompt="""Question: What is the capital of France?
Choices:
A. Berlin
B. Paris
C. Madrid
D. Rome
Answer:"""
):
    # ✅ Load mô hình từ Hugging Face đã fine-tune
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = model_path,
        max_seq_length = 2048,
        dtype = None,              # Tự chọn float16/bfloat16
        load_in_4bit = True,
    )

    model.eval()
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=32,
        do_sample=False,
        temperature=0.7,
        top_p=0.95,
    )
    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    print("\n--- Output ---\n", answer)

# Gọi hàm
infer_from_hf()

==((====))==  Unsloth 2025.6.12: Fast Llama patching. Transformers: 4.53.1.
   \\   /|    NVIDIA GeForce RTX 4060 Ti. Num GPUs = 1. Max memory: 15.996 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.7.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.3.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.30. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.6.12 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.
The following generation flags are not valid and may be ignored: ['temperature', 'top_p']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
LlamaForCausalLM has no `_prepare_4d_causal_attention_mask_with_cache_position` method defined in its base modeling class. Compiled forward passes will be sub-optimal. If you're writing code, see Llama for an example implementation. If you're a user, please report this issue on GitHub.



--- Output ---
 Question: What is the capital of France?
Choices:
A. Berlin
B. Paris
C. Madrid
D. Rome
Answer: B
