# Active Learning in MNIST dataset using Random Forest

**Prerequisites**:
You need to have encord-active [installed](https://docs.encord.com/docs/active-installation).

This notebook shows you how to plug your model in Encord Active and use it to run some iterations of the active learning cycle in the MNIST sandbox projects.

It follows five steps:
1. Download the MNIST sandbox projects.
2. Setup the model.
3. Select the acquisition functions.
4. Simulate the active learning workflow.
5. Compare and visualize the simulation results.

## Download the MNIST sandbox projects

In [None]:
from pathlib import Path
from encord_active.lib.common.active_learning import get_data_hashes_from_project
from encord_active.lib.project.project_file_structure import ProjectFileStructure
from encord_active.lib.project.sandbox_projects import fetch_prebuilt_project

def init_project_data(project_name, subset_size):
    # Choose where to store the project
    project_path = Path.cwd() / project_name
    # Download the project
    fetch_prebuilt_project(project_name, project_path)
    project_fs = ProjectFileStructure(project_path)
    # Select data from the project
    data_hashes = get_data_hashes_from_project(project_fs, subset_size)
    return project_fs, data_hashes

project_fs_train, data_hashes_train = init_project_data("[open-source][train]-mnist-dataset", subset_size=5000)
project_fs_test, data_hashes_test = init_project_data("[open-source][test]-mnist-dataset", subset_size=1000)

print(f"Train dataset size: {len(data_hashes_train)}")
print(f"Test dataset size: {len(data_hashes_test)}")

initial_data_amount = 500
n_iterations = 15
batch_size_to_label = 300
class_name = "digit" # name of the text classification to work with
print(f"Amount of data samples used to train the initial model: {initial_data_amount}")
print(f"Number of iterations in the active learning workflow: {n_iterations}")
print(f"Number of data samples annotated in each iteration: {batch_size_to_label}")

## Setup the model

In [None]:
from sklearn.ensemble import RandomForestClassifier
from encord_active.lib.metrics.acquisition_metrics.common import SKLearnClassificationModel

def init_and_train_model(X, y):
    forest = RandomForestClassifier(n_estimators = 500)
    forest.fit(X, y)
    return SKLearnClassificationModel(forest)

def get_accuracy_score(wrapped_model, X, y):
    return wrapped_model._model.score(X, y)

## Select the acquisition functions

The acquisition functions purpose is to recommend what data to label next in each cycle of the active learning workflow.

Choose from those already implemented in Encord Active or write your own.

In [None]:
from encord_active.lib.metrics.acquisition_metrics.acquisition_functions import Entropy, LeastConfidence, Margin, Variance
from encord_active.lib.metrics.heuristic.random import RandomImageMetric

# use 'asc' (ascending) and 'desc' (descending) ordering for posterior selection of k highest ranked data samples
selected_acq_funcs = [
    (Entropy, "desc"),
    (LeastConfidence, "desc"),
    (Margin, "asc"),
    (Variance, "asc"),
    (RandomImageMetric, "asc"),
]

## Simulate the active learning workflow

Run simulations of the active learning workflow with each acquisition function.

![active learning cycle](https://raw.githubusercontent.com/encord-team/encord-active/main/docs/docs/images/active-learning/active-learning-cycle.svg)

In [None]:
from typing import List
import numpy as np
from PIL import Image
from collections import defaultdict
from tqdm.auto import tqdm
from encord_active.lib.metrics.execute import execute_metrics
from encord_active.lib.common.active_learning import get_data,    get_metric_results, get_n_best_ranked_data_samples

def transform_image_data(images: List[Image]) -> List[np.ndarray]:
    return [np.asarray(image).flatten() / 255 for image in images]

accuracy_logger = defaultdict(dict)
X_test, y_test = get_data(project_fs_test, data_hashes_test, class_name)
X_test = transform_image_data(X_test)

for acq_func, rank_order in selected_acq_funcs:
    # Mockup of the initial labeling phase
    labeled_data_hashes_train = data_hashes_train[:initial_data_amount]
    unlabeled_data_hashes_train = set(data_hashes_train[initial_data_amount:])
    
    # Train the model
    X, y = get_data(project_fs_train, labeled_data_hashes_train, class_name)
    X = transform_image_data(X)

    model = init_and_train_model(X, y)
    
    accuracy_logger[acq_func.__name__][0] = get_accuracy_score(model, X_test, y_test)
    
    for it in tqdm(range(1, n_iterations + 1), disable=False, desc=f"Analyzing {acq_func.__name__} performance"):
        if acq_func.__name__ in ["RandomImageMetric"]:
            acq_func_instance = acq_func()
        else:
            acq_func_instance = acq_func(model)
        
        # Run the acquisition function
        execute_metrics([acq_func_instance], data_dir=project_fs_train.project_dir, use_cache_only=True)
        
        # Get the data scores
        acq_func_results = get_metric_results(project_fs_train, acq_func_instance)
        
        # Select the data to label next 
        data_to_label_next, _ = get_n_best_ranked_data_samples(
            acq_func_results, 
            batch_size_to_label, 
            rank_by=rank_order, 
            filter_by_data_hashes=unlabeled_data_hashes_train)
           
        # Mockup of the labeling phase
        X_new, y_new = get_data(project_fs_train, SKLearnClassificationModel, data_to_label_next, class_name)
        X.extend(X_new)
        y.extend(y_new)
        unlabeled_data_hashes_train.difference_update(data_to_label_next)
        
        # Train the model with the newly labeled data
        model = init_and_train_model(X, y)
        accuracy_logger[acq_func.__name__][it] = get_accuracy_score(model, X_test, y_test)

In [None]:
# misc: beautify function names
if RandomImageMetric.__name__ in accuracy_logger:
    accuracy_logger["Random"] = accuracy_logger.pop(RandomImageMetric.__name__)
if LeastConfidence.__name__ in accuracy_logger:
    accuracy_logger["Least Confidence"] = accuracy_logger.pop(LeastConfidence.__name__)

## Compare and visualize the simulation results

In [None]:
import matplotlib.pyplot as plt

for acq_func_name, points in accuracy_logger.items():
    xs, ys = zip(*points.items())
    plt.plot(xs, ys, label=acq_func_name)

plt.xlabel("Iteration")
plt.ylabel("Model Accuracy")
plt.xticks(range(n_iterations + 1))
plt.legend()
plt.show()