### Reference

In [58]:
import torch

In [59]:
def compute_ngram_loss(probs, transition, tgt_tokens):
    # probs: batch_size x num_vertices x vocab_size
    # transition: batch_size x num_vertices x num_vertices
    # tgt_tokens: batch_size x tgt_len
    # we assume tgt_tokens have no padding (all the same length)

    ngrams_order = 2

    with torch.no_grad():
        tgt_tokens_list = tgt_tokens.tolist()
        ngrams_dict_bsz = [{} for i in range(tgt_tokens.size(0))]
        ngrams_list_bsz = [[] for i in range(tgt_tokens.size(0))]
        ngrams_max_count_bsz = [[] for i in range(tgt_tokens.size(0))]
        for i in range(0,tgt_tokens.size(1)-ngrams_order+1):
            for j in range(len(ngrams_dict_bsz)):
                key = tuple(tgt_tokens_list[j][i:i+ngrams_order])
                if key in ngrams_dict_bsz[j].keys():
                    ngrams_max_count_bsz[j][ngrams_dict_bsz[j][key]] = ngrams_max_count_bsz[j][ngrams_dict_bsz[j][key]] + 1
                else:
                    ngrams_dict_bsz[j][key] = len(ngrams_list_bsz[j])
                    ngrams_list_bsz[j].append(tgt_tokens_list[j][i:i+ngrams_order])
                    ngrams_max_count_bsz[j].append(1)

        # padded_ngrams_num = max([len(ngrams_list) for ngrams_list in ngrams_list_bsz])
        # padded_ngrams_template = []
        # for i in range(ngrams_order):
        #     padded_ngrams_template.append(1)

        # for i in range(len(ngrams_list_bsz)):
        #     while len(ngrams_list_bsz[i]) < padded_ngrams_num:
        #         ngrams_list_bsz[i].append(padded_ngrams_template)
        #         ngrams_max_count_bsz[i].append(0)

        ngrams_tensor_bsz = torch.LongTensor(ngrams_list_bsz).to(tgt_tokens.device) #bsz, number of ngram, length of ngram
        ngrams_max_count_bsz = torch.tensor(ngrams_max_count_bsz).to(tgt_tokens.device) #bsz, number of ngram
        del ngrams_dict_bsz
        del ngrams_list_bsz



    arrival_prob = torch.ones(transition.size(0),1).to(transition)
    for i in range(1,transition.size(-1)):
        arrival_prob = torch.cat([arrival_prob, torch.mul(arrival_prob[:,0:i],transition[:,0:i,i]).sum(dim=-1).unsqueeze(-1)],dim=-1)


    expected_length = arrival_prob.sum(dim=-1)
    expected_tol_num_of_ngrams = arrival_prob.unsqueeze(1)

    for i in range(ngrams_order-1):
        expected_tol_num_of_ngrams= torch.bmm(expected_tol_num_of_ngrams,transition)


    expected_tol_num_of_ngrams = expected_tol_num_of_ngrams.sum(dim=-1).sum(dim=-1)


    first_word_in_each_gram = ngrams_tensor_bsz[:,:,0].unsqueeze(-1) #bsz, number of ngram, 1

    #bsz, number of ngram, prelen
    first_word_probs = torch.gather(input=probs.unsqueeze(1).expand(-1,first_word_in_each_gram.size(-2),-1,-1),dim=-1,index=first_word_in_each_gram.unsqueeze(2).expand(-1,-1,probs.size(-2),-1)).squeeze()


    expected_matched_num_of_ngrams = torch.mul(arrival_prob.unsqueeze(1),first_word_probs)
    del first_word_probs        

    for i in range(1,ngrams_order):
        target_at_this_word = ngrams_tensor_bsz[:,:,i].unsqueeze(-1) #bsz, number of ngram, 1

        #bsz, number of ngram, prelen
        word_probs = torch.gather(input=probs.unsqueeze(1).expand(-1,target_at_this_word.size(-2),-1,-1),dim=-1,index=target_at_this_word.unsqueeze(2).expand(-1,-1,probs.size(-2),-1)).squeeze(dim=-1)

        expected_matched_num_of_ngrams = torch.mul(torch.bmm(expected_matched_num_of_ngrams,transition),word_probs)
        del word_probs


    expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.sum(dim=-1)

    cutted_expected_matched_num_of_ngrams = torch.min(expected_matched_num_of_ngrams, ngrams_max_count_bsz.to(expected_matched_num_of_ngrams)).sum(dim=-1)



    #ngrams_F_score = cutted_expected_matched_num_of_ngrams / (expected_tol_num_of_ngrams[-1] + (tgt_tokens.ne(1).sum(dim=-1) - ngrams_order + 1))
    cutted_precision = cutted_expected_matched_num_of_ngrams / expected_tol_num_of_ngrams
    #reverse_length_ratio = tgt_tokens.ne(1).sum(dim=-1) / expected_length    
    #brief_penalty = torch.min(torch.ones_like(reverse_length_ratio),torch.exp(1.0-reverse_length_ratio))

    loss = cutted_precision

    return -loss

