In [66]:
# # Install dependencies
# !pip install nltk sklearn PyStemmer tqdm

In [35]:
import json

import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords
STOPWORDS = set(stopwords.words('english'))

import re
import nltk
import pickle as pkl
import numpy as np
from tqdm import tqdm
import scipy.sparse as sp
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
import Stemmer

def tokenizer_word(string):
    string = string.lower()
    tokens = [w for w in re.split("[\W,_]+", string) if ((w != "") and (w not in STOPWORDS))]
    tokens = stemmer.stemWords(tokens)
    return tokens

stemmer = Stemmer.Stemmer('english')

[nltk_data] Downloading package stopwords to /home/nilesh/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [15]:
trnX = [x.strip() for x in open(f'GZ-Eurlex-4.3K/raw/trn_X.txt')]
tstX = [x.strip() for x in open(f'GZ-Eurlex-4.3K/raw/tst_X.txt')]
Y = [x.strip() for x in open(f'GZ-Eurlex-4.3K/raw/Y.txt')]

In [92]:
# For bigger datasets (amazon-1m, wikipedia-1m) use larger MAX_FEATURES (500000/1000000 used in zestxml paper)
MAX_FEATURES=50000

In [93]:
unigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(1, 1), max_df=0.5, norm=None, max_features=MAX_FEATURES)
bigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(2, 2), max_df=0.5, norm=None, max_features=MAX_FEATURES)

%time unigram_trn_X_Xf = unigram_vectorizer.fit_transform(trnX)
%time bigram_trn_X_Xf = bigram_vectorizer.fit_transform(trnX)

%time unigram_tst_X_Xf = unigram_vectorizer.transform(tstX)
%time bigram_tst_X_Xf = bigram_vectorizer.transform(tstX)

trn_X_Xf = normalize(sp.hstack([unigram_trn_X_Xf, bigram_trn_X_Xf]))
tst_X_Xf = normalize(sp.hstack([unigram_tst_X_Xf, bigram_tst_X_Xf]))
Xf = np.concatenate([unigram_vectorizer.get_feature_names_out(), bigram_vectorizer.get_feature_names_out()])

CPU times: user 1min 8s, sys: 1.19 s, total: 1min 9s
Wall time: 3min 4s
CPU times: user 1min 40s, sys: 2.5 s, total: 1min 43s
Wall time: 4min 31s
CPU times: user 8.78 s, sys: 142 ms, total: 8.93 s
Wall time: 23.7 s
CPU times: user 10.5 s, sys: 55.3 ms, total: 10.5 s
Wall time: 29.2 s


In [88]:
unigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(1, 1), max_df=0.5, norm=None, max_features=MAX_FEATURES)
bigram_vectorizer = TfidfVectorizer(lowercase=False, tokenizer=tokenizer_word, ngram_range=(2, 2), max_df=0.5, norm=None, max_features=MAX_FEATURES)

%time unigram_Y_Yf = unigram_vectorizer.fit_transform(Y)
%time bigram_Y_Yf = bigram_vectorizer.fit_transform(Y)
lbl_Y_Yf = sp.csr_matrix((np.full(len(Y), unigram_vectorizer.idf_.max()), 
                          np.arange(len(Y)), 
                          range(len(Y)+1)), 
                         (len(Y), len(Y)))

Y_Yf = normalize(sp.hstack([unigram_Y_Yf, bigram_Y_Yf, lbl_Y_Yf]))
Yf = np.concatenate([unigram_vectorizer.get_feature_names_out(), 
                     bigram_vectorizer.get_feature_names_out(), 
                     [f'__label__{i}__{Y[i][:50]}' for i in range(len(Y))]])

CPU times: user 112 ms, sys: 125 µs, total: 112 ms
Wall time: 330 ms
CPU times: user 203 ms, sys: 974 µs, total: 204 ms
Wall time: 643 ms


## Visualize

In [28]:
def get_text(x, text, X_Xf, sep=' ', K=-1, attr='bold underline'):
    if K == -1: K = X_Xf[x].nnz
    sorted_inds = X_Xf[x].indices[np.argsort(-X_Xf[x].data)][:K]
    return '%d : \n'%x + sep.join(['%s(%.2f, %d)'%(_c(text[i], attr=attr), X_Xf[x, i], i) for i in sorted_inds])

