# 🤖 Bài thực hành 2: Huấn luyện mô hình phân loại bệnh với PhoBERT
Trong bài này, sinh viên sẽ fine-tune mô hình PhoBERT để dự đoán bệnh từ câu mô tả triệu chứng.

In [28]:
%pip install transformers torch pandas scikit-learn -q


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.3.1[0m[39;49m -> [0m[32;49m25.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [29]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from sklearn.model_selection import train_test_split

In [35]:
# Bước 1: Load dữ liệu đã xử lý
df = pd.read_csv('/Users/dinhquanghien/Documents/Học tập/pre_2/module/processed_medical_chat_dataset.csv')
df['binary_labels'] = df['binary_labels'].apply(eval)
df.head()

Unnamed: 0,text,labels,labels_list,binary_labels
0,Tôi bị ho và sốt suốt ba ngày nay.,"['viêm phổi', 'sốt virus']","['viêm phổi', 'sốt virus']","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,Tôi thấy đau bụng dữ dội và tiêu chảy.,['ngộ độc thực phẩm'],['ngộ độc thực phẩm'],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,Tôi thường xuyên chóng mặt và nhức đầu.,"['thiếu máu', 'cao huyết áp']","['thiếu máu', 'cao huyết áp']","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,"Khó thở, ho khan và cảm thấy mệt mỏi.","['covid-19', 'viêm phổi']","['covid-19', 'viêm phổi']","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,"Tôi bị ngứa, nôn, mệt mỏi, sụt cân, sốt cao, d...",['Vàng da'],['Vàng da'],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [36]:
# Bước 2: Định nghĩa Dataset
class MedicalChatDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        inputs = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_len, return_tensors='pt')
        return {
            'input_ids': inputs['input_ids'].squeeze(0),
            'attention_mask': inputs['attention_mask'].squeeze(0),
            'labels': torch.FloatTensor(label)
        }

In [37]:
# Bước 3: Định nghĩa mô hình
class PhoBERTClassifier(nn.Module):
    def __init__(self, num_labels):
        super(PhoBERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained("vinai/phobert-base")
        self.dropout = nn.Dropout(0.3)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_labels)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs[0][:, 0]
        x = self.dropout(pooled_output)
        return torch.sigmoid(self.fc(x))

In [38]:
# Bước 4: Khởi tạo DataLoader
from sklearn.model_selection import train_test_split

train_texts, val_texts, train_labels, val_labels = train_test_split(df['text'], df['binary_labels'], test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")

train_dataset = MedicalChatDataset(train_texts.tolist(), train_labels.tolist(), tokenizer)
val_dataset = MedicalChatDataset(val_texts.tolist(), val_labels.tolist(), tokenizer)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4)

In [39]:
# Bước 5: Huấn luyện mô hình với Focal Loss và Class Weights
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss
import numpy as np

# Sử dụng MPS (Metal Performance Shaders) cho Mac M1/M2/M3
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"🖥️  Using device: {device}")

# Calculate positive class weights from training data
train_labels_array = np.array(train_labels.tolist())
pos_counts = train_labels_array.sum(axis=0)
neg_counts = len(train_labels_array) - pos_counts

# Avoid division by zero
pos_counts = np.where(pos_counts == 0, 1, pos_counts)

# Calculate weights: neg/pos ratio (higher weight for rare classes)
pos_weights = neg_counts / pos_counts
pos_weights = torch.FloatTensor(pos_weights).to(device)

print(f"\n📊 Class imbalance statistics:")
print(f"   Min positive samples: {pos_counts.min():.0f}")
print(f"   Max positive samples: {pos_counts.max():.0f}")
print(f"   Avg positive samples: {pos_counts.mean():.1f}")
print(f"   Min weight: {pos_weights.min():.2f}")
print(f"   Max weight: {pos_weights.max():.2f}")
print(f"   Avg weight: {pos_weights.mean():.2f}\n")

model = PhoBERTClassifier(num_labels=len(train_labels.iloc[0])).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=0.01)  # Increased LR

# Use weighted BCE with pos_weights
criterion = nn.BCELoss(reduction='none')  # We'll apply weights manually

def focal_loss(outputs, targets, alpha=0.25, gamma=2.0):
    """Focal loss for handling class imbalance"""
    bce_loss = nn.functional.binary_cross_entropy(outputs, targets, reduction='none')
    pt = torch.where(targets == 1, outputs, 1 - outputs)
    focal_weight = (alpha * targets + (1 - alpha) * (1 - targets)) * ((1 - pt) ** gamma)
    loss = focal_weight * bce_loss
    return loss.mean()

def train_epoch(model, loader, device):
    """Train for one epoch with focal loss"""
    model.train()
    total_loss = 0
    for batch in loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        
        # Use focal loss
        loss = focal_loss(outputs, labels, alpha=0.75, gamma=2.0)  # High alpha for positive class
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # Gradient clipping
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, device, threshold=0.2):
    """Evaluate with adaptive threshold"""
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    total_loss = 0

    with torch.no_grad():
        for batch in loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            outputs = model(input_ids, attention_mask)
            loss = focal_loss(outputs, labels)
            total_loss += loss.item()

            all_probs.append(outputs.cpu().numpy())
            preds = (outputs > threshold).float()
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())

    all_preds = np.vstack(all_preds)
    all_probs = np.vstack(all_probs)
    all_labels = np.vstack(all_labels)

    # Debug info
    print(f"\n   🔍 Debug (threshold={threshold}):")
    print(f"      Avg prob:        {all_probs.mean():.4f} (target: ~{all_labels.mean():.4f})")
    print(f"      Max prob:        {all_probs.max():.4f}")
    print(f"      P75 prob:        {np.percentile(all_probs, 75):.4f}")
    print(f"      P95 prob:        {np.percentile(all_probs, 95):.4f}")
    print(f"      Positive preds:  {all_preds.sum():.0f} (target: {all_labels.sum():.0f})")
    
    # Calculate metrics
    metrics = {
        'loss': total_loss / len(loader),
        'exact_match': accuracy_score(all_labels, all_preds),
        'hamming_acc': 1 - hamming_loss(all_labels, all_preds),
        'precision_micro': precision_score(all_labels, all_preds, average='micro', zero_division=0),
        'recall_micro': recall_score(all_labels, all_preds, average='micro', zero_division=0),
        'f1_micro': f1_score(all_labels, all_preds, average='micro', zero_division=0),
        'f1_samples': f1_score(all_labels, all_preds, average='samples', zero_division=0),
    }

    return metrics

