# Itteratively Find Improved Consensus Labels for Multiannotator Data using Cleanlab

This example shows how to improve consensus labels by combining the CROWDLAB algorithm with itteratively retraining a model. For an introductory tutorial on finding consensus labels with the multiannotator library see [Find Best Consensus Labels for Multiannotator Data using Cleanlab](linke to hui wen's). 

The following code uses the [cifar10h](linke cifar10h) multiannotator labeling dataset which is a collection of 2751 annotators each labeling 200 examples for all 10,000 test images of the original [cifar10](link cifar10) dataset but **any multiannotator classification image dataset should work with the code below**. 

## 1. Install and import required dependencies, build example folder

In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from datetime import datetime

from utils.model_training import train_model
from utils.model_training import sum_xval_folds
from cleanlab.multiannotator import get_majority_vote_label
from cleanlab.multiannotator import get_label_quality_multiannotator

In [None]:
now = datetime.now() # Current date and time
experiment_path = "./experiment_" + str(int(now.timestamp()))

if not os.path.exists(experiment_path):
    os.makedirs(experiment_path)
    print("Directory " , experiment_path ,  " Created ")
else:    
    print("Directory " , experiment_path ,  " already exists")

print(f'Experiment saved in {experiment_path}')

In [None]:
# Import cifar10 data for model training or download it yourself by calling cifar2png cifar10 ./data/cifar10 --name-with-batch-index
!wget -nc 'https://cleanlab-public.s3.amazonaws.com/Multiannotator/cifar-10/cifar10_test.tar.gz'
!tar -xzf cifar10_test.tar.gz

# Import cifar10h pre calculated multiannotator labels and image paths
!cd $experiment_path && wget -nc 'https://cleanlab-public.s3.amazonaws.com/Multiannotator/cifar-10h/cifar-10h-worst25-coin20/c10h_labels_worst25_coin20.npy'
!cd $experiment_path && wget -nc 'https://cleanlab-public.s3.amazonaws.com/Multiannotator/cifar-10h/cifar-10h-worst25-coin20/c10h_image_paths.npy'

## 2. Load multiannotator labels and generate consensus labels for them
`multiannotator_labels` for this example is a precalculated subset of the original `cifar10h` annotator labels. The subset takes the worst 25 annotators and the incrementally add annotators from worst to best if they share annotations with those already in the subset until each of the 10,000 examples has at least 1 annotation. 

The reason for this being `cifar10` is an unnaturally easy dataset to label and using the original dataset would make the annotator agreement too high for our method to contribute meaningful improvement. Additionally in practice it is rare to have 50 annotators annotate a specific example but this subset ensures significantly sparser annotations.

In [None]:
# Load labels
multiannotator_labels = np.load(f'{experiment_path}/c10h_labels_worst25_coin20.npy')

# Load and reformat image paths to work for specific machine
image_paths = np.load(f'{experiment_path}/c10h_image_paths.npy', allow_pickle=True)
path = os.getcwd()
image_paths = [f"{path}/{image_path}" for image_path in image_paths]

<div class="alert alert-info">
Bringing Your Own Data (BYOD)?

You can easily replace the above with your own multiannotator dataset, and continue with the rest of the example.

`multiannotator_labels` should be a numpy array or pandas DataFrame with each column representing an annotator and each row representing an example. Your classes (and entries of `multiannotator_labels`) should be represented as integer indices 0, 1, ..., num_classes - 1, where examples that are not annotated by a particular annotator are represented using `np.nan`.
    
If working with images, `image_paths` should be a string of absolute or relative paths to the where each index corresponds the example for that row of `multiannotator_labels`.

If working with other data, `image_paths` should be a string of examples corresponding to the row of `multiannotator_labels`.

</div>


Before training our machine learning model, we must first obtain the consensus labels from the annotators that labeled the data. The simplest way to obtain an initial set of consensus labels is to select it using majority vote.

In [None]:
consensus_labels = get_majority_vote_label(multiannotator_labels)

## 3. Train model and use cleanlab to get better consensus labels

Next, we will train our model on the consensus labels obtained using majority vote to compute out-of-sample predicted probabilities. We will then use these `pred_probs` to generate more informed `consensus_labels` using Cleanlab's [CROWDLAB](link to something) algorithm. We them use these `consensus_labels` to train a better model that generates more accurate `pred_probs`. This process itterates until the `consensus_labels` have no more improvement.

`train_model()` trains a `resnet18` image model using cross validation to get out-of-sample predicted probabilities on the whole dataset. The function can be replaced with a custom training algorithm.

In [None]:
model_type = "resnet18" # You can also try with "swin_base_patch4_window7_224"

# Load model arguments
train_args = {
    "num_cv_folds": 5, 
    "verbose": 1, 
    "epochs": 1, 
    "holdout_frac": 0.2, 
    "time_limit": 21600, 
    "random_state": 123
}

In [None]:
# Loop through and retrain model on better consensus labels, save results
indices_changed = set()
seen_consensus_labels = list()
model_results = {}
itter = 0

while tuple(consensus_labels) not in seen_consensus_labels:
    seen_consensus_labels.append(tuple(consensus_labels)) # add curent consensus labels into the set
    model_results['itter'] = itter
    model_xval_results_folder = f'{experiment_path}/xval_results_itter{itter}' # + [model_type]

    # Zip consensus labels with their corresponding image_paths
    consensus_data = pd.DataFrame(zip(image_paths,consensus_labels), columns=["image", "label"])
    
    # Train model
    train_model(model_type, consensus_data, model_xval_results_folder, **train_args)
    pred_probs, labels, images = sum_xval_folds(model_type, model_xval_results_folder, **train_args)
    
    # Get improved consensus labels with label quality multiannotator using model pred probs
    label_quality_multiannotator = get_label_quality_multiannotator(multiannotator_labels, pred_probs, verbose=False)
    consensus_labels = label_quality_multiannotator["label_quality"]["consensus_label"].tolist()
    
    unique_indices = len(indices_changed)
    indices_changed.update(list(np.where(consensus_labels != np.array(seen_consensus_labels[-1]))[0]))
    unique_indices = len(indices_changed) - unique_indices
    label_changes_from_prior = np.sum(consensus_labels != np.array(seen_consensus_labels[-1]))
    
    print("Num changes in consensus labels from previous itter: ", np.sum(consensus_labels != np.array(seen_consensus_labels[-1])))
    print("Num unique indices changed: ", unique_indices)
    # End to delete
    
    results = {
        "pred_probs": pred_probs,
        "consensus_labels_in": labels, # consensus labels used to train the model
        "images": images, 
        "consensus_labels_out": consensus_labels, # new consensus labels generated from pred_probs
        "label_changes_from_prior": label_changes_from_prior, # num changes in consensus labels from previous itterations
        "unique_indices_added": unique_indices, # number of unique labels indices changed
    }
    
    model_results[itter] = results
    itter+=1
    
    if unique_indices == 0: # no more label improvement
        break

## 4. Measure consensus label accuracy and model performance

Since our annotators annotated a dataset `cifar10` to which there exist highly accurate `true_labels`, we can measure and report the accuracy of our methods against the ground truth. 

In a true multiannotator setting this would not be possible since ground truth labels will not exist. Instead we have [benchmarked](link paper or blogpost maybe) this method to ensure this method produces more accurate `consensus_labels` than common inustry methods for getting a consensus.

In [None]:
# Import and load ground truth labels (once again this is only for metrics, normally not have this information)
!cd $experiment_path && wget -nc 'https://cleanlab-public.s3.amazonaws.com/Multiannotator/cifar-10h/cifar-10h-worst25-coin20/c10h_test_labels.npy'

true_labels = np.load(f'{experiment_path}/c10h_test_labels.npy')

In [None]:
# Calculate CROWDLAB performance against the ground truth for each epoch
for i in range(itter):
    pred_probs = model_results[i]['pred_probs']
    consensus_labels_in = model_results[i]['consensus_labels_in']
    consensus_labels_out = model_results[i]['consensus_labels_out']

    acc_model_gtruth = (pred_probs.argmax(axis=1) == true_labels).mean()
    acc_consensus_gtruth = (consensus_labels_in == true_labels).mean()

    results = {
        "consensus_gtruth_accuracy": acc_consensus_gtruth, # consensus labels accuracy 
        "model_gtruth_accuracy": acc_model_gtruth,         # model label accuracy
    }
    
    model_results[i].update(results)

# Calculate accuracy of final CROWDLAB generated consensus label
final_consensus_labels = model_results[itter-1]["consensus_labels_out"]
final_consensus_accuracy = (final_consensus_labels == true_labels).mean()

print("Final consensus label accuracy (vs ground truth): ", final_consensus_accuracy)
print("Final model predictions accuracy (vs ground truth): ", model_results[itter-1]["model_gtruth_accuracy"])

Finally lets plot the results over multiple epochs and observe the improvement. Here epoch -1 corresponds to the initial `consensus_labels` generated using majority vote without the help of a model. Generating consensus with majority vote is the most common practice and therefore makes a great baseline to compare CROWDLAB against.

As you can see, the accuracy of the `consensus_labels` improves over several itterations of model training.

In [None]:
consensus_gtruth_accuracy = [model_results[i]["consensus_gtruth_accuracy"] 
                             for i in range(itter)] + [final_consensus_accuracy]

model_gtruth_accuracy = [model_results[i]["model_gtruth_accuracy"] for i in range(itter)]


# plot prc
plt.rcParams["figure.figsize"] = (17,6)
plt.show()

plt.subplot(1, 2, 1)
plt.plot(range(-1, itter),consensus_gtruth_accuracy)
plt.xlabel("Epochs", fontsize=14)
plt.ylabel("Consensus Label Accuracy", fontsize=14)
plt.title("Consensus Label vs Ground Truth Accuracy", fontsize=14, fontweight="bold")

plt.subplot(1, 2, 2)
plt.plot(range(0, itter), model_gtruth_accuracy)
plt.xlabel("Epochs", fontsize=14)
plt.ylabel("Model Prediction Accuracy", fontsize=14)
plt.title("Model Predictions vs Ground Truth Accuracy", fontsize=14, fontweight="bold")

plt.show()

In [None]:
if consensus_gtruth_accuracy[0] >= consensus_gtruth_accuracy[-1]:  # check cleanlab has improved prediction accuracy
    raise Exception("Cleanlab failed to improve baseline consensus label accuracy.")