In [None]:
import copy
import math

import matplotlib.pyplot as plt
import numpy as np

import cogni_scan.src.dbutil as dbutil
import cogni_scan.src.modeler.model as model

dbutil.SimpleSQL.setDatabaseName("scans")
models_per_ds = {}  # Maps a dataset id to the models that are using it.
for m in model.getModels():
    dsid = m.getDatasetID()[:8]
    if dsid not in models_per_ds:
        models_per_ds[dsid] = [m]
    else:
        models_per_ds[dsid].append(m)


class PredictionsSummary:
    def __init__(self):
        self._healthy_correct = {}
        self._healthy_wrong = {}
        self._demented_correct = {}
        self._demented_wrong = {}

    def getDemented(self):
        scans = {}
        for scan_id, count in self._demented_correct.items():
            if scan_id not in scans:
                scans[scan_id] = [0, 0]
            scans[scan_id][0] += count
        for scan_id, count in self._demented_wrong.items():
            if scan_id not in scans:
                scans[scan_id] = [0, 0]
            scans[scan_id][1] += count

        x = []
        for scan_id, (correct, wrong) in scans.items():
            x.append((scan_id, (correct, wrong)))
        return sorted(x, key=lambda y: y[1][0], reverse=True)

    def getHealthy(self):
        scans = {}
        for scan_id, count in self._healthy_correct.items():
            if scan_id not in scans:
                scans[scan_id] = [0, 0]
            scans[scan_id][0] += count
        for scan_id, count in self._healthy_wrong.items():
            if scan_id not in scans:
                scans[scan_id] = [0, 0]
            scans[scan_id][1] += count

        x = []
        for scan_id, (correct, wrong) in scans.items():
            x.append((scan_id, (correct, wrong)))
        return sorted(x, key=lambda y: y[1][0], reverse=True)

    def add(self, scan_info):
        pred = scan_info['pred']
        label = scan_info['label']
        scan_id = scan_info['scan_id']
        if label == 'HH':
            if pred <= 0.5:
                if scan_id not in self._healthy_correct:
                    self._healthy_correct[scan_id] = 0
                self._healthy_correct[scan_id] += 1
            else:
                if scan_id not in self._healthy_wrong:
                    self._healthy_wrong[scan_id] = 0
                self._healthy_wrong[scan_id] += 1
        elif label == 'HD':
            if pred >= 0.5:
                if scan_id not in self._demented_correct:
                    self._demented_correct[scan_id] = 0
                self._demented_correct[scan_id] += 1
            else:
                if scan_id not in self._demented_wrong:
                    self._demented_wrong[scan_id] = 0
                self._demented_wrong[scan_id] += 1


def classifyPredictions(ps, m, cutoff=0.5):
    for scan_info in m.getTestingPredictions():
        ps.add(scan_info)


for dsid, models in models_per_ds.items():
    print("----------------------")
    print(dsid)
    ps = PredictionsSummary()
    for m in models:
        classifyPredictions(ps, m)
    print(ps.getDemented())
    print("----------------------")
    print(ps.getHealthy())
    break


def plotCorrectWrong(x):
    labels = [y[0] for y in x]
    correct = [y[1][0] for y in x]
    wrong = [y[1][1] for y in x]

    counter = {
        'Correct': np.array(correct),
        'Wrong': np.array(wrong),
    }
    print(counter)

    width = 0.4
    fig, ax = plt.subplots(figsize=(14, 12))
    bottom = np.zeros(len(correct))
    for key, count in counter.items():
        p = ax.bar(labels, count, width, label=key, bottom=bottom)
        bottom += count
        ax.bar_label(p, label_type='center')
    ax.set_title('Correct/Wrong Predictions in Testing Dataset')
    ax.legend()


x = [(4363, (7, 0)), (2230, (7, 0)), (2145, (7, 0)), (3031, (7, 0)),
     (5773, (7, 0)), (4258, (7, 0)), (1374, (6, 1)), (14, (6, 1)),
     (2564, (5, 2)), (6360, (5, 2)), (5120, (5, 2)), (5271, (4, 3)),
     (934, (3, 4)), (5982, (2, 5)), (2047, (2, 5)), (1260, (1, 6)),
     (5868, (1, 6)), (4400, (1, 6)), (6961, (1, 6)), (1741, (0, 7)),
     (4898, (0, 7))]
plotCorrectWrong(x)
