In [1]:
%load_ext lab_black
import argparse
import contextlib
import datetime
import io
import logging
import multiprocessing
import os
import random
import sys
from itertools import chain, combinations
from timeit import default_timer as timer

import altair as alt
import altair_viewer
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import peewee
from evolutionary_search import EvolutionaryAlgorithmSearchCV
from json_tricks import dumps, loads
from playhouse.shortcuts import model_to_dict
from scipy.stats import randint, uniform
from sklearn.datasets import load_iris
from tabulate import tabulate
from IPython.core.display import display, HTML

from active_learning.cluster_strategies import (
    DummyClusterStrategy,
    MostUncertainClusterStrategy,
    RandomClusterStrategy,
    RoundRobinClusterStrategy,
)
from active_learning.dataStorage import DataStorage
from active_learning.experiment_setup_lib import (
    ExperimentResult,
    classification_report_and_confusion_matrix,
    get_db,
    get_single_al_run_stats_row,
    get_single_al_run_stats_table_header,
)
from active_learning.sampling_strategies import (
    BoundaryPairSampler,
    CommitteeSampler,
    RandomSampler,
    UncertaintySampler,
)

alt.renderers.enable("altair_viewer")
#  alt.renderers.enable('vegascope')

config = {
    "datasets_path": "../datasets",
    "db": "tunnel",
    "param_list_id": "best_global_score",
}

db = get_db(db_name_or_type=config["db"])



In [2]:
# select count(*), dataset_name from experimentresult group by dataset_name;
results = ExperimentResult.select(
    ExperimentResult.dataset_name,
    peewee.fn.COUNT(ExperimentResult.id_field).alias("dataset_name_count"),
).group_by(ExperimentResult.dataset_name)

for result in results:
    print("{:>4,d} {}".format(result.dataset_name_count, result.dataset_name))

  56 dwtc


In [3]:
# SELECT id_field, param_list_id, dataset_path, start_set_size as sss, sampling, cluster, allow_recommendations_after_stop as SA, stopping_criteria_uncertainty as SCU, stopping_criteria_std as SCS, stopping_criteria_acc as SCA, amount_of_user_asked_queries as "#q", acc_test, fit_score, global_score_norm, thread_id, end_time from experimentresult where param_list_id='31858014d685a3f1ba3e4e32690ddfc3' order by end_time, fit_score desc, param_list_id;
loaded_data = {}


def pre_fetch_data(top_n=0):
    best_param_list_id = table[top_n]["param_list_id"]

    results = (
        ExperimentResult.select()
        .where(ExperimentResult.param_list_id == best_param_list_id)
        .order_by(ExperimentResult.dataset_name)
    )

    loaded_data[top_n] = []
    for result in results:
        loaded_data[top_n].append(result)
    print("Loaded Top " + str(top_n) + " data")


# pre_fetch_data(0)

