In [1]:
import sys
sys.path.append("../src")

In [2]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer, CountVectorizer
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.pipeline import Pipeline
import nltk
from nltk.corpus import stopwords
from nltk.stem import SnowballStemmer
from nltk.tokenize import wordpunct_tokenize, word_tokenize
import string

from adthena_task.constants import DATA_DIR_TRAIN, DATA_DIR_TEST, SEED

In [3]:
nltk.download('stopwords')

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/lukaszbala/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


True

In [4]:
def clean_text(text):
    text = str(text).lower()
    tokens = [stemmer.stem(word) for word in wordpunct_tokenize(text) if word not in list(stop_words) + list(string.punctuation)]
    text = " ".join(tokens)
    return text

In [5]:
data = pd.read_csv(DATA_DIR_TRAIN, header=None)
data.columns = ["Query", "Label"]

In [6]:
stop_words = set(stopwords.words("english"))
stemmer = SnowballStemmer(language="english")

In [7]:
data["Clean_query"] = data["Query"].apply(clean_text)

In [8]:
X_train, X_test, y_train, y_test = \
train_test_split(data["Clean_query"], data["Label"], test_size=0.2, random_state=SEED)

In [13]:
vectorizer = TfidfVectorizer(ngram_range=(1,1))
vector_train = vectorizer.fit_transform(X_train)
vector_test = vectorizer.transform(X_test)

In [15]:
model = LogisticRegression()
model.fit(vector_train, y_train)
accuracy_score(y_test, model.predict(vector_test))

KeyboardInterrupt: 

In [None]:
## Confusion matrix

In [None]:
conf_matrix = confusion_matrix(y_test, model.predict(vector_test))

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
df_cm = pd.DataFrame(conf_matrix, index=model.classes_, columns=model.classes_)
sns.heatmap(df_cm,annot=True, fmt=”d”, ax=ax)
plt.ylabel(‘True label’)
plt.xlabel(‘Predicted label’)