
# Module 1: Symptom Severity Classification using PhoBERT

## Objective
Develop a classification model to predict **eye symptom severity** across four levels:

- **None** ‚Äì No noticeable symptoms
- **Mild** ‚Äì Minor discomfort or early symptoms
- **Moderate** ‚Äì Clear and persistent symptoms
- **Severe** ‚Äì Significant or serious eye strain symptoms

## Methodology
- Use **PhoBERT (vinai/phobert-base-v2)**, a Vietnamese pre-trained BERT-based language model.
- Fine-tune PhoBERT for **multi-class classification**.
- Evaluate using multiple metrics:
  - **Accuracy**
  - **Precision**
  - **Recall**
  - **F1-score**
  - **Confusion Matrix**


## 1. Import Libraries

In [1]:
import os
import json
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import AutoModel, AutoTokenizer, get_linear_schedule_with_warmup
from sklearn.metrics import (
    classification_report, 
    accuracy_score, 
    f1_score,
    precision_score,
    recall_score,
    confusion_matrix,
    roc_auc_score,
    roc_curve
)
from tqdm.auto import tqdm
from typing import Dict, List, Tuple
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

PyTorch version: 2.10.0+cpu
CUDA available: False


## 2. Configuration

In [None]:
# Configuration
CONFIG = {
    'model_name': 'vinai/phobert-base-v2',
    'num_classes': 4,
    'max_length': 128,
    'batch_size': 16,
    'learning_rate': 2e-5,
    'num_epochs': 5,
    'dropout': 0.3,
    'warmup_ratio': 0.1,
    'seed': 42,
    'save_dir': 'models/module1'
}

# Set seed for reproducibility
torch.manual_seed(CONFIG['seed'])
np.random.seed(CONFIG['seed'])

# Device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"\nüñ•Ô∏è  Using device: {device}")

# Label mapping
LABEL_MAP = {
    0: "None",
    1: "Nh·∫π",
    2: "V·ª´a",
    3: "N·∫∑ng"
}

## 3. Dataset Class

In [None]:
class EyeSymptomDataset(Dataset):
    """Dataset for eye symptom severity classification"""
    
    def __init__(self, texts: List[str], labels: List[int], tokenizer, max_length: int = 128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'label': torch.tensor(label, dtype=torch.long)
        }

## 4. Model Architecture

In [None]:
class PhoBERTClassifier(nn.Module):
    """PhoBERT-based symptom classifier with classification head"""
    
    def __init__(self, model_name: str, num_classes: int = 4, dropout: float = 0.3):
        super().__init__()
        
        self.bert = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        
        # Classification head with intermediate layer
        self.classifier = nn.Sequential(
            nn.Linear(self.bert.config.hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, input_ids, attention_mask):
        # Get BERT outputs
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        
        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]
        
        # Apply dropout
        pooled_output = self.dropout(pooled_output)
        
        # Classify
        logits = self.classifier(pooled_output)
        
        return logits

## 5. Metrics and Visualization Functions

In [None]:
def calculate_metrics(y_true, y_pred, y_probs=None):
    """Calculate comprehensive metrics"""
    
    metrics = {
        'accuracy': accuracy_score(y_true, y_pred),
        'precision_macro': precision_score(y_true, y_pred, average='macro', zero_division=0),
        'precision_weighted': precision_score(y_true, y_pred, average='weighted', zero_division=0),
        'recall_macro': recall_score(y_true, y_pred, average='macro', zero_division=0),
        'recall_weighted': recall_score(y_true, y_pred, average='weighted', zero_division=0),
        'f1_macro': f1_score(y_true, y_pred, average='macro', zero_division=0),
        'f1_weighted': f1_score(y_true, y_pred, average='weighted', zero_division=0),
    }
    
    # Per-class metrics
    precision_per_class = precision_score(y_true, y_pred, average=None, zero_division=0)
    recall_per_class = recall_score(y_true, y_pred, average=None, zero_division=0)
    f1_per_class = f1_score(y_true, y_pred, average=None, zero_division=0)
    
    for i in range(len(precision_per_class)):
        metrics[f'precision_class_{i}'] = precision_per_class[i]
        metrics[f'recall_class_{i}'] = recall_per_class[i]
        metrics[f'f1_class_{i}'] = f1_per_class[i]
    
    return metrics


