In [1]:
%load_ext lab_black
import argparse
import contextlib
import datetime
import io
import logging
import multiprocessing
import os
import random
import sys
from itertools import chain, combinations
from timeit import default_timer as timer

import altair as alt
import altair_viewer
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import peewee
from evolutionary_search import EvolutionaryAlgorithmSearchCV
from json_tricks import dumps, loads
from playhouse.shortcuts import model_to_dict
from scipy.stats import randint, uniform
from sklearn.datasets import load_iris
from tabulate import tabulate
from IPython.core.display import display, HTML

from active_learning.cluster_strategies import (
    DummyClusterStrategy,
    MostUncertainClusterStrategy,
    RandomClusterStrategy,
    RoundRobinClusterStrategy,
)
from active_learning.dataStorage import DataStorage
from active_learning.experiment_setup_lib import (
    ExperimentResult,
    classification_report_and_confusion_matrix,
    get_db,
    get_single_al_run_stats_row,
    get_single_al_run_stats_table_header,
    load_and_prepare_X_and_Y,
)
from active_learning.sampling_strategies import (
    BoundaryPairSampler,
    CommitteeSampler,
    RandomSampler,
    UncertaintySampler,
)

alt.renderers.enable("altair_viewer")
#  alt.renderers.enable('vegascope')

config = {
    "datasets_path": "../datasets",
    "db": "tunnel",
    "param_list_id": "best_global_score",
}

db = get_db(db_name_or_type=config["db"])



In [2]:
# select count(*), dataset_name from experimentresult group by dataset_name;
results = ExperimentResult.select(
    ExperimentResult.dataset_name,
    peewee.fn.COUNT(ExperimentResult.id_field).alias("dataset_name_count"),
).group_by(ExperimentResult.dataset_name)

for result in results:
    print("{:>4,d} {}".format(result.dataset_name_count, result.dataset_name))

4,663 orange
4,642 sylva
4,669 hiva
4,680 ibn_sina
4,683 dwtc
4,626 zebra


In [3]:
#  SELECT param_list_id, avg(fit_score), stddev(fit_score), avg(global_score), stddev(global_score), avg(start_set_size) as sss, count(*) FROM experimentresult WHERE start_set_size = 1 GROUP BY param_list_id ORDER BY 7 DESC, 4 DESC LIMIT 30;
from datetime import datetime, timedelta

results = (
    ExperimentResult.select(
        ExperimentResult.param_list_id,
        ExperimentResult.fit_score,
        ExperimentResult.global_score_no_weak_acc,
        ExperimentResult.amount_of_user_asked_queries,
        ExperimentResult.classifier,
        ExperimentResult.global_score_no_weak_acc,
        ExperimentResult.amount_of_user_asked_queries,
        ExperimentResult.dataset_name,
        ExperimentResult.test_fraction,
        ExperimentResult.sampling,
        ExperimentResult.cluster,
        ExperimentResult.nr_queries_per_iteration,
        ExperimentResult.with_uncertainty_recommendation,
        ExperimentResult.with_cluster_recommendation,
        ExperimentResult.uncertainty_recommendation_certainty_threshold,
        ExperimentResult.uncertainty_recommendation_ratio,
        ExperimentResult.cluster_recommendation_minimum_cluster_unity_size,
        ExperimentResult.cluster_recommendation_ratio_labeled_unlabeled,
        ExperimentResult.allow_recommendations_after_stop,
        ExperimentResult.stopping_criteria_uncertainty,
        ExperimentResult.stopping_criteria_acc,
        ExperimentResult.stopping_criteria_std,
        ExperimentResult.experiment_run_date,
    )
    .where(
        (ExperimentResult.amount_of_user_asked_queries < 500)
        & (ExperimentResult.dataset_name == "dwtc")
        & (
            ExperimentResult.experiment_run_date > (datetime(2020, 3, 24, 14, 0))
        )  # no stopping criterias
    )
    .order_by(
        # ExperimentResult.id_field.desc(),
        ExperimentResult.global_score_no_weak_acc.desc(),
    )
    .limit(20)
)

# INTERESSANT: selbst wenn es keine Einschr√§nkung bei der Berechnung auf weak/no_weak gibt werden Cluster verwendet!


table = []
id = 0
for result in results:
    data = {**{"id": id}, **vars(result)["__data__"]}
    # data["param_list_id"] = data["__data__"]["param_list_id"]
    # del data["__data__"]
    # del data["_dirty"]
    # del data["__rel__"]

    # data = {**data, **vars(one_param_list_id_result)["__data__"]}

    table.append(data)
    id += 1

display(HTML(tabulate(table, headers="keys", tablefmt="html")))

