In [1]:
import numpy as np
import pandas as pd
from sklearn.metrics import classification_report

In [2]:
model_names = [
    'deberta-v3-large',
    'stsb-deberta-v3-large',
    'stsb-roberta-large'
]

In [3]:
preds = []
pred_path = '../train/training_results'
for model_name in model_names:
    pred = pd.read_csv(f'{pred_path}/{model_name}/prediction_{model_name}.csv')
    pred.rename(columns={'prediction': f'{model_name}_pred', 'similarity': f'{model_name}_score'}, inplace=True)
    preds.append(pred)

In [4]:
# check if all predictions are the same length
assert all([len(preds[0]) == len(pred) for pred in preds])

In [5]:
# join predictions on axis=1, remove duplicate columns
all_preds = pd.concat(preds, axis=1)
all_preds = all_preds.loc[:,~all_preds.columns.duplicated()]
all_preds.head()

Unnamed: 0,label,id_1,id_2,text_1,text_2,deberta-v3-large_score,deberta-v3-large_pred,stsb-deberta-v3-large_score,stsb-deberta-v3-large_pred,stsb-roberta-large_score,stsb-roberta-large_pred
0,1,1089874,1089925,"PCCW's chief operating officer, Mike Butcher, ...",Current Chief Operating Officer Mike Butcher a...,0.999707,1,0.999622,1,0.999675,1
1,1,3019446,3019327,The world's two largest automakers said their ...,Domestic sales at both GM and No. 2 Ford Motor...,0.999596,1,0.999587,1,0.999539,1
2,1,1945605,1945824,According to the federal Centers for Disease C...,The Centers for Disease Control and Prevention...,0.999748,1,0.999688,1,0.999692,1
3,0,1430402,1430329,A tropical storm rapidly developed in the Gulf...,A tropical storm rapidly developed in the Gulf...,0.000936,0,0.982065,1,0.014677,0
4,0,3354381,3354396,The company didn't detail the costs of the rep...,But company officials expect the costs of the ...,0.001443,0,0.000571,0,0.002013,0


In [16]:
for model_name in model_names:
    print(f'`{model_name}` model classification report:')
    print(classification_report(all_preds['label'], all_preds[f'{model_name}_pred'], digits=4))
    # export each model's classification report to txt file
    with open(f'{pred_path}/{model_name}/classification_report_{model_name}.txt', 'w') as f:
        f.write(classification_report(all_preds['label'], all_preds[f'{model_name}_pred'], digits=4))

`deberta-v3-large` model classification report:
              precision    recall  f1-score   support

           0     0.8809    0.8062    0.8419       578
           1     0.9064    0.9451    0.9253      1147

    accuracy                         0.8986      1725
   macro avg     0.8936    0.8757    0.8836      1725
weighted avg     0.8978    0.8986    0.8974      1725

`stsb-deberta-v3-large` model classification report:
              precision    recall  f1-score   support

           0     0.8696    0.8304    0.8496       578
           1     0.9165    0.9372    0.9267      1147

    accuracy                         0.9014      1725
   macro avg     0.8930    0.8838    0.8881      1725
weighted avg     0.9007    0.9014    0.9009      1725

`stsb-roberta-large` model classification report:
              precision    recall  f1-score   support

           0     0.8405    0.8478    0.8441       578
           1     0.9229    0.9189    0.9209      1147

    accuracy                   

In [7]:
# vote on the majority
all_preds['final_pred'] = all_preds[[f'{model}_pred' for model in model_names]].mode(axis=1)[0]
all_preds.head()

Unnamed: 0,label,id_1,id_2,text_1,text_2,deberta-v3-large_score,deberta-v3-large_pred,stsb-deberta-v3-large_score,stsb-deberta-v3-large_pred,stsb-roberta-large_score,stsb-roberta-large_pred,final_pred
0,1,1089874,1089925,"PCCW's chief operating officer, Mike Butcher, ...",Current Chief Operating Officer Mike Butcher a...,0.999707,1,0.999622,1,0.999675,1,1
1,1,3019446,3019327,The world's two largest automakers said their ...,Domestic sales at both GM and No. 2 Ford Motor...,0.999596,1,0.999587,1,0.999539,1,1
2,1,1945605,1945824,According to the federal Centers for Disease C...,The Centers for Disease Control and Prevention...,0.999748,1,0.999688,1,0.999692,1,1
3,0,1430402,1430329,A tropical storm rapidly developed in the Gulf...,A tropical storm rapidly developed in the Gulf...,0.000936,0,0.982065,1,0.014677,0,0
4,0,3354381,3354396,The company didn't detail the costs of the rep...,But company officials expect the costs of the ...,0.001443,0,0.000571,0,0.002013,0,0


In [8]:
# export predictions results
all_preds.to_csv('./ensemble_results/ensemble_prediction.csv', index=False)
# export wrong predictions
all_preds[all_preds['label'] != all_preds['final_pred']].to_csv('./ensemble_results/ensemble_wrongs.csv', index=False)
all_preds[all_preds['label'] != all_preds['final_pred']][['label', 'text_1', 'text_2', 'final_pred']].reset_index(drop=True)