In [4]:
def better_results_top(top_n, budget, weak_clust, weak_cert):
    # select best result for budget of 1500 without WS
    results = (
        ExperimentResult.select(
            ExperimentResult.param_list_id,
            peewee.fn.AVG(ExperimentResult.acc_test).alias("avg_acc_test"),
            peewee.fn.AVG(ExperimentResult.fit_score).alias("avg_fit_score"),
            peewee.fn.STDDEV(ExperimentResult.fit_score).alias("stddev_fit_score"),
            peewee.fn.AVG(ExperimentResult.global_score_no_weak_acc).alias(
                "avg_global_score"
            ),
            peewee.fn.STDDEV(ExperimentResult.global_score_no_weak_acc).alias(
                "stddev_global_score"
            ),
            peewee.fn.AVG(ExperimentResult.amount_of_user_asked_queries).alias(
                "avg_amount_oracle"
            ),
            peewee.fn.STDDEV(ExperimentResult.amount_of_user_asked_queries).alias(
                "std_amount_oracle"
            ),
            peewee.fn.COUNT(ExperimentResult.param_list_id).alias("count"),
        )
        .where(
            (ExperimentResult.amount_of_user_asked_queries < budget)
            & (ExperimentResult.dataset_name == "dwtc")
            # & (ExperimentResult.experiment_run_date > (datetime(2020, 3, 24, 14, 0)))
            # & (ExperimentResult.experiment_run_date > (datetime(2020, 5, 8, 9, 20)))
            & (ExperimentResult.with_cluster_recommendation == weak_clust)
            & (ExperimentResult.with_uncertainty_recommendation == weak_cert)
            # & (peewee.fn.COUNT(ExperimentResult.id_field) == 3)
            # no stopping criterias
        )
        .group_by(ExperimentResult.param_list_id)
        .order_by(
            peewee.fn.COUNT(ExperimentResult.id_field).desc(),
            peewee.fn.AVG(ExperimentResult.acc_test).desc(),
        )
        .limit(1)
        .offset(top_n)
    )

    table = []
    id = 0
    for result in results:
        data = {**{"id": id}, **vars(result)}
        data["param_list_id"] = data["__data__"]["param_list_id"]
        del data["__data__"]
        del data["_dirty"]
        del data["__rel__"]
        table.append(data)
        id += 1

    display(HTML(tabulate(table, headers="keys", tablefmt="html")))

    best_param_list_id = table[0]["param_list_id"]

    results = (
        ExperimentResult.select()
        .where(ExperimentResult.param_list_id == best_param_list_id)
        .order_by(ExperimentResult.dataset_name)
    )

    loaded_data[0] = []
    for result in results:
        loaded_data[0].append(result)

    for result in loaded_data[0][:]:
        metrics = loads(result.metrics_per_al_cycle)
        test_data_metrics = [
            metrics["test_data_metrics"][0][f][0]["weighted avg"]
            for f in range(0, len(metrics["test_data_metrics"][0]))
        ]
        test_acc = [
            metrics["test_data_metrics"][0][f][0]["accuracy"]
            for f in range(0, len(metrics["test_data_metrics"][0]))
        ]

        data = pd.DataFrame(
            {
                "iteration": range(0, len(metrics["all_unlabeled_roc_auc_scores"])),
                "all_unlabeled_roc_auc_scores": metrics["all_unlabeled_roc_auc_scores"],
                "query_length": metrics["query_length"],
                "recommendation": metrics["recommendation"],
                "query_strong_accuracy_list": metrics["query_strong_accuracy_list"],
                "f1": [i["f1-score"] for i in test_data_metrics],
                "test_acc": test_acc,
                "fit_score": result.fit_score,
                #'asked_queries': [sum(metrics['query_length'][:i]) for i in range(0, len(metrics['query_length']))],
            }
        )
        data["acc_diff"] = data["test_acc"] - data["test_acc"].shift(1)
    display(
        HTML(
            tabulate(
                data.groupby(["recommendation"]).sum(), headers="keys", tablefmt="html"
            )
        )
    )


better_results_top(0, 211, False, False)
better_results_top(0, 211, True, False)
better_results_top(0, 211, False, True)
for i in range(1, 2):
    better_results_top(i, 211, True, True)

id,avg_acc_test,avg_fit_score,stddev_fit_score,avg_global_score,stddev_global_score,avg_amount_oracle,std_amount_oracle,count,param_list_id
0,0.801661,0.85991,,0.615803,,210,,1,d1cae789088256dbb7e26edd0ada3717


recommendation,iteration,all_unlabeled_roc_auc_scores,query_length,query_strong_accuracy_list,f1,test_acc,fit_score,acc_diff
A,231,18.9359,210,0,15.2447,15.4005,18.0581,0.333679
G,0,0.630812,4,0,0.44892,0.467982,0.85991,0.0


id,avg_acc_test,avg_fit_score,stddev_fit_score,avg_global_score,stddev_global_score,avg_amount_oracle,std_amount_oracle,count,param_list_id
0,0.806854,0.862888,,0.583971,,210,,1,47dada3f4267151b75cb4badfa593d0c


