In [None]:
import h5py
import numpy as np

In [None]:
SECONDS_PER_CLIP = 5
DATA_TYPE = 'unbalanced_data'
CLASS_LABELS = ['Safe', 'Violent', 'Sexual', 'Both']
DPI = 150

HOME_PATH = 'C:/Users/Bhagyashree/Desktop/project/kidsguard-dataset/'
METRICS_PATH = 'metrics/aggregate_3_sec/balanced_data_classifier/evaluation_metric_split_0.hdf5'.format(SECONDS_PER_CLIP, DATA_TYPE)
ROC_PLOT_PATH = 'aggregate_{0}_sec_{1}_classifier_roc.pdf'.format(SECONDS_PER_CLIP, DATA_TYPE)

In [None]:
def load_data(name, path=HOME_PATH+METRICS_PATH):
    f = h5py.File(path, 'r')
    return f[name]

In [None]:
from sklearn.metrics import confusion_matrix

def get_confusion_matrix(y_true, y_pred, normalise=False):
    conf_mat = confusion_matrix(y_true, y_pred)
    if normalise:
        conf_mat = conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis]
    return conf_mat

In [None]:
def get_error_rates(confusion_matrix):
    tp = np.diag(confusion_matrix)
    fp = np.sum(confusion_matrix, axis=0) - tp
    fn = np.sum(confusion_matrix, axis=1) - tp
    num_classes = confusion_matrix.shape[0]
    tn = []
    for i in range(num_classes):
        temp = np.delete(conf_mat, i, 0)    # delete ith row
        temp = np.delete(temp, i, 1)  # delete ith column
        tn.append(sum(temp))
    tn = np.array(tn)
    return tp, fp, fn, tn

In [None]:
def get_precision_recall(tp, fp, fn):
    precision = tp/(tp+fp)
    recall = tp/(tp+fn)
    return precision, recall

In [None]:
y_true = load_data('y_true')
y_pred = load_data('y_pred')
y_pred_score = load_data('y_pred_score')

In [None]:
label_true = [0 for i in range(len(CLASS_LABELS))]
for i in y_true:
    label_true[i] += 1 

In [None]:
label_pred = [0 for i in range(len(CLASS_LABELS))]
for i in y_pred:
    label_pred[i] += 1 

In [None]:
y_true=np.reshape(104,1)
y_pred=np.reshape(104,1)
conf_mat = get_confusion_matrix(y_true, y_pred)
np.sum(conf_mat, axis=1)

In [None]:
conf_mat.astype('float') / conf_mat.sum(axis=1)[:, np.newaxis]

In [None]:
tp, fp, fn, tn = get_error_rates(conf_mat)
print('TP: ', tp)
print('FP: ', fp)
print('FN: ', fn)
print('TN: ', tn)
precision, recall = get_precision_recall(tp, fp, fn)
print('\nPrecision: ', precision)
print('Recall: ', recall)