Skip to content

Latest commit

 

History

History
145 lines (111 loc) · 4.5 KB

multiple_keys.rst

File metadata and controls

145 lines (111 loc) · 4.5 KB

Multiple input and output keys

Catalyst supports models with multiple input arguments and multiple outputs.

Suppose that we need to train a siamese network. Firstly, need to create a dataset which will yield pairs of images and same class indicator which later can be used in contrastive loss.

import cv2
import numpy as np
from torch.utils.data import Dataset

class SiameseDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.image)

    def __getitem__(self, idx):
        original_image = ... # load image using `idx`
        is_same = np.random.uniform() >= 0.5  # use same or opposite class
        if is_same:
            pair_image = ... # load image from the same class and with index != `idx`
        else:
            pair_image = ... # load image from another class
        label = torch.FloatTensor([is_same])
        return original_image, pair_image, label
        # OR
        # return {"first": original_image, "second": pair_image, "labels": label}

Do not forget about contrastive loss:

import torch.nn as nn

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super().__init__()
        self.margin = margin

    def forward(self, l2_distance, labels):
        # ...
        return loss

Suppose you have a model which accepts two tensors - first and second and returns embeddings for an input batches and distance between them:

import torch.nn as nn

class SiameseMNIST(nn.Module):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(in_features, in_features * 2),
            nn.ReLU(),
            nn.Linear(in_features * 2, out_features),
        )
        self.

    def get_embeddings(self, batch):
        """Generate embeddings for a given batch of images.

        Args:
            batch (torch.Tensor): batch with images,
                expected shapes - [B, C, H, W].

        Returns:
            embeddings (torch.Tensor) for a given batch of images,
                output shapes - [B, out_features].
        """
        return self.layers(batch)


    def forward(self, first, second):
        """Forward pass.

        Args:
            first (torch.Tensor): batch with images,
                expected shapes - [B, C, H, W].
            second (torch.Tensor): batch with images,
                expected shapes - [B, C, H, W].

        Returns:
            embeddings (torch.Tensor) for a first batch of images,
                output shapes - [B, out_features]
            embeddings (torch.Tensor) for a second batch of images,
                output shapes - [B, out_features]
            absolute distance (torch.Tensor) between first and second image embeddings,
                output shapes - [B,]
        """
        first = self.get_embeddings(first)
        second = self.get_embeddings(second)
        difference = torch.sqrt(torch.sum(torch.pow(first - second, 2), 1))
        return first, second, distance

And then for python API:

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl

dataset = SiameseDataset(...)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

model = SiameseMNIST(...)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = ContrastiveLoss(margin=1.1)

runner = dl.SupervisedRunner(
    input_key=["first", "second"], # specify model inputs, should be the same as in forward method
    output_key=["first_emb", "second_emb", "l2_distance"],
    target_key=["labels"],
    loss_key="loss",
)
runner.train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    loaders=loaders,
    num_epochs=10,
    # callbacks=[],
    logdir="./siamese_logs",
    valid_loader="valid",
    valid_metric="loss",
    minimize_valid_metric=True,
    verbose=True,
    load_best_on_end=True,
)