# Training with learning rate warmup
print("\n🚀 Starting training with Focal Loss...\n")
print("=" * 100)
best_f1 = -1.0
epochs = 5
threshold = 0.2  # Lower threshold

for epoch in range(epochs):
    train_loss = train_epoch(model, train_loader, device)
    val_metrics = evaluate(model, val_loader, device, threshold=threshold)
    
    print(f"\n📊 Epoch {epoch+1}/{epochs}")
    print(f"   Train Loss:       {train_loss:.4f}")
    print(f"   Val Loss:         {val_metrics['loss']:.4f}")
    print(f"   Hamming Acc:      {val_metrics['hamming_acc']:.4f}")
    print(f"   Precision (micro): {val_metrics['precision_micro']:.4f}")
    print(f"   Recall (micro):    {val_metrics['recall_micro']:.4f}")
    print(f"   F1 (micro):        {val_metrics['f1_micro']:.4f}")
    print(f"   F1 (samples):      {val_metrics['f1_samples']:.4f}")
    print("=" * 100)
    
    metric_to_track = val_metrics['f1_micro'] if val_metrics['f1_micro'] > 0 else val_metrics['hamming_acc']
    
    if metric_to_track > best_f1:
        best_f1 = metric_to_track
        torch.save(model.state_dict(), 'phobert_medchat_model_best.pt')
        print(f"   ✅ Best model saved! (F1-micro: {val_metrics['f1_micro']:.4f})")

print(f"\n🎉 Training completed! Best F1-micro: {best_f1:.4f}")

🖥️  Using device: mps

📊 Class imbalance statistics:
   Min positive samples: 1
   Max positive samples: 83
   Avg positive samples: 28.0
   Min weight: 27.95
   Max weight: 2403.00
   Avg weight: 1237.58


🚀 Starting training with Focal Loss...


   🔍 Debug (threshold=0.2):
      Avg prob:        0.1614 (target: ~0.0116)
      Max prob:        0.7889
      P75 prob:        0.1893
      P95 prob:        0.2685
      Positive preds:  10519 (target: 601)

📊 Epoch 1/5
   Train Loss:       0.0083
   Val Loss:         0.0049
   Hamming Acc:      0.8078
   Precision (micro): 0.0565
   Recall (micro):    0.9884
   F1 (micro):        0.1068
   F1 (samples):      0.1097
   ✅ Best model saved! (F1-micro: 0.1068)

   🔍 Debug (threshold=0.2):
      Avg prob:        0.1179 (target: ~0.0116)
      Max prob:        0.8814
      P75 prob:        0.1336
      P95 prob:        0.1873
      Positive preds:  1982 (target: 601)

📊 Epoch 2/5
   Train Loss:       0.0020
   Val Loss:         0.0019
   Hamming

In [None]:
# Bước 6: Lưu và kiểm tra mô hình
import os

# Kiểm tra xem có mô hình best không
if os.path.exists('phobert_medchat_model_best.pt'):
    print("✅ Mô hình tốt nhất đã được lưu tại: phobert_medchat_model_best.pt")
    best_model_path = 'phobert_medchat_model_best.pt'
else:
    print("⚠️  Không tìm thấy best model, sử dụng mô hình cuối cùng")
    # Lưu mô hình cuối cùng
    torch.save(model.state_dict(), 'phobert_medchat_model_last.pt')
    print("✅ Mô hình cuối cùng đã được lưu tại: phobert_medchat_model_last.pt")
    best_model_path = 'phobert_medchat_model_last.pt'

# Test inference
print(f"\n🧪 Testing inference với mô hình: {best_model_path}")
model.load_state_dict(torch.load(best_model_path))
model.eval()

test_text = "Tôi bị ho, sốt và khó thở"
test_inputs = tokenizer(test_text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
test_inputs = {k: v.to(device) for k, v in test_inputs.items()}

with torch.no_grad():
    test_output = model(test_inputs['input_ids'], test_inputs['attention_mask'])
    test_probs = test_output.cpu().numpy()[0]
    
    # Show top 5 predictions
    top_5_indices = np.argsort(test_probs)[-5:][::-1]
    
    print(f"\nInput: '{test_text}'")
    print(f"Top 5 predictions:")
    for idx in top_5_indices:
        print(f"   Label {idx}: {test_probs[idx]:.4f}")
    
    # Binary predictions with threshold
    test_preds = (test_output > 0.3).float().cpu().numpy()[0]
    predicted_labels = [i for i, val in enumerate(test_preds) if val == 1]
    print(f"\nPredicted label indices (threshold=0.3): {predicted_labels}")
    
print("✅ Model inference hoạt động tốt!")