# Text Classification using Sentence Transformers
In this notebook we will be doing text classification using document embeddings obtained using a pre-trained [Sentence Transformer](https://www.sbert.net) model.  SentenceTransformers is a framework for sentence / text embeddings which works particularly well for shorter text.  It was developed in 2019 and uses Siamese-BERT to develop semantically meaningful sentence embeddings which can be compared using cosine similarity.  You can use a [pretrained embedding model](https://www.sbert.net/docs/pretrained_models.html) or can train your own on a corpus.

Our goal will be to classify the articles in the AgNews dataset into their correct category: "World", "Sports", "Business", or "Sci/Tec".

**Notes:**  
- This must be run using GPU acceleration

**References:**
- Read the [Sentence-BERT paper](https://arxiv.org/abs/1908.10084) by Reimers & Gurevych

In [2]:
!pip install sentence_transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting sentence_transformers
  Downloading sentence-transformers-2.2.2.tar.gz (85 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 KB[0m [31m4.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting transformers<5.0.0,>=4.6.0
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m62.6 MB/s[0m eta [36m0:00:00[0m
Collecting sentencepiece
  Downloading sentencepiece-0.1.97-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m72.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting huggingface-hub>=0.4.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3

In [3]:
import os
import numpy as np
import pandas as pd
import string
import time
import urllib.request
import zipfile
import torch

from sklearn.linear_model import LogisticRegression
#!pip install sentence_transformers
from sentence_transformers import SentenceTransformer

import warnings
warnings.filterwarnings('ignore')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Download and prepare data

In [5]:
# Download the data
if not os.path.exists('../data'):
    os.mkdir('../data')
if not os.path.exists('../data/agnews'):
    url = 'https://storage.googleapis.com/aipi540-datasets/agnews.zip'
    urllib.request.urlretrieve(url,filename='../data/agnews.zip')
    zip_ref = zipfile.ZipFile('../data/agnews.zip', 'r')
    zip_ref.extractall('../data/agnews')
    zip_ref.close()

train_df = pd.read_csv('../data/agnews/train.csv')
test_df = pd.read_csv('../data/agnews/test.csv')

# Combine title and description of article to use as input documents for model
train_df['full_text'] = train_df.apply(lambda x: ' '.join([x['Title'],x['Description']]),axis=1)
test_df['full_text'] = test_df.apply(lambda x: ' '.join([x['Title'],x['Description']]),axis=1)

# Create dictionary to store mapping of labels
ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

train_df.head(20)

Unnamed: 0,Class Index,Title,Description,full_text
0,3,Wall St. Bears Claw Back Into the Black (Reuters),"Reuters - Short-sellers, Wall Street's dwindli...",Wall St. Bears Claw Back Into the Black (Reute...
1,3,Carlyle Looks Toward Commercial Aerospace (Reu...,Reuters - Private investment firm Carlyle Grou...,Carlyle Looks Toward Commercial Aerospace (Reu...
2,3,Oil and Economy Cloud Stocks' Outlook (Reuters),Reuters - Soaring crude prices plus worries\ab...,Oil and Economy Cloud Stocks' Outlook (Reuters...
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...,Iraq Halts Oil Exports from Main Southern Pipe...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco...","Oil prices soar to all-time record, posing new..."
5,3,"Stocks End Up, But Near Year Lows (Reuters)",Reuters - Stocks ended slightly higher on Frid...,"Stocks End Up, But Near Year Lows (Reuters) Re..."
6,3,Money Funds Fell in Latest Week (AP),AP - Assets of the nation's retail money marke...,Money Funds Fell in Latest Week (AP) AP - Asse...
7,3,Fed minutes show dissent over inflation (USATO...,USATODAY.com - Retail sales bounced back a bit...,Fed minutes show dissent over inflation (USATO...
8,3,Safety Net (Forbes.com),Forbes.com - After earning a PH.D. in Sociolog...,Safety Net (Forbes.com) Forbes.com - After ear...
9,3,Wall St. Bears Claw Back Into the Black,"NEW YORK (Reuters) - Short-sellers, Wall Stre...",Wall St. Bears Claw Back Into the Black NEW Y...


In [10]:
train_df[train_df['full_text'].str.contains('oil')]

Unnamed: 0,Class Index,Title,Description,full_text
3,3,Iraq Halts Oil Exports from Main Southern Pipe...,Reuters - Authorities have halted oil export\f...,Iraq Halts Oil Exports from Main Southern Pipe...
4,3,"Oil prices soar to all-time record, posing new...","AFP - Tearaway world oil prices, toppling reco...","Oil prices soar to all-time record, posing new..."
5,3,"Stocks End Up, But Near Year Lows (Reuters)",Reuters - Stocks ended slightly higher on Frid...,"Stocks End Up, But Near Year Lows (Reuters) Re..."
11,3,No Need for OPEC to Pump More-Iran Gov,TEHRAN (Reuters) - OPEC can do nothing to dou...,No Need for OPEC to Pump More-Iran Gov TEHRAN...
12,3,Non-OPEC Nations Should Up Output-Purnomo,JAKARTA (Reuters) - Non-OPEC oil exporters sh...,Non-OPEC Nations Should Up Output-Purnomo JAK...
...,...,...,...,...
119913,1,Scant Progress on Post-Kyoto as Climate Talks ...,Reuters - U.N. talks on climate\change ended e...,Scant Progress on Post-Kyoto as Climate Talks ...
119950,4,Digitized And Brought To Life,Digital technology is radically changing the 1...,Digitized And Brought To Life Digital technolo...
119967,3,Sabotage Stops Iraq's North Oil Exports (Reuters),Reuters - Saboteurs blew up Iraq's northern\ex...,Sabotage Stops Iraq's North Oil Exports (Reute...
119974,3,Russia to hold Yukos auction despite US ruling,MOSCOW - Russia said on Friday it would go ahe...,Russia to hold Yukos auction despite US ruling...


In [11]:
# View a couple of the documents
for i in range(5):
    print(train_df.iloc[i]['full_text'])
    print()

Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\band of ultra-cynics, are seeing green again.

Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\which has a reputation for making well-timed and occasionally\controversial plays in the defense industry, has quietly placed\its bets on another part of the market.

Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring crude prices plus worries\about the economy and the outlook for earnings are expected to\hang over the stock market next week during the depth of the\summer doldrums.

Iraq Halts Oil Exports from Main Southern Pipeline (Reuters) Reuters - Authorities have halted oil export\flows from the main pipeline in southern Iraq after\intelligence showed a rebel militia could strike\infrastructure, an oil official said on Saturday.

Oil prices soar to all-time record, posing new menace to US economy (AFP) AFP - Tearaway world

## Create document embeddings
We will load a pre-trained model [('all-MiniLM-L6-v2')](https://huggingface.co/sentence-transformers/all-MiniLM-L6-v2) which we will then use to create embeddings for our training and test set text.  The MiniLM-L6-v2 model was trained on 1.1 billion sentence pairs to produce high-quality sentence / short document embeddings in 384 dimensions which can be used for example to calculate similarity between documents.  

In [12]:
# Load pre-trained model
senttrans_model = SentenceTransformer('all-MiniLM-L6-v2',device=device)

# Create embeddings for training set text
X_train = train_df['full_text'].values.tolist()
X_train = [senttrans_model.encode(doc) for doc in X_train]

# Create embeddings for test set text
X_test = test_df['full_text'].values.tolist()
X_test = [senttrans_model.encode(doc) for doc in X_test]

Downloading (…)e9125/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)7e55de9125/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)55de9125/config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)125/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading (…)e9125/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

Downloading (…)9125/train_script.py:   0%|          | 0.00/13.2k [00:00<?, ?B/s]

Downloading (…)7e55de9125/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)5de9125/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

## Train classification model
Finally, we will used our embeddings as features to train a softmax regression model to classify the documents.

In [13]:
# Train a classification model using logistic regression classifier
y_train = train_df['Class Index']
logreg_model = LogisticRegression(solver='saga')
logreg_model.fit(X_train,y_train)
preds = logreg_model.predict(X_train)
acc = sum(preds==y_train)/len(y_train)
print('Accuracy on the training set is {:.3f}'.format(acc))

Accuracy on the training set is 0.900


## Evaluate model performance

In [14]:
# Evaluate performance on the test set
y_test = test_df['Class Index']
preds = logreg_model.predict(X_test)
acc = sum(preds==y_test)/len(y_test)
print('Accuracy on the test set is {:.3f}'.format(acc))

Accuracy on the test set is 0.896


In [17]:
from sklearn.metrics import classification_report, confusion_matrix
print(classification_report(y_test, preds))
confusion_matrix(y_test, preds)

              precision    recall  f1-score   support

           1       0.91      0.89      0.90      1900
           2       0.96      0.97      0.97      1900
           3       0.85      0.86      0.85      1900
           4       0.87      0.86      0.87      1900

    accuracy                           0.90      7600
   macro avg       0.90      0.90      0.90      7600
weighted avg       0.90      0.90      0.90      7600



array([[1694,   51,   96,   59],
       [  31, 1848,   13,    8],
       [  79,    9, 1625,  187],
       [  63,   18,  176, 1643]])

In [18]:
test_df['preds']=preds
test_df

Unnamed: 0,Class Index,Title,Description,full_text,preds
0,3,Fears for T N pension after talks,Unions representing workers at Turner Newall...,Fears for T N pension after talks Unions repre...,3
1,4,The Race is On: Second Private Team Sets Launc...,"SPACE.com - TORONTO, Canada -- A second\team o...",The Race is On: Second Private Team Sets Launc...,4
2,4,Ky. Company Wins Grant to Study Peptides (AP),AP - A company founded by a chemistry research...,Ky. Company Wins Grant to Study Peptides (AP) ...,4
3,4,Prediction Unit Helps Forecast Wildfires (AP),AP - It's barely dawn when Mike Fitzpatrick st...,Prediction Unit Helps Forecast Wildfires (AP) ...,4
4,4,Calif. Aims to Limit Farm-Related Smog (AP),AP - Southern California's smog-fighting agenc...,Calif. Aims to Limit Farm-Related Smog (AP) AP...,4
...,...,...,...,...,...
7595,1,Around the world,Ukrainian presidential candidate Viktor Yushch...,Around the world Ukrainian presidential candid...,4
7596,2,Void is filled with Clement,With the supply of attractive pitching options...,Void is filled with Clement With the supply of...,2
7597,2,Martinez leaves bitter,Like Roger Clemens did almost exactly eight ye...,Martinez leaves bitter Like Roger Clemens did ...,2
7598,3,5 of arthritis patients in Singapore take Bext...,SINGAPORE : Doctors in the United States have ...,5 of arthritis patients in Singapore take Bext...,3


In [19]:
incorrect_df = test_df[test_df['preds']!=test_df['Class Index']]
incorrect_df

Unnamed: 0,Class Index,Title,Description,full_text,preds
9,4,"Card fraud unit nets 36,000 cards","In its first two years, the UK's dedicated car...","Card fraud unit nets 36,000 cards In its first...",3
23,4,Some People Not Eligible to Get in on Google IPO,Google has billed its IPO as a way for everyda...,Some People Not Eligible to Get in on Google I...,3
24,4,Rivals Try to Turn Tables on Charles Schwab,By MICHAEL LIEDTKE SAN FRANCISCO (AP) -- W...,Rivals Try to Turn Tables on Charles Schwab By...,3
33,1,"Man Sought #36;50M From McGreevey, Aides Say ...",AP - The man who claims Gov. James E. McGreeve...,"Man Sought #36;50M From McGreevey, Aides Say ...",3
43,4,Spam suspension hits Sohu.com shares (FT.com),"FT.com - Shares in Sohu.com, a leading US-list...",Spam suspension hits Sohu.com shares (FT.com) ...,3
...,...,...,...,...,...
7567,4,This week in merger news,This week saw three merger deals worth about \...,This week in merger news This week saw three m...,3
7585,1,Pricey Drug Trials Turn Up Few New Blockbusters,The \$500 billion drug industry is stumbling b...,Pricey Drug Trials Turn Up Few New Blockbuster...,3
7589,2,The Newest Hope ; Marriage of Necessity Just M...,"NEW YORK - The TV lights were on, the cameras ...",The Newest Hope ; Marriage of Necessity Just M...,1
7595,1,Around the world,Ukrainian presidential candidate Viktor Yushch...,Around the world Ukrainian presidential candid...,4


In [22]:
incorrect_df['true label'] = incorrect_df['Class Index'].apply(lambda x: ag_news_label[x])
incorrect_df['pred label'] = incorrect_df['preds'].apply(lambda x: ag_news_label[x])
incorrect_df

Unnamed: 0,Class Index,Title,Description,full_text,preds,true label,pred label
9,4,"Card fraud unit nets 36,000 cards","In its first two years, the UK's dedicated car...","Card fraud unit nets 36,000 cards In its first...",3,Sci/Tec,Business
23,4,Some People Not Eligible to Get in on Google IPO,Google has billed its IPO as a way for everyda...,Some People Not Eligible to Get in on Google I...,3,Sci/Tec,Business
24,4,Rivals Try to Turn Tables on Charles Schwab,By MICHAEL LIEDTKE SAN FRANCISCO (AP) -- W...,Rivals Try to Turn Tables on Charles Schwab By...,3,Sci/Tec,Business
33,1,"Man Sought #36;50M From McGreevey, Aides Say ...",AP - The man who claims Gov. James E. McGreeve...,"Man Sought #36;50M From McGreevey, Aides Say ...",3,World,Business
43,4,Spam suspension hits Sohu.com shares (FT.com),"FT.com - Shares in Sohu.com, a leading US-list...",Spam suspension hits Sohu.com shares (FT.com) ...,3,Sci/Tec,Business
...,...,...,...,...,...,...,...
7567,4,This week in merger news,This week saw three merger deals worth about \...,This week in merger news This week saw three m...,3,Sci/Tec,Business
7585,1,Pricey Drug Trials Turn Up Few New Blockbusters,The \$500 billion drug industry is stumbling b...,Pricey Drug Trials Turn Up Few New Blockbuster...,3,World,Business
7589,2,The Newest Hope ; Marriage of Necessity Just M...,"NEW YORK - The TV lights were on, the cameras ...",The Newest Hope ; Marriage of Necessity Just M...,1,Sports,World
7595,1,Around the world,Ukrainian presidential candidate Viktor Yushch...,Around the world Ukrainian presidential candid...,4,World,Sci/Tec


In [29]:
display(incorrect_df[incorrect_df.index==7567]['full_text'])

7567    This week in merger news This week saw three m...
Name: full_text, dtype: object