Catalyst supports models with multiple input arguments and multiple outputs.
Suppose that we need to train a siamese network. Firstly, need to create a dataset which will yield pairs of images and same class indicator which later can be used in contrastive loss.
import cv2
import numpy as np
from torch.utils.data import Dataset
class SiameseDataset(Dataset):
def __init__(self, images, labels):
self.images = images
self.labels = labels
def __len__(self):
return len(self.image)
def __getitem__(self, idx):
original_image = ... # load image using `idx`
is_same = np.random.uniform() >= 0.5 # use same or opposite class
if is_same:
pair_image = ... # load image from the same class and with index != `idx`
else:
pair_image = ... # load image from another class
label = torch.FloatTensor([is_same])
return original_image, pair_image, label
# OR
# return {"first": original_image, "second": pair_image, "labels": label}
Do not forget about contrastive loss:
import torch.nn as nn
class ContrastiveLoss(nn.Module):
def __init__(self, margin=1.0):
super().__init__()
self.margin = margin
def forward(self, l2_distance, labels):
# ...
return loss
Suppose you have a model which accepts two tensors - first and second and returns embeddings for an input batches and distance between them:
import torch.nn as nn
class SiameseMNIST(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.layers = nn.Sequential(
nn.Linear(in_features, in_features * 2),
nn.ReLU(),
nn.Linear(in_features * 2, out_features),
)
self.
def get_embeddings(self, batch):
"""Generate embeddings for a given batch of images.
Args:
batch (torch.Tensor): batch with images,
expected shapes - [B, C, H, W].
Returns:
embeddings (torch.Tensor) for a given batch of images,
output shapes - [B, out_features].
"""
return self.layers(batch)
def forward(self, first, second):
"""Forward pass.
Args:
first (torch.Tensor): batch with images,
expected shapes - [B, C, H, W].
second (torch.Tensor): batch with images,
expected shapes - [B, C, H, W].
Returns:
embeddings (torch.Tensor) for a first batch of images,
output shapes - [B, out_features]
embeddings (torch.Tensor) for a second batch of images,
output shapes - [B, out_features]
absolute distance (torch.Tensor) between first and second image embeddings,
output shapes - [B,]
"""
first = self.get_embeddings(first)
second = self.get_embeddings(second)
difference = torch.sqrt(torch.sum(torch.pow(first - second, 2), 1))
return first, second, distance
And then for python API:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from catalyst import dl
dataset = SiameseDataset(...)
loader = DataLoader(dataset, batch_size=32, num_workers=1)
loaders = {"train": loader, "valid": loader}
model = SiameseMNIST(...)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = ContrastiveLoss(margin=1.1)
runner = dl.SupervisedRunner(
input_key=["first", "second"], # specify model inputs, should be the same as in forward method
output_key=["first_emb", "second_emb", "l2_distance"],
target_key=["labels"],
loss_key="loss",
)
runner.train(
model=model,
criterion=criterion,
optimizer=optimizer,
loaders=loaders,
num_epochs=10,
# callbacks=[],
logdir="./siamese_logs",
valid_loader="valid",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
)