# Text Topic/Document Classification
Using GDI text dataset, we set to develop a deep learning model to classify text with high performance.


In [1]:
!pip install bert-embedding
from bert_embedding import BertEmbedding
import numpy as np
import pandas as pd
from fastai.vision import Path
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity as cs
import operator

Collecting bert-embedding
  Downloading https://files.pythonhosted.org/packages/62/85/e0d56e29a055d8b3ba6da6e52afe404f209453057de95b90c01475c3ff75/bert_embedding-1.0.1-py3-none-any.whl
Collecting mxnet==1.4.0 (from bert-embedding)
[?25l  Downloading https://files.pythonhosted.org/packages/c0/e9/241aadccc4522f99adee5b6043f730d58adb7c001e0a68865a3728c3b4ae/mxnet-1.4.0-py2.py3-none-manylinux1_x86_64.whl (29.6MB)
[K     |████████████████████████████████| 29.6MB 4.8MB/s 
[?25hCollecting numpy==1.14.6 (from bert-embedding)
[?25l  Downloading https://files.pythonhosted.org/packages/e5/c4/395ebb218053ba44d64935b3729bc88241ec279915e72100c5979db10945/numpy-1.14.6-cp36-cp36m-manylinux1_x86_64.whl (13.8MB)
[K     |████████████████████████████████| 13.8MB 19.7MB/s 
Collecting gluonnlp==0.6.0 (from bert-embedding)
[?25l  Downloading https://files.pythonhosted.org/packages/e2/07/037585c23bccec19ce333b402997d98b09e43cc8d2d86dc810d57249c5ff/gluonnlp-0.6.0.tar.gz (209kB)
[K     |███████

We want to load up te text files and its topics into a pandas Dataframe

In [2]:
classes = ["business", "politics", "sport"]
targets = []
texts = []
for c in tqdm(classes):
    for tf in Path(f"../input/PatriotHack-master/data/{c}/{c}").ls():
        with open(tf, 'r', encoding = 'unicode_escape') as text:
            targets.append(c)
            texts.append(text.read())
df = pd.DataFrame()
df["Target"] = targets
df["Text"] = texts
df.head(3)

100%|██████████| 3/3 [00:03<00:00,  1.30s/it]


Unnamed: 0,Target,Text
0,business,FAO warns on impact of subsidies\n\nBillions o...
1,business,Singapore growth at 8.1% in 2004\n\nSingapore'...
2,business,Bank payout to Pinochet victims\n\nA US bank h...


Then, we load up a pretrained Bert model for word embeddings

In [3]:
be = BertEmbedding(model='bert_12_768_12', dataset_name='book_corpus_wiki_en_uncased')

Vocab file is not found. Downloading.
Downloading /tmp/.mxnet/models/book_corpus_wiki_en_uncased-a6607397.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/vocab/book_corpus_wiki_en_uncased-a6607397.zip...
Downloading /tmp/.mxnet/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip from https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/models/bert_12_768_12_book_corpus_wiki_en_uncased-75cc780f.zip...


We want to then create embeddings for each document. We do this by taking the mean of all the word embeddings in the document. We also noticed there are empty text files so we drop that entry.

In [4]:
def getVector(sent): return np.mean(np.array(be([sent], 'sum')[0][1]), axis=0)[None]
df["Embedding"] = df["Text"].apply(lambda x: getVector(x))
df.head(3)

  `ndarray`, however any non-default value will be.  If the


Unnamed: 0,Target,Text,Embedding
0,business,FAO warns on impact of subsidies\n\nBillions o...,"[[-0.21912901, 0.19704409, 0.203581, 0.3530849..."
1,business,Singapore growth at 8.1% in 2004\n\nSingapore'...,"[[-0.7985942, -0.19978853, 0.35024175, 0.33936..."
2,business,Bank payout to Pinochet victims\n\nA US bank h...,"[[-0.02139423, -0.47071832, 0.16319667, 0.1360..."


In [5]:
for x,i in enumerate(df["Embedding"]):
    try: i.shape[1]
    except: df = df.drop(x, axis=0)

Then, we get the embeddings for all of the topics by taking the mean of document embeddings.

In [6]:
cembs = {}
for c in classes: cembs[c] = np.mean(np.array(df[df["Target"] == c]["Embedding"]), axis=0)
cembs

{'business': array([[-0.100995, -0.120632,  0.27129 ,  0.031712, ..., -0.031566, -0.290568,  0.156345, -0.109788]], dtype=float32),
 'politics': array([[-0.058113, -0.165831,  0.244364, -0.262047, ..., -0.001339, -0.29378 , -0.144061,  0.133773]], dtype=float32),
 'sport': array([[-0.175366, -0.068401,  0.310909, -0.25563 , ..., -0.332931, -0.143316, -0.109809, -0.05016 ]], dtype=float32)}

Lets test it out.

In [7]:
text = "This is a business article about the stock market. We find that the price of Google will skyrocket."
query = " ".join(text.split(" "))
qemb = getVector(query)

First, we check what topic this document is in.

In [8]:
results = {}
for k,v in cembs.items():
    results[k] =  cs(v, qemb)[0][0]
results = sorted(results.items(), key=operator.itemgetter(1))
subsection = df[df["Target"]==results[-1][0]]
print(results[-1][0])

business


Then, we find the most relevant document.

In [9]:
eresults = {}
for x, doc in zip(subsection["Text"],subsection["Embedding"]):
    eresults[x] =  cs(doc, qemb)[0][0]
sorted(eresults.items(), key=operator.itemgetter(1))[-1][0]

'Dollar hits new low versus euro\n\nThe US dollar has continued its record-breaking slide and has tumbled to a new low against the euro.\n\nInvestors are betting that the European Central Bank (ECB) will not do anything to weaken the euro, while the US is thought to favour a declining dollar. The US is struggling with a ballooning trade deficit and analysts said one of the easiest ways to fund it was by allowing a depreciation of the dollar. They have predicted that the dollar is likely to fall even further.\n\nThe US currency was trading at $1.364 per euro at 1800 GMT on Monday. This compares with $1.354 to the euro in late trading in New York on Friday, which was then a record low.\n\nThe dollar has weakened sharply since September when it traded about $1.20 against the euro. It has lost 7% this year, while against the Japanese yen it is down 3.2%. Traders said that thin trading levels had amplified Monday\'s move. "It\'s not going to take much to push [the dollar] one way or the oth

In [10]:
query

'This is a business article about the stock market. We find that the price of Google will skyrocket.'

In [11]:
# import re
# df["Text"] = df["Text"].apply(lambda x: "[CLS] "+re.sub(r"[\n.,-]", " [SEP] ", x).replace("\"", "").replace("  ", " "))