In [84]:
import numpy as np
import pandas as pd
import re
import nltk
from ast import literal_eval
from collections import Counter
nltk.download('stopwords')
from nltk.corpus import stopwords
from scipy import sparse as sp_sparse
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score 
from sklearn.preprocessing import MultiLabelBinarizer

[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [5]:
#to read data, as tags are string of list contaning tags.
# we use literal eval to convert string of list into list
def read_data(file):
  data=pd.read_csv(file,sep='\t')
  data.tags=data.tags.apply(literal_eval)
  return data


In [6]:
train=read_data('data/train.tsv')
validation=read_data('data/validation.tsv')

In [7]:
X_train,X_val=train['title'].values,validation['title'].values
y_train,y_val=train['tags'].values,validation['tags'].values

In [8]:
REPLACE_BY_SPACE_RE = re.compile('[/(){}\[\]\|@,;]')
BAD_SYMBOLS_RE = re.compile('[^0-9a-z #+_]')
STOPWORDS = set(stopwords.words('english'))

def preprocess(text):
    """
        text: a string
        
        return: modified initial string
    """
    text = text.lower()# lowercase text
    text = REPLACE_BY_SPACE_RE.sub('',text)# replace REPLACE_BY_SPACE_RE symbols by space in text
    text = BAD_SYMBOLS_RE.sub('', text)# delete symbols which are in BAD_SYMBOLS_RE from text
    text = ' '.join([x for x in text.split() if x and x not in STOPWORDS]) # delete stopwords from text
    return text

In [9]:
X_train=[preprocess(x) for x in X_train]
X_val=[preprocess(x) for x in X_val]

In [10]:
X_train[:10]

['draw stacked dotplot r',
 'mysql select records datetime field less specified value',
 'terminate windows phone 81 app',
 'get current time specific country via jquery',
 'configuring tomcat use ssl',
 'awesome nested set plugin add new children tree various levels',
 'create map json response ruby rails 3',
 'rspec test method called',
 'springboot catalina lifecycle exception',
 'import data excel mysql database using php']

In [11]:
y_train[:10]

array([list(['r']), list(['php', 'mysql']), list(['c#']),
       list(['javascript', 'jquery']), list(['java']),
       list(['ruby-on-rails']), list(['ruby', 'ruby-on-rails-3', 'json']),
       list(['ruby']), list(['java', 'spring', 'spring-mvc']),
       list(['php', 'codeigniter'])], dtype=object)

In [12]:
##################################################
################# Bag of words ###################
##################################################

words_freq=Counter()
tags_freq=Counter()

for text in X_train:
  for words in text.split():
    words_freq[words]+=1

for tag_list in y_train:
  for tags in tag_list:
    tags_freq[tags]+=1


In [14]:
words_freq.items()



In [15]:
tags_freq.items()

dict_items([('r', 1727), ('php', 13907), ('mysql', 3092), ('c#', 19077), ('javascript', 19078), ('jquery', 7510), ('java', 18661), ('ruby-on-rails', 3344), ('ruby', 2326), ('ruby-on-rails-3', 692), ('json', 2026), ('spring', 1346), ('spring-mvc', 618), ('codeigniter', 786), ('class', 509), ('html', 4668), ('ios', 3256), ('c++', 6469), ('eclipse', 992), ('python', 8940), ('list', 693), ('objective-c', 4338), ('swift', 1465), ('xaml', 438), ('asp.net', 3939), ('wpf', 1289), ('multithreading', 1118), ('image', 672), ('performance', 512), ('twitter-bootstrap', 501), ('linq', 964), ('xml', 1347), ('numpy', 502), ('ajax', 1767), ('django', 1835), ('laravel', 525), ('android', 2818), ('rest', 456), ('asp.net-mvc', 1244), ('web-services', 633), ('string', 1573), ('excel', 443), ('winforms', 1468), ('arrays', 2277), ('c', 3119), ('sockets', 579), ('osx', 490), ('entity-framework', 649), ('mongodb', 350), ('opencv', 401), ('xcode', 900), ('uitableview', 460), ('algorithm', 419), ('python-2.7', 4

In [36]:
dict_size=8000
words_list=sorted(words_freq.keys(),key=lambda x : words_freq[x],reverse=True)[:dict_size]
words_to_index={w:ii for ii,w in enumerate(words_list)}
total_words=words_to_index.keys()

def bag_of_words(text,wti,dictsize):
  ''' text: input text,
      wti: word to index dict,
      dictsize: size of dict'''
  ans=np.zeros(dictsize)
  for words in text.split():
    if words in wti:
      ans[wti[words]]+=1
  return ans


In [39]:
X_train_bag=sp_sparse.vstack([sp_sparse.csr_matrix(bag_of_words(x,words_to_index,dict_size)) for x in X_train ])
X_val_bag=sp_sparse.vstack([sp_sparse.csr_matrix(bag_of_words(x,words_to_index,dict_size)) for x in X_val ])

In [46]:
y_train.shape,X_train_bag.shape

((100000,), (100000, 8000))

In [47]:
y_train[:5]

array([list(['r']), list(['php', 'mysql']), list(['c#']),
       list(['javascript', 'jquery']), list(['java'])], dtype=object)

In [55]:
mlb=MultiLabelBinarizer(classes=sorted(tags_freq.keys()))
y_train=mlb.fit_transform(y_train)
y_val=mlb.transform(y_val)

In [57]:
y_train.shape

(100000, 100)

In [81]:
lr = LogisticRegression(penalty='l1', C=2, solver='liblinear')
ovr = OneVsRestClassifier(lr)
ovr.fit(X_train_bag, y_train)

OneVsRestClassifier(estimator=LogisticRegression(C=2, class_weight=None,
                                                 dual=False, fit_intercept=True,
                                                 intercept_scaling=1,
                                                 l1_ratio=None, max_iter=100,
                                                 multi_class='auto',
                                                 n_jobs=None, penalty='l1',
                                                 random_state=None,
                                                 solver='liblinear', tol=0.0001,
                                                 verbose=0, warm_start=False),
                    n_jobs=None)

In [82]:
pred=ovr.predict(X_val_bag)

In [83]:
print('Accuracy:', accuracy_score(y_val, pred))
print('F1-score macro:', f1_score(y_val, pred, average='macro'))
print('F1-score micro:', f1_score(y_val, pred, average='micro'))
print('F1-score weighted:', f1_score(y_val, pred, average='weighted'))

Accuracy: 0.35933333333333334
F1-score macro: 0.520527057495733
F1-score micro: 0.6766430860390592
F1-score weighted: 0.657715046187274
