In [13]:
import numpy as np
from sklearn.metrics import classification_report, f1_score
from load_data import *

In [2]:
TRAIN_PATH = './snips/snips_train_actual.csv'
TEST_PATH = './snips/snips_test_actual.csv'

In [3]:
data_loader = SnipsDataLoader(TRAIN_PATH, None, TEST_PATH)
data_loader.split_train_valid(valid_size=0.05, keep_class_ratios=True)

In [4]:
X_train, y_train = data_loader.get_train_data()
X_valid, y_valid = data_loader.get_valid_data()

In [5]:
feature_extractor = FeatureExtractor(X_train, X_valid)
feature_extractor.extract_features(keep_words_threshold=5)
X_train = feature_extractor.get_train_encodings()
X_valid = feature_extractor.get_valid_encodings()

In [6]:
class MultinomialNaiveBayes():
    def __init__(self):
        pass
    
    def fit(self, X, y):
        num_examples, vocab_size = X.shape
        num_labels = np.amax(y) + 1
        y_one_hot = np.eye(num_labels)[y]
        X_row_sum = np.sum(X, axis=1, keepdims=True)
        
        self.vocab_probs = (1 + np.dot(X.T, y_one_hot)) / (vocab_size + np.dot(X_row_sum.T, y_one_hot))
        self.prior_probs = np.mean(y_one_hot, axis=0)
        self.vocab_log_probs = np.log(self.vocab_probs)
        self.prior_log_probs = np.log(self.prior_probs)
    
    def predict(self, X):
        post_probs = np.dot(X, self.vocab_log_probs) + self.prior_log_probs
        predictions = np.argmax(post_probs, axis=1)
        return predictions

In [7]:
def calculate_accuracy(predictions, targets):
    return np.mean(predictions == targets)

In [8]:
model = MultinomialNaiveBayes()
model.fit(X_train, y_train)
y_predict = model.predict(X_valid)
calculate_accuracy(y_predict, y_valid)

0.9739130434782609

In [11]:
print(classification_report(y_valid, y_predict, digits=5))

              precision    recall  f1-score   support

           0    0.97980   1.00000   0.98980        97
           1    0.97059   1.00000   0.98507        99
           2    1.00000   1.00000   1.00000       100
           3    1.00000   1.00000   1.00000        98
           4    0.92708   0.90816   0.91753        98
           5    0.97872   0.93878   0.95833        98
           6    0.96040   0.97000   0.96517       100

    accuracy                        0.97391       690
   macro avg    0.97380   0.97385   0.97370       690
weighted avg    0.97382   0.97391   0.97375       690



In [12]:
data_loader.index_to_text_label_mapping

{0: 'AddToPlaylist',
 1: 'BookRestaurant',
 2: 'GetWeather',
 3: 'RateBook',
 4: 'SearchCreativeWork',
 5: 'SearchScreeningEvent',
 6: 'PlayMusic'}

In [16]:
print(f1_score(y_valid, y_predict, average='weighted'))

0.9737452137920579
