In [1]:
from torchvision.datasets import ImageFolder

from lib.cnn_classifiers import train_transform, val_transform


train_ds = ImageFolder(
    'data/GroZiVitroBasic/',
    transform=train_transform
)

val_gallery_ds = ImageFolder(
    'data/sodas/gallery',
    transform=val_transform
)

val_query_ds = ImageFolder(
    'data/sodas/queries',
    transform=val_transform
)

In [2]:
from lib.triplet_sampler import TripletSampler, split_triplet_tensor

In [3]:
batch_size = 3 * 10
num_workers = 0

In [4]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=train_ds,
    batch_sampler=TripletSampler(train_ds, batch_size=batch_size),
    num_workers=num_workers,
)

val_gallery_loader = DataLoader(
    dataset=val_gallery_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

val_query_loader = DataLoader(
    dataset=val_query_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

In [5]:
from torch import nn


triplet_loss = nn.TripletMarginLoss(margin=1.0)


def model_forward(model, imgs):
    embeddings = model(imgs)
    embeddings = torch.squeeze(embeddings)
    return embeddings / embeddings.norm(dim=1, keepdim=True)


def train_step(model, batch, batch_idx):
    imgs, labels = batch

    (a_labels,
     p_labels,
     n_labels) = split_triplet_tensor(labels)
    assert torch.all(a_labels == p_labels)
    assert torch.all(a_labels != n_labels)

    embeddings = model_forward(model, imgs)

    (a_embs,
     p_embs,
     n_embs) = split_triplet_tensor(embeddings)

    return triplet_loss(a_embs, p_embs, n_embs)

In [6]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
import torch

from lib.metric_learning import match_embeddings
from lib.evaluation_metrics import calc_ap

def run_train(model, optimizer, train_loader, val_gallery_loader,
              val_query_loader, num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    
    idx_to_class = {
        idx: label
        for label, idx in val_gallery_loader.dataset.class_to_idx.items()
    }

    writer = SummaryWriter()

    for epoch in tqdm(range(num_epochs), desc='Epoch'):
        # Training epoch
        model.train()
        for batch_idx, train_batch in tqdm(enumerate(train_loader),
                                           total=len(train_loader),
                                           leave=False, desc='Train batch'):
            train_batch = batch_to_device(train_batch, device)
            loss = train_step(model, train_batch, batch_idx)
            writer.add_scalar("Loss/train", loss, epoch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Validation epoch
        model.eval()

        # Create gallery embeddings
        gallery_embeddings = []
        gallery_labels = []
        for batch_idx, val_batch in tqdm(enumerate(val_gallery_loader),
                                         total=len(val_gallery_loader),
                                         leave=False, desc='Val gallery batch'):
            val_batch = batch_to_device(val_batch, device)
            imgs, labels = val_batch

            with torch.no_grad():
                embeddings = model_forward(model, imgs)

            gallery_embeddings.append(embeddings)
            gallery_labels.append(labels)

        # Create query embeddings
        query_embeddings = []
        query_labels = []
        for batch_idx, val_batch in tqdm(enumerate(val_query_loader),
                                         total=len(val_query_loader),
                                         leave=False, desc='Val query batch'):
            val_batch = batch_to_device(val_batch, device)
            imgs, labels = val_batch

            with torch.no_grad():
                embeddings = model_forward(model, imgs)

            query_embeddings.append(embeddings)
            query_labels.append(labels)


        gallery_embeddings = torch.cat(gallery_embeddings).cpu()
        query_embeddings = torch.cat(query_embeddings).cpu()
        gallery_labels = torch.cat(gallery_labels).cpu().numpy()
        query_labels = torch.cat(query_labels).cpu().numpy()

        sim_mat = match_embeddings(gallery_embeddings, query_embeddings).numpy()
        
        for label in gallery_labels:
            ap = calc_ap(label, sim_mat, gallery_labels, query_labels)
            writer.add_scalar(f"AP_val/{idx_to_class[label]}", ap, epoch)

    writer.flush()
    writer.close()


def batch_to_device(batch, device):
    batch[0] = batch[0].to(device)
    batch[1] = batch[1].to(device)

    return batch

In [7]:
import tensorboard

%load_ext tensorboard
%tensorboard --logdir runs

In [None]:
from lib.metric_learning import get_cut_off_cnn
from torch.optim import SGD


model = get_cut_off_cnn('resnet50')

for param in model.parameters():
    param.requires_grad = False

for param in model.layer4.parameters():
    param.requires_grad = True

optimizer = SGD(model.layer4.parameters(), lr=0.01)

run_train(model, optimizer, train_loader,
          val_gallery_loader, val_query_loader,
          num_epochs=50)

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Val gallery batch:   0%|          | 0/1 [00:00<?, ?it/s]

Val query batch:   0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Val gallery batch:   0%|          | 0/1 [00:00<?, ?it/s]

Val query batch:   0%|          | 0/1 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]