In [21]:
import spacy
import numpy as np
import re
import itertools
from collections import Counter
import json
import pandas as pd
import torch
import dill as pickle
import time
from data_module.data_preprocessor import preprocess_question, get_label
from collections import Counter
from pandas_ml import ConfusionMatrix

## Load Test Dataset, Models, and Preprocessing Module

In [2]:
df_test = pd.read_csv("ir_test_dataset.csv")

In [3]:
models = []
models.append(torch.load('ensemble_learning_related/model1.model'))
models.append(torch.load('ensemble_learning_related/model2.model'))
models.append(torch.load('ensemble_learning_related/model3.model'))
models.append(torch.load('ensemble_learning_related/model4.model'))

In [4]:
text_field = pickle.load(open("ensemble_learning_related/text_vocab.pkl", "rb"))
label_field = pickle.load(open("ensemble_learning_related/label_vocab.pkl", "rb"))

## Create function for evaluation

In [11]:
def model_prediction(model, text_field, label_field, test_data):
    res = []
    time_predictions = []
    for text in test_data:
        start_time = time.time()
        text = preprocess_question(text, text_field, use_gpu=True)
        model.eval()
        y = model(text)
        label_string = get_label(y, label_field)
        time_predictions.append(time.time() - start_time)
        res.append(label_string)
        del text
        torch.cuda.empty_cache()
    avg_time = np.average(time_predictions)
    return res, avg_time

In [12]:
def most_voted(res_column_stack):
    most_voted_res = []
    for row in res_column_stack:
        c = Counter(row)
        most_voted_res.append(int(c.most_common(1)[0][0]))
    return most_voted_res

In [13]:
def ensemble_model_prediction(models, text_field, label_field, test_data):
    res_all = []
    avg_time_each_model = []
    for model in models:
        res_tmp, avg_time_tmp = model_prediction(model, text_field, label_field, test_data)
        res_all.append(res_tmp)
        avg_time_each_model.append(avg_time_tmp)
    res_all = np.column_stack(res_all)
    most_voted_res = most_voted(res_all)
    return most_voted_res, avg_time_each_model

## Run ensemble model

In [40]:
ensemble_res, avg_time_each_model = ensemble_model_prediction(models, text_field, label_field, df_test.text.values)

  logloss = F.log_softmax(logit) # log of softmax


In [41]:
avg_time_each_model

[0.0063748300075531,
 0.0075662493705749515,
 0.009007121801376344,
 0.010398153305053712]

In [29]:
confusion_matrix = ConfusionMatrix(df_test.label, ensemble_res)

In [30]:
confusion_matrix.print_stats()

Confusion Matrix:

Predicted  -1   0    1  __all__
Actual                         
-1         25   4   16       45
0           6  50   18       74
1           2  12  867      881
__all__    33  66  901     1000


Overall Statistics:

Accuracy: 0.942
95% CI: (0.9256635046176478, 0.955667139747181)
No Information Rate: ToDo
P-Value [Acc > NIR]: 2.183484722865502e-06
Kappa: 0.7097823367525641
Mcnemar's Test P-Value: ToDo


Class Statistics:

Classes                                        -1          0          1
Population                                   1000       1000       1000
P: Condition positive                          45         74        881
N: Condition negative                         955        926        119
Test outcome positive                          33         66        901
Test outcome negative                         967        934         99
TP: True Positive                              25         50        867
TN: True Negative                             947    

## Calculate average time predictions

In [32]:
avg_time_model_1 = avg_time_each_model[0]

In [31]:
df_train = pd.read_csv("ir_train_dataset.csv")

In [33]:
ensemble_res, avg_time_each_model = ensemble_model_prediction(models, text_field, label_field, df_train.text.values)

  logloss = F.log_softmax(logit) # log of softmax


In [37]:
avg_time_model_1 = (avg_time_each_model[0] + avg_time_model_1)/2

In [38]:
avg_time_model_1

0.006443462557262845

In [39]:
1/avg_time_model_1 #records per sec

155.19605974474626