# OLX case study: Binary classification

In the second notebook, we develop different approaches for binary image classification. The goal is to distinguish photos with clocks from photos with other objects. Our solutions can be categorized as follows:
1. Postprocessing the output of ImageNet-pretrained models in two different ways.
2. Transfer learning with pretrained feature extractors and the available OLX data.

All models are  in PyTorch. This is an arbitrary choice, since one can very certainly do the same things in Tensorflow for example. During the experiments I did not have access to a GPU. Hence, GPU support is disabled at the moment.

We will now start to implement and train the models. Thereafter we will investigate the quality of their binary decisions as well as their confidence scores (if available). The model that performs best in that regards is saved for future use.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models
from sklearn.metrics import (
    confusion_matrix, roc_curve, 
    precision_recall_curve, auc
)

from utils import (
    DataImport, BinarySet, SummedProbabilities,
    BalancedSampler, ClassifierTraining,
    analyze_predictions, predict_loader
)

## Data import

Let us import the data again.

In [None]:
data_dir = '../data' # data directory

guess_label = lambda file_name:'clock' if '_' not in file_name else 'other' # label from file name

data = DataImport(data_dir, guess_label)
data.print_summary()

## Preprocessing

We then continue with the definition of some small preprocessing pipelines. They determine how the images are processed before being ingested into a model. That might include a normalization/standardization and some resizing or cropping operations. This is a very important step.

In [None]:
SHAPE = (224, 224)
MEAN = (0.485, 0.456, 0.406) # ImageNet data
STD = (0.229, 0.224, 0.225)

transform = {
    # resized images
    'resize': transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(SHAPE),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ]),
    # cropped images
    'crop': transforms.Compose([
        transforms.ToPILImage(),
        transforms.CenterCrop(SHAPE),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ]),
    # full-sized images
    'full': transforms.Compose([
        transforms.ToPILImage(),
        transforms.ToTensor(),
        transforms.Normalize(MEAN, STD)
    ])
}

In [None]:
resize_set = BinarySet(data, target='clock', transform=transform['resize'])
crop_set = BinarySet(data, target='clock', transform=transform['crop'])
full_set = BinarySet(data, target='clock', transform=transform['full'])

In [None]:
batch_size = 32 # number of samples per mini-batch

resize_loader = DataLoader(resize_set, batch_size=batch_size, shuffle=True)
crop_loader = DataLoader(crop_set, batch_size=batch_size, shuffle=True)
full_loader = DataLoader(full_set, batch_size=1, shuffle=True) # one image per batch

We can investigate the look of some resized images for instance.

In [None]:
images, labels = next(iter(resize_loader)) # generate data
plot_size = (3, 4)
plot_ids = np.random.choice(np.arange(len(images)), size=np.prod(plot_size), replace=False)

fig, axes = plt.subplots(nrows=plot_size[0], ncols=plot_size[1], figsize=(8, 6))
for idx, ax in enumerate(axes.ravel().tolist()):
    image = np.clip(images[plot_ids[idx]].numpy().transpose(1,2,0) * STD + MEAN, 0, 1)
    label = labels[plot_ids[idx]].numpy()
    ax.imshow(image)
    ax.set_title(label)
    ax.set(xticks=[], yticks=[])
fig.tight_layout()

## Pretrained models

In the following, we will realize two straightforward ways classifying images with and without clocks. They are based on postprocessing the output of ImageNet-trained models. Among its 1000 standard classes, there are multiple types of watches and clocks. That allows for the two possibilities:
- First, one might simply check whether or not a clock-relevant class is contained in the top predictions of the pretrained model.
- Second, one could sum up the predicted probabilities for all relevant classes in order to obtain a continuous rather than a binary estimate.

To start with, we import a pretrained model. An AlexNet architecture is opted for as the simplest choice. Of course, one might experiment with more complex models as well. That might promise better performance.

In [None]:
pretrained_model = models.alexnet(weights=models.AlexNet_Weights.DEFAULT)
# pretrained_model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
# pretrained_model = models.resnet101(weights=models.ResNet101_Weights.DEFAULT)

for param in pretrained_model.parameters():
    param.requires_grad = False # freeze model parameters

On that basis, a model is constructed that accumulates the probability of being any type of clock. To that end one just has to add the relevant probabilities as predicted by the pretrained model. Unlike in TensorFlow, there is unfortunately no "decode_predictions" function in PyTorch. That is why one has to manually select the relevant model responses.

In [None]:
clock_ids = (409, 530, 531, 826, 892) # TODO: check correctness
clock_model = SummedProbabilities(pretrained_model, clock_ids)

The performance of first model is now evaluated on the basis of the full-sized images. For a given number of how many predicted top classes should be considered, some measures such as the accuracy and the confusion matrix are computed. While the number of top classes is suprisingly high, the classification performance of this simple approach is "ok-ish". A more detailed analysis can be certainly done.

In [None]:
k = 70 # top k classes to include

summary = analyze_predictions(
    pretrained_model,
    full_loader,
    k=k,
    target_ids=clock_ids
)

print('Confusion matrix:\n', summary['confusion'])
print('Accuracy: {:.2f}'.format(summary['accuracy']))
print('Precision: {:.2f}'.format(summary['precision']))
print('Recall: {:.2f}'.format(summary['recall']))

For the second model that accumulates the probability of being a clock over all relevant ImageNet classes, the performance evaluation proceeds analogously. A suprisingly low threshold parameter leads to an "acceptable" classifier. Its performance on full-sized images is comparable to the first model. We observe similar results when using the resized images with different aspect ratios.

In [None]:
threshold = 0.007 # probability threshold

summary = analyze_predictions(
    clock_model,
    full_loader,
    threshold=threshold
)