### Other Implementation

In [60]:
from typing import Tuple, Union, List
from torch import Tensor
from collections import defaultdict

In [61]:
def find_ngrams(target_seqs, n, as_tensor=True) -> Union[Tuple[Tensor, Tensor], Tuple[List, List]]:
    """
    Given a 2D tensor of target sequences, and n-gram order, calculate
    which n-grams are present in the target sequences as well as the
    number of times they occur. We assume that the target sequences
    has no padding

    @param target_seqs: 2D tensor of target sequences
    @param n: n-gram order
    @param as_tensor: whether to return information as tensors or lists

    @return: n-grams present in the target sequences and their counts
    """
    batch_size, seq_len = target_seqs.shape
    assert seq_len >= n, "Sequence length must be greater than or equal to n-gram order"
    with torch.no_grad():
        ngrams = []
        counts = []

        for b in range(batch_size):
            ngram_dict = defaultdict(int)
            for i in range(seq_len - n + 1):
                ngram = tuple(target_seqs[b, i:i+n].tolist())
                ngram_dict[ngram] += 1

            ngrams.append(list(ngram_dict.keys()))
            counts.append(list(ngram_dict.values()))
        
        if as_tensor:
            ngrams = torch.tensor(ngrams, dtype=torch.long)
            counts = torch.tensor(counts, dtype=torch.long)
        
        return ngrams, counts

In [62]:
def passing_probability(transitions: Tensor, return_log: bool = False) -> Tensor:
    """
    Given a tensor of transition probabilities, calculate the probability
    of passing through each vertex in the graph, we'll assume that transitions
    is log-probabilities. If return_log is true, we return the passing
    probabilities in log-space otherwise we return them in linear space

    @param transitions: tensor of transition probabilities
    @return: tensor of passing probabilities
    """
    batch_size, num_vertices, _ = transitions.shape
    
    probs = torch.zeros(batch_size, num_vertices, device=transitions.device)
    probs[:, 0] = 1.0
    probs = torch.log(probs)

    for i in range(1, num_vertices):
        transition_column = transitions[:, :, i]
        current_sum = torch.logsumexp(probs + transition_column, dim=-1)
        probs[:, i] = current_sum
    
    if return_log:
        return probs
    else:
        return torch.exp(probs)

In [63]:
def log_bmm(log_A, log_B):
    """
    Performs a batch matrix multiplication in log space.

    Args:
        log_A: A tensor of shape (b, m, n) representing log(A).
        log_B: A tensor of shape (b, n, p) representing log(B).

    Returns:
        A tensor of shape (b, m, p) representing log(A @ B).
    """
    b, m, n = log_A.shape
    _, _, p = log_B.shape

    # 1. Expand dimensions to align for element-wise addition (broadcast)
    log_A_expanded = log_A.unsqueeze(3)  # Shape (b, m, n, 1)
    log_B_expanded = log_B.unsqueeze(1)  # Shape (b, 1, n, p)

    # 2. Perform addition in log-space for equivalent to product in linear space
    log_product = log_A_expanded + log_B_expanded  # Shape (b, m, n, p)

    # 3. LogSumExp over the `n` dimension (matrix multiplication reduction)
    log_C = torch.logsumexp(log_product, dim=2)  # Shape (b, m, p)

    return log_C