Unnamed: 0,label,text_1,text_2,final_pred
0,0,Ballmer has been vocal in the past warning tha...,"In the memo, Ballmer reiterated the open-sourc...",1
1,1,"Snow's remark ``has a psychological impact,'' ...",Snow's remark on the dollar's effects on expor...,0
2,1,Another body was pulled from the water on Thur...,Two more bodies were seen floating down the ri...,0
3,1,"Amgen shares gained 93 cents, or 1.45 percent,...",Shares of Allergan were up 14 cents at $78.40 ...,0
4,0,"During a screaming match in 1999, Carolyn told...","She, in turn, occasionally told John that she ...",1
...,...,...,...,...
146,0,"A 32-count indictment ""strikes at one of the v...",The newly unsealed 32-count indictment alleges...,1
147,1,"The council includes 13 Shiites, five Kurds, f...","There are five ethnic Kurds, five Sunni Muslim...",0
148,1,Brendsel and chief financial officer Vaughn Cl...,The company's chief executive retired and chie...,0
149,0,The delay comes on the heels of Boeing Chairma...,"On Monday, it also announced the resignation o...",1


In [14]:
print('Ensemble Classification Report')
print(classification_report(all_preds['label'], all_preds['final_pred'], digits=4))
# export ensemble classification report to txt file
with open('./ensemble_results/ensemble_classification_report.txt', 'w') as f:
    f.write(classification_report(all_preds['label'], all_preds['final_pred'], digits=4))

Ensemble Classification Report
              precision    recall  f1-score   support

           0     0.8917    0.8408    0.8655       578
           1     0.9220    0.9486    0.9351      1147

    accuracy                         0.9125      1725
   macro avg     0.9069    0.8947    0.9003      1725
weighted avg     0.9119    0.9125    0.9118      1725



In [10]:
# all the models are good at predicting 1, but not 0
# samples that all models predict wrong: label=0, every each model predicts 1
print('False Positives')
all_preds['avg_pred'] = all_preds[[f'{model}_pred' for model in model_names]].mean(axis=1)
false_positives = all_preds[(all_preds['label'] == 0) & (all_preds['avg_pred'] == 1)].reset_index(drop=True)
# drop the average prediction column
false_positives.drop(columns=['avg_pred'], inplace=True)
false_positives.to_csv('./ensemble_results/ensemble_false_positives.csv', index=False)
print("Number of false positives:", len(false_positives))
false_positives.head()

False Positives
Number of false positives: 50


Unnamed: 0,label,id_1,id_2,text_1,text_2,deberta-v3-large_score,deberta-v3-large_pred,stsb-deberta-v3-large_score,stsb-deberta-v3-large_pred,stsb-roberta-large_score,stsb-roberta-large_pred,final_pred
0,0,749900,749726,Ballmer has been vocal in the past warning tha...,"In the memo, Ballmer reiterated the open-sourc...",0.999575,1,0.999488,1,0.999535,1,1
1,0,44397,44620,Garner said the self-proclaimed mayor of Baghd...,Garner said self-proclaimed Baghdad mayor Moha...,0.999706,1,0.999692,1,0.999591,1,1
2,0,2287353,2287336,"""It appears from our initial report that this ...","Said Mr. Burke: ""It was a textbook landing con...",0.999734,1,0.99956,1,0.999644,1,1
3,0,2836872,2837029,"But, to the dismay of Reaganites, there is no ...",But there is no mention of the economic recove...,0.995793,1,0.998911,1,0.99719,1,1
4,0,1793425,1793488,A Florida grand jury investigating pharmaceuti...,A Florida grand jury indicted 19 people for co...,0.999713,1,0.999503,1,0.999554,1,1


In [11]:
print('False negatives')
false_negatives = all_preds[(all_preds['label'] == 1) & (all_preds['avg_pred'] == 0)].reset_index(drop=True)
# drop the average prediction column
false_negatives.drop(columns=['avg_pred'], inplace=True)
false_negatives.to_csv('./ensemble_results/ensemble_false_negatives.csv', index=False)
print("Number of false negatives:", len(false_negatives))
false_negatives.head()

False negatives
Number of false negatives: 33


Unnamed: 0,label,id_1,id_2,text_1,text_2,deberta-v3-large_score,deberta-v3-large_pred,stsb-deberta-v3-large_score,stsb-deberta-v3-large_pred,stsb-roberta-large_score,stsb-roberta-large_pred,final_pred
0,1,1596213,1596237,Another body was pulled from the water on Thur...,Two more bodies were seen floating down the ri...,0.001019,0,0.000606,0,0.233863,0,0
1,1,1401946,1401697,"Amgen shares gained 93 cents, or 1.45 percent,...",Shares of Allergan were up 14 cents at $78.40 ...,0.00199,0,0.001213,0,0.009337,0,0
2,1,1958143,1958023,The blue-chip Dow Jones industrial average .DJ...,The Dow Jones Industrial Average [$INDU] ended...,0.001158,0,0.004025,0,0.005881,0,0
3,1,813244,813532,But church members and observers say they expe...,But church members and observers say they anti...,0.001374,0,0.004524,0,0.012105,0,0
4,1,636412,636322,"As planned, the services will be rolled into Y...",The reporting services will also be rolled int...,0.003012,0,0.028222,0,0.001802,0,0
