In [1]:
### Fine-tune PhoBERT with AttentionPooling for Text Classification
!pip install torch torchvision torchaudio --quiet
!pip install transformers --quiet
!pip install scikit-learn --quiet

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel, get_scheduler
from torch.optim import AdamW
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
from tqdm import tqdm
import os
import json

### Custom Attention Pooling Layer
import torch.nn as nn

class AttentionPooling(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.attn = nn.Linear(hidden_size, 1)

    def forward(self, hidden_states, mask):
        # hidden_states: (batch, seq_len, hidden)
        # mask: (batch, seq_len)
        scores = self.attn(hidden_states).squeeze(-1)  # (batch, seq_len)
        
        # mask padding tokens
        scores = scores.masked_fill(mask == 0, -1e9)
        
        weights = torch.softmax(scores, dim=1)         # (batch, seq_len)
        pooled = torch.bmm(weights.unsqueeze(1), hidden_states).squeeze(1)
        return pooled


### PhoBERT + AttentionPooling
class PhoBERTWithAttention(nn.Module):
    def __init__(self, model_name="vinai/phobert-base", num_labels=3):
        super().__init__()
        self.phobert = AutoModel.from_pretrained(model_name)
        hidden_size = self.phobert.config.hidden_size
        self.pooling = AttentionPooling(hidden_size)
        self.fc = nn.Linear(hidden_size, num_labels)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.phobert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state  # (batch, seq_len, hidden)
        
        pooled = self.pooling(hidden_states, attention_mask)  # attention pooling
        logits = self.fc(pooled)

        loss = None
        if labels is not None:
            loss = nn.CrossEntropyLoss()(logits, labels)
        return {"loss": loss, "logits": logits}


### Load tokenizer
MODEL_NAME = "vinai/phobert-base"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=False, trust_remote_code=True)
model = PhoBERTWithAttention(MODEL_NAME, num_labels=3)

### Dataset class
class JsonlDataset(Dataset):
    def __init__(self, path, tokenizer, max_len=256):
        self.samples = []
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                self.samples.append(json.loads(line))
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        text = " ".join([m["content"] for m in sample["messages"]])
        
        label_map = {"no": 0, "extrinsic": 1, "intrinsic": 2}
        label = label_map[sample["label"]]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(label, dtype=torch.long)
        }


### Load data
train_dataset = JsonlDataset("../train.jsonl", tokenizer)
val_dataset   = JsonlDataset("../val.jsonl", tokenizer)
test_dataset  = JsonlDataset("../test.jsonl", tokenizer)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False)
test_loader  = DataLoader(test_dataset, batch_size=16, shuffle=False)


### Training config
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)
num_epochs = 3

num_training_steps = num_epochs * len(train_loader)
scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps
)


### Training loop
save_dir = "./checkpoints_attention"
os.makedirs(save_dir, exist_ok=True)

for epoch in range(num_epochs):
    print(f"\n===== Epoch {epoch+1}/{num_epochs} =====")
    
    model.train()
    total_loss = 0
    loop = tqdm(enumerate(train_loader), total=len(train_loader), desc="Training", leave=True)
    for step, batch in loop:
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs["loss"]

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        total_loss += loss.item()
        avg_loss = total_loss / (step + 1)
        loop.set_postfix(loss=loss.item(), avg_loss=avg_loss)

    avg_train_loss = total_loss / len(train_loader)
    print(f"Average training loss for epoch {epoch+1}: {avg_train_loss:.4f}")

    # save checkpoint
    epoch_save_path = os.path.join(save_dir, f"epoch_{epoch+1}")
    os.makedirs(epoch_save_path, exist_ok=True)
    torch.save(model.state_dict(), os.path.join(epoch_save_path, "pytorch_model.bin"))
    tokenizer.save_pretrained(epoch_save_path)
    print(f"Saved checkpoint to {epoch_save_path}")

SAVE_DIR = "./model_attention"
os.makedirs(SAVE_DIR, exist_ok=True)
torch.save(model.state_dict(), os.path.join(SAVE_DIR, "pytorch_model.bin"))
tokenizer.save_pretrained(SAVE_DIR)
print(f"Final model saved to {SAVE_DIR}")

  from .autonotebook import tqdm as notebook_tqdm



===== Epoch 1/3 =====


Training: 100%|██████████| 10/10 [02:05<00:00, 12.58s/it, avg_loss=1.11, loss=1.11]


Average training loss for epoch 1: 1.1148
Saved checkpoint to ./checkpoints_attention\epoch_1

===== Epoch 2/3 =====


Training: 100%|██████████| 10/10 [01:55<00:00, 11.56s/it, avg_loss=1.08, loss=1.06]


Average training loss for epoch 2: 1.0769
Saved checkpoint to ./checkpoints_attention\epoch_2

===== Epoch 3/3 =====


Training: 100%|██████████| 10/10 [01:52<00:00, 11.29s/it, avg_loss=1.07, loss=1.09]


Average training loss for epoch 3: 1.0687
Saved checkpoint to ./checkpoints_attention\epoch_3
Final model saved to ./model_attention
