In [None]:
import json
from collections import defaultdict
from functools import partialmethod
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import yaml
from loguru import logger
from PIL import Image
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle
from tqdm.auto import tqdm

from encord_active.lib.common.iterator import DatasetIterator
from encord_active.lib.metrics.execute import execute_metrics
from encord_active.lib.project.project_file_structure import ProjectFileStructure

# silence logger
logger.remove()
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)

# Utility functions

In [None]:
def get_data_hashes_from_project(project_fs: ProjectFileStructure, subset_size=None):
    iterator = DatasetIterator(project_fs.project_dir, subset_size)
    data_hashes = [(iterator.label_hash, iterator.du_hash) for data_unit, img_pth in iterator.iterate()]
    return data_hashes

def get_data_from_data_hashes(project_fs: ProjectFileStructure, data_hashes: list[tuple[str, str]]):
    image_arrays, class_labels = zip(*(get_data_sample(project_fs, data_hash) for data_hash in data_hashes))
    return list(image_arrays), list(class_labels)

def get_data_sample(project_fs: ProjectFileStructure, data_hash: tuple[str, str]):
    label_hash, du_hash = data_hash
    lr_struct = project_fs.label_row_structure(label_hash)
    
    # get classification label
    label_row = json.loads(lr_struct.label_row_file.read_text())
    class_label = get_classification_label(label_row, du_hash, class_name="digit")
    
    # get image
    image_path = lr_struct.images_dir / f"{du_hash}.{label_row['data_units'][du_hash]['data_type'].split('/')[-1]}"
    image_array = np.asarray(Image.open(image_path)).flatten()
    
    return image_array, class_label

def get_classification_label(label_row, du_hash: str, class_name: str):
    data_unit = label_row["data_units"][du_hash]
    filtered_class = [_class for _class in data_unit["labels"]["classifications"] if _class["name"] == class_name]
    if len(filtered_class) == 0:
        return None
    class_hash = filtered_class[0]["classificationHash"]
    class_label = label_row["classification_answers"][class_hash]["classifications"][0]["answers"]
    return class_label

def train_model(X_train, y_train, model=None):
    # use logistic regression model as a dummy model example
    if model is None:
        model = LogisticRegression()
    model.fit(X_train, y_train)
    return model

def get_model_accuracy(X_test, y_test, model):
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    return accuracy

def get_n_best_ranked_data_samples(project_fs: ProjectFileStructure, data_hashes, n, acq_func_instance, rank_by: str):
    execute_metrics([acq_func_instance], data_dir=project_fs.project_dir)
    unique_acq_func_name = acq_func_instance.metadata.get_unique_name()
    acq_func_results = pd.read_csv(project_fs.metrics / f"{unique_acq_func_name}.csv")
        
    # filter acquisition function results to only contain data samples specified in data_hashes
    str_data_hashes = tuple(f"{label_hash}_{du_hash}" for label_hash, du_hash in data_hashes)
    filtered_results = acq_func_results[acq_func_results['identifier'].str.startswith(str_data_hashes, na=False)]
    
    if rank_by == "asc": # get the first n data samples if they were sorted by ascending score order
        best_n = filtered_results[["identifier", "score"]].nsmallest(n, "score", keep="first")["identifier"]
    elif rank_by == "desc":  # get the first n data samples if they were sorted by descending score order
        best_n = filtered_results[["identifier", "score"]].nlargest(n, "score", keep="first")["identifier"]
    else:
        raise ValueError
    return [get_data_hash_from_identifier(identifier) for identifier in best_n]
    
def get_data_hash_from_identifier(identifier: str):
    return tuple(identifier.split("_", maxsplit=2)[:2])

# Load files

In [None]:
config = yaml.safe_load(Path("config.yaml").read_text())["active-learning-in-mnist"]

# train
project_dir_train = Path(config["train"]["project_dir"])
project_fs_train = ProjectFileStructure(project_dir_train)
data_hashes_train = get_data_hashes_from_project(project_fs_train, subset_size=None)
# shuffle data hashes
data_hashes_train = shuffle(data_hashes_train, random_state=42)
print(f"Train dataset size: {len(data_hashes_train)}")

# test
project_dir_test = Path(config["test"]["project_dir"])
project_fs_test = ProjectFileStructure(project_dir_test)
data_hashes_test = get_data_hashes_from_project(project_fs_test)
X_test, y_test = get_data_from_data_hashes(project_fs_test, data_hashes_test)
print(f"Test dataset size: {len(data_hashes_test)}")

# active learning (AL) config variables
initial_data_amount = config["initial_data_amount"]
n_iterations = config["n_iterations"]
batch_size_to_label = config["batch_size_to_label"]
print(f"Initial amount of labeled data in the train dataset: {initial_data_amount}")
print(f"Number of iterations in the active learning (AL) workflow: {n_iterations}")
print(f"Number of data samples annotated between AL iterations: {batch_size_to_label}")

# Load acquisition functions

In [None]:
from encord_active.lib.metrics.acquisition_functions import Entropy, LeastConfidence, Margin, Variance
from encord_active.lib.metrics.heuristic.random import RandomImageMetric

# use 'asc' (ascending) and 'desc' (descending) ordering for posterior selection of k highest ranked data samples
acq_funcs = [
    (Entropy, "desc"),
    (LeastConfidence, "desc"),
    (Margin, "asc"),
    (Variance, "asc"),
    (RandomImageMetric, "asc"),
]

# Run the active learning workflow with each acquisition function

In [None]:
accuracy_logger = defaultdict(dict)
for acq_func, rank_order in acq_funcs:
    # mockup of the initial labeling phase
    labeled_data_hashes_train = data_hashes_train[:initial_data_amount]
    unlabeled_data_hashes_train = set(data_hashes_train[initial_data_amount:])
    
    X, y = get_data_from_data_hashes(project_fs_train, labeled_data_hashes_train)
    model = train_model(X, y)
    accuracy_logger[acq_func.__name__][0] = get_model_accuracy(X_test, y_test, model)
    for it in tqdm(range(1, n_iterations + 1), disable=False, desc=f"Analyzing {acq_func.__name__} performance"):
        if acq_func.__name__ in ["RandomImageMetric"]:
            acq_func_instance = acq_func()
        else:
            acq_func_instance = acq_func(model)
        data_to_label_next = get_n_best_ranked_data_samples(project_fs_train, unlabeled_data_hashes_train, batch_size_to_label, acq_func_instance, rank_by=rank_order)
        
        # mockup of the labeling phase
        X_new, y_new = get_data_from_data_hashes(project_fs_train, data_to_label_next)
        unlabeled_data_hashes_train.difference_update(data_to_label_next)
        
        X.extend(X_new)
        y.extend(y_new)
        model = train_model(X, y)
        accuracy_logger[acq_func.__name__][it] = get_model_accuracy(X_test, y_test, model)

In [None]:
# misc: beautify function names
if RandomImageMetric.__name__ in accuracy_logger:
    accuracy_logger["Random"] = accuracy_logger.pop(RandomImageMetric.__name__)
if LeastConfidence.__name__ in accuracy_logger:
    accuracy_logger["Least Confidence"] = accuracy_logger.pop(LeastConfidence.__name__)

# Show results of the active learning workflow

In [None]:
for acq_func_name, points in accuracy_logger.items():
    xs, ys = zip(*points.items())
    plt.plot(xs, ys, label=acq_func_name)

plt.xlabel("Iteration")
plt.ylabel("Model Accuracy")
plt.xticks(range(n_iterations + 1))
plt.legend()
plt.show()