print('Confusion matrix:\n', summary['confusion'])
print('Accuracy: {:.2f}'.format(summary['accuracy']))
print('Precision: {:.2f}'.format(summary['precision']))
print('Recall: {:.2f}'.format(summary['recall']))

In summary, both approaches provide a viable classifier. The first one makes purely binary decisions, whereas the second one establishes a continuous classification score with a probabilistic interpretation. Due to its small values, however, the use of this score might be arguable. One interpretation is that the clocks and other objects in the OLX data set are somewhat different (but not too much) from images found in ImageNet.

## Transfer learning

An alternative method based on transfer learning is pursued. Here, the same pretrained model is taken as a feature extractor, and a new binary classification head is trained on the basis of our data. This will hopefully yield a well-behaved continuous classification score. On the downside, we do not expect that this would generalize well to data distributions that are very different from our data set.

Since our data set is small, we use data augmentation techniques. The first step is to define an appropriate  augmentation pipeline. We simply use some more or less reasonable settings. The parameters should be more carefully tuned in the future though.

In [None]:
transform['train'] = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomRotation(
        45, # too (less) agressive?
        interpolation=transforms.InterpolationMode.BILINEAR
    ),
    transforms.RandomResizedCrop(SHAPE),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(MEAN, STD)
])

Our data set is also weakly imbalanced. Hence, an oversampling scheme is implemented that generates roughly balanced mini-batches. In conjunction with data augmentation, this mitigates the imbalance to some degree. A split between train and validation set is also realized.

In [None]:
val_frac = 0.3 # fraction of samples for validation

train_set = BinarySet(
    data,
    target='clock',
    transform=transform['train']
)

indices = np.random.permutation(np.arange(len(train_set)))
split_idx = int(np.floor((1 - val_frac) * len(train_set)))
train_ids = indices[:split_idx].tolist()
val_ids = indices[split_idx:].tolist()

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    sampler=BalancedSampler(train_set, indices=train_ids)
)

val_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    sampler=BalancedSampler(train_set, indices=val_ids)
)

Now we define our binary classification architecture. It mainly consists of a pretrained feature extractor and a linear single-output classifier at the end. It is remarked that the final model output is not yet activated with a sigmoid function.

In [None]:
binary_model = nn.Sequential(
    pretrained_model.features,
    nn.AdaptiveAvgPool2d(output_size=(6, 6)),
    nn.Flatten(),
    nn.Linear(in_features=256*6*6, out_features=1)
)

Binary cross entropy is used as the loss function and an "arbitrary" optimizer and learning rate is chosen. Only the weights of the final classification layer are trained. An l2-regularizer is used on those weights. A systematic hyperparameter optimization is beyond the scope of this case study.

In [None]:
criterion = nn.BCEWithLogitsLoss(reduction='mean') # requires logits

optimizer = torch.optim.Adam(binary_model.parameters(), lr=0.001, weight_decay=0.2)

classifier = ClassifierTraining(
    binary_model,
    criterion,
    optimizer,
    train_loader,
    val_loader
)

Time to start the training!

In [None]:
history = classifier.fit(no_epochs=100, log_interval=5)

Let us shortly have a look at the mandatory standard plot below. Keep in mind that only a single layer has been trained.

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
ax.plot(np.array(history['train_loss']), label='train', alpha=0.7)
ax.plot(np.array(history['val_loss']), label='val.', alpha=0.7)
ax.set(xlabel='epoch', ylabel='loss')
ax.set_xlim((0, history['no_epochs']))
ax.legend()
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

The final accuracies on the train and val. set are evaluated. We are not overfitting.

In [None]:
train_loss, train_acc = classifier.test(train_loader)
val_loss, val_acc = classifier.test(val_loader)

print('Train acc.: {:.4f}'.format(train_acc))
print('Val. acc.: {:.4f}'.format(val_acc))

For comparing against the non-learning approaches from above, the performance is also evaluated on the set of full-sized images that have not been resized and augmented. Of course one might here criticize the overlap with the training set. The threshold parameter could be further tuned.

In [None]:
threshold = 0.5 # probability threshold

summary = analyze_predictions(
    binary_model,
    full_loader,
    threshold=threshold
)

print('Confusion matrix:\n', summary['confusion'])
print('Accuracy: {:.2f}'.format(summary['accuracy']))
print('Precision: {:.2f}'.format(summary['precision']))
print('Recall: {:.2f}'.format(summary['recall']))

Well, that looks quite good. The performance is better than the previous approaches. Moreover we have now a meaningful classification score. We export the learned weights, such that they can be imported and used later on. We only have to keep the limitations in mind when deploying the model. For a better model that would generalize beyond our toy scenario, we would need data sets that contain more images and more object categories.

In [None]:
weights_file = '../weights.pt'
torch.save(binary_model[-1].state_dict(), weights_file)

## Addendum: Classifier evaluation

In [None]:
y_pred, y_true = predict_loader(
    binary_model,
    full_loader,
    return_true=True
)
y_pred = y_pred.numpy()
y_true = y_true.numpy()

In [None]:
fpr, tpr, thresholds = roc_curve(y_true, y_pred)
print('Area under ROC curve: {:.4f}'.format(auc(fpr, tpr)))

In [None]:
precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
print('Area under PR curve: {:.4f}'.format(auc(recall, precision)))

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(8, 4))
axes[0].plot(fpr, tpr)
axes[0].set(title='ROC curve', xlabel='FPR', ylabel='TPR')
axes[1].plot(recall, precision)
axes[1].set(title='PR curve', xlabel='recall', ylabel='precision')
for ax in axes:
    ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
    ax.set_axisbelow(True)
axes[1].yaxis.set_label_position('right')
axes[1].yaxis.tick_right()
fig.tight_layout()