Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] faq docs #1202

Merged
merged 8 commits into from May 17, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
145 changes: 145 additions & 0 deletions docs/faq/multiple_keys.rst
@@ -0,0 +1,145 @@
Multiple input and output keys
==============================================================================

Catalyst supports models with multiple input arguments and multiple outputs.

Suppose that we need to train a siamese network.
Firstly, need to create a dataset which will yield pairs of images and same class indicator
which later can be used in contrastive loss.


.. code-block:: python

import cv2
import numpy as np
from torch.utils.data import Dataset

class SiameseDataset(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels

def __len__(self):
return len(self.image)

def __getitem__(self, idx):
original_image = ... # load image using `idx`
is_same = np.random.uniform() >= 0.5 # use same or opposite class
if is_same:
pair_image = ... # load image from the same class and with index != `idx`
else:
pair_image = ... # load image from another class
label = torch.FloatTensor([is_same])
return original_image, pair_image, label
# OR
# return {"first": original_image, "second": pair_image, "labels": label}


Do not forget about contrastive loss:

.. code-block:: python

import torch.nn as nn

class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin

def forward(self, l2_distance, labels):
# ...
return loss



Suppose you have a model which accepts two tensors - `first` and `second`
and returns embeddings for an input batches and distance between them:

.. code-block:: python

import torch.nn as nn

class SiameseMNIST(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_features, in_features * 2),
nn.ReLU(),
nn.Linear(in_features * 2, out_features),
)
self.

def get_embeddings(self, batch):
"""Generate embeddings for a given batch of images.

Args:
batch (torch.Tensor): batch with images,
expected shapes - [B, C, H, W].

Returns:
embeddings (torch.Tensor) for a given batch of images,
output shapes - [B, out_features].
"""
return self.layers(batch)


def forward(self, first, second):
"""Forward pass.

Args:
first (torch.Tensor): batch with images,
expected shapes - [B, C, H, W].
second (torch.Tensor): batch with images,
expected shapes - [B, C, H, W].

Returns:
embeddings (torch.Tensor) for a first batch of images,
output shapes - [B, out_features]
embeddings (torch.Tensor) for a second batch of images,
output shapes - [B, out_features]
absolute distance (torch.Tensor) between first and second image embeddings,
output shapes - [B,]
"""
first = self.get_embeddings(first)
second = self.get_embeddings(second)
difference = torch.sqrt(torch.sum(torch.pow(first - second, 2), 1))
return first, second, distance


And then for python API:

.. code-block:: python

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl

dataset = SiameseDataset(...)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}

model = SiameseMNIST(...)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = ContrastiveLoss(margin=1.1)

runner = dl.SupervisedRunner(
input_key=["first", "second"], # specify model inputs, should be the same as in forward method
output_key=["first_emb", "second_emb", "l2_distance"],
target_key=["labels"],
loss_key="loss",
)
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=10,
# callbacks=[],
logdir="./siamese_logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)
1 change: 1 addition & 0 deletions docs/index.rst
Expand Up @@ -227,6 +227,7 @@ Indices and tables
faq/ddp

faq/multi_components
faq/multiple_keys
faq/early_stopping
faq/checkpointing
faq/debugging
Expand Down