code from [Text classification with BERT](https://medium.com/@khang.pham.exxact/text-classification-with-bert-7afaacc5e49b)

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def load_data(data_file):
    df = pd.read_json(data_file, lines=True)
    texts = df['headline'].tolist()
    labels = df['is_sarcastic'].tolist()
    return texts, labels

In [4]:
data_file = 'datasets/Sarcasm_Headlines_Dataset.json'
texts, labels = load_data(data_file)

In [5]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text,
            return_tensors='pt',
            max_length=self.max_length,
            padding='max_length',
            truncation=True)

        return {'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten(),
                'labels': torch.tensor(label)}

In [6]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask)
        pooled_output = outputs.pooler_output
        output = self.dropout(pooled_output)
        logits = self.fc(output)
        return logits

In [7]:
def train(model, data_loader, optimizer, scheduler, device):
    model.train()
    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        outputs = model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [9]:
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    true_labels = []

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask)
            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())
            true_labels.extend(labels.cpu().tolist())
    return accuracy_score(true_labels, predictions), classification_report(true_labels, predictions)

In [10]:
def predict_sarcasm(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids, attention_mask)
        _, preds = torch.max(outputs, dim=1)
    return "sarcastic" if preds.item() == 1 else "not sarcastic"

In [11]:
bert_model_name = 'bert-base-uncased'
num_classes = 2
max_length = 128
batch_size = 16
num_epochs = 4
learning_rate = 2e-5

In [12]:
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)

tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
train_data_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_data_loader = DataLoader(val_dataset, batch_size=batch_size)

In [18]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BERTClassifier(bert_model_name, num_classes).to(device)

In [14]:
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_data_loader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)



In [18]:
for epoch in range(num_epochs):
    print(f"Epoch: {epoch + 1}/{num_epochs}")
    train(model, train_data_loader, optimizer, scheduler, device)
    accuracy, report = evaluate(model, val_data_loader, device)
    print(f'Validation accuracy: {accuracy:.4f}')
    print(report)

Epoch: 1/4
Validation accuracy: 0.9148
              precision    recall  f1-score   support

           0       0.89      0.97      0.93      2996
           1       0.96      0.84      0.90      2346

    accuracy                           0.91      5342
   macro avg       0.92      0.91      0.91      5342
weighted avg       0.92      0.91      0.91      5342

Epoch: 2/4
Validation accuracy: 0.9339
              precision    recall  f1-score   support

           0       0.94      0.94      0.94      2996
           1       0.92      0.93      0.92      2346

    accuracy                           0.93      5342
   macro avg       0.93      0.93      0.93      5342
weighted avg       0.93      0.93      0.93      5342

Epoch: 3/4
Validation accuracy: 0.9292
              precision    recall  f1-score   support

           0       0.92      0.96      0.94      2996
           1       0.94      0.90      0.92      2346

    accuracy                           0.93      5342
   macro av

In [19]:
torch.save(model.state_dict(), 'bert_sarcasm_model.pth')

In [19]:
model.load_state_dict(torch.load('bert_sarcasm_model.pth', weights_only=True))
model.eval()

BERTClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwis

In [20]:
test_text = "thirtysomething scientists unveil doomsday clock of hair loss"
sentiment = predict_sarcasm(test_text, model, tokenizer, device)
print(test_text)
print(f"Sarcastic?: {sentiment}")

thirtysomething scientists unveil doomsday clock of hair loss
Sarcastic?: sarcastic


In [21]:
text = "richard branson's global-warming donation nearly as much as cost of failed balloon trips"
print(text)
print(f"Sarcastic?: {predict_sarcasm(text, model, tokenizer, device)}")

richard branson's global-warming donation nearly as much as cost of failed balloon trips
Sarcastic?: sarcastic


In [22]:
text = "florist who turned away gay couple wants supreme court to hear her case"
print(text)
print(f"Sarcastic?: {predict_sarcasm(text, model, tokenizer, device)}")

florist who turned away gay couple wants supreme court to hear her case
Sarcastic?: not sarcastic


In [23]:
text = "she went to answer the door without knowing it was me all along"
print(text)
print(f"Sarcastic?: {predict_sarcasm(text, model, tokenizer, device)}")

she went to answer the door without knowing it was me all along
Sarcastic?: not sarcastic


In [21]:
text = "bread scientist confused by lack of gluten in blood"
print(text)
print(f"Sarcastic?: {predict_sarcasm(text, model, tokenizer, device)}")

bread scientist confused by lack of gluten in blood
Sarcastic?: sarcastic


In [22]:
other_texts, other_labels = load_data('../data/kaggle_sarcasm/Sarcasm_Headlines_Dataset_v2.json')

In [25]:
for text, label in zip(other_texts[:5], other_labels[:5]):
    print(text)
    print(f"True label: {label}")
    print(f"Predicted label: {predict_sarcasm(text, model, tokenizer, device)}")
    print()

thirtysomething scientists unveil doomsday clock of hair loss
True label: 1
Predicted label: sarcastic

dem rep. totally nails why congress is falling short on gender, racial equality
True label: 0
Predicted label: not sarcastic

eat your veggies: 9 deliciously different recipes
True label: 0
Predicted label: not sarcastic

inclement weather prevents liar from getting to work
True label: 1
Predicted label: sarcastic

mother comes pretty close to using word 'streaming' correctly
True label: 1
Predicted label: sarcastic



In [26]:
tru, fals = 0, 0
for text, label in zip(other_texts, other_labels):
    pred = 1 if predict_sarcasm(text, model, tokenizer, device) == 'sarcastic' else 0
    if pred == label:
        tru += 1
    else:
        fals += 1


In [27]:
print(f"Accuracy: {tru / (tru + fals)}")

Accuracy: 0.9780215940459136
