In [None]:
import os
from glob import glob
from ast import literal_eval

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from IPython.core.display import display, HTML

from sklearn.metrics import roc_auc_score

In [None]:

pd.set_option('display.max_rows', 5000)
pd.set_option('display.max_columns', 500)
pd.set_option('display.width', 1000)

plt.style.use('dark_background')

display(HTML("<style>"
    + "#notebook { padding-top:0px; } " ""
    + ".container { width:100%; } "
    + ".end_space { min-height:0px; } "
    + "</style>"))

In [None]:
logs_path = 'C:\\Users\\emace\\AImageLab\\SRV-Continual\\results\\canomaly\\logs'
logs_ext = '.pyd'

In [None]:
def delete_lines(content: str, filenames: list[str]):
    for filename in filenames:
        print(f'In file {filename}:')
        with open(filename, 'r') as f:
            lines = f.readlines()
        end_lines = []
        for i, line in enumerate(lines):
            if content in line:
                print(f'Deleted line {i}: {line}')
            else:
                end_lines.append(line)
        if len(lines) == len(end_lines):
            print('No line to delete')
        else:
            with open(filename, 'w') as f:
                f.writelines(end_lines)

In [None]:
# delete_lines('443c988d-d031-4168-951f-e60ac3df1a76', glob(logs_path + '/*-fmnist*/*-2*/*' + logs_ext, recursive=True))
# delete_lines('443c988d-d031-4168-951f-e60ac3df1a76', glob(logs_path + '/*-mnist*/*-2*/*' + logs_ext, recursive=True))

In [None]:
exp_dict = {}
exp_list = []
for log_file in glob(logs_path + '/**/*' + logs_ext, recursive=True):
    print(log_file)
    with open(log_file, 'r') as f:
        exps = []
        for i, line in enumerate(f.readlines()):
            try:
                exps.append(literal_eval(line.replace('nan', 'None')))
            except:
                print(f'Unparsed line {i}:\n\t{exps[:-1]}\n-->\t{line}')
        exp_list.extend(exps)
        exps = {exp['id']: exp for exp in exps}
        exp_dict = {**exp_dict, **exps}
        # literal_eval(f.readline().replace('nan', 'None'))

In [None]:
def print_exp_info(exp: dict):
    print({k: exp[k] for k in exp if k not in ['logs', 'results', 'knowledge']})

In [None]:
# usage example:  show_exp_images(experiments[0], True)
def show_exp_images(exp: dict, show_origins=False):
    for task in exp['results']:
        cur_images = exp['results'][task]['images']
        fig, axs = plt.subplots(2, 5, figsize=(15, 8))
        fig.suptitle(f'TASK {task} {exp["knowledge"][task]}', fontsize=30)
        for r, row in enumerate(axs):
            for c, cell in enumerate(row):
                idx = r*5 + c
                image = np.zeros((28, 28, 3), dtype=float)
                cell.set_title(cur_images[idx]['label'])
                orig = np.array(cur_images[idx]['original'][0])
                recon = np.array(cur_images[idx]['reconstruction'][0]).clip(0, 1)
                if show_origins:
                    image[:,:,1] = orig
                image[:,:,0] = recon
                image[:,:,2] = recon
                cell.imshow(image)
        plt.show()

In [None]:
def compute_weighted_auc(anomalies: np.array, scores: np.array):
    n_anomalies = anomalies.sum().item()
    n_normals = len(anomalies) - n_anomalies
    weights = np.zeros_like(scores)
    weights[anomalies == 0] = n_anomalies/len(anomalies)
    weights[anomalies == 1] = n_normals/len(anomalies)
    return roc_auc_score(anomalies, scores, sample_weight=weights)


def compute_all_aucs(anomalies, scores):
    total_auc = roc_auc_score(anomalies, scores)

    n_anomalies = anomalies.sum().item()
    n_normals = len(anomalies) - n_anomalies
    weighted_auc = compute_weighted_auc(anomalies, scores)

    min_label = 1 if n_anomalies < n_normals else 0
    max_label = 1 - min_label
    n_per_class = n_anomalies if n_anomalies < n_normals else n_normals
    idxs_norm = np.where(anomalies==min_label)[0]
    idxs_anom = np.random.choice(np.where(anomalies==max_label)[0], size=n_per_class, replace=False)
    idxs = np.concatenate((idxs_norm, idxs_anom))
    balanced_auc = roc_auc_score(anomalies[idxs], scores[idxs])

    return total_auc, weighted_auc, balanced_auc