id,param_list_id,fit_score,global_score_no_weak_acc,amount_of_user_asked_queries,classifier,dataset_name,test_fraction,sampling,cluster,nr_queries_per_iteration,with_uncertainty_recommendation,with_cluster_recommendation,uncertainty_recommendation_certainty_threshold,uncertainty_recommendation_ratio,cluster_recommendation_minimum_cluster_unity_size,cluster_recommendation_ratio_labeled_unlabeled,allow_recommendations_after_stop,stopping_criteria_uncertainty,stopping_criteria_acc,stopping_criteria_std,experiment_run_date
0,d153e81ce1ee57d35246ec2016def2ea,0.832164,0.678943,440,RF,dwtc,0.5,uncertainty_max_margin,dummy,10,True,False,0.7,0.001,0.6,0.61,True,1,1,1,2020-04-07 06:46:42.691024
1,b2c3481495339dd5680f25471ee9d033,0.831608,0.633469,361,RF,dwtc,0.5,uncertainty_max_margin,MostUncertain_max_margin,10,True,False,0.66,0.1,0.54,0.63,False,1,1,1,2020-04-09 20:44:41.870104
2,1aee853c5404ebd688bc2dff3feb04b5,0.814866,0.629504,430,RF,dwtc,0.5,uncertainty_lc,MostUncertain_max_margin,10,True,True,0.589256,0.1,0.564668,0.749682,False,1,1,1,2020-03-30 04:42:43.726968
3,78f6cb6f5fa82a7398956d1ea3195729,0.814186,0.623357,471,RF,dwtc,0.5,uncertainty_lc,MostUncertain_max_margin,10,True,False,0.65,0.01,0.76,0.68,True,1,1,1,2020-03-30 16:42:28.328699
4,a9a838b75896624a8abeb5b762918e77,0.818637,0.617734,424,RF,dwtc,0.5,uncertainty_max_margin,MostUncertain_max_margin,10,True,False,0.52,0.1,0.9,1.0,False,1,1,1,2020-04-06 05:45:48.001882
5,dd8703be1d25dfd26ffc4d3488080f7c,0.816888,0.617545,422,RF,dwtc,0.5,uncertainty_lc,dummy,10,True,False,0.79,0.0001,0.73,0.63,False,1,1,1,2020-04-07 23:20:20.636977
6,cc82393f124dc3c02fa6d6736952b969,0.838352,0.610387,320,RF,dwtc,0.5,uncertainty_max_margin,dummy,10,True,False,0.62,0.0001,0.96,0.51,True,1,1,1,2020-04-03 19:40:24.601820
7,1553cfb328fa74bc19244bb75fa486b4,0.807553,0.609532,450,RF,dwtc,0.5,uncertainty_max_margin,MostUncertain_lc,10,True,True,0.626507,0.0001,0.942284,0.506822,True,1,1,1,2020-03-29 19:43:48.450170
8,a98696e2d5df2c0162b98e9d5d32bc28,0.827349,0.606718,292,RF,dwtc,0.5,uncertainty_max_margin,dummy,10,True,False,0.7,0.01,0.94,0.53,True,1,1,1,2020-04-05 07:00:05.275220
9,ee1c1507cb5ad51ce3a7359b2a6340e9,0.812137,0.596607,279,RF,dwtc,0.5,uncertainty_lc,dummy,10,True,False,0.58,0.01,0.78,0.81,False,1,1,1,2020-04-05 17:18:33.392355


In [4]:
# SELECT id_field, param_list_id, dataset_path, start_set_size as sss, sampling, cluster, allow_recommendations_after_stop as SA, stopping_criteria_uncertainty as SCU, stopping_criteria_std as SCS, stopping_criteria_acc as SCA, amount_of_user_asked_queries as "#q", acc_test, fit_score, global_score_norm, thread_id, end_time from experimentresult where param_list_id='31858014d685a3f1ba3e4e32690ddfc3' order by end_time, fit_score desc, param_list_id;
loaded_data = {}


def pre_fetch_data(top_n=0):
    best_param_list_id = table[top_n]["param_list_id"]

    results = (
        ExperimentResult.select()
        .where(ExperimentResult.param_list_id == best_param_list_id)
        .order_by(ExperimentResult.dataset_name)
    )

    loaded_data[top_n] = []
    for result in results:
        loaded_data[top_n].append(result)
    print("Loaded Top " + str(top_n) + " data")


pre_fetch_data(0)

Loaded Top 0 data


