In [1]:
from torchvision.datasets import ImageFolder

from lib.cnn_classifiers import train_transform, val_transform


train_ds = ImageFolder(
    'data/GroZiVitroBasic/',
    transform=train_transform
)

val_ds = ImageFolder(
    'data/sodas/',
    transform=val_transform
)

In [2]:
import numpy as np
import torch
from torch.utils.data.sampler import Sampler
import random


class TripletSampler(Sampler):
    """
    Return batches with anchors, positives and negatives.
    """
    def __init__(self, dataset, batch_size, shuffle=True):
        if not batch_size % 3 == 0:
            raise ValueError(
                'Batch size should be divisible by 3.'
            )

        _, sample_labels = zip(*dataset.samples)
        self.sample_labels = np.array(sample_labels)
        self.batch_size = batch_size
        self.shuffle = shuffle

    def __len__(self):
        return len(self.sample_labels) // (self.batch_size // 3)

    def __iter__(self):
        sample_idxs = list(range(len(self.sample_labels)))

        if self.shuffle:
            random.shuffle(sample_idxs)

        batch = []
        for anchor_idx in sample_idxs:
            anchor_label = self.sample_labels[anchor_idx]

            # Find sample indices with same label as anchor
            pos_idxs = np.where(self.sample_labels == anchor_label)[0]
            # Drop the anchor sample index itself from the positives
            pos_idxs = pos_idxs[pos_idxs != anchor_idx]

            if len(pos_idxs) == 0:
                continue

            neg_idxs = list(np.where(self.sample_labels != anchor_label)[0])

            pos_idx = np.random.choice(pos_idxs, 1)[0]
            neg_idx = np.random.choice(neg_idxs, 1)[0]

            batch.extend([anchor_idx, pos_idx, neg_idx])

            if len(batch) == self.batch_size:
                yield batch
                batch = []

In [3]:
batch_size = 3 * 10
num_workers = 0

In [4]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    dataset=train_ds,
    batch_sampler=TripletSampler(train_ds, batch_size=batch_size),
    num_workers=num_workers,
)

val_loader = DataLoader(
    dataset=val_ds,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
)

In [5]:
def split_triplet_tensor(tensor):
    a_tensor =  tensor[0::3]
    p_tensor =  tensor[1::3]
    n_tensor =  tensor[2::3]

    return (
        a_tensor, p_tensor, n_tensor
    )

In [6]:
from torch import nn


triplet_loss = nn.TripletMarginLoss(margin=1.0)


def train_step(model, batch, batch_idx):
    imgs, labels = batch

    (a_labels,
     p_labels,
     n_labels) = split_triplet_tensor(labels)
    assert torch.all(a_labels == p_labels)
    assert torch.all(a_labels != n_labels)

    embeddings = model(imgs)

    (a_embs,
     p_embs,
     n_embs) = split_triplet_tensor(embeddings)

    return triplet_loss(a_embs, p_embs, n_embs)


def val_step(model, batch, batch_idx):
    pass

In [9]:
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter


def run_train(model, optimizer, train_loader, val_loader=None,
              num_epochs=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)

    writer = SummaryWriter()

    for epoch in tqdm(range(num_epochs), desc='Epoch'):
        # Training epoch
        model.train()
        for batch_idx, train_batch in tqdm(enumerate(train_loader),
                                           total=len(train_loader),
                                           leave=False, desc='Train batch'):
            train_batch[0] = train_batch[0].to(device)
            train_batch[1] = train_batch[1].to(device)
            loss = train_step(model, train_batch, batch_idx)
            writer.add_scalar("Loss/train", loss, epoch)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        if val_loader is None:
            continue

        # Validation epoch
        val_outputs = []
        model.eval()
        for batch_idx, val_batch in tqdm(enumerate(val_loader)):
            val_batch.to(device)

            with torch.no_grad():
                val_output = val_step(val_batch)

            val_outputs.append(val_output)

    writer.flush()
    writer.close()

In [11]:
from lib.cnn_classifiers import get_cut_off_cnn
from torch.optim import SGD


model = get_cut_off_cnn('resnet50')

for param in model.parameters():
    param.requires_grad = False

for param in model.layer4.parameters():
    param.requires_grad = True

optimizer = SGD(model.layer4.parameters(), lr=0.01)

run_train(model, optimizer, train_loader, num_epochs=50)

Epoch:   0%|          | 0/50 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

Train batch:   0%|          | 0/67 [00:00<?, ?it/s]

In [18]:
import tensorboard

In [19]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [22]:
%tensorboard --logdir runs