# Fine‑Tuning Qwen2.5 em flashcards médicos com QLoRA

Este notebook demonstra como fazer o fine-tuning de um modelo Qwen2.5 em um conjunto de dados de flashcards médicos usando QLoRA. A abordagem é leve o suficiente para rodar em uma única GPU T4 no Google Colab.

In [None]:
!pip -q install -U pip

!pip -q install "numpy>=2,<2.1"

!pip -q install -U \
  "transformers>=4.41,<5" \
  "datasets>=2.19.0" \
  "accelerate>=0.28.0" \
  "peft>=0.10.0" \
  "bitsandbytes>=0.43.0" \
  "trl>=0.9.6" \
  "evaluate>=0.4.1" \
  "fsspec==2025.3.0" \
  "gcsfs==2025.3.0"

In [None]:
import torch
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
    print(round(torch.cuda.get_device_properties(0).total_memory/1024/1024/1024,2))
else:
    print("CUDA indisponível")

## Configuração

In [None]:
MODEL_NAME = "Qwen/Qwen2.5-1.5B-Instruct"
DATASET_NAME = "flwrlabs/medical-meadow-medical-flashcards"
MAX_SEQ_LEN = 1024
BATCH_SIZE = 4
GRAD_ACCUM_STEPS = 4
LEARNING_RATE = 2e-4
NUM_EPOCHS = 1
OUTPUT_DIR = "outputs"
ADAPTER_OUTPUT_DIR = "outputs/adapter"
MERGED_OUTPUT_DIR = "outputs/merged"
SEED = 42

In [None]:
import random, numpy as np
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

## Carregar o Dataset

In [None]:
from datasets import load_dataset
dataset = load_dataset(DATASET_NAME, split="train")
split = dataset.train_test_split(test_size=0.05, seed=SEED)
train_dataset = split["train"]
eval_dataset = split["test"]

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

## Formatar dados

In [None]:
def apply_template(example):
    user_content = example["instruction"] + "\n" + example["input"]
    messages = [
        {"role": "system", "content": "You are a helpful medical assistant."},
        {"role": "user", "content": user_content},
        {"role": "assistant", "content": example["output"]},
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False)
    return {"text": text}

train_dataset = train_dataset.map(apply_template, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(apply_template, remove_columns=eval_dataset.column_names)

## Tokenizar os prompts

In [None]:
def tokenize_function(sample):
    result = tokenizer(sample["text"], truncation=True, max_length=MAX_SEQ_LEN, padding=False)
    result["labels"] = result["input_ids"].copy()
    return result

train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=train_dataset.column_names)
eval_dataset = eval_dataset.map(tokenize_function, batched=True, remove_columns=eval_dataset.column_names)

## Definir o colator de dados

In [None]:
import torch
def data_collator(features):
    input_ids = [torch.tensor(f["input_ids"], dtype=torch.long) for f in features]
    attention_mask = [torch.tensor(f["attention_mask"], dtype=torch.long) for f in features]
    labels = [torch.tensor(f["labels"], dtype=torch.long) for f in features]
    batch = {}
    batch["input_ids"] = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    batch["attention_mask"] = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
    batch["labels"] = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    return batch

## Carregar o modelo com QLoRA e aplicar adaptadores LoRA

In [None]:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True)
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(r=16, lora_alpha=32, target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"], lora_dropout=0.05, bias="none", task_type="CAUSAL_LM")
model = get_peft_model(model, lora_config)

def print_trainable_parameters(model):
    trainable = 0
    total = 0
    for _, param in model.named_parameters():
        total += param.numel()
        if param.requires_grad:
            trainable += param.numel()
    print(trainable)
    print(total)
    print(trainable/total)
print_trainable_parameters(model)

## Treinamento

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    num_train_epochs=NUM_EPOCHS,
    learning_rate=LEARNING_RATE,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    save_total_limit=1,
    fp16=True,
    bf16=False,
    optim="paged_adamw_8bit",
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator
)

trainer.train()
trainer.evaluate()

## Salvar o adaptador e o tokenizador

In [None]:
model.save_pretrained(ADAPTER_OUTPUT_DIR)
tokenizer.save_pretrained(ADAPTER_OUTPUT_DIR)

## Opcional: mesclar o LoRA ao modelo base

In [None]:
from peft import PeftModel
try:
    base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
    peft_model = PeftModel.from_pretrained(base_model, ADAPTER_OUTPUT_DIR)
    merged_model = peft_model.merge_and_unload()
    merged_model.save_pretrained(MERGED_OUTPUT_DIR)
except Exception as e:
    print(e)

## Demonstração de inferência

In [None]:
from peft import PeftModel
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, quantization_config=bnb_config, device_map="auto", trust_remote_code=True)
model = PeftModel.from_pretrained(model, ADAPTER_OUTPUT_DIR)
model.eval()

def chat(query):
    messages = [
        {"role": "system", "content": "You are a helpful medical assistant."},
        {"role": "user", "content": query}
    ]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=1280, temperature=0.7)
    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return response

queries = [
    "A 67-year-old with atrial fibrillation is on warfarin and starts trimethoprim-sulfamethoxazole. What complication is most likely, and why?",
    "A patient with chronic kidney disease has persistently high phosphate and low calcium. What happens to PTH over time, and what bone changes can result?",
    "After 10 days of clindamycin, a patient develops watery diarrhea and abdominal cramping. What is the likely diagnosis and first-line treatment?",
    "A newborn becomes lethargic with poor feeding and vomiting. Labs show ammonia is markedly elevated with normal glucose. Name a likely metabolic disorder category and an initial management step.",
    "A 24-year-old has fever, dysuria, and flank pain with nausea. Urinalysis shows WBC casts. What diagnosis is most likely and what empiric antibiotics are commonly used?"
]
for q in queries:
    print(q)
    print(chat(q))
    print()