In [1]:
from os import listdir, walk
from os.path import isfile, join
import numpy as np

In [2]:
def get_filenames_from_dir(file_dir, include_sub_dir=False):
    if include_sub_dir:
        filenames = []
        for root, _, files in walk(file_dir, topdown=False):
            for f in files:
                filenames.append(join(root, f).replace(file_dir + "/", ""))
    else:
        filenames = [f for f in listdir(file_dir) if isfile(join(file_dir, f))]
        
    return filenames


filenames = get_filenames_from_dir('../tests_pred/')


def get_pred_labels(filepath):
    file = open(filepath)
    for _ in range(3):
        next(file)
    
    pred, labels = [], []
    for line in file:
        line = line.strip().split('\t')
        pred.append(int(line[2]))
        labels.append(int(line[3]))
    
    return pred, labels


def PRF1(pred, labels, digits=3):
    tp, fp, tn, fn = 0, 0, 0, 0
    for i, j in zip(pred, labels):
        if i == 1 and j == 1:
            tp += 1
        elif i == 1 and j == 0:
            fp += 1
        elif i == 0 and j == 0:
            tn += 1
        else:
            fn += 1
            
    P = tp / (tp + fp)
    R = tp / (tp + fn)
    F1 = 2 * P * R / (P + R)
    return round(P, digits), round(R, digits), round(F1, digits)


def getBaseLinePRF1(filenames, cond):
    for fname in filenames:
        if 'aug' not in fname and cond in fname:
            pred, labels = get_pred_labels('../tests_pred/' + fname)
            yield PRF1(pred, labels, digits=3)
            

def getAugREDAPRF1(filenames, cond):
    for fname in filenames:
        if 'ngram' not in fname:
            if 'reda' in fname and cond in fname:
                pred, labels = get_pred_labels('../tests_pred/' + fname)
                yield PRF1(pred, labels, digits=3)
                
                
def getAugNgramPRF1(filenames, cond):
    for fname in filenames:
        if 'ngram' in fname and cond in fname:
            pred, labels = get_pred_labels('../tests_pred/' + fname)
            yield PRF1(pred, labels, digits=3)
            

mean = lambda x: np.round(np.mean(np.array(x), axis=0), 3)


def report(func, filenames=filenames):
    
    models = ['bow', 'cnn', 'lstm', 'gru', 'ernieGram']
    models_res = []
    for model in models:
        model_res = mean(list(func(filenames, model)))
        print(f'The average Precision, Recall, and F1 for {model} is: ', model_res)
        models_res.append(model_res)
    print(f'The overall average Precision, Recall, and F1 is: ', mean(models_res))

In [3]:
report(getBaseLinePRF1)

The average Precision, Recall, and F1 for bow is:  [0.825 0.615 0.704]
The average Precision, Recall, and F1 for cnn is:  [0.805 0.628 0.705]
The average Precision, Recall, and F1 for lstm is:  [0.827 0.625 0.712]
The average Precision, Recall, and F1 for gru is:  [0.824 0.634 0.717]
The average Precision, Recall, and F1 for ernieGram is:  [0.958 0.78  0.859]
The overall average Precision, Recall, and F1 is:  [0.848 0.656 0.739]


In [4]:
report(getAugREDAPRF1)

The average Precision, Recall, and F1 for bow is:  [0.817 0.633 0.713]
The average Precision, Recall, and F1 for cnn is:  [0.781 0.636 0.7  ]
The average Precision, Recall, and F1 for lstm is:  [0.814 0.634 0.713]
The average Precision, Recall, and F1 for gru is:  [0.819 0.638 0.717]
The average Precision, Recall, and F1 for ernieGram is:  [0.959 0.758 0.846]
The overall average Precision, Recall, and F1 is:  [0.838 0.66  0.738]


In [5]:
report(getAugNgramPRF1)

The average Precision, Recall, and F1 for bow is:  [0.818 0.629 0.711]
The average Precision, Recall, and F1 for cnn is:  [0.762 0.638 0.693]
The average Precision, Recall, and F1 for lstm is:  [0.821 0.63  0.713]
The average Precision, Recall, and F1 for gru is:  [0.817 0.633 0.714]
The average Precision, Recall, and F1 for ernieGram is:  [0.953 0.76  0.846]
The overall average Precision, Recall, and F1 is:  [0.834 0.658 0.735]
