In [1]:
import torch 
from torch import nn, Tensor
import torch.functional as F

In [3]:
# Initialize embeddings
embedding = nn.Embedding(1000, 128)
anchor_ids = torch.randint(0, 1000, (1,))
positive_ids = torch.randint(0, 1000, (1,))
negative_ids = torch.randint(0, 1000, (1,))


In [6]:
anchor = torch.randn((32, 60, 24))
positive = torch.randn((32, 60, 24))
negative = torch.randn((32, 60, 24))

In [7]:
anchor.shape

torch.Size([32, 60, 24])

In [8]:
triplet_loss = nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
output = triplet_loss(anchor, positive, negative)

In [10]:
output

tensor(1.1234)

---

In [2]:
input1 = torch.randn(3, requires_grad=True)
input1

tensor([ 1.5075, -1.4304, -1.1908], requires_grad=True)

In [4]:
target = torch.randn(3).sign()
target

tensor([-1., -1., -1.])

---

In [None]:
loss = nn.MultiLabelMarginLoss()
x = torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]])
# for target y, only consider labels 3 and 0, not after label -1
y = torch.LongTensor([[3, 0, -1, 1]])
# 0.25 * ((1-(0.1-0.2)) + (1-(0.1-0.4)) + (1-(0.8-0.2)) + (1-(0.8-0.4)))


In [None]:
loss(x, y)

---

In [5]:
# Target are to be padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
S = 30      # Target sequence length of longest target in batch (padding length)
S_min = 10  # Minimum target length, for demonstration purposes
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
# Initialize random batch of targets (0 = blank, 1:C = classes)
target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)


In [6]:
input.shape

torch.Size([50, 16, 20])

In [7]:
target.shape

torch.Size([16, 30])

In [8]:
input_lengths.shape

torch.Size([16])

In [None]:
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
N = 16      # Batch size
# Initialize random batch of input vectors, for *size = (T,N,C)
input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()
# Target are to be un-padded and unbatched (effectively N=1)
T = 50      # Input sequence length
C = 20      # Number of classes (including blank)
# Initialize random batch of input vectors, for *size = (T,C)
input = torch.randn(T, C).log_softmax(2).detach().requires_grad_()
input_lengths = torch.tensor(T, dtype=torch.long)
# Initialize random batch of targets (0 = blank, 1:C = classes)
target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
ctc_loss = nn.CTCLoss()
loss = ctc_loss(input, target, input_lengths, target_lengths)
loss.backward()

In [None]:
# Built-in Distance Function
triplet_loss = \
    nn.TripletMarginWithDistanceLoss(distance_function=nn.PairwiseDistance())
output = triplet_loss(anchor, positive, negative)
output.backward()
# Custom Distance Function
def l_infinity(x1, x2):
    return torch.max(torch.abs(x1 - x2), dim=1).values
triplet_loss = (
    nn.TripletMarginWithDistanceLoss(distance_function=l_infinity, margin=1.5))
output = triplet_loss(anchor, positive, negative)
output.backward()
# Custom Distance Function (Lambda)
triplet_loss = (
    nn.TripletMarginWithDistanceLoss(
        distance_function=lambda x, y: 1.0 - F.cosine_similarity(x, y)))
output = triplet_loss(anchor, positive, negative)
output.backward()