In [64]:
def compute_ngram_loss2(probs, transition, tgt_tokens, ngrams_order=2):
    # probs: batch_size x num_vertices x vocab_size
    # transition: batch_size x num_vertices x num_vertices
    # tgt_tokens: batch_size x tgt_len
    # we assume tgt_tokens have no padding (all the same length)
    ngrams_tensor_bsz, ngrams_max_count_bsz = find_ngrams(tgt_tokens, ngrams_order)

    arrival_prob = passing_probability(torch.log(transition), return_log=False)

    expected_tol_num_of_ngrams = arrival_prob.unsqueeze(1)

    for i in range(ngrams_order-1):
        #expected_tol_num_of_ngrams= torch.bmm(expected_tol_num_of_ngrams,transition)
        expected_tol_num_of_ngrams = log_bmm(expected_tol_num_of_ngrams.log(), transition.log()).exp()


    expected_tol_num_of_ngrams = torch.log(expected_tol_num_of_ngrams)
    # expected_tol_num_of_ngrams = expected_tol_num_of_ngrams.sum(dim=-1).sum(dim=-1)
    expected_tol_num_of_ngrams = torch.logsumexp(expected_tol_num_of_ngrams, dim=-1)
    expected_tol_num_of_ngrams = torch.logsumexp(expected_tol_num_of_ngrams, dim=-1)
    expected_tol_num_of_ngrams = expected_tol_num_of_ngrams.exp()


    first_word_in_each_gram = ngrams_tensor_bsz[:,:,0].unsqueeze(-1) #bsz, number of ngram, 1

    #bsz, number of ngram, prelen
    first_word_probs = torch.gather(input=probs.unsqueeze(1).expand(-1,first_word_in_each_gram.size(-2),-1,-1),dim=-1,index=first_word_in_each_gram.unsqueeze(2).expand(-1,-1,probs.size(-2),-1)).squeeze()


    #expected_matched_num_of_ngrams = torch.mul(arrival_prob.unsqueeze(1),first_word_probs)
    expected_matched_num_of_ngrams = arrival_prob.unsqueeze(1).log() + first_word_probs.log()
    expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.exp()
    del first_word_probs        

    for i in range(1,ngrams_order):
        target_at_this_word = ngrams_tensor_bsz[:,:,i].unsqueeze(-1) #bsz, number of ngram, 1

        #bsz, number of ngram, prelen
        word_probs = torch.gather(input=probs.unsqueeze(1).expand(-1,target_at_this_word.size(-2),-1,-1),dim=-1,index=target_at_this_word.unsqueeze(2).expand(-1,-1,probs.size(-2),-1)).squeeze(dim=-1)

        expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.log()
        #expected_matched_num_of_ngrams = torch.mul(torch.bmm(expected_matched_num_of_ngrams,transition),word_probs)
        expected_matched_num_of_ngrams = log_bmm(expected_matched_num_of_ngrams, transition.log())
        expected_matched_num_of_ngrams = expected_matched_num_of_ngrams + word_probs.log()
        expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.exp()
        del word_probs


    expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.log()
    #expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.sum(dim=-1)
    expected_matched_num_of_ngrams = torch.logsumexp(expected_matched_num_of_ngrams, dim=-1)
    expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.exp()

    expected_matched_num_of_ngrams = expected_matched_num_of_ngrams.log()
    ngrams_max_count_bsz = ngrams_max_count_bsz.log()
    cutted_expected_matched_num_of_ngrams = torch.min(expected_matched_num_of_ngrams, ngrams_max_count_bsz)#.sum(dim=-1)
    cutted_expected_matched_num_of_ngrams = torch.logsumexp(cutted_expected_matched_num_of_ngrams, dim=-1)
    cutted_expected_matched_num_of_ngrams = cutted_expected_matched_num_of_ngrams.exp()


    cutted_expected_matched_num_of_ngrams = cutted_expected_matched_num_of_ngrams.log()
    expected_tol_num_of_ngrams = expected_tol_num_of_ngrams.log()
    #ngrams_F_score = cutted_expected_matched_num_of_ngrams / (expected_tol_num_of_ngrams[-1] + (tgt_tokens.ne(1).sum(dim=-1) - ngrams_order + 1))
    cutted_precision = cutted_expected_matched_num_of_ngrams - expected_tol_num_of_ngrams
    #reverse_length_ratio = tgt_tokens.ne(1).sum(dim=-1) / expected_length    
    #brief_penalty = torch.min(torch.ones_like(reverse_length_ratio),torch.exp(1.0-reverse_length_ratio))

    loss = cutted_precision.exp()

    return -loss

### Test Cases

#### Test Case 1

In [65]:
emissions = torch.tensor([
    # these are not valid probabilities, but it is just for testing
    [
        [0.1, 0.2, 0.3, 0.4],
        [0.5, 0.6, 0.7, 0.8],
        [-0.1, -0.2, -0.3, -0.4],
        [-0.5, -0.6, -0.7, -0.8],
        [1.0, 0.0, -1.0, 0.9]
    ],
    [
        [2.1, 2.2, 2.3, 2.4],
        [2.5, 2.6, 2.7, 2.8],
        [-2.1, -2.2, -2.3, -2.4],
        [-2.5, -2.6, -2.7, -2.8],
        [2.0, 3.0, -2.0, 2.9]
    ]
])
batch_size, num_vertices, vocab_size = emissions.shape

In [66]:
transitions = torch.randn(batch_size, num_vertices, num_vertices)

In [67]:
mask = torch.tril(torch.ones(num_vertices, num_vertices), diagonal=0)

In [68]:
transitions[mask.expand_as(transitions) == 1] = float('-inf')

In [69]:
transitions = torch.softmax(transitions, dim=-1)

In [70]:
transitions[transitions.isnan()] = 0.0

In [71]:
emissions = torch.abs(emissions)
transitions = torch.abs(transitions)

In [72]:
emissions = torch.softmax(emissions, dim=-1)

In [73]:
tgt_tokens = torch.tensor([
    [0, 3, 1],
    [0, 2, 1]
])

In [74]:
ref = compute_ngram_loss(emissions, transitions, tgt_tokens)

In [75]:
ours = compute_ngram_loss2(emissions, transitions, tgt_tokens)

In [76]:
ref, ours

(tensor([-0.1139, -0.1227]), tensor([-0.0315, -0.0327]))