In [66]:
# # Install dependencies
# !pip install nltk sklearn PyStemmer tqdm

In [35]:
import json

import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
STOPWORDS = set(stopwords.words('english'))

import re
import nltk
import pickle as pkl
import numpy as np
from tqdm import tqdm
import scipy.sparse as sp
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
import Stemmer

def tokenizer_word(string):
    string = string.lower()
    tokens = [w for w in re.split("[\W,_]+", string) if ((w != "") and (w not in STOPWORDS))]
    tokens = stemmer.stemWords(tokens)
    return tokens

stemmer = Stemmer.Stemmer('english')

[nltk_data] Downloading package stopwords to /home/nilesh/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [None]:
def read_sparse_mat(filename):
    with open(filename) as f:
        nr, nc = map(int, f.readline().split(' '))
        data = []; inds = []; indptr = [0]
        for line in tqdm(f, total=nr):
            row = list(map(lambda x: x.split(':'), line.split()))
            inds.append(np.array(list(map(lambda x: int(x[0]), row))))
            data.append(np.array(list(map(lambda x: float(x[1]), row))))
            indptr.append(indptr[-1]+len(row))
        spmat = sp.csr_matrix((np.concatenate(data), np.concatenate(inds), indptr), (nr, nc))
        return spmat

In [155]:
trn_X_Y = read_sparse_mat(f'GZ-Eurlex-4.3K/trn_X_Y.txt')

trnX = [x.strip() for x in open(f'GZ-Eurlex-4.3K/raw/trn_X.txt')]
tstX = [x.strip() for x in open(f'GZ-Eurlex-4.3K/raw/tst_X.txt')]
Y = [x.strip() for x in open(f'GZ-Eurlex-4.3K/raw/Y.txt')]
parent_Y = [x.strip() for x in open(f'GZ-Eurlex-4.3K/raw/Y.parent.txt')]

nnz = trn_X_Y.getnnz(0)
seen_labels = np.where(nnz > 0)[0]
seen_Y = [Y[i] for i in seen_labels]

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 45000/45000 [00:00<00:00, 155883.19it/s]


In [92]:
# For bigger datasets (amazon-1m, wikipedia-1m) use larger MAX_FEATURES (500000/1000000 used in zestxml paper)
MAX_FEATURES=50000

In [93]:
unigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(1, 1), max_df=0.5, norm=None, max_features=MAX_FEATURES)
bigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(2, 2), max_df=0.5, norm=None, max_features=MAX_FEATURES)

%time unigram_trn_X_Xf = unigram_vectorizer.fit_transform(trnX)
%time bigram_trn_X_Xf = bigram_vectorizer.fit_transform(trnX)

%time unigram_tst_X_Xf = unigram_vectorizer.transform(tstX)
%time bigram_tst_X_Xf = bigram_vectorizer.transform(tstX)

trn_X_Xf = normalize(sp.hstack([unigram_trn_X_Xf, bigram_trn_X_Xf]))
tst_X_Xf = normalize(sp.hstack([unigram_tst_X_Xf, bigram_tst_X_Xf]))
Xf = np.concatenate([unigram_vectorizer.get_feature_names_out(), bigram_vectorizer.get_feature_names_out()])

CPU times: user 1min 8s, sys: 1.19 s, total: 1min 9s
Wall time: 3min 4s
CPU times: user 1min 40s, sys: 2.5 s, total: 1min 43s
Wall time: 4min 31s
CPU times: user 8.78 s, sys: 142 ms, total: 8.93 s
Wall time: 23.7 s
CPU times: user 10.5 s, sys: 55.3 ms, total: 10.5 s
Wall time: 29.2 s


In [159]:
unigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(1, 1), max_df=0.5, norm=None, max_features=MAX_FEATURES)
bigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(2, 2), max_df=0.5, norm=None, max_features=MAX_FEATURES)
parent_vectorizer = TfidfVectorizer(lowercase=False, norm=None)

%time unigram_Y_Yf = unigram_vectorizer.fit_transform(Y)
%time bigram_Y_Yf = bigram_vectorizer.fit_transform(Y)
%time parent_Y_Yf = parent_vectorizer.fit_transform(parent_Y)
lbl_Y_Yf = sp.csr_matrix((np.full(len(seen_Y), unigram_vectorizer.idf_.max()), 
                          np.arange(len(seen_Y)), 
                          np.append([0], np.cumsum(nnz > 0))), 
                         (len(Y), len(seen_Y)))

Y_Yf = normalize(sp.hstack([unigram_Y_Yf, bigram_Y_Yf, lbl_Y_Yf, parent_Y_Yf]))
Yf = np.concatenate([unigram_vectorizer.get_feature_names_out(), 
                     bigram_vectorizer.get_feature_names_out(), 
                     [f'__label__{i}__{Y[i][:50]}' for i in range(len(Y)) if nnz[i] > 0], 
                     [f'__parent__{yf}' for yf in parent_vectorizer.get_feature_names_out()]])

CPU times: user 95.7 ms, sys: 38 µs, total: 95.8 ms
Wall time: 94.8 ms
CPU times: user 121 ms, sys: 967 µs, total: 122 ms
Wall time: 122 ms
CPU times: user 26.1 ms, sys: 0 ns, total: 26.1 ms
Wall time: 26.1 ms