recommendation,iteration,all_unlabeled_roc_auc_scores,query_length,query_strong_accuracy_list,f1,test_acc,fit_score,acc_diff
A,231,18.5866,210,0,14.5096,14.883,18.1206,0.320526
G,0,0.754505,4,0,0.49708,0.486327,0.862888,0.0


id,avg_acc_test,avg_fit_score,stddev_fit_score,avg_global_score,stddev_global_score,avg_amount_oracle,std_amount_oracle,count,param_list_id
0,0.808584,0.863876,,0.624376,,210,,1,c05ad3a7c9e554bc547fecb2ad115542


recommendation,iteration,all_unlabeled_roc_auc_scores,query_length,query_strong_accuracy_list,f1,test_acc,fit_score,acc_diff
A,408,18.8112,210,0,15.3839,15.5151,18.1414,0.327103
G,0,0.713996,4,0,0.501203,0.520942,0.863876,0.0
U,970,28.4741,247,31,24.2074,24.2897,26.7802,-0.03946


id,avg_acc_test,avg_fit_score,stddev_fit_score,avg_global_score,stddev_global_score,avg_amount_oracle,std_amount_oracle,count,param_list_id
0,0.792316,0.854504,,0.58973,,210,,1,c2c0090ad2fc0ae613e0dd7160c08c10


recommendation,iteration,all_unlabeled_roc_auc_scores,query_length,query_strong_accuracy_list,f1,test_acc,fit_score,acc_diff
A,231,18.583,210,0,14.8466,14.9823,17.9446,0.334372
G,0,0.725939,4,0,0.449467,0.457944,0.854504,0.0


In [22]:
#  SELECT param_list_id, avg(fit_score), stddev(fit_score), avg(global_score), stddev(global_score), avg(start_set_size) as sss, count(*) FROM experimentresult WHERE start_set_size = 1 GROUP BY param_list_id ORDER BY 7 DESC, 4 DESC LIMIT 30;
from datetime import datetime, timedelta

# & (ExperimentResult.experiment_run_date > (datetime(2020, 3, 24, 14, 0))) # no stopping criterias
#  & (ExperimentResult.experiment_run_date > (datetime(2020, 3, 30, 12, 23))) # optics


results = (
    ExperimentResult.select(ExperimentResult.param_list_id,)
    .where(
        # (ExperimentResult.amount_of_user_asked_queries < 211)
        (ExperimentResult.dataset_name == "dwtc")
        # & (ExperimentResult.experiment_run_date > (datetime(2020, 3, 24, 14, 0)))
        # & (ExperimentResult.experiment_run_date > (datetime(2020, 5, 8, 9, 20)))
        # & (ExperimentResult.with_cluster_recommendation == True)
        # & (ExperimentResult.with_uncertainty_recommendation == True)
        # & (peewee.fn.COUNT(ExperimentResult.id_field) == 3)
        # no stopping criterias
    )
    .order_by(ExperimentResult.acc_test.desc(),)
    .limit(100)
)

table = []
id = 0
for result in results:
    data = {**{"id": id}, **vars(result)}
    data["param_list_id"] = data["__data__"]["param_list_id"]
    del data["__data__"]
    del data["_dirty"]
    del data["__rel__"]

    # get one param_list_id

    one_param_list_id_result = (
        ExperimentResult.select(
            ExperimentResult.acc_test_oracle,
            ExperimentResult.acc_test,
            ExperimentResult.fit_score,
            ExperimentResult.fit_time,
            ExperimentResult.amount_of_all_labels,
            ExperimentResult.amount_of_user_asked_queries,
            # ExperimentResult.classifier,
            # ExperimentResult.test_fraction,
            ExperimentResult.sampling,
            ExperimentResult.cluster,
            # ExperimentResult.nr_queries_per_iteration,
            ExperimentResult.with_uncertainty_recommendation,
            ExperimentResult.with_cluster_recommendation,
            ExperimentResult.uncertainty_recommendation_certainty_threshold,
            ExperimentResult.uncertainty_recommendation_ratio,
            ExperimentResult.cluster_recommendation_minimum_cluster_unity_size,
            ExperimentResult.cluster_recommendation_ratio_labeled_unlabeled,
            ExperimentResult.allow_recommendations_after_stop,
            # ExperimentResult.stopping_criteria_uncertainty,
            # ExperimentResult.stopping_criteria_acc,
            # ExperimentResult.stopping_criteria_std,
            ExperimentResult.experiment_run_date,
        )
        .where(ExperimentResult.param_list_id == data["param_list_id"])
        .limit(1)
    )[0]

    data = {**data, **vars(one_param_list_id_result)["__data__"]}

    table.append(data)
    id += 1

