In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data

import albumentations as A
import cv2
import numpy as np
import skimage as ski

import matplotlib.pyplot as plt
import os
import copy

from tqdm import tqdm
from IPython.display import clear_output

import psutil
import pynvml
import sys

import struct

import sklearn.metrics as metrics

import gc

sys.path.append('/home/meribejayson/Desktop/Projects/SharkCNN/training_models/dataloaders/')


from test_dataset import SharkDatasetTest as SharkDataset

In [2]:
output_file_path = 'preds_labels.dat'

pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)

image_width = 1920
image_height = 1080

target_iters = 300
images_per_iter = 5

target_sample = image_height * image_width * target_iters * images_per_iter

In [3]:
def read_from_binary_file(file_path):
    dt = np.dtype([('ann_pred', np.float32), ('lr_pred', np.float32), ('label', np.uint32)])
    
    record_size = dt.itemsize

    total_records = os.path.getsize(file_path) // record_size

    quarter_records = total_records // 8

    data = np.memmap(file_path, dtype=dt, mode='r', shape=(quarter_records,))
    
    ann_preds = data['ann_pred']
    lr_preds = data['lr_pred']
    labels = data['label'].astype(int)
    
    return ann_preds, lr_preds, labels

In [4]:
preds_ann, preds_lr, all_labels = read_from_binary_file(output_file_path)

In [5]:
precision_ann, recall_ann, pr_thresh_ann = metrics.precision_recall_curve(all_labels, preds_ann)

In [None]:
pr_plot_ann = metrics.PrecisionRecallDisplay(precision=precision_ann, recall=recall_ann)
pr_plot_ann.plot()
plt.title('Precision-Recall ANN')
plt.show()
plt.clf()

In [None]:
f1_score_ann = (2 * precision_ann * recall_ann) / (precision_ann + recall_ann)

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(pr_thresh_ann, f1_score_ann[:-1], color='blue')

plt.xlabel('Confidence')
plt.ylabel('F1-Score')
plt.title('F1-Score ANN')
plt.show()

plt.clf()

In [None]:
del f1_score_ann

gc.collect()

In [6]:
confs = [0.5, 0.55, 0.60, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]

map_50_ann = 0
map_50_95_ann = 0

map_50_ann = np.mean(precision_ann[:-1][pr_thresh_ann > 0.5])

for conf in confs:
    map_50_95_ann += (np.mean(precision_ann[:-1][pr_thresh_ann > conf]) / len(confs))

print("ANN:")
print(f"MAP50: {map_50_ann} MAP50-95: {map_50_95_ann}\n")

ANN:
MAP50: 0.029351181581156672 MAP50-95: 0.022157586795707437



In [None]:
del precision_ann, recall_ann, pr_thresh_ann

gc.collect()

In [None]:
fpr_ann, tpr_ann, tp_thresh_ann = metrics.roc_curve(all_labels, preds_ann)

In [None]:
roc_plot_ann = metrics.RocCurveDisplay(fpr=fpr_ann, tpr=tpr_ann)
roc_plot_ann.plot()
plt.title('ROC ANN')
plt.show()
plt.clf()

In [None]:
roc_auc_score_ann = metrics.roc_auc_score(all_labels, preds_ann)

print(f"The area under the ANN ROC curve is {roc_auc_score_ann}")

In [None]:
del preds_ann
del fpr_ann, tpr_ann, tp_thresh_ann

gc.collect()

In [None]:
precision_lr, recall_lr, pr_thresh_lr = metrics.precision_recall_curve(all_labels, preds_lr)

In [None]:
pr_plot_lr = metrics.PrecisionRecallDisplay(precision=precision_lr, recall=recall_lr)
pr_plot_lr.plot()
plt.title('Precision-Recall LR')
plt.show()
plt.clf()

In [None]:
f1_score_lr = (2 * precision_lr * recall_lr) / (precision_lr + recall_lr)

In [None]:
plt.figure(figsize=(12, 8))
plt.plot(pr_thresh_lr, f1_score_lr[:-1], color='blue')


plt.xlabel('Confidence')
plt.ylabel('F1-Score')
plt.title('F1-Score LR')
plt.show()

plt.clf()

In [None]:
del f1_score_lr

gc.collect()

In [None]:
confs = [0.5, 0.55, 0.60, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95]

map_50_lr = 0
map_50_95_lr = 0

map_50_lr = np.mean(precision_lr[:-1][pr_thresh_lr > 0.5])

for conf in confs:
    map_50_95_lr += (np.mean(precision_lr[:-1][pr_thresh_lr > conf]) / len(confs))

print("LR:")
print(f"MAP50: {map_50_lr} MAP50-95: {map_50_95_lr}\n")

In [None]:
del precision_lr, recall_lr, pr_thresh_lr

gc.collect()

In [None]:
fpr_lr, tpr_lr, tp_thresh_lr = metrics.roc_curve(all_labels, preds_lr)

In [None]:
roc_plot_lr = metrics.RocCurveDisplay(fpr=fpr_lr, tpr=tpr_lr)
roc_plot_lr.plot()
plt.title('ROC LR')
plt.show()
plt.clf()

In [None]:
roc_auc_score_lr = metrics.roc_auc_score(all_labels, preds_lr)

print(f"The area under the LR ROC curve is {roc_auc_score_lr}")