In [8]:
!pip -q install datasets transformers scikit-learn

[0m

In [9]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from sklearn.metrics import accuracy_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
dataset = load_dataset("ag_news")

train_small = dataset["train"].shuffle(seed=42).select(range(2000))
test_small  = dataset["test"].shuffle(seed=42).select(range(500))

label2name = {0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech"}

len(train_small), len(test_small), train_small[0]

(2000,
 500,
 {'text': 'Bangladesh paralysed by strikes Opposition activists have brought many towns and cities in Bangladesh to a halt, the day after 18 people died in explosions at a political rally.',
  'label': 0})

In [11]:
from collections import Counter
from sklearn.metrics import accuracy_score
import torch

label_counts = Counter(int(x["label"]) for x in train_small)
default_label = label_counts.most_common(1)[0][0]

world_keywords = [
    "government", "president", "country", "nation", "war", "minister",
    "election", "vote", "conflict", "leader", "peace", "leader"
]
sports_keywords = [
    "match", "win", "team", "game", "player", "cup", "league",
    "goal", "coach", "season", "tournament", "final", "olympic", "score"
]
business_keywords = [
    "market", "stock", "company", "business", "profit", "trade", "share",
    "investment", "bank", "loan", "revenue", "tax", "oil", "dollar", "gas"
]
scitech_keywords = [
    "technology", "software", "internet", "research", "science",
    "computer", "device", "phone", "chip", "AI", "robot", "network", "data",
    "nasa", "space", "satellite"
]

def baseline_predict(text: str) -> int:
    t = text.lower()
    scores = [0, 0, 0, 0]  # [World, Sports, Business, Sci/Tech]

    for w in world_keywords:
        if w in t:
            scores[0] += 1
    for w in sports_keywords:
        if w in t:
            scores[1] += 1
    for w in business_keywords:
        if w in t:
            scores[2] += 1
    for w in scitech_keywords:
        if w in t:
            scores[3] += 1

    if max(scores) == 0:
        return default_label

    return scores.index(max(scores))

texts  = [x["text"] for x in test_small]
labels = [int(x["label"]) for x in test_small]

baseline_preds = [baseline_predict(t) for t in texts]
baseline_acc = accuracy_score(labels, baseline_preds)

print("Baseline accuracy:", baseline_acc)

Baseline accuracy: 0.53


In [12]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model_name = "textattack/distilbert-base-uncased-ag-news"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): DistilBertSdpaAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)


In [13]:
from sklearn.metrics import accuracy_score

def pipeline_predict(text_list):
    enc = tokenizer(
        text_list,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors="pt"
    )
    enc = {k: v.to(device) for k, v in enc.items()}

    with torch.no_grad():
        outputs = model(**enc)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)

    return preds.cpu().tolist()

texts  = [x["text"] for x in test_small]
labels = [int(x["label"]) for x in test_small]

batch_size = 32
pipeline_preds = []

for i in range(0, len(texts), batch_size):
    batch_text = texts[i:i+batch_size]
    batch_pred = pipeline_predict(batch_text)
    pipeline_preds.extend(batch_pred)

pipeline_acc = accuracy_score(labels, pipeline_preds)

print("Baseline accuracy :", baseline_acc)
print("Pipeline accuracy :", pipeline_acc)

Baseline accuracy : 0.53
Pipeline accuracy : 0.928


In [14]:
label2name = {
    0: "World",
    1: "Sports",
    2: "Business",
    3: "Sci/Tech",
}

print(f"Baseline accuracy on the test set:  {baseline_acc:.4f}")
print(f"Pipeline accuracy on the test set:  {pipeline_acc:.4f}")
print()

print("Examples where the baseline and the pipeline make different predictions:\n")

count = 0
for text, y_true, y_base, y_pipe in zip(texts, labels, baseline_preds, pipeline_preds):
    if y_base != y_pipe and count < 5:
        print("=" * 80)
        print("Text:          ", text)
        print("True label:    ", label2name[y_true])
        print("Baseline pred: ", label2name[y_base])
        print("Pipeline pred: ", label2name[y_pipe])
        count += 1

if count == 0:
    print("The baseline and the pipeline made almost identical predictions.")
    print("Below are three random examples from the test set.\n")
    for i in range(3):
        print("=" * 80)
        print("Text:          ", texts[i])
        print("True label:    ", label2name[labels[i]])
        print("Baseline pred: ", label2name[baseline_preds[i]])
        print("Pipeline pred: ", label2name[pipeline_preds[i]])


Baseline accuracy on the test set:  0.5300
Pipeline accuracy on the test set:  0.9280

Examples where the baseline and the pipeline make different predictions:

Text:           Indian board plans own telecast of Australia series The Indian cricket board said on Wednesday it was making arrangements on its own to broadcast next month #39;s test series against Australia, which is under threat because of a raging TV rights dispute.
True label:     Sports
Baseline pred:  Sci/Tech
Pipeline pred:  Sports
Text:           REVIEW: 'Half-Life 2' a Tech Masterpiece (AP) AP - It's been six years since Valve Corp. perfected the first-person shooter with "Half-Life." Video games have come a long way since, with better graphics and more options than ever. Still, relatively few games have mustered this one's memorable characters and original science fiction story.
True label:     Sci/Tech
Baseline pred:  Sports
Pipeline pred:  Sci/Tech
Text:           China's inflation rate slows sharply but problems r