In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [102]:
random_tensor = torch.rand(16, 12, 12000)
batch, seq, vocab = random_tensor.size(0), random_tensor.size(1), random_tensor.size(2)

In [103]:
labels = torch.randint(0, 12000, (16, 12))

# Custom NLL loss calculation

In [136]:
torch.manual_seed(42)
total_loss = 0
total_tokens = 0
for _ in range(5):
    random_tensor = torch.rand(16, 12, 12000)
    batch, seq, vocab = random_tensor.size(0), random_tensor.size(1), random_tensor.size(2)
    labels = torch.randint(0, 12000, (16, 12))
    log_probs = F.log_softmax(random_tensor, 2)
    labels = labels.unsqueeze(2)
    log_probs_gathered = torch.gather(log_probs, 2, labels).squeeze(2).sum(1)
    nll_loss = -log_probs_gathered.mean()
    # Would backpropagate here: nll_loss.backward()
    
    total_loss += nll_loss.item()*batch
    total_tokens += batch*seq 
avg_loss = total_loss/total_tokens
print(f"Average loss: {avg_loss}")
ppl = torch.exp(torch.tensor(avg_loss))
print(f"PPL: {ppl.item()}")

Average loss: 9.441426595052084
PPL: 12599.6748046875


# PyTorch NLLLoss()

In [142]:
torch.manual_seed(42)
criterion = nn.NLLLoss(reduction='sum')
total_loss = 0
total_tokens = 0
for _ in range(5):
    random_tensor = torch.rand(16, 12, 12000)
    batch, seq, vocab = random_tensor.size(0), random_tensor.size(1), random_tensor.size(2)
    labels = torch.randint(0, 12000, (16, 12))
    log_probs = F.log_softmax(random_tensor, 2)
    log_probs = log_probs.view(batch*seq, -1)
    labels = labels.view(batch*seq)
    pt_loss = criterion(log_probs, labels)
    print(pt_loss)
    # Would backpropagate here: pt_loss.backward()
    
    total_loss += pt_loss.item()#*batch#*seq
    total_tokens += batch*seq
avg_loss = total_loss/total_tokens
print(f"Average loss: {avg_loss}")
ppl = torch.exp(torch.tensor(avg_loss))
print(f"PPL: {ppl.item()}")

tensor(1814.5616)
tensor(1812.5277)
tensor(1814.5669)
tensor(1809.8182)
tensor(1812.2954)
Average loss: 9.441426976521809
PPL: 12599.6865234375


In [125]:
random_tensor = torch.rand(16, 12, 12000)
batch, seq, vocab = random_tensor.size(0), random_tensor.size(1), random_tensor.size(2)
labels = torch.randint(0, 12000, (16, 12))
log_probs = F.log_softmax(random_tensor, 2)
log_probs = log_probs.view(batch*seq, -1)
labels = labels.view(batch*seq)
pt_loss = criterion(log_probs, labels)

In [127]:
pt_loss*seq

tensor(113.1933)

In [16]:
mask = torch.tril(torch.ones(10, 10, dtype=torch.bool, device="cpu"), diagonal=0)

In [18]:
mask = mask.long()

In [19]:
mask

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])