In [29]:
class bcolors:
    purple = '\033[95m'
    blue = '\033[94m'
    green = '\033[92m'
    warn = '\033[93m' # dark yellow
    fail = '\033[91m' # dark red
    white = '\033[37m'
    yellow = '\033[33m'
    red = '\033[31m'
    
    ENDC = '\033[0m'
    bold = '\033[1m'
    underline = '\033[4m'
    reverse = '\033[7m'
    
    on_grey = '\033[40m'
    on_yellow = '\033[43m'
    on_red = '\033[41m'
    on_blue = '\033[44m'
    on_green = '\033[42m'
    on_magenta = '\033[45m'
    
def _c(*args, attr='bold'):
    string = ''.join([bcolors.__dict__[a] for a in attr.split()])
    string += ' '.join([str(arg) for arg in args])+bcolors.ENDC
    return string

In [72]:
def read_sparse_mat(filename):
    with open(filename) as f:
        nr, nc = map(int, f.readline().split(' '))
        data = []; inds = []; indptr = [0]
        for line in tqdm(f, total=nr):
            row = list(map(lambda x: x.split(':'), line.split()))
            inds.append(np.array(list(map(lambda x: int(x[0]), row))))
            data.append(np.array(list(map(lambda x: float(x[1]), row))))
            indptr.append(indptr[-1]+len(row))
        spmat = sp.csr_matrix((np.concatenate(data), np.concatenate(inds), indptr), (nr, nc))
        return spmat

In [73]:
orig_Xf = [x.strip() for x in open(f'GZ-Eurlex-4.3K/Xf.txt')]
orig_trn_X_Xf = read_sparse_mat(f'GZ-Eurlex-4.3K/trn_X_Xf.txt')

orig_Yf = [x.strip() for x in open(f'GZ-Eurlex-4.3K/Yf.txt')]
orig_Y_Yf = read_sparse_mat(f'GZ-Eurlex-4.3K/Y_Yf.txt')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4271/4271 [00:00<00:00, 20766.12it/s]


In [96]:
x = 1

print(get_text(x, Xf, trn_X_Xf))
print(get_text(x, orig_Xf, orig_trn_X_Xf))

1 : 
[1m[4mlanguag[0m(0.62, 30043) [1m[4moffici languag[0m(0.30, 85515) [1m[4minstitut communiti[0m(0.29, 79910) [1m[4mlanguag use[0m(0.26, 81485) [1m[4mdraft[0m(0.18, 20889) [1m[4minstitut[0m(0.17, 27777) [1m[4mperson subject[0m(0.13, 87211) [1m[4mstate person[0m(0.13, 95152) [1m[4mjurisdict member[0m(0.12, 80670) [1m[4mrule procedur[0m(0.12, 92843) [1m[4mfour[0m(0.11, 24028) [1m[4mone offici[0m(0.10, 85654) [1m[4mcourt justic[0m(0.10, 70483) [1m[4mjurisdict[0m(0.09, 28866) [1m[4mjustic[0m(0.09, 28869) [1m[4mcourt[0m(0.08, 18873) [1m[4msend[0m(0.08, 41617) [1m[4mcommuniti shall[0m(0.08, 68776) [1m[4mcommuniti may[0m(0.07, 68687) [1m[4mdocument[0m(0.07, 20692) [1m[4muse[0m(0.07, 46489) [1m[4msender[0m(0.07, 41619) [1m[4mrecognis offici[0m(0.07, 90586) [1m[4mlanguag languag[0m(0.07, 81480) [1m[4mstate send[0m(0.06, 95205) [1m[4mlanguag one[0m(0.06, 81483) [1m[4muse european[0m(0.06, 98079) [1m[4mone[0m(0

In [94]:
y = 1

print(get_text(y, Yf, Y_Yf))
print(get_text(y, orig_Yf, orig_Y_Yf))

1 : 
[1m[4m__label__1__financing[0m(0.81, 19072) [1m[4mfinanc[0m(0.58, 1630)
1 : 
[1m[4m__label__1__financing[0m(0.81, 19072) [1m[4mfinanc[0m(0.58, 1630)