In [None]:
def compute_exp_metrics(exp: dict, per_task=True):
    knowledge = []
    metrics = pd.DataFrame(index=exp['results'], columns=[str(labels) for labels in exp['knowledge'].values()] + ['total'], dtype='float')
    for t, task in enumerate(exp['results']):
        knowledge.extend(exp['knowledge'][task])
        targets = np.array(exp['results'][task]['targets'])
        scores = np.array(exp['results'][task]['rec_errs'])
        anomalies = (~np.isin(targets, knowledge)).astype(int)
        auc = compute_weighted_auc(anomalies, scores)
        metrics.loc[task, 'total'] = auc
        # print(f'task {task}: {auc}')

        if t > 0 and per_task:
            for in_t, in_task in zip(range(t+1), exp['results']):
                np_knowledge = np.array(knowledge)
                excluded_labels = np_knowledge[~np.isin(np_knowledge, exp['knowledge'][in_task])].tolist()
                mask = ~np.isin(targets, excluded_labels)
                in_targets = targets[mask]
                in_scores = scores[mask]
                in_anomalies = (~np.isin(in_targets, knowledge)).astype(int)
                in_auc = compute_weighted_auc(in_anomalies, in_scores)
                metrics.loc[task, str(exp['knowledge'][in_task])] = in_auc
                # print(f'  t{in_task} vs all: {in_auc}')
        else:
            metrics.loc[task, str(exp['knowledge'][task])] = auc

    final_auc = metrics.loc[task, 'total']
    average_auc = metrics.loc[:, "total"].mean()
    # print(f'final {final_auc} average {average_auc}')
    return final_auc, average_auc, metrics

def show_aucs_per_task(task_aucs: pd.DataFrame):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    sns.heatmap(task_aucs, annot=True, ax=ax, cmap='Reds', cbar=False)
    plt.ylabel('Task')
    plt.xlabel('Class')
    plt.title('Auc per task and class')


In [None]:
def reconstruction_confusion_matrix(exp: dict):
    labels = np.unique(np.array(next(iter(exp['results'].values()))['targets'])).tolist()
    indexes = [key + str(exp['knowledge'][key]) for key in exp['knowledge']]
    matrix = pd.DataFrame(index=indexes,
                          columns=labels, dtype='float')

    for idx, task in zip(indexes, exp['results']):
        scores = np.array(exp['results'][task]['rec_errs'])
        targets = np.array(exp['results'][task]['targets'])
        for label in labels:
            matrix.loc[idx, label] = scores[targets == label].mean()

    return matrix

def show_conf_matrix(cmatrix: pd.DataFrame):
    fig,ax = plt.subplots(figsize=(10,10))
    sns.heatmap(data=cmatrix, ax=ax,annot=True, cbar=False, cmap='Reds')
    plt.ylabel('Task')
    plt.xlabel('Class')
    plt.title('Reconstruction error per task and class')

In [None]:
## print metrics of experiments
def exp_disclosure(exp: dict, info=True, images=False, origins=False, aucs=False, cmatrix=False):
    if info:
        print_exp_info(exp)

    if images:
        show_exp_images(exp, origins)

    if aucs:
        final_auc, average_auc, task_aucs =  compute_exp_metrics(exp)
        print(f'final {final_auc} average {average_auc}')
        show_aucs_per_task(task_aucs)

    if cmatrix:
        cmatrix = reconstruction_confusion_matrix(exp)
        show_conf_matrix(cmatrix)

In [None]:
exp_disclosure(exp_dict['d7d4d5a0-c112-4c91-8b5d-6b04d396aeea'], images=True, origins=False, cmatrix=False)
exp_disclosure(exp_dict['ce93a2f9-78f2-4586-a9cc-4263c79a803f'], images=True, origins=False, cmatrix=False)


In [None]:
# AE: joint impara un po tutto
# exp_disclosure(exp_dict['a8b7df79-7031-4f95-ba9c-aaae316ced5e'], images=True)
exp_disclosure(exp_dict['3dfef3e0-f68c-4194-9caa-ded9be3005e8'], images=True, origins=False, cmatrix=False)
exp_disclosure(exp_dict['e6ba6f77-0d03-48bd-a342-5a2f9f54a240'], images=True, origins=False, cmatrix=False)

In [None]:
exp_disclosure(exp_dict['6ca895ec-4878-4a60-b125-3264d816ef2a'], images=True)
exp_disclosure(exp_dict['e3c2cf61-59ca-4df8-b1d5-204a80579857'], images=True)
