In [1]:
import random
import torch
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm

# enable tqdm in pandas
tqdm.pandas()

# set to True to use the gpu (if there is one available)
use_gpu = True

# select device
device = torch.device('cuda' if use_gpu and torch.cuda.is_available() else 'cpu')
print(f'device: {device.type}')

# random seed
seed = 1234

# set random seed
if seed is not None:
    print(f'random seed: {seed}')
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

device: cuda
random seed: 1234


In [2]:
def read_data(filename):
    # read csv file
    df = pd.read_csv(filename, header=None)
    # add column names
    df.columns = ['label', 'title', 'description']
    # make labels zero-based
    df['label'] -= 1
    # concatenate title and description, and remove backslashes
    df['text'] = df['title'] + " " + df['description']
    df['text'] = df['text'].str.replace('\\', ' ', regex=False)
    return df

In [3]:
labels = open('data/ag_news_csv/classes.txt').read().splitlines()
train_df = read_data('data/ag_news_csv/train.csv')
test_df = read_data('data/ag_news_csv/test.csv')
train_df

Unnamed: 0,label,title,description,text
0,2,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli...",Wall St. Bears Claw Back Into the Black (Reute...
1,2,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...,Carlyle Looks Toward Commercial Aerospace (Reu...
2,2,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...,Oil and Economy Cloud Stocks' Outlook (Reuters...
3,2,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...,Iraq Halts Oil Exports from Main Southern Pipe...
4,2,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco...","Oil prices soar to all-time record, posing new..."
...,...,...,...,...
119995,0,Pakistan's Musharraf Says Won't Quit as Army C...,KARACHI (Reuters) - Pakistani President Perve...,Pakistan's Musharraf Says Won't Quit as Army C...
119996,1,Renteria signing a top-shelf deal,Red Sox general manager Theo Epstein acknowled...,Renteria signing a top-shelf deal Red Sox gene...
119997,1,Saban not going to Dolphins yet,The Miami Dolphins will put their courtship of...,Saban not going to Dolphins yet The Miami Dolp...
119998,1,Today's NFL games,PITTSBURGH at NY GIANTS Time: 1:30 p.m. Line: ...,Today's NFL games PITTSBURGH at NY GIANTS Time...


In [4]:
from sklearn.model_selection import train_test_split

train_df, eval_df = train_test_split(train_df, train_size=0.9)
train_df.reset_index(inplace=True, drop=True)
eval_df.reset_index(inplace=True, drop=True)

print(f'train rows: {len(train_df.index):,}')
print(f'eval rows: {len(eval_df.index):,}')
print(f'test rows: {len(test_df.index):,}')

train rows: 108,000
eval rows: 12,000
test rows: 7,600


In [5]:
from datasets import Dataset, DatasetDict

ds = DatasetDict()
ds['train'] = Dataset.from_pandas(train_df)
ds['validation'] = Dataset.from_pandas(eval_df)
ds['test'] = Dataset.from_pandas(test_df)
ds

DatasetDict({
    train: Dataset({
        features: ['label', 'title', 'description', 'text'],
        num_rows: 108000
    })
    validation: Dataset({
        features: ['label', 'title', 'description', 'text'],
        num_rows: 12000
    })
    test: Dataset({
        features: ['label', 'title', 'description', 'text'],
        num_rows: 7600
    })
})

In [6]:
from transformers import AutoTokenizer

transformer_name = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(transformer_name)

In [7]:
def tokenize(examples):
    return tokenizer(examples['text'], truncation=True)

train_ds = ds['train'].map(tokenize, batched=True, remove_columns=['title', 'description', 'text'])
eval_ds = ds['validation'].map(tokenize, batched=True, remove_columns=['title', 'description', 'text'])
train_ds.to_pandas()

  0%|          | 0/108 [00:00<?, ?ba/s]

  0%|          | 0/12 [00:00<?, ?ba/s]

Unnamed: 0,label,input_ids,attention_mask
0,3,"[101, 3270, 11906, 1522, 1146, 7106, 1111, 251...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,0,"[101, 4222, 11404, 1174, 117, 1476, 1130, 2696...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,0,"[101, 158, 119, 156, 119, 12068, 5084, 1116, 9...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,2,"[101, 22087, 8223, 1611, 1106, 4417, 5572, 324...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,0,"[101, 7270, 118, 2733, 1383, 1111, 12448, 7430...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...
107995,0,"[101, 6096, 117, 10378, 3969, 5977, 1111, 8988...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
107996,0,"[101, 16409, 118, 16587, 159, 4064, 1106, 1564...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
107997,0,"[101, 19569, 5480, 10582, 2087, 1867, 158, 119...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
107998,0,"[101, 11560, 3881, 108, 3614, 132, 3498, 2944,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [8]:
from transformers import AutoConfig

config = AutoConfig.from_pretrained(transformer_name, num_labels=len(labels))

In [9]:
from transformers.models.distilbert.modeling_distilbert import DistilBertForSequenceClassification

model = (
    DistilBertForSequenceClassification
    .from_pretrained(transformer_name, config=config)
)

Some weights of the model checkpoint at distilbert-base-cased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['pre_classifier.weight', 'pre_classifier.bias', 'classifier

In [10]:
from transformers import TrainingArguments

num_epochs = 2
batch_size = 24
logging_steps = len(ds['train']) // batch_size
model_name = f'{transformer_name}-sequence-classification'
training_args = TrainingArguments(
    output_dir=model_name,
    log_level='error',
    num_train_epochs=num_epochs,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    evaluation_strategy='epoch',
    weight_decay=0.01,
    disable_tqdm=False,
    logging_steps=logging_steps,
)

In [11]:
from sklearn.metrics import accuracy_score

def compute_metrics(eval_pred):
    y_true = eval_pred.label_ids
    y_pred = np.argmax(eval_pred.predictions, axis=-1)
    return {'accuracy': accuracy_score(y_true, y_pred)}

In [12]:
from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=eval_ds,
    tokenizer=tokenizer,
)

In [13]:
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy
1,0.24,0.179792,0.938417
2,0.1274,0.184756,0.944083


TrainOutput(global_step=9000, training_loss=0.18369327799479165, metrics={'train_runtime': 932.0479, 'train_samples_per_second': 231.748, 'train_steps_per_second': 9.656, 'total_flos': 6597782385046272.0, 'train_loss': 0.18369327799479165, 'epoch': 2.0})

In [14]:
test_ds = ds['test'].map(tokenize, batched=True, remove_columns=['title', 'description', 'text'])
test_ds.to_pandas()

  0%|          | 0/8 [00:00<?, ?ba/s]

Unnamed: 0,label,input_ids,attention_mask
0,2,"[101, 11284, 1116, 1111, 157, 151, 12966, 1170...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
1,3,"[101, 1109, 6398, 1110, 1212, 131, 2307, 7219,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
2,3,"[101, 148, 1183, 119, 1881, 16387, 1116, 4468,...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
3,3,"[101, 11689, 15906, 6115, 12056, 1116, 1370, 2...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,3,"[101, 11917, 8914, 119, 19294, 4206, 1106, 215...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
...,...,...,...
7595,0,"[101, 5596, 1103, 1362, 5284, 5200, 3234, 1384...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
7596,1,"[101, 159, 7874, 1110, 2709, 1114, 13875, 1556...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
7597,1,"[101, 16247, 2972, 9178, 2409, 4271, 140, 1418...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
7598,2,"[101, 126, 1104, 1893, 8167, 10721, 4420, 1107...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."


In [15]:
output = trainer.predict(test_ds)
output

PredictionOutput(predictions=array([[ 0.33578217, -4.917484  ,  4.3817263 , -3.033465  ],
       [-1.4582787 , -3.9827685 , -2.7275746 ,  4.845849  ],
       [-1.1790004 , -3.325     , -2.9807591 ,  4.5253825 ],
       ...,
       [-0.7262152 ,  6.5225377 , -3.2033384 , -3.5806625 ],
       [-0.49669605, -5.0438304 ,  5.0929165 , -2.9071586 ],
       [-3.3449006 , -5.1755157 ,  2.1987727 ,  2.7667317 ]],
      dtype=float32), label_ids=array([2, 3, 3, ..., 1, 2, 2]), metrics={'test_loss': 0.18910787999629974, 'test_accuracy': 0.9472368421052632, 'test_runtime': 9.3543, 'test_samples_per_second': 812.462, 'test_steps_per_second': 33.888})

In [16]:
from sklearn.metrics import classification_report

y_true = output.label_ids
y_pred = np.argmax(output.predictions, axis=-1)
target_names = labels
print(classification_report(y_true, y_pred, target_names=target_names))

              precision    recall  f1-score   support

       World       0.96      0.96      0.96      1900
      Sports       0.99      0.99      0.99      1900
    Business       0.93      0.91      0.92      1900
    Sci/Tech       0.91      0.94      0.92      1900

    accuracy                           0.95      7600
   macro avg       0.95      0.95      0.95      7600
weighted avg       0.95      0.95      0.95      7600