def plot_confusion_matrix(y_true, y_pred, labels, title='Confusion Matrix'):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=labels, yticklabels=labels,
                cbar_kws={'label': 'Count'})
    plt.title(title, fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    return plt.gcf()


def plot_normalized_confusion_matrix(y_true, y_pred, labels, title='Normalized Confusion Matrix'):
    """Plot normalized confusion matrix (percentage)"""
    cm = confusion_matrix(y_true, y_pred)
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='YlOrRd',
                xticklabels=labels, yticklabels=labels,
                cbar_kws={'label': 'Percentage'})
    plt.title(title, fontsize=14, fontweight='bold')
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.tight_layout()
    return plt.gcf()


def plot_metrics_per_class(metrics, labels):
    """Plot precision, recall, F1 per class"""
    num_classes = len(labels)
    
    precision = [metrics[f'precision_class_{i}'] for i in range(num_classes)]
    recall = [metrics[f'recall_class_{i}'] for i in range(num_classes)]
    f1 = [metrics[f'f1_class_{i}'] for i in range(num_classes)]
    
    x = np.arange(len(labels))
    width = 0.25
    
    fig, ax = plt.subplots(figsize=(12, 6))
    ax.bar(x - width, precision, width, label='Precision', alpha=0.8)
    ax.bar(x, recall, width, label='Recall', alpha=0.8)
    ax.bar(x + width, f1, width, label='F1-Score', alpha=0.8)
    
    ax.set_xlabel('Class', fontsize=12)
    ax.set_ylabel('Score', fontsize=12)
    ax.set_title('Metrics per Class', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(labels)
    ax.legend()
    ax.set_ylim(0, 1.1)
    ax.grid(axis='y', alpha=0.3)
    
    plt.tight_layout()
    return fig


def plot_training_history(history):
    """Plot training history"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Loss
    axes[0, 0].plot(history['train_loss'], label='Train Loss', marker='o')
    axes[0, 0].plot(history['val_loss'], label='Val Loss', marker='s')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Loss over Epochs')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[0, 1].plot(history['val_accuracy'], label='Validation Accuracy', 
                    marker='o', color='green')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy')
    axes[0, 1].set_title('Validation Accuracy over Epochs')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # F1 Score
    axes[1, 0].plot(history['val_f1'], label='Validation F1', 
                    marker='o', color='red')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('F1 Score')
    axes[1, 0].set_title('Validation F1 Score over Epochs')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # All metrics combined
    axes[1, 1].plot(history['val_accuracy'], label='Accuracy', marker='o')
    axes[1, 1].plot(history['val_f1'], label='F1 Score', marker='s')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Score')
    axes[1, 1].set_title('Validation Metrics Comparison')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    return fig

## 6. Training Class

In [None]:
class Module1Trainer:
    """Training pipeline for Module 1"""
    
    def __init__(
        self,
        model: nn.Module,
        train_loader: DataLoader,
        val_loader: DataLoader,
        device: str = 'cpu',
        learning_rate: float = 2e-5,
        num_epochs: int = 10
    ):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device
        self.num_epochs = num_epochs
        
        # Optimizer
        self.optimizer = AdamW(self.model.parameters(), lr=learning_rate)
        
        # Learning rate scheduler
        total_steps = len(train_loader) * num_epochs
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=int(0.1 * total_steps),
            num_training_steps=total_steps
        )
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Track best model
        self.best_val_f1 = 0.0
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'val_accuracy': [],
            'val_f1': [],
            'val_precision': [],
            'val_recall': []
        }
    
    def train_epoch(self) -> float:
        """Train for one epoch"""
        self.model.train()
        total_loss = 0
        
        progress_bar = tqdm(self.train_loader, desc="Training")
        
        for batch in progress_bar:
            # Move to device
            input_ids = batch['input_ids'].to(self.device)
            attention_mask = batch['attention_mask'].to(self.device)
            labels = batch['label'].to(self.device)
            
            # Forward pass
            self.optimizer.zero_grad()
            logits = self.model(input_ids, attention_mask)
            loss = self.criterion(logits, labels)
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.scheduler.step()
            
            total_loss += loss.item()
            progress_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        return total_loss / len(self.train_loader)
    
    def evaluate(self, data_loader: DataLoader) -> Dict:
        """Evaluate model with comprehensive metrics"""
        self.model.eval()
        total_loss = 0
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            for batch in tqdm(data_loader, desc="Evaluating"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device)
                
                logits = self.model(input_ids, attention_mask)
                loss = self.criterion(logits, labels)
                
                total_loss += loss.item()
                
                probs = torch.softmax(logits, dim=1)
                preds = torch.argmax(logits, dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
        
        # Calculate comprehensive metrics
        metrics = calculate_metrics(all_labels, all_preds, all_probs)
        metrics['loss'] = total_loss / len(data_loader)
        metrics['predictions'] = all_preds
        metrics['labels'] = all_labels
        metrics['probabilities'] = all_probs
        
        return metrics
    
    def train(self, save_dir: str = "models/module1"):
        """Full training loop"""
        os.makedirs(save_dir, exist_ok=True)
        
        print(f"\nüöÄ Starting training for {self.num_epochs} epochs...")
        
        for epoch in range(self.num_epochs):
            print(f"\nüìç Epoch {epoch + 1}/{self.num_epochs}")
            
            # Train
            train_loss = self.train_epoch()
            
            # Validate
            val_results = self.evaluate(self.val_loader)
            
            # Update history
            self.history['train_loss'].append(train_loss)
            self.history['val_loss'].append(val_results['loss'])
            self.history['val_accuracy'].append(val_results['accuracy'])
            self.history['val_f1'].append(val_results['f1_weighted'])
            self.history['val_precision'].append(val_results['precision_weighted'])
            self.history['val_recall'].append(val_results['recall_weighted'])
            
            # Print metrics
            print(f"\nüìä Metrics:")
            print(f"  Train Loss: {train_loss:.4f}")
            print(f"  Val Loss: {val_results['loss']:.4f}")
            print(f"  Val Accuracy: {val_results['accuracy']:.4f}")
            print(f"  Val Precision (weighted): {val_results['precision_weighted']:.4f}")
            print(f"  Val Recall (weighted): {val_results['recall_weighted']:.4f}")
            print(f"  Val F1 (weighted): {val_results['f1_weighted']:.4f}")
            
            # Save best model
            if val_results['f1_weighted'] > self.best_val_f1:
                self.best_val_f1 = val_results['f1_weighted']
                self.save_model(save_dir, epoch)
                print(f"\n‚úÖ New best model saved! F1: {self.best_val_f1:.4f}")
        
        # Save training history
        with open(f"{save_dir}/training_history.json", 'w', encoding='utf-8') as f:
            json.dump(self.history, f, indent=2, ensure_ascii=False)
        
        print(f"\nüéâ Training completed! Best Val F1: {self.best_val_f1:.4f}")
        
        return self.history
    
    def save_model(self, save_dir: str, epoch: int):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'best_val_f1': self.best_val_f1,
            'history': self.history
        }
        torch.save(checkpoint, f"{save_dir}/best_model.pt")

## 7. Load and Explore Data

C√°ch thu th·∫≠p data https://forms.gle/APoRvHVKu9yAXzaC9

Full data: 1716

Train: 1200

Val: 516 (c√°i d∆∞·ªõi ch∆∞a reset √°)

In [2]:
# Load data
train_df = pd.read_csv(r"D:\School\NLP\train.csv")
val_df = pd.read_csv(r"D:\School\NLP\val.csv")

print(f"üìä Data Statistics:")
print(f"  Training samples: {len(train_df)}")
print(f"  Validation samples: {len(val_df)}")
print(f"\nüìã Training data preview:")
display(train_df.head())

üìä Data Statistics:
  Training samples: 1200
  Validation samples: 172

üìã Training data preview:


Unnamed: 0,text,label,text_cleaned,severity,text_length,word_count
0,Kh√¥ng c√≥ d·∫•u hi·ªáu kh√¥ hay ƒë·ªè m·∫Øt,1,Kh√¥ng c√≥ d·∫•u hi·ªáu kh√¥ hay ƒë·ªè m·∫Øt,,32,8
1,Kh√¥ng c√≥ d·∫•u hi·ªáu kh√¥ hay ƒë·ªè m·∫Øt after long sc...,1,Kh√¥ng c√≥ d·∫•u hi·ªáu kh√¥ hay ƒë·ªè m·∫Øt after long sc...,,55,12
2,"M·∫Øt m·ªèi r√µ r·ªát, ph·∫£i nheo khi nh√¨n",3,"M·∫Øt m·ªèi r√µ r·ªát, ph·∫£i nheo khi nh√¨n",V·ª´a,34,8
3,Kh√¥ng c√≥ d·∫•u hi·ªáu kh√¥ hay ƒë·ªè m·∫Øt,1,Kh√¥ng c√≥ d·∫•u hi·ªáu kh√¥ hay ƒë·ªè m·∫Øt,,32,8
4,"Severe eye strain, hard to keep eyes open t·ª´ s...",4,"Severe eye strain, hard to keep eyes open t·ª´ s...",N·∫∑ng,57,12


In [None]:
# Class distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training set
train_counts = train_df['label'].value_counts().sort_index()
axes[0].bar(range(len(train_counts)), train_counts.values, alpha=0.7)
axes[0].set_xlabel('Severity Level')
axes[0].set_ylabel('Count')
axes[0].set_title('Training Set - Class Distribution')
axes[0].set_xticks(range(len(train_counts)))
axes[0].set_xticklabels([LABEL_MAP[i] for i in range(CONFIG['num_classes'])])
axes[0].grid(axis='y', alpha=0.3)

# Validation set
val_counts = val_df['label'].value_counts().sort_index()
axes[1].bar(range(len(val_counts)), val_counts.values, alpha=0.7, color='orange')
axes[1].set_xlabel('Severity Level')
axes[1].set_ylabel('Count')
axes[1].set_title('Validation Set - Class Distribution')
axes[1].set_xticks(range(len(val_counts)))
axes[1].set_xticklabels([LABEL_MAP[i] for i in range(CONFIG['num_classes'])])
axes[1].grid(axis='y', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nüìà Class distribution:")
print("\nTraining set:")
for idx, count in train_counts.items():
    print(f"  {LABEL_MAP[idx-1]}: {count} ({count/len(train_df)*100:.2f}%)")

print("\nValidation set:")
for idx, count in val_counts.items():
    print(f"  {LABEL_MAP[idx-1]}: {count} ({count/len(val_df)*100:.2f}%)")

## 8. Prepare Data

In [None]:
# Initialize tokenizer
print(f"Loading tokenizer: {CONFIG['model_name']}")
tokenizer = AutoTokenizer.from_pretrained(CONFIG['model_name'])

# Create datasets (convert labels from 1-4 to 0-3)
train_dataset = EyeSymptomDataset(
    train_df['text_cleaned'].tolist(),
    (train_df['label'] - 1).tolist(),  # Convert 1-4 to 0-3
    tokenizer,
    max_length=CONFIG['max_length']
)

val_dataset = EyeSymptomDataset(
    val_df['text_cleaned'].tolist(),
    (val_df['label'] - 1).tolist(),
    tokenizer,
    max_length=CONFIG['max_length']
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG['batch_size'], 
    shuffle=True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG['batch_size']
)

print(f"\n‚úÖ Data preparation completed!")
print(f"  Training batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

## 9. Initialize Model

In [None]:
# Initialize model
print(f"Initializing model: {CONFIG['model_name']}")
model = PhoBERTClassifier(
    model_name=CONFIG['model_name'],
    num_classes=CONFIG['num_classes'],
    dropout=CONFIG['dropout']
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nüìä Model Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"\nüèóÔ∏è  Model architecture:")
print(model)

## 10. Train Model

In [None]:
# Initialize trainer
trainer = Module1Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    learning_rate=CONFIG['learning_rate'],
    num_epochs=CONFIG['num_epochs']
)

# Train
history = trainer.train(save_dir=CONFIG['save_dir'])

## 11. Visualize Training History

In [None]:
# Plot training history
fig = plot_training_history(history)
plt.show()

# Save figure
fig.savefig(f"{CONFIG['save_dir']}/training_history.png", dpi=300, bbox_inches='tight')
print(f"\nüíæ Training history plot saved!")

## 12. Evaluate on Validation Set

In [None]:
# Evaluate on validation set
print("\nüîç Evaluating on validation set...")
val_results = trainer.evaluate(val_loader)

print("\n" + "="*60)
print("FINAL VALIDATION METRICS")
print("="*60)
print(f"\nOverall Metrics:")
print(f"  Accuracy: {val_results['accuracy']:.4f}")
print(f"  Precision (macro): {val_results['precision_macro']:.4f}")
print(f"  Precision (weighted): {val_results['precision_weighted']:.4f}")
print(f"  Recall (macro): {val_results['recall_macro']:.4f}")
print(f"  Recall (weighted): {val_results['recall_weighted']:.4f}")
print(f"  F1-Score (macro): {val_results['f1_macro']:.4f}")
print(f"  F1-Score (weighted): {val_results['f1_weighted']:.4f}")

print(f"\nPer-Class Metrics:")
for i in range(CONFIG['num_classes']):
    print(f"\n  Class {i} ({LABEL_MAP[i]}):")
    print(f"    Precision: {val_results[f'precision_class_{i}']:.4f}")
    print(f"    Recall: {val_results[f'recall_class_{i}']:.4f}")
    print(f"    F1-Score: {val_results[f'f1_class_{i}']:.4f}")

## 13. Detailed Classification Report

In [None]:
# Generate detailed classification report
labels_list = [LABEL_MAP[i] for i in range(CONFIG['num_classes'])]

print("\n" + "="*60)
print("DETAILED CLASSIFICATION REPORT")
print("="*60)
print(classification_report(
    val_results['labels'], 
    val_results['predictions'],
    target_names=labels_list,
    digits=4
))

## 14. Confusion Matrix Visualization

In [None]:
# Plot confusion matrix (counts)
fig1 = plot_confusion_matrix(
    val_results['labels'],
    val_results['predictions'],
    labels_list,
    'Confusion Matrix - Validation Set'
)
plt.show()
fig1.savefig(f"{CONFIG['save_dir']}/confusion_matrix.png", dpi=300, bbox_inches='tight')

# Plot normalized confusion matrix (percentages)
fig2 = plot_normalized_confusion_matrix(
    val_results['labels'],
    val_results['predictions'],
    labels_list,
    'Normalized Confusion Matrix - Validation Set'
)
plt.show()
fig2.savefig(f"{CONFIG['save_dir']}/confusion_matrix_normalized.png", dpi=300, bbox_inches='tight')

print("\nüíæ Confusion matrices saved!")

## 15. Per-Class Metrics Visualization

In [None]:
# Plot per-class metrics
fig = plot_metrics_per_class(val_results, labels_list)
plt.show()
fig.savefig(f"{CONFIG['save_dir']}/metrics_per_class.png", dpi=300, bbox_inches='tight')

print("\nüíæ Per-class metrics plot saved!")

## 16. Prediction Function

In [None]:
class Module1Predictor:
    """Inference class for Module 1"""
    
    def __init__(
        self,
        model_path: str,
        tokenizer_name: str = "vinai/phobert-base-v2",
        device: str = 'cpu'
    ):
        self.device = device
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
        
        # Load model
        self.model = PhoBERTClassifier(tokenizer_name, num_classes=4)
        checkpoint = torch.load(model_path, map_location=device)
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model.to(device)
        self.model.eval()
        
        # Severity mapping
        self.severity_map = LABEL_MAP
    
    def predict(self, text: str, max_length: int = 128) -> Dict:
        """
        Predict severity for a single text
        
        Args:
            text: Input text
            
        Returns:
            Dictionary with prediction results
        """
        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        # Predict
        with torch.no_grad():
            logits = self.model(input_ids, attention_mask)
            probs = torch.softmax(logits, dim=1)
            pred_class = torch.argmax(probs, dim=1).item()
            confidence = probs[0][pred_class].item()
        
        return {
            'label': pred_class + 1,  # Convert back to 1-4 scale
            'severity': self.severity_map[pred_class],
            'confidence': confidence,
            'probabilities': {
                self.severity_map[i]: probs[0][i].item() 
                for i in range(len(self.severity_map))
            },
            'original_text': text
        }

## 17. Test Predictions

In [None]:
# Load predictor
predictor = Module1Predictor(
    model_path=f"{CONFIG['save_dir']}/best_model.pt",
    device=device
)

# Test examples
test_texts = [
    "M·∫Øt t√¥i h∆°i kh√¥ v√† ng·ª©a nh·∫π",
    "M·∫Øt r·∫•t ƒëau v√† ƒë·ªè, nh√¨n m·ªù",
    "C·∫£m gi√°c kh√¥ m·∫Øt m·ªôt ch√∫t th√¥i",
    "M·∫Øt b√¨nh th∆∞·ªùng, kh√¥ng c√≥ tri·ªáu ch·ª©ng g√¨"
]

print("\n" + "="*60)
print("TEST PREDICTIONS")
print("="*60)

for i, text in enumerate(test_texts, 1):
    result = predictor.predict(text)
    
    print(f"\n{i}. Text: {text}")
    print(f"   Predicted Severity: {result['severity']} (Label: {result['label']})")
    print(f"   Confidence: {result['confidence']:.4f}")
    print(f"   Probabilities:")
    for severity, prob in result['probabilities'].items():
        print(f"     {severity}: {prob:.4f}")

## 18. Save Final Results

In [None]:
# Save final results
results_summary = {
    'model_name': CONFIG['model_name'],
    'num_classes': CONFIG['num_classes'],
    'training_config': CONFIG,
    'best_val_f1': trainer.best_val_f1,
    'final_metrics': {
        'accuracy': val_results['accuracy'],
        'precision_macro': val_results['precision_macro'],
        'precision_weighted': val_results['precision_weighted'],
        'recall_macro': val_results['recall_macro'],
        'recall_weighted': val_results['recall_weighted'],
        'f1_macro': val_results['f1_macro'],
        'f1_weighted': val_results['f1_weighted']
    },
    'per_class_metrics': {
        LABEL_MAP[i]: {
            'precision': val_results[f'precision_class_{i}'],
            'recall': val_results[f'recall_class_{i}'],
            'f1': val_results[f'f1_class_{i}']
        }
        for i in range(CONFIG['num_classes'])
    }
}

with open(f"{CONFIG['save_dir']}/results_summary.json", 'w', encoding='utf-8') as f:
    json.dump(results_summary, f, indent=2, ensure_ascii=False)

print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETED SUCCESSFULLY!")
print("="*60)
print(f"\nüìÅ All results saved to: {CONFIG['save_dir']}")
print(f"\nüìä Files saved:")
print(f"  - best_model.pt")
print(f"  - training_history.json")
print(f"  - results_summary.json")
print(f"  - training_history.png")
print(f"  - confusion_matrix.png")
print(f"  - confusion_matrix_normalized.png")
print(f"  - metrics_per_class.png")

## 19. Summary

### What this notebook does:

1. **Data Loading**: Loads training and validation datasets for eye symptom classification
2. **Data Preprocessing**: Tokenizes Vietnamese text using PhoBERT tokenizer
3. **Model Architecture**: Uses PhoBERT with a custom classification head
4. **Training**: Fine-tunes the model with AdamW optimizer and learning rate scheduling
5. **Evaluation**: Comprehensive metrics including:
   - Accuracy
   - Precision (macro & weighted)
   - Recall (macro & weighted)
   - F1-Score (macro & weighted)
   - Per-class metrics
   - Confusion Matrix
6. **Visualization**: Training curves, confusion matrices, and per-class performance
7. **Inference**: Easy-to-use predictor class for making predictions on new text

### Key Features:
- Multi-class classification (4 severity levels)
- Vietnamese language support via PhoBERT
- Comprehensive evaluation metrics
- Beautiful visualizations
- Model checkpointing
- Easy inference interface