In [1]:
import random
import torch
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# enable tqdm in pandas
tqdm.pandas()

In [2]:
# set to True to use the gpu (if there is one available)
use_gpu = True

# select device
device = torch.device('cuda' if use_gpu and torch.cuda.is_available() else 'cpu')
print(f'device: {device.type}')

device: cpu


In [3]:
# random seed
seed = 33

# set random seed
if seed is not None:
    print(f'random seed: {seed}')
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

random seed: 33


In [4]:
train_df = pd.read_csv('../data/ag_news_csv/train.csv', header=None)
train_df.columns = ['class index', 'title', 'description']
train_df

Unnamed: 0,class index,title,description
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."
...,...,...,...
119995,1,Pakistan's Musharraf Says Won't Quit as Army C...,KARACHI (Reuters) - Pakistani President Perve...
119996,2,Renteria signing a top-shelf deal,Red Sox general manager Theo Epstein acknowled...
119997,2,Saban not going to Dolphins yet,The Miami Dolphins will put their courtship of...
119998,2,Today's NFL games,PITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...


In [5]:
labels = open('../data/ag_news_csv/classes.txt').read().splitlines()
classes = train_df['class index'].map(lambda i: labels[i-1])
train_df.insert(1, 'class', classes)
train_df

Unnamed: 0,class index,class,title,description
0,3,Business,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli..."
1,3,Business,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...
2,3,Business,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...
3,3,Business,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...
4,3,Business,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco..."
...,...,...,...,...
119995,1,World,Pakistan's Musharraf Says Won't Quit as Army C...,KARACHI (Reuters) - Pakistani President Perve...
119996,2,Sports,Renteria signing a top-shelf deal,Red Sox general manager Theo Epstein acknowled...
119997,2,Sports,Saban not going to Dolphins yet,The Miami Dolphins will put their courtship of...
119998,2,Sports,Today's NFL games,PITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...


In [6]:
pd.Series(train_df['class']).value_counts()

class
Business    30000
Sci/Tech    30000
Sports      30000
World       30000
Name: count, dtype: int64

However, the text contains some spurious backslashes in some parts of the text. They are meant to represent newlines in the original text. An example can be seen below, between the words "dwindling" and "band".

In [7]:
print(train_df.loc[0, 'description'])

Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.


We will replace the backslashes with spaces on the whole column using pandas replace method.

In [8]:
train_df['text'] = train_df['title'].str.lower() + " " + train_df['description'].str.lower()
train_df['text'] = train_df['text'].str.replace('\\', ' ', regex=False)
train_df

