In [None]:
import yaml
from pathlib import Path
import numpy as np
from PIL import Image
from sklearn.utils import shuffle
from encord_active.lib.common.iterator import DatasetIterator
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from encord_active.lib.project.project_file_structure import ProjectFileStructure
from collections import defaultdict
import json
from tqdm import tqdm
from encord_active.lib.metrics.execute import execute_metrics
import pandas as pd
import matplotlib.pyplot as plt

# Load files

In [None]:
# load config file
config = yaml.safe_load(Path("config.yaml").read_text())
project_dir = Path(config["project_dir"])
project_dir_train = project_dir # todo change project_dir in the config file to train and test
project_fs_train = ProjectFileStructure(project_dir_train)

In [None]:
# utility functions

def get_data_hashes_from_project(project_dir: Path, subset_size=None):
    iterator = DatasetIterator(project_dir, subset_size)
    data_hashes = [(iterator.label_hash, iterator.du_hash) for data_unit, img_pth in iterator.iterate()]
    return data_hashes

def get_data_from_data_hashes(project_fs: ProjectFileStructure, data_hashes: list[tuple[str, str]]):
    image_arrays, class_labels = zip(*(get_data_sample(project_fs, data_hash) for data_hash in data_hashes))
    return list(image_arrays), list(class_labels)

def get_data_sample(project_fs: ProjectFileStructure, data_hash: tuple[str, str]):
    label_hash, du_hash = data_hash
    lr_struct = project_fs.label_row_structure(label_hash)
    
    # get classification label
    label_row = json.loads(lr_struct.label_row_file.read_text())
    class_label = get_classification_label(label_row, du_hash, class_name="Classification")
    
    # get image
    image_path = lr_struct.images_dir / f"{du_hash}.{label_row['data_units'][du_hash]['data_type'].split('/')[-1]}"
    image_array = np.asarray(Image.open(image_path)).flatten()
    
    return image_array, class_label

def get_classification_label(label_row, du_hash: str, class_name: str):
    data_unit = label_row["data_units"][du_hash]
    filtered_class = [_class for _class in data_unit["labels"]["classifications"] if _class["name"] == class_name]
    if len(filtered_class) == 0:
        return None
    class_hash = filtered_class[0]["classificationHash"]
    # check if this is the same for text classification instead of radio button classification
    class_label = label_row["classification_answers"][class_hash]["classifications"][0]["answers"][0]["name"]
    return class_label

def train_model(X_train, y_train, model=None):
    # use logistic regression model as a dummy model example
    if model is None:
        model = LogisticRegression()
    model.fit(X_train, y_train)
    return model

def get_model_accuracy(X_test, y_test, model):
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    return accuracy

def get_n_best_ranked_data_samples(project_fs: ProjectFileStructure, data_hashes, n, acq_func_instance, rank_by: str):
    execute_metrics([acq_func_instance], data_dir=project_fs.project_dir)
    unique_acq_func_name = acq_func_instance.metadata.get_unique_name()
    acq_func_results = pd.read_csv(project_fs.metrics / f"{unique_acq_func_name}.csv")
        
    # filter acquisition function results to only contain data samples specified in data_hashes
    str_data_hashes = tuple(f"{label_hash}_{du_hash}" for label_hash, du_hash in data_hashes)
    filtered_results = acq_func_results[acq_func_results['identifier'].str.startswith(str_data_hashes, na=False)]
    
    if rank_by == "asc": # get the first n data samples if they were sorted by ascending score order
        best_n = filtered_results[["identifier", "score"]].nsmallest(n, "score", keep="first")["identifier"]
    elif rank_by == "desc":  # get the first n data samples if they were sorted by descending score order
        best_n = filtered_results[["identifier", "score"]].nlargest(n, "score", keep="first")["identifier"]
    else:
        raise ValueError
    return [get_data_hash_from_identifier(identifier) for identifier in best_n]
    
def get_data_hash_from_identifier(identifier: str):
    return tuple(identifier.split("_", maxsplit=2)[:2])

In [None]:
# prepare data hashes from train project
subset_size = None
data_hashes_train = get_data_hashes_from_project(project_dir_train, subset_size)

# shuffle data hashes
data_hashes_train = shuffle(data_hashes_train, random_state=42)

print(f"Train dataset size: {len(data_hashes_train)}")

In [None]:
# prepare test data
project_dir_test = project_dir  # todo change this for the real test project
project_fs_test = ProjectFileStructure(project_dir_test)
data_hashes_test = get_data_hashes_from_project(project_dir_test)
X_test, y_test = get_data_from_data_hashes(project_fs_test, data_hashes_test)
print(f"Test dataset size: {len(X_test)}")

In [None]:
# set configuration variables needed for the active learning workflow (move to config.yaml)
initial_data_amount = 20  # initial amount of labeled data
n_iterations = 10 # number of iterations of the active learning paradigm
batch_size_to_label = 5 # number of data samples labeled between AL iterations

In [None]:
# load common acquisition functions (for active learning)
from encord_active.lib.metrics.acquisition_functions import Entropy, LeastConfidence, Margin, Variance

# use 'asc' (ascending) and 'desc' (descending) ordering for later selection of k highest ranked data samples
acq_funcs = [(Entropy, "desc"), (LeastConfidence, "desc"), (Margin, "asc"), (Variance, "desc")]

In [None]:
accuracy_logger = defaultdict(dict)
for acq_func, rank_order in acq_funcs:
    print(f"Analyzing acquisition function: {acq_func.__name__}")
    
    # mockup of the initial labeling phase
    labeled_data_hashes_train = data_hashes_train[:initial_data_amount]
    unlabeled_data_hashes_train = set(data_hashes_train[initial_data_amount:])
    
    X, y = get_data_from_data_hashes(project_fs_train, labeled_data_hashes_train)
    model = train_model(X, y)
    accuracy_logger[acq_func.__name__][0] = get_model_accuracy(X_test, y_test, model)
    for it in tqdm(range(1, n_iterations + 1)):
        acq_func_instance = acq_func(model)
        data_to_label_next = get_n_best_ranked_data_samples(project_fs_train, unlabeled_data_hashes_train, batch_size_to_label, acq_func_instance, rank_by=rank_order)
        
        # mockup of the labeling phase
        X_new, y_new = get_data_from_data_hashes(project_fs_train, data_to_label_next)
        unlabeled_data_hashes_train.difference_update(data_to_label_next)
        
        X.extend(X_new)
        y.extend(y_new)
        model = train_model(X, y)
        accuracy_logger[acq_func.__name__][it] = get_model_accuracy(X_test, y_test, model)

In [None]:
for acq_func_name, points in accuracy_logger.items():
    xs, ys = zip(*points.items())
    plt.plot(xs, ys, label=acq_func_name)

plt.xlabel("Iteration")
plt.ylabel("Model Accuracy")
plt.legend()
plt.show()