# Finetuning a CNN classifier

## Create datasets

In [1]:
from torchvision.datasets import ImageFolder

from lib.cnn_classifiers import train_transform, val_transform


train_ds = ImageFolder(
    'data/GroZiVitroBasic/',
    transform=train_transform
)

val_ds = ImageFolder(
    'data/sodas/queries',
    transform=val_transform
)

## Create data loaders

In [2]:
batch_size = 60
num_workers = 8

In [3]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=train_ds,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
)

val_loader = DataLoader(
    dataset=val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

## Define what should happen in a single training step

In [1]:
from torch import nn
import torch.nn.functional as F


def run_train_step(model, batch, batch_idx):
    imgs, labels = batch

    class_logits = model(imgs)

    return F.binary_cross_entropy_with_logits(
        class_logits,
        labels
    )

## Call the training step on each batch in a training epoch

In [5]:
def run_train_epoch(model, train_loader, optimizer,
                    epoch_idx, writer=None, device='cpu'):
    """
    Run a training epoch.
    """
    # Put model in train mode
    model.train()

    for batch_idx, train_batch in tqdm(enumerate(train_loader),
                                       total=len(train_loader),
                                       leave=False, desc='Train batch'):
        train_batch = batch_to_device(train_batch, device)
        loss = run_train_step(model, train_batch, batch_idx)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Log the training loss
        if writer is not None:
            writer.add_scalar("Loss/train", loss, epoch_idx)


def batch_to_device(batch, device):
    batch[0] = batch[0].to(device)
    batch[1] = batch[1].to(device)

    return batch

## Define what should happen during a *validation epoch*

In [6]:
def run_val_epoch(model, val_loader,
                  epoch_idx, writer, device='cpu'):
    # Put model in eval mode
    model.eval()

    # Compute query embeddings
    sim_mat, q_labels = compute_logits_from_dataloader(model,
                                                       val_loader,
                                                       device)

    # Log average precision
    idx_to_class = {
        idx: label
        for label, idx in val_loader.dataset.class_to_idx.items()
    }

    for label in g_labels:
        ap = calc_ap(label, sim_mat, g_labels, q_labels)
        writer.add_scalar(f"AP_val/{idx_to_class[label]}", ap, epoch_idx)


def compute_logits_from_dataloader(model, data_loader, device='cpu'):
    all_logits = []
    all_labels = []

    for batch_idx, batch in tqdm(enumerate(data_loader),
                                 total=len(data_loader),
                                 leave=False):
        batch = batch_to_device(batch, device)
        imgs, labels = batch

        with torch.no_grad():
            logits = model(imgs)

        all_logits.append(logits)
        all_labels.append(labels)

    all_logits = torch.cat(all_logits).cpu()
    all_labels = torch.cat(all_labels).cpu().numpy()

    return all_logits, all_labels

## Put everything together in a training loop

In [7]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch

from lib.metric_learning import match_embeddings
from lib.evaluation_metrics import calc_ap


def run_training(model, optimizer, train_loader, val_gallery_loader,
                 val_query_loader, num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    writer = SummaryWriter()

    for epoch_idx in tqdm(range(num_epochs), desc='Epoch'):
        run_train_epoch(model, train_loader, optimizer,
                        epoch_idx, writer, device)
        run_val_epoch(model, val_gallery_loader, val_query_loader,
                      epoch_idx, writer, device)

    writer.flush()
    writer.close()

## Start TensorBoard for logging

In [8]:
import tensorboard

%load_ext tensorboard
%tensorboard --logdir runs

## Define the model

In [9]:
from lib.cnn_classifiers import get_cnn

model = get_cnn('resnet50')

# TODO: Replace classification layer with new number of classes
raise NotImplementedError

for param in model.parameters():
    param.requires_grad = False

for param in model.layer4.parameters():
    param.requires_grad = True

# TODO: classifier layer also requires grad

## Run the training loop

In [None]:
from torch.optim import SGD

# TODO: parameters of classifier layer also need to be optimized
optimizer = SGD(model.layer4.parameters(), lr=0.1)

run_training(model, optimizer, train_loader,
             val_gallery_loader, val_query_loader,
             num_epochs=50)

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/33 [00:00<?, ?it/s]

In [None]:
%debug

## Save the trained model

In [None]:
state_dict = model.state_dict()
torch.save(state_dict, 'finetuned_embedding.pth')