# $\texttt{dopanim}$: Annotation Times in Active Learning

In [None]:
import numpy as np
import os
import torch
import pylab as plt
import sys

# TODO: Append the path to your `multi-annotator-machine-learning` project.
sys.path.append("../../")

from maml.data import Dopanim
from maml.data import SSLDatasetWrapper
from skactiveml.pool import UncertaintySampling, RandomSampling
from skactiveml.utils import MISSING_LABEL
from skactiveml.classifier import SklearnClassifier
from sklearn.linear_model import LogisticRegression
from tqdm.auto import tqdm

# TODO: Adjust data path to your dataset.
DATA_PATH = "."

# TODO: Adjust flag for downloading the dataset.
DOWNLOAD = False

# TODO: Adjust path for saving figures.
FIGURE_PATH = "."

### Load Training and Test data

In [None]:
# Load DinoV2 model as SSL backbone.
repo_or_dir = "facebookresearch/dinov2"
model = "dinov2_vits14"
ssl_model = torch.hub.load(repo_or_dir=repo_or_dir, model=model)

# Load training and test dataset, including annotation times as averages per sample
test_ds = Dopanim(DATA_PATH, version='test', variant='full', download=DOWNLOAD)
train_ds = Dopanim(DATA_PATH, version='train', variant='full', transform=test_ds.transform, download=DOWNLOAD)
train_times = train_ds.load_annotation_times(train_ds.observation_ids, train_ds.annotators)
train_times[train_times == -1] = np.nan
train_times = np.nanmean(train_times, axis=-1)

# Enable usage of cached datasets.
train_ds = SSLDatasetWrapper(dataset=train_ds, model=ssl_model, cache=True)
test_ds = SSLDatasetWrapper(dataset=test_ds, model=ssl_model, cache=True)

# Create numpy arrays of training and test datasets for usage with `scikit-activeml`.
X_train, y_train = [], []
for data in train_ds:
    X_train.append(data['x'].numpy())
    y_train.append(data['y'].numpy())
X_train = np.array(X_train)
y_train = np.array(y_train)
X_test, y_test = [], []
for data in test_ds:
    X_test.append(data['x'].numpy())
    y_test.append(data['y'].numpy())
X_test = np.array(X_test)
y_test = np.array(y_test)

### Define Helper Functions for Active Learning

In [None]:
def al_cycle(al_strat: str="random", num_init: int=10, num_acq: int=19, batch_size: int=10, seed: int=42):
    """
    Helper function for performing an active learning cycle.
    
    Parameters
    ----------
    al_strat : 'random'' or 'margin'
        Name of the active learning strategy.
    num_init : int
        Number of initially labelled samples.
    num_acq : int
        Number of label acquisition cycles.
    batch_size : int
        Number of label acquisitions per cycle.
    seed : int
        Seed to ensure reproducibility.
        
    Returns
    -------
    learning_curve: list
        Learning curve as a list of dictionaries.
    """
    np.random.seed(seed)
    y_pool = np.full(shape=y_train.shape, fill_value=MISSING_LABEL)

    init_indices = np.random.choice(range(len(X_train)), size=num_init, replace=False)
    y_pool[init_indices] = y_train[init_indices]

    if al_strat == 'random':
        qs = RandomSampling(random_state=seed)
    elif al_strat == 'margin':
        qs = UncertaintySampling(random_state=seed, method='margin_sampling')
    else:
        raise NotImplementedError()

    learning_curve = []
    query_idx = None
    for i_acq in tqdm(range(num_acq + 1)):
        if i_acq != 0:
            if al_strat == 'random':
                query_idx = qs.query(X=X_train, y=y_pool, batch_size=batch_size)
            elif al_strat == 'margin':
                query_idx = qs.query(X=X_train, y=y_pool, clf=clf, batch_size=batch_size)
            else: 
                raise NotImplementedError()
            y_pool[query_idx] = y_train[query_idx]

        clf = SklearnClassifier(
            LogisticRegression(random_state=seed, max_iter=3000),
            classes=np.unique(y_train),
            random_state=seed,
        )
        clf.fit(X_train, y_pool)

        query_indices = np.where(~np.isnan(y_pool))[0].tolist()
        anno_time = np.mean(train_times[query_idx]) if query_idx is not None else np.mean(train_times[query_indices])
        result = {
            # 'query_indices': query_indices,
            'num_samples': len(query_indices),
            'anno_time': anno_time,
            'test_acc': clf.score(X_test, y_test),
        }
        learning_curve.append(result)
    return learning_curve

