# Use Baal in production (Image classification)

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/baal-org/baal/blob/master/notebooks/baal_prod_cls.ipynb)

In this tutorial, we will show you how to use Baal during your labeling task.

**NOTE** In this tutorial, we assume that we do not know the labels!

### Install baal

```bash
pip install baal
```

We will first need a dataset! For the purpose of this demo, we will use a classification dataset, but Baal
works on more than computer vision! As long as we can estimate the uncertainty of a prediction, Baal can be used.

We will use the [Natural Images Dataset](https://www.kaggle.com/prasunroy/natural-images).
Please extract the data in `/tmp/natural_images`.


In [1]:
from glob import glob
import os
from sklearn.model_selection import train_test_split
files = glob('/tmp/natural_images/*/*.jpg')
classes = os.listdir('/tmp/natural_images')
train, test = train_test_split(files, random_state=1337)  # Split 75% train, 25% validation
print(f"Train: {len(train)}, Valid: {len(test)}, Num. classes : {len(classes)}")


Introducing `baal.active.FileDataset` and `baal.active.ActiveLearningDataset`

FileDataset is simply an object that loads data and implements `def label(self, idx: int, lbl: Any)`.
This methods is necessary to label items in the dataset. You can set any value you want for unlabelled items,
in our example we use -1.

`ActiveLearningDataset` is a wrapper around a `Dataset` that performs data management.
When you iterate over it, it will return labelled items only.

To learn more on dataset management, visit [this notebook](./fundamentals/active-learning.ipynb).



In [2]:
from baal.active import FileDataset, ActiveLearningDataset
from torchvision import transforms

train_transform = transforms.Compose([transforms.RandomHorizontalFlip(),
                                      transforms.Resize(224),
                                      transforms.RandomCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# We use -1 to specify that the data is unlabeled.
train_dataset = FileDataset(train, [-1] * len(train), train_transform)

test_transform = transforms.Compose([transforms.Resize(224),
                                      transforms.RandomCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

# We use -1 to specify that the data is unlabeled.
test_dataset = FileDataset(test, [-1] * len(test), test_transform)
active_learning_ds = ActiveLearningDataset(train_dataset, pool_specifics={'transform': test_transform})



We now have two unlabeled datasets : train and validation. We encapsulate the training dataset in a 
`ActiveLearningDataset` object which will take care of the split between labeled and unlabeled samples.
We are now ready to use Active Learning.
We will use a technique called MC-Dropout, Baal supports other techniques (see README) and proposes a similar API
for each of them.
When using MC-Dropout with Baal, you can use any model as long as there are some Dropout Layers. These layers are essential to compute
the uncertainty of the model.

Baal propose several models, but it also supports custom models using baal.bayesian.dropout.MCDropoutModule.

In this example, we will use VGG-16, a popular model from `torchvision`.

In [3]:
import torch
from torch import nn, optim
from baal.modelwrapper import ModelWrapper
from torchvision.models import vgg16
from baal.bayesian.dropout import MCDropoutModule
USE_CUDA = torch.cuda.is_available()
model = vgg16(pretrained=False, num_classes=len(classes))
# This will modify all Dropout layers to be usable at test time which is
# required to perform Active Learning.
model = MCDropoutModule(model)
if USE_CUDA:
  model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-4)

# ModelWrapper is an object similar to keras.Model.
baal_model = ModelWrapper(model, criterion)



### Heuristics

To rank uncertainty, we will use a heuristic. For classification and segmentation, BALD is the recommended
heuristic. We will also add noise to the heuristic to lower the selection bias added by the AL process.
This is done by specifying `shuffle_prop` in the heuristic constructor.


In [4]:
from baal.active.heuristics import BALD
heuristic = BALD(shuffle_prop=0.1)


### Oracle
When the AL process requires a new item to labeled, we need to provide an Oracle. In your case, the Oracle will
be a human labeler most likely. For this example, we're lucky the class label is in the image path!


In [5]:
# This function would do the work that a human would do.
def get_label(img_path):
  return classes.index(img_path.split('/')[-2])



### Labeling process
The labeling will go like this:
1. Label all the test set and some samples from the training set.
2. Train the model for a few epoch on the training set.
3. Select the K-top uncertain samples according to the heuristic.
4. Label those samples.
5. If not done, go back to 2.



In [6]:
import numpy as np
# 1. Label all the test set and some samples from the training set.
for idx in range(len(test_dataset)):
  img_path = test_dataset.files[idx]
  test_dataset.label(idx, get_label(img_path))
  
# Let's label 100 training examples randomly first.
# Note: the indices here are relative to the pool of unlabelled items!
train_idxs = np.random.permutation(np.arange(len(train_dataset)))[:100].tolist()
labels = [get_label(train_dataset.files[idx]) for idx in train_idxs]
active_learning_ds.label(train_idxs, labels)

print(f"Num. labeled: {len(active_learning_ds)}/{len(train_dataset)}")


In [7]:
# 2. Train the model for a few epoch on the training set.
baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)
baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)

print("Metrics:", {k:v.avg for k,v in baal_model.metrics.items()})


In [8]:
# 3. Select the K-top uncertain samples according to the heuristic.
pool = active_learning_ds.pool
if len(pool) == 0:
  raise ValueError("We're done!")

# We make 15 MCDropout iterations to approximate the uncertainty.
predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)
# We will label the 10 most uncertain samples.
top_uncertainty = heuristic(predictions)[:10]


In [9]:
# 4. Label those samples.
oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)
labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]
print(list(zip(labels, oracle_indices)))
active_learning_ds.label(top_uncertainty, labels)



In [None]:
# 5. If not done, go back to 2.
for step in range(5): # 5 Active Learning step!
  # 2. Train the model for a few epoch on the training set.
  print(f"Training on {len(active_learning_ds)} items!")
  baal_model.train_on_dataset(active_learning_ds, optimizer, batch_size=16, epoch=5, use_cuda=USE_CUDA)
  baal_model.test_on_dataset(test_dataset, batch_size=16, use_cuda=USE_CUDA)

  print("Metrics:", {k:v.avg for k,v in baal_model.metrics.items()})
  
  # 3. Select the K-top uncertain samples according to the heuristic.
  pool = active_learning_ds.pool
  if len(pool) == 0:
    print("We're done!")
    break
  predictions = baal_model.predict_on_dataset(pool, batch_size=16, iterations=15, use_cuda=USE_CUDA, verbose=False)
  top_uncertainty = heuristic(predictions)[:10]
  # 4. Label those samples.
  oracle_indices = active_learning_ds._pool_to_oracle_index(top_uncertainty)
  labels = [get_label(train_dataset.files[idx]) for idx in oracle_indices]
  active_learning_ds.label(top_uncertainty, labels)
  
  

And we're done!
Be sure to save the dataset and the model.


In [11]:
torch.save({
  'active_dataset': active_learning_ds.state_dict(),
  'model': baal_model.state_dict(),
  'metrics': {k:v.avg for k,v in baal_model.metrics.items()}
}, '/tmp/baal_output.pth')


## Support
Submit an issue or reach us to our Slack!