Unnamed: 0,class index,class,title,description,text
0,3,Business,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli...",wall st. bears claw back into the black (reute...
1,3,Business,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...,carlyle looks toward commercial aerospace (reu...
2,3,Business,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...,oil and economy cloud stocks' outlook (reuters...
3,3,Business,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...,iraq halts oil exports from main southern pipe...
4,3,Business,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco...","oil prices soar to all-time record, posing new..."
...,...,...,...,...,...
119995,1,World,Pakistan's Musharraf Says Won't Quit as Army C...,KARACHI (Reuters) - Pakistani President Perve...,pakistan's musharraf says won't quit as army c...
119996,2,Sports,Renteria signing a top-shelf deal,Red Sox general manager Theo Epstein acknowled...,renteria signing a top-shelf deal red sox gene...
119997,2,Sports,Saban not going to Dolphins yet,The Miami Dolphins will put their courtship of...,saban not going to dolphins yet the miami dolp...
119998,2,Sports,Today's NFL games,PITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...,today's nfl games pittsburgh at ny giants time...


In [None]:
#from nltk.tokenize import word_tokenize
# train_df['tokens'] = train_df['text'].apply(word_tokenize)
# train_df

In [9]:
from nltk.tokenize import TweetTokenizer

tokenizer = TweetTokenizer()
train_df['tokens'] = train_df['text'].apply(tokenizer.tokenize)
train_df

Unnamed: 0,class index,class,title,description,text,tokens
0,3,Business,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli...",wall st. bears claw back into the black (reute...,"[wall, st, ., bears, claw, back, into, the, bl..."
1,3,Business,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...,carlyle looks toward commercial aerospace (reu...,"[carlyle, looks, toward, commercial, aerospace..."
2,3,Business,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...,oil and economy cloud stocks' outlook (reuters...,"[oil, and, economy, cloud, stocks, ', outlook,..."
3,3,Business,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...,iraq halts oil exports from main southern pipe...,"[iraq, halts, oil, exports, from, main, southe..."
4,3,Business,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco...","oil prices soar to all-time record, posing new...","[oil, prices, soar, to, all-time, record, ,, p..."
...,...,...,...,...,...,...
119995,1,World,Pakistan's Musharraf Says Won't Quit as Army C...,KARACHI (Reuters) - Pakistani President Perve...,pakistan's musharraf says won't quit as army c...,"[pakistan's, musharraf, says, won't, quit, as,..."
119996,2,Sports,Renteria signing a top-shelf deal,Red Sox general manager Theo Epstein acknowled...,renteria signing a top-shelf deal red sox gene...,"[renteria, signing, a, top-shelf, deal, red, s..."
119997,2,Sports,Saban not going to Dolphins yet,The Miami Dolphins will put their courtship of...,saban not going to dolphins yet the miami dolp...,"[saban, not, going, to, dolphins, yet, the, mi..."
119998,2,Sports,Today's NFL games,PITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...,today's nfl games pittsburgh at ny giants time...,"[today's, nfl, games, pittsburgh, at, ny, gian..."


Now we will load the GloVe word embeddings.

In [10]:
from gensim.models import KeyedVectors
glove = KeyedVectors.load_word2vec_format("../data/glove/glove.6B.300d.txt", no_header=True)
glove.vectors.shape

(400000, 300)

The word embeddings have been pretrained in a different corpus, so it would be a good idea to estimate how good our tokenization matches the GloVe vocabulary.

In [11]:
from collections import Counter

def count_unknown_words(data, vocabulary):
    counter = Counter()
    for row in tqdm(data):
        counter.update(tok for tok in row if tok not in vocabulary)
    return counter

# find out how many times each unknown token occurrs in the corpus
c = count_unknown_words(train_df['tokens'], glove.key_to_index)

# find the total number of tokens in the corpus
total_tokens = train_df['tokens'].map(len).sum()

# find some statistics about occurrences of unknown tokens
unk_tokens = sum(c.values())
percent_unk = unk_tokens / total_tokens
distinct_tokens = len(list(c))

print(f'total number of tokens: {total_tokens:,}')
print(f'number of unknown tokens: {unk_tokens:,}')
print(f'number of distinct unknown tokens: {distinct_tokens:,}')
print(f'percentage of unkown tokens: {percent_unk:.2%}')
print('top 50 unknown words:')
for token, n in c.most_common(10):
    print(f'\t{n}\t{token}')

  0%|          | 0/120000 [00:00<?, ?it/s]

total number of tokens: 5,226,844
number of unknown tokens: 131,074
number of distinct unknown tokens: 22,558
percentage of unkown tokens: 2.51%
top 50 unknown words:
	44316	#39
	2984	</b>
	2983	<b>
	2119	href
	2117	</a>
	1813	=/
	1813	quickinfo
	1813	fullquote
	1307	#36
	1023	world's


Glove embeddings seem to have a good coverage on this dataset -- only 1.25% of the tokens in the dataset are unknown, i.e., don't appear in the GloVe vocabulary.

Still, we will need a way to handle these unknown tokens. Our approach will be to add a new embedding to GloVe that will be used to represent them. This new embedding will be initialized as the average of all the GloVe embeddings.

We will also add another embedding, this one initialized to zeros, that will be used to pad the sequences of tokens so that they all have the same length. This will be useful when we train with mini-batches.

In [12]:
# string values corresponding to the new embeddings
unk_tok = '[UNK]'
pad_tok = '[PAD]'

# initialize the new embedding values
unk_emb = glove.vectors.mean(axis=0)
pad_emb = np.zeros(300)

# add new embeddings to glove
glove.add_vectors([unk_tok, pad_tok], [unk_emb, pad_emb])

# get token ids corresponding to the new embeddings
unk_id = glove.key_to_index[unk_tok]
pad_id = glove.key_to_index[pad_tok]

unk_id, pad_id

(400000, 400001)

In [15]:
from sklearn.model_selection import train_test_split

train_df, dev_df = train_test_split(train_df, train_size=0.8)
train_df.reset_index(inplace=True)
dev_df.reset_index(inplace=True)

We will now add a new column to our dataframe that will contain the padded sequences of token ids.

In [16]:
threshold = 10
tokens = train_df['tokens'].explode().value_counts()
vocabulary = set(tokens[tokens > threshold].index.tolist())
print(f'vocabulary size: {len(vocabulary):,}')

vocabulary size: 15,561


In [17]:
# find the length of the longest list of tokens
max_tokens = train_df['tokens'].map(len).max()

# return unk_id for infrequent tokens too
def get_id(tok):
    if tok in vocabulary:
        return glove.key_to_index.get(tok, unk_id)
    else:
        return unk_id

# function that gets a list of tokens and returns a list of token ids,
# with padding added accordingly
def token_ids(tokens):
    tok_ids = [get_id(tok) for tok in tokens]
    pad_len = max_tokens - len(tok_ids)
    return tok_ids + [pad_id] * pad_len

# add new column to the dataframe
train_df['token ids'] = train_df['tokens'].progress_map(token_ids)
train_df

  0%|          | 0/76800 [00:00<?, ?it/s]

Unnamed: 0,level_0,index,class index,class,title,description,text,tokens,token ids
0,78634,94532,4,Sci/Tech,DuPont Faces New Complaint,Chemical giant DuPont Co. withheld information...,dupont faces new complaint chemical giant dupo...,"[dupont, faces, new, complaint, chemical, gian...","[14424, 1919, 50, 4499, 2291, 1752, 14424, 164..."
1,86794,109473,3,Business,"Weldon: Reports Say JNJ, Guidant Merger Seen W...",Johnson amp; Johnson (nyse: JNJ - news - peop...,"weldon: reports say jnj, guidant merger seen w...","[weldon, :, reports, say, jnj, ,, guidant, mer...","[400000, 45, 687, 203, 281164, 1, 31027, 3176,..."
2,8634,98666,3,Business,Oracle claims support of PeopleSoft shareholders,PeopleSoft has remained defiant in the long-ru...,oracle claims support of peoplesoft shareholde...,"[oracle, claims, support, of, peoplesoft, shar...","[9094, 1267, 280, 3, 32099, 3258, 32099, 31, 1..."
3,32942,68360,2,Sports,Sixers grab preseason win over Spurs,Another victory escaped the clutches of the Sa...,sixers grab preseason win over spurs another v...,"[sixers, grab, preseason, win, over, spurs, an...","[25679, 7987, 10387, 320, 74, 7506, 170, 651, ..."
4,23026,1629,3,Business,US consumer prices decline unexpectedly in July,"NEW YORK, August 17 (New Ratings) The US cons...",us consumer prices decline unexpectedly in jul...,"[us, consumer, prices, decline, unexpectedly, ...","[95, 1493, 468, 1943, 9052, 6, 375, 50, 196, 1..."
...,...,...,...,...,...,...,...,...,...
76795,86026,77640,4,Sci/Tech,It #39;s flak jacket time for Microsoft #39;s ...,What kind of E-mail is landing in Martin Taylo...,it #39;s flak jacket time for microsoft #39;s ...,"[it, #39, ;, s, flak, jacket, time, for, micro...","[20, 400000, 89, 1534, 400000, 400000, 79, 10,..."
76796,75648,79026,4,Sci/Tech,CA #39;s Open-Source Ingres Now Available,The company is shipping the open-source databa...,ca #39;s open-source ingres now available the ...,"[ca, #39, ;, s, open-source, ingres, now, avai...","[855, 400000, 89, 1534, 39943, 400000, 114, 77..."
76797,70712,33312,1,World,Iraqi PM: #39;Terrorists pouring in #39;,Iraq #39;s interim Prime Minister Iyad Allawi ...,iraqi pm: #39;terrorists pouring in #39; iraq...,"[iraqi, pm, :, #39, ;, terrorists, pouring, in...","[710, 3345, 45, 400000, 89, 2712, 11545, 6, 40..."
76798,27587,3777,1,World,Sharon presses on with Gaza plan,JERUSALEM (Reuters) - Israeli Prime Minister A...,sharon presses on with gaza plan jerusalem (re...,"[sharon, presses, on, with, gaza, plan, jerusa...","[2548, 15704, 13, 17, 1166, 394, 2013, 23, 108..."


In [18]:
max_tokens = dev_df['tokens'].map(len).max()
dev_df['token ids'] = dev_df['tokens'].progress_map(token_ids)
dev_df

  0%|          | 0/19200 [00:00<?, ?it/s]

Unnamed: 0,level_0,index,class index,class,title,description,text,tokens,token ids
0,90311,85643,4,Sci/Tech,Intel Rolls Out Single-Core #39;Madison #39; ...,Intel Corp. is refreshing its 64-bit Itanium 2...,intel rolls out single-core #39;madison #39; ...,"[intel, rolls, out, single-core, #39, ;, madis...","[5438, 7474, 66, 400000, 400000, 89, 5880, 400..."
1,85105,65462,2,Sports,Rice Shipped to Seahawks for Draft Pick (AP),AP - The Seattle Seahawks finally got Jerry Ri...,rice shipped to seahawks for draft pick (ap) a...,"[rice, shipped, to, seahawks, for, draft, pick...","[1818, 7424, 4, 12954, 10, 1737, 2065, 23, 158..."
2,59692,6588,2,Sports,Ulmer gives New Zealand historic gold,Athens - Sarah Ulmer gave New Zealand their fi...,ulmer gives new zealand historic gold athens -...,"[ulmer, gives, new, zealand, historic, gold, a...","[64400, 1829, 50, 1272, 1590, 764, 3264, 11, 4..."
3,59101,81832,4,Sci/Tech,Study of nicotine addiction advances,Scientists have said for decades that nicotine...,study of nicotine addiction advances scientist...,"[study, of, nicotine, addiction, advances, sci...","[806, 3, 17447, 10433, 6309, 2154, 33, 16, 10,..."
4,93147,37627,4,Sci/Tech,PeopleSoft Rolls Out Upgrade Incentive Program,PeopleSoft wants its customers to get quot;ag...,peoplesoft rolls out upgrade incentive program...,"[peoplesoft, rolls, out, upgrade, incentive, p...","[32099, 7474, 66, 6837, 8503, 371, 32099, 1025..."
...,...,...,...,...,...,...,...,...,...
19195,16493,93864,4,Sci/Tech,Internet auction house eBay launches Philippin...,AFP - Online auction site eBay said it has lau...,internet auction house ebay launches philippin...,"[internet, auction, house, ebay, launches, phi...","[925, 4473, 166, 10891, 8338, 2848, 825, 23, 1..."
19196,42112,29656,1,World,Probe launched into parliament intrusion,Britain has launched an inquiry into the first...,probe launched into parliament intrusion brita...,"[probe, launched, into, parliament, intrusion,...","[3615, 1169, 75, 668, 19508, 695, 31, 1169, 29..."
19197,78858,104497,2,Sports,Falcons sign Alge Crumpler to six-year contrac...,After signing a new six-year contract worth ab...,falcons sign alge crumpler to six-year contrac...,"[falcons, sign, alge, crumpler, to, six-year, ...","[8942, 1100, 400000, 400000, 4, 35650, 953, 37..."
19198,7554,118508,1,World,UHaiti Gov #39;t Negotiates With Ex-Soldiers,With hundreds of UN troops and Haitian police ...,uhaiti gov #39;t negotiates with ex-soldiers w...,"[uhaiti, gov, #39, ;, t, negotiates, with, ex-...","[400000, 15919, 400000, 89, 2159, 400000, 17, ..."


Now we will get a numpy 2-dimensional array corresponding to the token ids, and a 1-dimensional array with the gold classes. Note that the classes are one-based (i.e., they start at one), but we need them to be zero-based, so we need to subtract one from this array.

In [19]:
from torch.utils.data import Dataset

class MyDataset(Dataset):
    def __init__(self, x, y):
        self.x = x
        self.y = y
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, index):
        x = torch.tensor(self.x[index])
        y = torch.tensor(self.y[index])
        return x, y

Next, we construct our PyTorch model, which is a feed-forward neural network with two layers:

In [20]:
from torch import nn
import torch.nn.functional as F

class Model(nn.Module):
    def __init__(self, vectors, pad_id, hidden_dim, output_dim, dropout):
        super().__init__()
        # embeddings must be a tensor
        if not torch.is_tensor(vectors):
            vectors = torch.tensor(vectors)
        # keep padding id
        self.padding_idx = pad_id
        # embedding layer
        self.embs = nn.Embedding.from_pretrained(vectors, padding_idx=pad_id)
        # feedforward layers
        self.layers = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(vectors.shape[1], hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )
        
    def forward(self, x):
        # get boolean array with padding elements set to false
        not_padding = torch.isin(x, self.padding_idx, invert=True)
        # get lengths of examples (excluding padding)
        lengths = torch.count_nonzero(not_padding, axis=1)
        # get embeddings
        x = self.embs(x)
        # calculate means
        x = x.sum(dim=1) / lengths.unsqueeze(dim=1)
        # pass to rest of the model
        output = self.layers(x)
        # calculate softmax if we're not in training mode
        #if not self.training:
        #    output = F.softmax(output, dim=1)
        return output

Next, we implement the training procedure. We compute the loss and accuracy on the development partition after each epoch.

In [21]:
from torch import optim
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score

# hyperparameters
lr = 1e-3
weight_decay = 0
batch_size = 500
shuffle = True
n_epochs = 5
hidden_dim = 50
output_dim = len(labels)
dropout = 0.1
vectors = glove.vectors

# initialize the model, loss function, optimizer, and data-loader
model = Model(vectors, pad_id, hidden_dim, output_dim, dropout).to(device)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
train_ds = MyDataset(train_df['token ids'], train_df['class index'] - 1)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle)
dev_ds = MyDataset(dev_df['token ids'], dev_df['class index'] - 1)
dev_dl = DataLoader(dev_ds, batch_size=batch_size, shuffle=shuffle)

train_loss = []
train_acc = []

dev_loss = []
dev_acc = []

In [22]:
# train the model
for epoch in range(n_epochs):
    losses = []
    gold = []
    pred = []
    model.train()
    for X, y_true in tqdm(train_dl, desc=f'epoch {epoch+1} (train)'):
        # clear gradients
        model.zero_grad()
        # send batch to right device
        X = X.to(device)
        y_true = y_true.to(device)
        # predict label scores
        y_pred = model(X)
        # compute loss
        loss = loss_func(y_pred, y_true)
        # accumulate for plotting
        losses.append(loss.detach().cpu().item())
        gold.append(y_true.detach().cpu().numpy())
        pred.append(np.argmax(y_pred.detach().cpu().numpy(), axis=1))
        # backpropagate
        loss.backward()
        # optimize model parameters
        optimizer.step()
    train_loss.append(np.mean(losses))
    train_acc.append(accuracy_score(np.concatenate(gold), np.concatenate(pred)))
    
    model.eval()
    with torch.no_grad():
        losses = []
        gold = []
        pred = []
        for X, y_true in tqdm(dev_dl, desc=f'epoch {epoch+1} (dev)'):
            X = X.to(device)
            y_true = y_true.to(device)
            y_pred = model(X)
            loss = loss_func(y_pred, y_true)
            losses.append(loss.cpu().item())
            gold.append(y_true.cpu().numpy())
            pred.append(np.argmax(y_pred.cpu().numpy(), axis=1))
        dev_loss.append(np.mean(losses))
        dev_acc.append(accuracy_score(np.concatenate(gold), np.concatenate(pred)))

epoch 1 (train):   0%|          | 0/154 [00:00<?, ?it/s]

epoch 1 (dev):   0%|          | 0/39 [00:00<?, ?it/s]

epoch 2 (train):   0%|          | 0/154 [00:00<?, ?it/s]

epoch 2 (dev):   0%|          | 0/39 [00:00<?, ?it/s]

epoch 3 (train):   0%|          | 0/154 [00:00<?, ?it/s]

epoch 3 (dev):   0%|          | 0/39 [00:00<?, ?it/s]

epoch 4 (train):   0%|          | 0/154 [00:00<?, ?it/s]

epoch 4 (dev):   0%|          | 0/39 [00:00<?, ?it/s]

epoch 5 (train):   0%|          | 0/154 [00:00<?, ?it/s]

epoch 5 (dev):   0%|          | 0/39 [00:00<?, ?it/s]

In [23]:
import matplotlib.pyplot as plt
%matplotlib inline

x = np.arange(n_epochs) + 1

plt.plot(x, train_loss)
plt.plot(x, dev_loss)
plt.legend(['train', 'dev'])
plt.xlabel('epoch')
plt.ylabel('loss')
plt.grid(True)

In [24]:
plt.plot(x, train_acc)
plt.plot(x, dev_acc)
plt.legend(['train', 'dev'])
plt.xlabel('epoch')
plt.ylabel('accuracy')
plt.grid(True)

Next, we evaluate on the testing partition:

In [30]:
# repeat all preprocessing done above, this time on the test set
test_df = pd.read_csv('../data/ag_news_csv/test.csv', header=None)
test_df.columns = ['class index', 'title', 'description']
test_df['text'] = test_df['title'].str.lower() + " " + test_df['description'].str.lower()
test_df['text'] = test_df['text'].str.replace('\\', ' ', regex=False)

#test_df['tokens'] = test_df['text'].progress_map(word_tokenize)
test_df['tokens'] = test_df['text'].progress_map(tokenizer.tokenize)

max_tokens = dev_df['tokens'].map(len).max()
test_df['token ids'] = test_df['tokens'].progress_map(token_ids)

  0%|          | 0/7600 [00:00<?, ?it/s]

  0%|          | 0/7600 [00:00<?, ?it/s]

In [31]:
from sklearn.metrics import classification_report

# set model to evaluation mode
model.eval()

dataset = MyDataset(test_df['token ids'], test_df['class index'] - 1)
data_loader = DataLoader(dataset, batch_size=batch_size)
y_pred = []

# don't store gradients
with torch.no_grad():
    for X, _ in tqdm(data_loader):
        X = X.to(device)
        # predict one class per example
        y = torch.argmax(model(X), dim=1)
        # convert tensor to numpy array (sending it back to the cpu if needed)
        y_pred.append(y.cpu().numpy())
        # print results
    print(classification_report(dataset.y, np.concatenate(y_pred), target_names=labels))

  0%|          | 0/16 [00:00<?, ?it/s]

              precision    recall  f1-score   support

       World       0.92      0.88      0.90      1900
      Sports       0.95      0.97      0.96      1900
    Business       0.84      0.87      0.85      1900
    Sci/Tech       0.87      0.87      0.87      1900

    accuracy                           0.90      7600
   macro avg       0.90      0.90      0.90      7600
weighted avg       0.90      0.90      0.90      7600