## Visualize

In [28]:
def get_text(x, text, X_Xf, sep=' ', K=-1, attr='bold underline'):
    if K == -1: K = X_Xf[x].nnz
    sorted_inds = X_Xf[x].indices[np.argsort(-X_Xf[x].data)][:K]
    return '%d : \n'%x + sep.join(['%s(%.2f, %d)'%(_c(text[i], attr=attr), X_Xf[x, i], i) for i in sorted_inds])

In [29]:
class bcolors:
    purple = '\033[95m'
    blue = '\033[94m'
    green = '\033[92m'
    warn = '\033[93m' # dark yellow
    fail = '\033[91m' # dark red
    white = '\033[37m'
    yellow = '\033[33m'
    red = '\033[31m'
    
    ENDC = '\033[0m'
    bold = '\033[1m'
    underline = '\033[4m'
    reverse = '\033[7m'
    
    on_grey = '\033[40m'
    on_yellow = '\033[43m'
    on_red = '\033[41m'
    on_blue = '\033[44m'
    on_green = '\033[42m'
    on_magenta = '\033[45m'
    
def _c(*args, attr='bold'):
    string = ''.join([bcolors.__dict__[a] for a in attr.split()])
    string += ' '.join([str(arg) for arg in args])+bcolors.ENDC
    return string

In [73]:
orig_Xf = [x.strip() for x in open(f'GZ-Eurlex-4.3K/Xf.txt')]
orig_trn_X_Xf = read_sparse_mat(f'GZ-Eurlex-4.3K/trn_X_Xf.txt')

orig_Yf = [x.strip() for x in open(f'GZ-Eurlex-4.3K/Yf.txt')]
orig_Y_Yf = read_sparse_mat(f'GZ-Eurlex-4.3K/Y_Yf.txt')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4271/4271 [00:00<00:00, 20766.12it/s]


In [158]:
x = np.random.randint(len(trnX))

print(get_text(x, Xf, trn_X_Xf))
print(get_text(x, orig_Xf, orig_trn_X_Xf))

11318 : 
[1m[4mpoultrymeat[0m(0.55, 37453) [1m[4m1906[0m(0.29, 3189) [1m[4mfresh poultrymeat[0m(0.20, 77447) [1m[4mretail[0m(0.15, 39802) [1m[4mtemperatur[0m(0.14, 44550) [1m[4mretail trade[0m(0.14, 92528) [1m[4mcut[0m(0.13, 19254) [1m[4mâ c[0m(0.13, 99977) [1m[4mmarket standard[0m(0.12, 82998) [1m[4mâ[0m(0.12, 48983) [1m[4m9 februari[0m(0.11, 60498) [1m[4mregul eec[0m(0.11, 91315) [1m[4mstorag[0m(0.11, 43401) [1m[4mfresh[0m(0.10, 24176) [1m[4mfebruari 1993[0m(0.10, 76205) [1m[4mcertain market[0m(0.09, 67166) [1m[4mstiffen[0m(0.09, 43334) [1m[4m90[0m(0.08, 10023) [1m[4mcommunit[0m(0.08, 18124) [1m[4mpoint[0m(0.08, 37084) [1m[4mperform[0m(0.08, 36406) [1m[4mtrade regul[0m(0.08, 97335) [1m[4mtreatment[0m(0.07, 45448) [1m[4mshop[0m(0.07, 42003) [1m[4mpurpos suppli[0m(0.07, 89841) [1m[4m1235[0m(0.07, 2018) [1m[4m173 6[0m(0.07, 52511) [1m[4mtake due[0m(0.07, 96331) [1m[4msuitabl human[0m(0.07, 95910) [1m

In [160]:
y = np.random.randint(len(Y))

print(get_text(y, Yf, Y_Yf))
print(get_text(y, orig_Yf, orig_Y_Yf))

3075 : 
[1m[4mtelevis[0m(0.47, 4029) [1m[4mhdtv[0m(0.28, 1951) [1m[4mdefinit[0m(0.28, 1086) [1m[4mdigit televis[0m(0.28, 8000) [1m[4mhdtv digit[0m(0.28, 10781) [1m[4mtelevis hdtv[0m(0.28, 17735) [1m[4mdefinit televis[0m(0.28, 7801) [1m[4mhigh definit[0m(0.28, 10872) [1m[4m__label__3075__high-definition television; HDTV; digital televisi[0m(0.28, 22038) [1m[4m__parent__4432[0m(0.26, 23937) [1m[4mdigit[0m(0.23, 1167) [1m[4mhigh[0m(0.21, 1984)
3075 : 
[1m[4mtelevis[0m(0.47, 4029) [1m[4mdefinit[0m(0.28, 1086) [1m[4mhdtv[0m(0.28, 1951) [1m[4mdefinit televis[0m(0.28, 7801) [1m[4mdigit televis[0m(0.28, 8000) [1m[4mhdtv digit[0m(0.28, 10781) [1m[4mhigh definit[0m(0.28, 10872) [1m[4mtelevis hdtv[0m(0.28, 17735) [1m[4m__label__3075__high-definition television; HDTV; digital televisi[0m(0.28, 22038) [1m[4m__parent__4432[0m(0.26, 23920) [1m[4mdigit[0m(0.23, 1167) [1m[4mhigh[0m(0.21, 1984)
