In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(sci_mode=False)

In [2]:
input_tensor = torch.randn(5, 3, requires_grad=True)
target_tensor = torch.randn(5, 3)

target_tensor = F.softmax(target_tensor, dim=1)
print(input_tensor.shape)
print(input_tensor)
print(target_tensor.shape)
print(target_tensor)


torch.Size([5, 3])
tensor([[-0.8394,  1.6679,  0.5327],
        [ 0.5723,  0.3403, -1.8713],
        [ 0.7562,  0.2992,  0.8017],
        [ 0.0177, -0.0691, -0.6578],
        [-0.6957,  1.6073,  0.4324]], requires_grad=True)
torch.Size([5, 3])
tensor([[0.0891, 0.2255, 0.6854],
        [0.3820, 0.4924, 0.1256],
        [0.2391, 0.6474, 0.1135],
        [0.1500, 0.3117, 0.5384],
        [0.1093, 0.7942, 0.0965]])


In [39]:
# [batch, seq_len, hidden_size]

input_tensor = torch.tensor(
    [
        [
            [0.6677, 1.0659, -2.6843],
            [-1.7150, -0.6264, 1.6480],
            [0.1499, 0.9312, -0.8252],
            [-0.0703, 2.6022, 0.8705],
            [-0.2852, -1.4555, -0.5974],
        ],
        [
            [1.3257, -1.1151, 0.1241],
            [-3.5583, 0.1384, -0.6323],
            [-0.2360, -0.1520, -0.3012],
            [0.6702, 0.7576, 1.5766],
            [-0.8387, -0.0367, -0.5284],
        ],
    ]
)

target_tensor = torch.tensor(
    [
        [
            [0.3662, 0.3598, 0.2739],
            [0.5289, 0.2933, 0.1778],
            [0.8479, 0.0478, 0.1043],
            [-100, -100, -100],
            [-100, -100, -100],
        ],
        [
            [0.0752, 0.8602, 0.0647],
            [0.3488, 0.2728, 0.3784],
            [0.8237, 0.1216, 0.0547],
            [0.1242, 0.7966, 0.0792],
            [-100, -100, -100],
        ],
    ]
)

print("input_tensor", input_tensor.shape)
print(input_tensor)
print("target_tensor", target_tensor.shape)
print(target_tensor)

input_tensor = input_tensor.flatten(end_dim=1)
target_tensor = target_tensor.flatten(end_dim=1)


print("input_tensor", input_tensor.shape)
print(input_tensor)
print("target_tensor", target_tensor.shape)
print(target_tensor)

print("\n", "*" * 20 + "Masking" + "*" * 20, "\n")

mask = None
if mask is None:
    mask = ~torch.any(target_tensor == -100, dim=1)

print(mask)

input_tensor = input_tensor[mask]
target_tensor = target_tensor[mask]

input_tensor = F.softmax(input_tensor, dim=1)
print("input_tensor", input_tensor.shape)
print(input_tensor)
print("target_tensor", target_tensor.shape)
print(target_tensor)


kl_loss = nn.KLDivLoss(reduction="batchmean")

print("\n", "*" * 20 + "KL Div 1" + "*" * 20, "\n")
output_1 = kl_loss(torch.log(input_tensor), target_tensor)
print(output_1)

print("\n", "*" * 20 + "KL Div 2" + "*" * 20, "\n")
output_2 = kl_loss(torch.log(target_tensor), input_tensor)
print(output_2)


print(output_1 + output_2)

input_tensor torch.Size([2, 5, 3])
tensor([[[ 0.6677,  1.0659, -2.6843],
         [-1.7150, -0.6264,  1.6480],
         [ 0.1499,  0.9312, -0.8252],
         [-0.0703,  2.6022,  0.8705],
         [-0.2852, -1.4555, -0.5974]],

        [[ 1.3257, -1.1151,  0.1241],
         [-3.5583,  0.1384, -0.6323],
         [-0.2360, -0.1520, -0.3012],
         [ 0.6702,  0.7576,  1.5766],
         [-0.8387, -0.0367, -0.5284]]])
