In [10]:
import numpy as np
from load_data import *
from sklearn.metrics import classification_report

In [11]:
TRAIN_PATH = './atis/atis_train_actual.csv'
TEST_PATH = './atis/atis_test_actual.csv'

In [12]:
data_loader = SnipsDataLoader(train_path=TRAIN_PATH, valid_path=None, test_path=TEST_PATH)
data_loader.split_train_valid(valid_size=0.05, keep_class_ratios=True)

In [13]:
X_train, y_train = data_loader.get_train_data()
X_valid, y_valid = data_loader.get_valid_data()

In [14]:
feature_extractor = FeatureExtractor(X_train, X_valid)
feature_extractor.extract_features(keep_words_threshold=5)
X_train = feature_extractor.get_train_encodings()
X_valid = feature_extractor.get_valid_encodings()

In [15]:
class MultinomialNaiveBayes():
    def __init__(self):
        pass
    
    def fit(self, X, y):
        num_examples, vocab_size = X.shape
        num_labels = np.amax(y) + 1
        y_one_hot = np.eye(num_labels)[y]
        X_row_sum = np.sum(X, axis=1, keepdims=True)
        
        self.vocab_probs = (1 + np.dot(X.T, y_one_hot)) / (vocab_size + np.dot(X_row_sum.T, y_one_hot))
        self.prior_probs = np.mean(y_one_hot, axis=0)
        self.vocab_log_probs = np.log(self.vocab_probs)
        self.prior_log_probs = np.log(self.prior_probs)
    
    def predict(self, X):
        post_probs = np.dot(X, self.vocab_log_probs) + self.prior_log_probs
        predictions = np.argmax(post_probs, axis=1)
        return predictions

In [16]:
def calculate_accuracy(predictions, targets):
    return np.mean(predictions == targets)

In [17]:
model = MultinomialNaiveBayes()
model.fit(X_train, y_train)
y_predict = model.predict(X_valid)
calculate_accuracy(y_predict, y_valid)

0.9256198347107438

In [18]:
print(classification_report(y_valid, y_predict))

              precision    recall  f1-score   support

           0       0.96      0.95      0.95       183
           1       1.00      0.67      0.80         3
           2       0.73      0.90      0.81        21
           3       1.00      0.75      0.86         4
           4       1.00      1.00      1.00        13
           5       0.60      0.38      0.46         8
           6       0.88      1.00      0.93         7
           7       1.00      1.00      1.00         3

    accuracy                           0.93       242
   macro avg       0.90      0.83      0.85       242
weighted avg       0.93      0.93      0.92       242



In [19]:
data_loader.index_to_text_label_mapping

{0: 'atis_flight',
 1: 'atis_flight_time',
 2: 'atis_airfare',
 3: 'atis_aircraft',
 4: 'atis_ground_service',
 5: 'atis_airline',
 6: 'atis_abbreviation',
 7: 'atis_quantity'}