display(HTML(tabulate(table, headers="keys", tablefmt="html")))

id,param_list_id,acc_test_oracle,acc_test,fit_score,fit_time,amount_of_all_labels,amount_of_user_asked_queries,sampling,cluster,with_uncertainty_recommendation,with_cluster_recommendation,uncertainty_recommendation_certainty_threshold,uncertainty_recommendation_ratio,cluster_recommendation_minimum_cluster_unity_size,cluster_recommendation_ratio_labeled_unlabeled,allow_recommendations_after_stop,experiment_run_date
0,53532c1d1a70bcf9496663f9ae18365a,0.802008,0.809969,0.864666,43.0757,214,210,uncertainty_max_margin,dummy,True,False,0.92,0.01,0.79,0.81,True,2020-05-27 19:44:45.771448
1,c05ad3a7c9e554bc547fecb2ad115542,0.824161,0.808584,0.863876,137.119,461,210,uncertainty_max_margin,dummy,True,False,0.86,0.0001,0.76,0.57,True,2020-05-27 19:33:02.785327
2,031c64ec57dba32c6075d0273620c764,0.797508,0.807892,0.863481,22.2081,214,210,uncertainty_max_margin,dummy,True,False,0.98,0.0001,0.93,0.74,True,2020-05-27 19:25:11.489584
3,38923187d2daa75ed438ad3f57cbcacd,0.819315,0.8072,0.863086,40.6649,214,210,uncertainty_max_margin,dummy,False,False,0.85,0.0001,0.56,0.95,True,2020-05-27 20:17:38.070149
4,47dada3f4267151b75cb4badfa593d0c,0.817584,0.806854,0.862888,22.0207,214,210,uncertainty_max_margin,dummy,False,True,1.0,0.001,0.88,0.52,True,2020-05-27 19:25:11.419208
5,cefdf267c2f74643ce4ad63c7e0a6603,0.806854,0.805123,0.861897,43.22,214,210,uncertainty_max_margin,dummy,False,False,0.92,1e-05,0.94,0.95,True,2020-05-27 20:16:11.438671
6,f346618e7f3e5e5625b074bf438b8851,0.8072,0.803738,0.861103,27.0071,223,210,uncertainty_max_margin,dummy,False,True,0.9,0.01,0.85,0.61,True,2020-05-27 19:25:39.983269
7,c9c31b034649bfe59278325721b4bb87,0.809277,0.803738,0.861103,48.3172,220,210,uncertainty_max_margin,dummy,False,True,0.89,1e-05,0.52,0.79,True,2020-05-27 20:17:04.714805
8,e3db57eec4eef400b9e07ad8cd87046e,0.788854,0.802354,0.860308,47.9345,444,210,uncertainty_max_margin,dummy,True,False,0.92,1e-06,0.51,0.97,True,2020-05-27 19:25:37.269651
9,5d80617c389a45cf74c773cc5ec274d2,0.81343,0.802008,0.860109,50.167,214,210,uncertainty_max_margin,dummy,True,False,0.91,1e-05,0.94,0.65,True,2020-05-27 19:31:35.979137