target_tensor torch.Size([2, 5, 3])
tensor([[[     0.3662,      0.3598,      0.2739],
         [     0.5289,      0.2933,      0.1778],
         [     0.8479,      0.0478,      0.1043],
         [  -100.0000,   -100.0000,   -100.0000],
         [  -100.0000,   -100.0000,   -100.0000]],

        [[     0.0752,      0.8602,      0.0647],
         [     0.3488,      0.2728,      0.3784],
         [     0.8237,      0.1216,      0.0547],
         [     0.1242,      0.7966,      0.0792],
         [  -100.0000,   -100.0000,   -100.0000]]])
input_tensor torch.Size([10, 3])
tensor([

tensor(1.0687)
tensor(0.9326)


In [40]:
logits = torch.tensor(
    [
        [
            [0.6677, 1.0659, -2.6843],
            [-1.7150, -0.6264, 1.6480],
            [0.1499, 0.9312, -0.8252],
            [-0.0703, 2.6022, 0.8705],
            [-0.2852, -1.4555, -0.5974],
        ],
        [
            [1.3257, -1.1151, 0.1241],
            [-3.5583, 0.1384, -0.6323],
            [-0.2360, -0.1520, -0.3012],
            [0.6702, 0.7576, 1.5766],
            [-0.8387, -0.0367, -0.5284],
        ],
    ]
)

labels = torch.tensor(
    [
        [
            [0.3662, 0.3598, 0.2739],
            [0.5289, 0.2933, 0.1778],
            [0.8479, 0.0478, 0.1043],
            [-100, -100, -100],
            [-100, -100, -100],
        ],
        [
            [0.0752, 0.8602, 0.0647],
            [0.3488, 0.2728, 0.3784],
            [0.8237, 0.1216, 0.0547],
            [0.1242, 0.7966, 0.0792],
            [-100, -100, -100],
        ],
    ]
)

print(logits)
print(labels)
print()

logits = F.softmax(logits, dim=2)

tensor_truth = labels.flatten(end_dim=1)
tensor_pred = logits.flatten(end_dim=1)

# tensor_pred = F.softmax(tensor_pred, dim=1)

mask = None
if mask is None:
    mask = ~torch.any(tensor_truth == -100, dim=1)

tensor_pred = tensor_pred[mask]
tensor_truth = tensor_truth[mask]

print(tensor_pred)
print(tensor_truth)
print()

loss_fct = nn.KLDivLoss(reduction="batchmean")
loss_1 = loss_fct(torch.log(tensor_pred), tensor_truth)
loss_2 = loss_fct(torch.log(tensor_truth), tensor_pred)

loss = loss_1 + loss_2

print(loss_1)
print(loss_2)
print(loss)


tensor([[[ 0.6677,  1.0659, -2.6843],
         [-1.7150, -0.6264,  1.6480],
         [ 0.1499,  0.9312, -0.8252],
         [-0.0703,  2.6022,  0.8705],
         [-0.2852, -1.4555, -0.5974]],

        [[ 1.3257, -1.1151,  0.1241],
         [-3.5583,  0.1384, -0.6323],
         [-0.2360, -0.1520, -0.3012],
         [ 0.6702,  0.7576,  1.5766],
         [-0.8387, -0.0367, -0.5284]]])
tensor([[[     0.3662,      0.3598,      0.2739],
         [     0.5289,      0.2933,      0.1778],
         [     0.8479,      0.0478,      0.1043],
         [  -100.0000,   -100.0000,   -100.0000],
         [  -100.0000,   -100.0000,   -100.0000]],

        [[     0.0752,      0.8602,      0.0647],
         [     0.3488,      0.2728,      0.3784],
         [     0.8237,      0.1216,      0.0547],
         [     0.1242,      0.7966,      0.0792],
         [  -100.0000,   -100.0000,   -100.0000]]])

tensor([[0.3962, 0.5900, 0.0139],
        [0.0304, 0.0904, 0.8791],
        [0.2808, 0.6133, 0.1059],
        [

In [None]:
# collated_dataset = data_collator(features=dataset)

# collated_dataset["input_ids"]

# attention_mask = collated_dataset["attention_mask"]
# mask = attention_mask.unsqueeze(1).float()

# print(mask.shape)
# print(mask)
# print(*["-" * 100] * 10, sep="\n")
# print(attention_mask.shape)
# print(attention_mask)

# mask * torch.randn(8, 360, 1024)