In [None]:
import pickle
import numpy as np
import matplotlib.pyplot as plt
import sklearn as sk
import sklearn.utils
import sklearn.model_selection
import sklearn.metrics
import sklearn.linear_model

In [None]:
data = np.load('results/results_neo_0.npz')
y = data['y']

In [None]:
data = np.load('results/results_neo_0_enh.npz')
probas_0 = data['preds']

In [None]:
data = np.load('results/results_neo_1_enh.npz')
probas_1 = data['preds']

In [None]:
data = np.load('results/results_neo_2_enh.npz')
probas_2 = data['preds']

In [None]:
with open('data_train.p', 'rb') as f:
    data_train = pickle.load(f)

## Calculate Preds

In [None]:
MLE = lambda probas: np.argmax(probas, axis=1)

In [None]:
preds_0 = MLE(probas_0)
preds_1 = MLE(probas_1)
preds_2 = MLE(probas_2)

### MAP
* GPT-2: prior_manual = np.array([5.0, 23.5, 3.5, 2.25, 3.7])
* GPT-Neo: prior_manual = np.array([14.0, 2.5, 3.5, 3.5, 1.0])

In [None]:
MAP = lambda probas, prior: np.argmax(probas * prior, axis=1)

In [None]:
prior_manual = np.array([18.5, 3.0, 4.5, 4.5, 1.0])
prior_empirical = np.array([461., 624., 339.,  95.,  72.])

In [None]:
map_emp_0 = MAP(probas_0, prior_empirical)
map_emp_1 = MAP(probas_1, prior_empirical)
map_emp_2 = MAP(probas_2, prior_empirical)

In [None]:
map_0 = MAP(probas_0, prior_manual)
map_1 = MAP(probas_1, prior_manual)
map_2 = MAP(probas_2, prior_manual)

### Code for manual grid search

In [None]:
range_0 = []# np.arange(13, 19, .5)
range_1 = []# np.arange(1, 6, .5)
range_2 = []# np.arange(2, 7, .5)
range_3 = []# np.arange(3, 8, .5)
range_4 = []# np.arange(1, 3, .5)

In [None]:
best_f1 = 0
best_weights = []

for w0 in range_0:
    print(f'w0: {w0}')
    for w1 in range_1:
        for w2 in range_2:
            for w3 in range_3:
                for w4 in range_4:
                    cur_weights = np.array([w0, w1, w2, w3, w4])
                    cur_preds = MAP(probas_0 + probas_1 + probas_2, cur_weights)
                    cur_f1 = sk.metrics.f1_score(y, cur_preds, average='macro')
                    if cur_f1 > best_f1:
                        best_f1 = cur_f1
                        best_weights = [w0, w1, w2, w3, w4]

In [None]:
best_weights

### Evaluate Consistency

In [None]:
print('Pairwise Rand scores, MLE:')
print(sk.metrics.rand_score(preds_0, preds_1))
print(sk.metrics.rand_score(preds_0, preds_2))
print(sk.metrics.rand_score(preds_1, preds_2))

In [None]:
print('Pairwise Rand scores, MAP:')
print(sk.metrics.rand_score(map_0, map_1))
print(sk.metrics.rand_score(map_0, map_2))
print(sk.metrics.rand_score(map_1, map_2))

In [None]:
import krippendorff_alpha

In [None]:
print('Krippendorff\'s alpha (MLE):')
print(krippendorff_alpha.krippendorff_alpha([preds_0, preds_1, preds_2], metric=krippendorff_alpha.nominal_metric))

In [None]:
print('Krippendorff\'s alpha (MAP):')
print(krippendorff_alpha.krippendorff_alpha([map_0, map_1, map_2], metric=krippendorff_alpha.nominal_metric))

## Classification Reports

In [None]:
def report(y, preds, save_name=False):
    print(sk.metrics.classification_report(y, preds, zero_division=0))
    sk.metrics.ConfusionMatrixDisplay.from_predictions(y, preds, display_labels=data_train.keys(), \
                                                       xticks_rotation=45, normalize='true')
    if save_name:
        plt.savefig(save_name, bbox_inches='tight')
    plt.show()

In [None]:
print("MLE:")
report(y, preds_0)
report(y, preds_1)
report(y, preds_2)

In [None]:
print("MAP:")
report(y, map_0)
report(y, map_1)
report(y, map_2)

In [None]:
report(y, MAP(probas_0 + probas_1 + probas_2, prior_manual), save_name='confusion_MAP.png')

In [None]:
report(y, MLE(probas_0 + probas_1 + probas_2), save_name='confusion_MLE.png')

In [None]:
report(y, MAP(probas_0 + probas_1 + probas_2, prior_empirical), save_name='confusion_MAP_empiric.png')

# Test Data

In [None]:
data_test = np.load('results/results_neo_test.npz')
y_test = data_test['y']

In [None]:
data_test = np.load('results/results_neo_test_enh.npz')
probas_test = data_test['preds']

In [None]:
map_test = MAP(probas_test, prior_manual)

In [None]:
report(y_test, map_test, save_name='confusion_test.png')