In [18]:
import torch
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, precision_recall_curve, average_precision_score, roc_auc_score

In [19]:
# Load the merged dataset
merged_dataset = pd.read_csv('../data/merged.csv')

# Encode the labels
labels = merged_dataset['intent'].unique().tolist()
label_map = {label: index for index, label in enumerate(labels)}
merged_dataset['encoded_label'] = merged_dataset['intent'].map(label_map)

# Split the dataset into train, validation, and test
train_data = merged_dataset[merged_dataset['partition'] == 'train']
val_data = merged_dataset[merged_dataset['partition'] == 'val']
test_data = merged_dataset[merged_dataset['partition'] == 'test']


In [20]:
# Load the pre-trained tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')

# Tokenize and encode the training set
train_encodings = tokenizer(list(train_data['utt']), truncation=True, padding=True)
train_labels = torch.tensor(list(train_data['encoded_label']))

# Tokenize and encode the test set
test_encodings = tokenizer(list(test_data['utt']), truncation=True, padding=True)
test_labels = torch.tensor(list(test_data['encoded_label']))


Downloading model.safetensors:   6%|▌         | 41.9M/672M [55:46<13:58:03, 12.5kB/s]


In [21]:
# Fine-tuning with few-shot learning
model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-uncased', num_labels=210)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

train_dataset = torch.utils.data.TensorDataset(
    torch.tensor(train_encodings.input_ids), 
    torch.tensor(train_encodings.attention_mask),
    train_labels
    )
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

epochs = 3

for epoch in range(epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for input_ids, attention_mask, labels in train_loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        _, predictions = torch.max(logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    accuracy = correct / total
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Accuracy = {accuracy:.4f}")

# Save the model
torch.save(model.state_dict(), '../models/few-shot-bert.pt')

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingu

Epoch 1: Train Loss = 3.9039, Accuracy = 0.3344
Epoch 2: Train Loss = 2.0450, Accuracy = 0.6817
Epoch 3: Train Loss = 1.1126, Accuracy = 0.8059


In [25]:
# Data Augmentation to improve model's accuracy
# Random Word Masking: Masking random words in the input, encourage model to learn the context of the sentence

import random

model = BertForSequenceClassification.from_pretrained('bert-base-multilingual-uncased', num_labels=210)
model.load_state_dict(torch.load('../models/few-shot-bert.pt'))
model.to(device)

def random_word_masking(input_ids, attention_mask, mask_prob=0.15):
    masked_input_ids = input_ids.clone()
    for i in range(masked_input_ids.shape[0]):
        input_length = (attention_mask[i] == 1).sum().item()  # Length of the input sequence
        for j in range(input_length):
            if random.random() < mask_prob:
                masked_input_ids[i, j] = tokenizer.mask_token_id
    return masked_input_ids

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
additional_epochs = 3

for epoch in range(additional_epochs):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for input_ids, attention_mask, labels in train_loader:
        # Apply random word masking to the input
        masked_input_ids = random_word_masking(input_ids, attention_mask)

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        masked_input_ids = masked_input_ids.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # Forward pass with both original and masked input
        outputs_orig = model(input_ids, attention_mask=attention_mask, labels=labels)
        outputs_masked = model(masked_input_ids, attention_mask=attention_mask, labels=labels)

        loss = 0.5 * (outputs_orig.loss + outputs_masked.loss)  # Combine the losses

        _, predictions = torch.max(outputs_orig.logits, dim=1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)

        loss.backward()
        optimizer.step()

        train_loss += loss.item()

    train_loss /= len(train_loader)
    accuracy = correct / total
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Accuracy = {accuracy:.4f}")

# Save the updated model
torch.save(model.state_dict(), '../models/few-shot-bert.pt')


Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingu

Epoch 1: Train Loss = 0.8649, Accuracy = 0.8746
Epoch 2: Train Loss = 0.6202, Accuracy = 0.9065
Epoch 3: Train Loss = 0.4824, Accuracy = 0.9277


In [26]:
# Evaluation
model.eval()
test_dataset = torch.utils.data.TensorDataset(
    torch.tensor(test_encodings.input_ids), 
    torch.tensor(test_encodings.attention_mask), 
    test_labels)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

test_loss = 0.0
correct_predictions = 0
total_predictions = 0

with torch.no_grad():
    for input_ids, attention_mask, labels in test_loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        test_loss += loss.item()
        predictions = torch.argmax(logits, dim=1)
        correct_predictions += (predictions == labels).sum().item()
        total_predictions += len(labels)

test_loss /= len(test_loader)
accuracy = correct_predictions / total_predictions

print(f"Test Loss = {test_loss:.4f}")
print(f"Accuracy = {accuracy:.4f}")


Test Loss = 0.5355
Accuracy = 0.8629


In [28]:
# Evaluate the model
predictions = []
true_labels = []
probs = []  # Store probabilities for ROC and precision-recall curves

model.eval()

with torch.no_grad():
    for input_ids, attention_mask, labels in test_loader:
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs.logits, dim=1)
        probabilities = torch.softmax(outputs.logits, dim=1)  # Calculate class probabilities

        predictions.extend(preds.cpu().tolist())
        true_labels.extend(labels.cpu().tolist())
        probs.extend(probabilities.cpu().tolist())

# Convert the probabilities to a numpy array
probs = np.array(probs)

# Print classification report
report = classification_report(true_labels, predictions)
print(report)

# Create Confusion Matrix
cm = confusion_matrix(true_labels, predictions)

# Normalize Confusion Matrix
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Count the occurrences of true labels and predicted labels
true_label_counts = np.bincount(true_labels)
predicted_label_counts = np.bincount(predictions)

# Get the unique labels
labels = np.arange(1, 211)  # Assuming labels range from 1 to 210

# Plot bar chart for true labels
plt.figure(figsize=(12, 8))
plt.bar(labels, true_label_counts, color='blue', alpha=0.7, label='True Labels')
plt.xlabel('Label')
plt.ylabel('Count')
plt.title('Distribution of True Labels')
plt.legend()
plt.show()

# Plot bar chart for predicted labels
plt.figure(figsize=(12, 8))
plt.bar(labels, predicted_label_counts, color='green', alpha=0.7, label='Predicted Labels')
plt.xlabel('Label')
plt.ylabel('Count')
plt.title('Distribution of Predicted Labels')
plt.legend()
plt.show()

KeyboardInterrupt: 