In [1]:
%%capture
!pip install transformers torch torchvision torchaudio numpy

In [2]:
import torch
from transformers import AutoModel, AutoTokenizer, BertForQuestionAnswering, BertModel, BertTokenizer
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
import torch.optim as optim
from datasets import load_dataset
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
class ClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, qa_prompt="What is the sentiment?"):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.qa_prompt = qa_prompt

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
      inputs = self.tokenizer(self.qa_prompt, self.texts[idx], return_tensors="pt",
                                 padding="max_length", truncation=True, max_length=512).to(device)
      inputs = {k: v.squeeze(0) for k, v in inputs.items()}
      return inputs, torch.tensor(self.labels[idx], dtype=torch.long)

In [5]:
# 2. Định nghĩa các mô hình
class ClassificationModel(nn.Module):
    def __init__(self, input_size, num_classes, hidden_size=256, dropout_prob=0.1):
      super(ClassificationModel, self).__init__()
      self.fc1 = nn.Linear(input_size, hidden_size)
      self.bn1 = nn.BatchNorm1d(hidden_size)
      self.dropout1 = nn.Dropout(dropout_prob)
      self.fc2 = nn.Linear(hidden_size, hidden_size // 2)
      self.bn2 = nn.BatchNorm1d(hidden_size // 2)
      self.dropout2 = nn.Dropout(dropout_prob)
      self.fc3 = nn.Linear(hidden_size // 2, num_classes)

    def forward(self, x):
      x = F.relu(self.fc1(x))
      x = self.bn1(x)
      x = self.dropout1(x)
      x = F.relu(self.fc2(x))
      x = self.bn2(x)
      x = self.dropout2(x)
      x = self.fc3(x)
      return x

In [6]:
# 3. Khởi tạo các model
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
qa_model = AutoModel.from_pretrained("chuthienlong/pretrain_squad")

classification_model = ClassificationModel(768, 5) # 5 class, embedding 768

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



config.json:   0%|          | 0.00/561 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/265M [00:00<?, ?B/s]

In [7]:
# 4. Load datasets
sst5_dataset = load_dataset("SetFit/sst5", split="train[:1000]")
texts = sst5_dataset["text"]
labels = sst5_dataset["label"]
classification_dataset = ClassificationDataset(texts, labels, tokenizer)
classification_dataloader = DataLoader(classification_dataset, batch_size=128, shuffle = True)

README.md:   0%|          | 0.00/421 [00:00<?, ?B/s]

Repo card metadata block was not found. Setting CardData to empty.


train.jsonl:   0%|          | 0.00/1.32M [00:00<?, ?B/s]

dev.jsonl:   0%|          | 0.00/171k [00:00<?, ?B/s]

test.jsonl:   0%|          | 0.00/343k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8544 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1101 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2210 [00:00<?, ? examples/s]

In [8]:
# 5. Train model phân loại
classification_model.to(device)
qa_model.to(device)
optimizer_classification = optim.AdamW(classification_model.parameters(), lr = 1e-4)
scheduler_classification = ReduceLROnPlateau(optimizer_classification, mode="max", factor=0.5, patience=3, verbose=True)
num_epochs_classification = 20
criterion = nn.CrossEntropyLoss()

best_f1 = 0
epochs_no_improve = 0
patience = 5

for epoch in range(num_epochs_classification):
  classification_model.train()
  all_predictions = []
  all_labels = []
  for inputs, labels in tqdm(classification_dataloader, desc=f"Epoch {epoch+1} Classification Model"):
    labels=labels.to(device)
    with torch.no_grad():
        outputs = qa_model(**inputs)
        logits = outputs.last_hidden_state.mean(dim=1).to(device)
        concatenated_logits = logits.to(device)
    outputs = classification_model(concatenated_logits)
    loss = criterion(outputs, labels)
    optimizer_classification.zero_grad()
    loss.backward()
    optimizer_classification.step()
    predictions = torch.argmax(outputs, dim=1)
    all_predictions.extend(predictions.cpu().numpy())
    all_labels.extend(labels.cpu().numpy())
  classification_model.eval()
  accuracy = accuracy_score(all_labels, all_predictions)
  f1 = f1_score(all_labels, all_predictions, average="weighted")
  scheduler_classification.step(f1)
  print(f"Epoch {epoch+1} -  Accuracy: {accuracy:.4f}, F1: {f1:.4f}")
  if f1 > best_f1:
    best_f1 = f1
    epochs_no_improve = 0
  else:
    epochs_no_improve += 1
  if epochs_no_improve > patience:
    print(f"Early stopping at epoch {epoch+1}")
    break
print("Completed")

Epoch 1 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.23s/it]


Epoch 1 -  Accuracy: 0.2300, F1: 0.2294


Epoch 2 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.16s/it]


Epoch 2 -  Accuracy: 0.2880, F1: 0.2865


Epoch 3 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.16s/it]


Epoch 3 -  Accuracy: 0.3640, F1: 0.3638


Epoch 4 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.17s/it]


Epoch 4 -  Accuracy: 0.3870, F1: 0.3856


Epoch 5 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.17s/it]


Epoch 5 -  Accuracy: 0.4040, F1: 0.4017


Epoch 6 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.17s/it]


Epoch 6 -  Accuracy: 0.4490, F1: 0.4476


Epoch 7 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.18s/it]


Epoch 7 -  Accuracy: 0.4890, F1: 0.4859


Epoch 8 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.18s/it]


Epoch 8 -  Accuracy: 0.4990, F1: 0.4990


Epoch 9 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.18s/it]


Epoch 9 -  Accuracy: 0.5160, F1: 0.5140


Epoch 10 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 10 -  Accuracy: 0.5530, F1: 0.5480


Epoch 11 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 11 -  Accuracy: 0.5580, F1: 0.5557


Epoch 12 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 12 -  Accuracy: 0.5640, F1: 0.5628


Epoch 13 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 13 -  Accuracy: 0.5950, F1: 0.5931


Epoch 14 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 14 -  Accuracy: 0.6000, F1: 0.5973


Epoch 15 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 15 -  Accuracy: 0.6260, F1: 0.6224


Epoch 16 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 16 -  Accuracy: 0.6240, F1: 0.6201


Epoch 17 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 17 -  Accuracy: 0.6520, F1: 0.6509


Epoch 18 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 18 -  Accuracy: 0.6610, F1: 0.6590


Epoch 19 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]


Epoch 19 -  Accuracy: 0.6540, F1: 0.6514


Epoch 20 Classification Model: 100%|██████████| 8/8 [00:09<00:00,  1.19s/it]

Epoch 20 -  Accuracy: 0.6720, F1: 0.6709
Completed