In [149]:
def visualise_top_n(top_n=0):
    charts = []

    alt.renderers.enable("html")

    for result in loaded_data[top_n][:]:
        metrics = loads(result.metrics_per_al_cycle)
        test_data_metrics = [
            metrics["test_data_metrics"][0][f][0]["weighted avg"]
            for f in range(0, len(metrics["test_data_metrics"][0]))
        ]
        test_acc = [
            metrics["test_data_metrics"][0][f][0]["accuracy"]
            for f in range(0, len(metrics["test_data_metrics"][0]))
        ]

        data = pd.DataFrame(
            {
                "iteration": range(0, len(metrics["all_unlabeled_roc_auc_scores"])),
                "all_unlabeled_roc_auc_scores": metrics["all_unlabeled_roc_auc_scores"],
                "query_length": metrics["query_length"],
                "recommendation": metrics["recommendation"],
                "query_strong_accuracy_list": metrics["query_strong_accuracy_list"],
                "f1": [i["f1-score"] for i in test_data_metrics],
                "test_acc": test_acc,
                #'asked_queries': [sum(metrics['query_length'][:i]) for i in range(0, len(metrics['query_length']))],
            }
        )

        # bar width
        data["asked_queries"] = data["query_length"].cumsum()
        data["asked_queries_end"] = data["asked_queries"].shift(fill_value=0)

        # print(data[['asked_queries', 'query_length']])

        data["recommendation"] = data["recommendation"].replace(
            {
                "A": "Oracle",
                "C": "Weak Cluster",
                "U": "Weak Certainty",
                "G": "Ground Truth",
            }
        )

        # data = data[:100]

        # calculate global score OHNE

        chart = (
            alt.Chart(data)
            .mark_rect(
                # point=True,
                # line=True,
                # interpolate='step-after',
            )
            .encode(
                x=alt.X("asked_queries_end", title="#asked queries (weak and oracle)"),
                x2="asked_queries",
                color=alt.Color("recommendation", scale=alt.Scale(scheme="tableau10")),
                tooltip=[
                    "iteration",
                    "f1",
                    "test_acc",
                    "all_unlabeled_roc_auc_scores",
                    "query_strong_accuracy_list",
                    "query_length",
                    "recommendation",
                ],
                # scale=alt.Scale(domain=[0,1])
            )
            .properties(title=result.dataset_name)
            .interactive()
        )
        charts.append(
            alt.hconcat(
                chart.encode(
                    alt.Y(
                        "all_unlabeled_roc_auc_scores", scale=alt.Scale(domain=[0, 1])
                    )
                ).properties(title=result.dataset_name + ": roc_auc"),
                # chart.encode(alt.Y('f1', scale=alt.Scale(domain=[0,1]))).properties(title=result.dataset_name + ': f1'),
                chart.encode(
                    alt.Y("test_acc", scale=alt.Scale(domain=[0, 1]))
                ).properties(title=result.dataset_name + ": test_acc"),
            )
        )

    return alt.vconcat(*charts).configure()


visualise_top_n(0)

In [6]:
pre_fetch_data(1)
visualise_top_n(1)

Loaded Top 1 data


In [7]:
pre_fetch_data(2)
visualise_top_n(2)

Loaded Top 2 data


In [8]:
pre_fetch_data(3)
visualise_top_n(3)

Loaded Top 3 data


In [9]:
pre_fetch_data(4)
visualise_top_n(4)

Loaded Top 4 data


In [10]:
pre_fetch_data(5)
visualise_top_n(5)

Loaded Top 5 data


In [157]:
def compare_data(datasets):
    charts = []

    alt.renderers.enable("html")
    all_data = pd.DataFrame()

    for i, dataset in enumerate(datasets):
        for result in dataset:
            if result.dataset_name != "dwtc":
                continue
            metrics = loads(result.metrics_per_al_cycle)
            test_data_metrics = [
                metrics["test_data_metrics"][0][f][0]["weighted avg"]
                for f in range(0, len(metrics["test_data_metrics"][0]))
            ]
            test_acc = [
                metrics["test_data_metrics"][0][f][0]["accuracy"]
                for f in range(0, len(metrics["test_data_metrics"][0]))
            ]

            data = pd.DataFrame(
                {
                    "iteration": range(0, len(metrics["all_unlabeled_roc_auc_scores"])),
                    "all_unlabeled_roc_auc_scores": metrics[
                        "all_unlabeled_roc_auc_scores"
                    ],
                    "query_length": metrics["query_length"],
                    "recommendation": metrics["recommendation"],
                    "query_strong_accuracy_list": metrics["query_strong_accuracy_list"],
                    "f1": [i["f1-score"] for i in test_data_metrics],
                    "test_acc": test_acc,
                    "top_n": str(i),
                    #'asked_queries': [sum(metrics['query_length'][:i]) for i in range(0, len(metrics['query_length']))],
                }
            )

            # bar width
            data["asked_queries"] = data["query_length"].cumsum()
            data["asked_queries_end"] = data["asked_queries"].shift(fill_value=0)

            # print(data[['asked_queries', 'query_length']])

            data["recommendation"] = data["recommendation"].replace(
                {
                    "A": "Oracle",
                    "C": "Weak Cluster",
                    "U": "Weak Certainty",
                    "G": "Ground Truth",
                }
            )

            all_data = pd.concat([all_data, data])

        # data = data[:100]

        # calculate global score OHNE

    return (
        alt.Chart(all_data,)
        .mark_line(
            point=True, interpolate="step-before"
        )  # , strokeCap="butt", strokeDash=[5])
        .encode(
            x=alt.X("asked_queries", scale=alt.Scale(type="linear")),
            y=alt.Y("test_acc", scale=alt.Scale(domain=[0.3, 0.82], type="linear")),
            color="top_n",
            # shape="top_n",
            # stroke="recommendation",
            shape="recommendation",
            # color="recommendation",
        )
        .properties(title=result.dataset_name)
        .interactive()
    )


compare_data(
    [
        loaded_data[0],
        loaded_data[1],
        # loaded_data[2],
    ]  # , loaded_data[3], loaded_data[4]]
)