In [31]:
# Import libraries for generic data preprocessing
import os
import numpy as np
import pandas as pd

# Import BERT transformer libraries
import torch
from torch import tensor
from torch.nn import Softmax
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification

### Set Random Seed

In [3]:
np.random.seed(10)
os.environ["TOKENIZERS_PARALLELISM"] = 'false'

### Initialize BERT Tokenizer and Pre-Trained Classifier

In [10]:
# Initialize BERT tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

# Initialize pre-trained news classifier
model = DistilBertForSequenceClassification.from_pretrained('dkhara/bert-news')

Downloading:   0%|          | 0.00/268M [00:00<?, ?B/s]

### Create Test Article

In [5]:
articles = ['The Republican party won the presidential election in 2016. Implying, Donald Trump was the 45th president of the United States of America.']

### Tokenize Article

In [12]:
# Tokenize articles
tokenized_articles = tokenizer(
    text=articles,
    return_tensors='pt',
    max_length=512,
    truncation=True,
    padding=True
    )

In [13]:
# Extract tokenized input IDs
pt_articles = tokenized_articles['input_ids']

### Test Article on Pre-Trained BERT News Model

In [19]:
# Pass articles into model for predictions
pred = model(pt_articles)

In [21]:
# Retrieve logits from output
logits = pred.logits

In [24]:
# Get predicted cluster
torch.argmax(logits)

# This seems correct!
# The 13th cluster refers to the politics topic

tensor(13)

In [57]:
# Retrieve softmax function for computing probabilities of multi-label classes
sm = Softmax(dim=0)

# Reformat logits
first_article = logits[0].detach()

# Retrieve probabilities rounded to two decimal places
np.around(sm(first_article).numpy(), 2)

# Notice:
# Our model is also very certain the topic is politics

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)