# Transfer learning a CNN classifier

The pretrained CNN classifiers are clearly already capable of extracting some meaningful information out of an input image. However, they are not suitable as-is for our use-case. The classifiers do not contain any of the classes that we are interested in. To solve this, we will **replace the final classification layer** with a  layer that outputs the correct number of classes (four in our case: CocaCola, Fanta, Pepsi and Sprite). Next, we train that layer - along with the final layers of the network - on our training data.

## Create datasets

As usual, we want to validate our model on unseen data. As such, we need two data sets.

Think of a PyTorch `Dataset` simply as a list of all items that are present in your dataset. For example, you can index a `Dataset` instance using square bracket notation and you can ask the number of items in the dataset using `len()`. In our case, such an item consists of an image and the corresponding class label.

When your data has a rather unconventional structure, you will need to write your own sub-class of `Dataset`. Our data, however, is clearly structured: all images are inside a folder that has the name of the class. In such a case, we can make use of `ImageFolder` to create our `Dataset` object.

In [1]:
from torchvision.datasets import ImageFolder

from lib.cnn_classifiers import train_transform, val_transform


train_ds = ImageFolder(
    'data/sodas/train/',
    transform=train_transform
)

val_ds = ImageFolder(
    'data/sodas/query/',
    transform=val_transform
)

## Create data loaders

During training, we will pass **batches** of data through the network. These batches are created by `DataLoader`s. The data loading process is rather expensive, as it requires lots of IO-operations and often also includes computing image transformations. If we would do this one image at a time, we will create a serious bottleneck when each batch is created.

To avoid such a bottleneck, `DataLoader`s run **multiple threads**. Each of these threads will **take samples from the dataset** and perform the IO-operations and image transformations. With `num_workers`, you can configure the number of threads you want to use. A sane choice is **the number of CPUs on your device**.

We create a data loader for both datasets.

In [2]:
batch_size = 12  # The number of images in each batch
num_workers = 12  # Use the number of CPUs here

In [3]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

val_loader = DataLoader(
    dataset=val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

## Define what should happen in a single training step

A training loop consists of multiple training epochs and a training epoch consists of multiple training steps. That's where the most important calculation happens: you define **how to calculate the loss**. For classification problems, **cross-entropy loss** is frequently used. We will use it here as well.

In [4]:
from torch import nn
import torch.nn.functional as F


def run_train_step(model, batch, batch_idx, num_classes):
    imgs, labels = batch

    # When we pass these "logits" through a softmax function, we will get a "probability distribution"
    # over all classes.
    class_logits = model(imgs)

    # Convert the integer labels to one-hot encoded vectors, which is needed to compute CEL
    one_hot_targets = F.one_hot(labels, num_classes)

    return F.binary_cross_entropy_with_logits(
        class_logits,
        one_hot_targets.float()  # .float() Necessary for CEL computation
    )

## Call the training step on each batch in a training epoch

Now that we have our smallest step in the training loop, we can put it into an entire training **epoch**. When the epoch is over, the neural network will have seen each training sample once.

In [5]:
def run_train_epoch(model, train_loader, optimizer,
                    epoch_idx, writer=None, device='cpu'):
    """
    Run a training epoch.
    """
    # Put model in train mode
    model.train()

    num_classes = len(train_loader.dataset.classes)

    for batch_idx, train_batch in tqdm(enumerate(train_loader),
                                       total=len(train_loader),
                                       leave=False, desc='Train batch'):
        train_batch = batch_to_device(train_batch, device)
        loss = run_train_step(model, train_batch, batch_idx,
                              num_classes)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log the training loss
        if writer is not None:
            writer.add_scalar("Loss/train", loss, epoch_idx)


def batch_to_device(batch, device):
    batch[0] = batch[0].to(device)
    batch[1] = batch[1].to(device)

    return batch

## Define what should happen during a *validation epoch*

Of course, we want to check how well our network performs on **unseen** data. Therefore, we will also add a **validation loop** to our training.

As mentioned before, applying softmax to the logits that the network returns, will yield a sort of probability distribution over all class labels. For each query, such a probability distribution can thus be formed. This is actually also a **similarity matrix**, like the one we have seen in the first notebook.

In [6]:
import numpy as np

def run_val_epoch(model, val_loader, epoch_idx, writer, device='cpu'):
    # Put model in eval mode
    model.eval()

    logits, q_labels = compute_logits_from_dataloader(
        model,
        val_loader,
        device
    )
    
    # Compute similarity matrix by applying softmax to logits
    sim_mat = F.softmax(logits, dim=1)

    # Log average precision
    idx_to_class = {
        idx: class_name
        for class_name, idx in val_loader.dataset.class_to_idx.items()
    }

    # Create an array with the labels (indices) in the dataset
    uniq_labels = np.array(list(idx_to_class))

    for label in uniq_labels:
        ap = calc_ap(label, sim_mat, uniq_labels, q_labels)
        writer.add_scalar(f"AP_val/{idx_to_class[label]}", ap, epoch_idx)


def compute_logits_from_dataloader(model, data_loader, device='cpu'):
    all_logits = []
    all_labels = []

    for batch_idx, batch in tqdm(enumerate(data_loader),
                                 total=len(data_loader),
                                 leave=False):
        batch = batch_to_device(batch, device)
        imgs, labels = batch

        with torch.no_grad():
            logits = model(imgs)

        all_logits.append(logits)
        all_labels.append(labels)

    all_logits = torch.cat(all_logits).cpu()
    all_labels = torch.cat(all_labels).cpu().numpy()

    return all_logits, all_labels

## Put everything together in a training loop

In [7]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch

from lib.metric_learning import match_embeddings
from lib.evaluation_metrics import calc_ap


def run_training(model, optimizer, train_loader, val_loader, num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    writer = SummaryWriter()

    for epoch_idx in tqdm(range(num_epochs), desc='Epoch'):
        run_train_epoch(model, train_loader, optimizer,
                        epoch_idx, writer, device)
        run_val_epoch(model, val_loader, epoch_idx, writer,
                      device)

    writer.flush()
    writer.close()

## Start TensorBoard for logging

In [8]:
import tensorboard

%load_ext tensorboard
%tensorboard --logdir runs

Reusing TensorBoard on port 6006 (pid 24558), started 0:18:54 ago. (Use '!kill 24558' to kill it.)

## Define the model

In [9]:
from lib.cnn_classifiers import get_cnn_clf, get_top_parameters


model = get_cnn_clf("resnet50", num_classes=4)

# Get the parameters at the end of the network
# These are the ones we will be training
top_parameters = get_top_parameters(model)

# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Now unfreeze the top ones
for param in top_parameters:
    param.requires_grad = True

## Run the training loop

In [10]:
from torch.optim import SGD

# The top parameters need to be optimized
optimizer = SGD(top_parameters, lr=0.1)

run_training(model, optimizer, train_loader,
             val_loader, num_epochs=25)

Epoch:   0%|          | 0/25 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

Train batch:   0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

## Save the trained model

In [11]:
state_dict = model.state_dict()
torch.save(state_dict, 'tl_cnn_clf.pth')