def avg_lcs(learning_curve: list):
    """
    Computes averages and standard deviations of the learning curves.
    
    Parameters
    ----------
    learning_curve : list
        Learning curve as a list of dictionaries.
        
    Returns
    -------
    avg_lc : list
        Averages of the given learning curves.
    std_lc : list
        Standard deviations of the given learning curves.
    """
    avg_lc = [{k: [] for k in d} for d in learning_curve[0]]
    for lc in learning_curve:
        for cycle_dict, avg_cycle_dict in zip(lc, avg_lc):
            for key in cycle_dict:
                avg_cycle_dict[key].append(cycle_dict[key])
    avg_lc_std = [{key: np.std(val) for key, val in d.items()} for d in avg_lc]
    avg_lc = [{key: np.mean(val) for key, val in d.items()} for d in avg_lc]
    return avg_lc, avg_lc_std

### Perform Active Learning

In [None]:
# Define active learning setup.
num_init = 100
batch_size = 100
num_acq = 19

# Perform active learning.
lcs_random = []
lcs_margin = []
for seed in range(50):
    print(seed)
    lcs_random.append(al_cycle('random', num_init=num_init, num_acq=num_acq, batch_size=batch_size, seed=seed))
    lcs_margin.append(al_cycle('margin', num_init=num_init, num_acq=num_acq, batch_size=batch_size, seed=seed))
    
# Evaluate learning curves.
avg_lc_random, avg_lc_random_std = avg_lcs(lcs_random)
avg_lc_margin, avg_lc_margin_std = avg_lcs(lcs_margin)

### Plot Active Learning Results

In [None]:
plt.figure(figsize=(16, 5))
plt.subplot(121)
plt.plot([d['num_samples'] for d in avg_lc_random], [d['test_acc'] for d in avg_lc_random], label='Random', color="#e580ffff")
plt.fill_between([d['num_samples'] for d in avg_lc_random], np.subtract([d['test_acc'] for d in avg_lc_random], [d['test_acc'] for d in avg_lc_random_std]), np.add([d['test_acc'] for d in avg_lc_random], [d['test_acc'] for d in avg_lc_random_std]), alpha=.3, color="#e580ffff")
plt.plot([d['num_samples'] for d in avg_lc_margin], [d['test_acc'] for d in avg_lc_margin], label='Margin', color="#5fd3bcff")
plt.fill_between([d['num_samples'] for d in avg_lc_margin], np.subtract([d['test_acc'] for d in avg_lc_margin], [d['test_acc'] for d in avg_lc_margin_std]), np.add([d['test_acc'] for d in avg_lc_margin], [d['test_acc'] for d in avg_lc_margin_std]), alpha=.3, color="#5fd3bcff")
plt.xticks(np.arange(0, 2250, step=500))
plt.yticks(np.arange(0.55, 0.95, step=0.1))
plt.legend()
plt.subplot(122)
plt.plot([d['num_samples'] for d in avg_lc_random], [d['anno_time'] for d in avg_lc_random], label='Random', color="#e580ffff")
plt.fill_between(
    [d['num_samples'] for d in avg_lc_random], 
    np.subtract([d['anno_time'] for d in avg_lc_random], [d['anno_time'] for d in avg_lc_random_std]), 
    np.add([d['anno_time'] for d in avg_lc_random], [d['anno_time'] for d in avg_lc_random_std]), 
    alpha=.3, color="#e580ffff")
plt.plot([d['num_samples'] for d in avg_lc_margin], [d['anno_time'] for d in avg_lc_margin], label='Margin', color="#5fd3bcff")
plt.fill_between(
    [d['num_samples'] for d in avg_lc_margin], 
    np.subtract([d['anno_time'] for d in avg_lc_margin], [d['anno_time'] for d in avg_lc_margin_std]), 
    np.add([d['anno_time'] for d in avg_lc_margin], [d['anno_time'] for d in avg_lc_margin_std]), 
    alpha=.3, color="#5fd3bcff")
plt.xticks(np.arange(0, 2250, step=500))
plt.yticks([7, 7.5, 8.0])
plt.savefig(os.path.join(FIGURE_PATH, "annotation_times_active_learning.pdf"))