In [1]:
import torch
from torch import Tensor

In [2]:
target_seq = torch.tensor([1,3,2,4,0])

In [3]:
num_vertices = 10
vocab_size = 5

In [4]:
transition_matrix = torch.tensor(
    [
        [0, 0.9, 0.04, 0, 0.06, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0.5, 0.5, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ]
)
emission_matrix = torch.ones(num_vertices, vocab_size) / vocab_size

In [5]:
def dfs(target_seq, transition_matrix, emission_matrix, seq, prob, storage):
    if len(seq) == len(target_seq):
        storage.append((seq, prob))
        return
    if len(seq) == 0:
        start_prob = emission_matrix[0][target_seq[0]]
        dfs(target_seq, transition_matrix, emission_matrix, [0], start_prob, storage)
    else:
        next_candidates = transition_matrix[seq[-1]]
        for i, p in enumerate(next_candidates):
            if p > 0:
                new_prob = prob * p * emission_matrix[i][target_seq[len(seq)]]
                dfs(target_seq, transition_matrix, emission_matrix, seq + [i], new_prob, storage)
                
def max_prob_path(storage):
    max_prob = 0
    max_seq = []
    for seq, prob in storage:
        if prob > max_prob:
            max_prob = prob
            max_seq = seq
    return max_seq, max_prob

In [6]:
storage = []

In [7]:
dfs(target_seq, transition_matrix, emission_matrix, [], 1, storage)

In [8]:
total_prob = sum([prob for seq, prob in storage])

In [9]:
max_prob_path(storage)

([0, 1, 6, 7, 9], tensor(0.0003))

In [10]:
def dfs_lynchpin(target_seq, transition_matrix, emission_matrix, seq, prob, storage, assignments=None):
    if len(seq) == len(target_seq):
        storage.append((seq, prob))
        return
    if len(seq) == 0:
        start_prob = emission_matrix[0][target_seq[0]]
        dfs_lynchpin(target_seq, transition_matrix, emission_matrix, [0], start_prob, storage, assignments)
    else:
        next_candidates = transition_matrix[seq[-1]]
        current_assignment = -1 if assignments is None else assignments[len(seq)]
        for i, p in enumerate(next_candidates):
            if p > 0 and (current_assignment == -1 or current_assignment == i):
                new_prob = prob * p * emission_matrix[i][target_seq[len(seq)]]
                dfs_lynchpin(
                    target_seq, 
                    transition_matrix, 
                    emission_matrix, 
                    seq + [i], 
                    new_prob, 
                    storage,
                    assignments)

In [11]:
storage = []

In [12]:
assignments = torch.tensor([-1, 4, -1, 7, -1])

In [13]:
dfs_lynchpin(target_seq, transition_matrix, emission_matrix, [], 1, storage, assignments)

In [14]:
storage

[([0, 4, 5, 7, 9], tensor(9.6000e-06)), ([0, 4, 6, 7, 9], tensor(9.6000e-06))]

In [32]:
total_lynchpin_prob = sum([prob for seq, prob in storage])

In [15]:
import einops

In [16]:
def vector_gather(vectors, indices):
    """
    Gathers (batched) vectors according to indices.
    Arguments:
        vectors: Tensor[N, L, D]
        indices: Tensor[N, K] or Tensor[N]
    Returns:
        Tensor[N, K, D] or Tensor[N, D]
    """
    N, L, D = vectors.shape
    squeeze = False
    if indices.ndim == 1:
        squeeze = True
        indices = indices.unsqueeze(-1)
    N2, K = indices.shape
    assert N == N2
    indices = einops.repeat(indices, "N K -> N K D", D=D)
    out = torch.gather(vectors, dim=1, index=indices)
    if squeeze:
        out = out.squeeze(1)
    return out

In [17]:
def logsumexp(x: Tensor, dim: int) -> Tensor:
    # Solving nan issue when x contains -inf
    # See https://github.com/pytorch/pytorch/issues/31829
    # https://github.com/thu-coai/DA-Transformer/blob/main/fs_plugins/custom_ops/dag_loss.py
    m, _ = x.max(dim=dim, keepdim=True)
    mask = m == -float("inf")
    m = m.detach()
    s = (x - m.masked_fill_(mask, 0)).exp_().sum(dim=dim, keepdim=True)
    return s.masked_fill_(mask, 1).log_() + m.masked_fill_(mask, -float("inf"))

In [18]:
def dag_loss_raw(targets, transition_matrix, emission_probs):
    """
    Calculates the directed acyclic graph (DAG) loss given the targets, transition matrix, and emission probabilities.
    It returns the dynamic programming table of which one of the entries is the DAG loss.

    Args:
        targets (torch.Tensor): The target sequence of shape (batch_size, m).
        transition_matrix (torch.Tensor): The transition matrix of shape (batch_size, l, l).
        emission_probs (torch.Tensor): The emission probabilities of shape (batch_size, l, vocab_size).

    Returns:
        torch.Tensor: The DAG loss of shape (batch_size, m, l).
    """
    batch_size, m = targets.shape
    _, l, vocab_size = emission_probs.shape
    dp = torch.ones((batch_size, m, l), device=transition_matrix.device)
    dp[dp == 1] = -float("inf")
    initial_probs = torch.gather(
        emission_probs, dim=2, index=targets[:, 0].unsqueeze(1).unsqueeze(2)
    )
    dp[:, 0, 0] = initial_probs.squeeze(2).squeeze(1)
    # assumes that transition_matrix and emission_probs are already in log space
    # also we need to tranpose emission_probs so it is vocab_size x l
    # so the vector gather works
    emission_probs = emission_probs.transpose(1, 2)
    for i in range(1, m):
        dp[:, i, :] = vector_gather(emission_probs, targets[:, i]) + (
            (
                logsumexp(
                    dp[:, i - 1, :].unsqueeze(1).transpose(1, 2) + transition_matrix,
                    dim=1,
                )
            ).squeeze(1)
        )
    return dp

In [19]:
transition_matrix = torch.log(transition_matrix)

In [20]:
emission_matrix = torch.log(emission_matrix)

In [21]:
transition_matrix = transition_matrix.unsqueeze(0)
emission_matrix = emission_matrix.unsqueeze(0)

In [22]:
target_seq = target_seq.unsqueeze(0)

In [23]:
dp = dag_loss_raw(target_seq, transition_matrix, emission_matrix)



In [24]:
torch.exp(dp)

tensor([[[2.0000e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 3.6000e-02, 1.6000e-03, 0.0000e+00, 2.4000e-03,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 3.2000e-04, 0.0000e+00,
          2.4000e-04, 7.4400e-03, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 1.5360e-03, 6.4000e-05, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 3.2000e-04]]])

In [25]:
float(total_prob)

0.00032000005012378097

In [26]:
float(torch.exp(dp[0][-1][-1]))

0.0003199999628122896

In [27]:
def correct_dp_slice_for_spans(
    dp_slice: Tensor, target_span_indices: Tensor, iteration: int
) -> Tensor:
    """
    @param dp: (batch_size, num_vertices, 1) tensor of scores that is being consiered
        by the current iteration (position in target sequence)
    @param target_span_indices: (batch_size, seq_len) tensor of indices, if element
        is non-negative, it represents that position in target sequence span otherwise
        it is negative
    @param iteration: int, the current iteration in the target sequence

    @return dp_slice: (batch_size, num_vertices, 1) tensor of scores that guarantees
        that all paths considered go through the vertex specified by the target_span_indices
        if it is non-negative, otherwise scores of paths are not changed to -inf
    """
    res = dp_slice
    res = res.squeeze(2)

    # here we cannot simply use this res to update the dp table,
    # because this res update assumes that all paths are allowed,
    # however, if we are at a span, only the path that passes
    # through the span's specified vertex is allowed
    # to do this, we set the probability of all other vertices
    # to -inf, so that they are not considered in the max operation
    # in the next step. This guarantees in the next step, we only
    # consider paths that pass through the span's specified vertex
    curr_span_status = target_span_indices[:, iteration - 1]
    in_span = curr_span_status >= 0

    # we need to do a masked_fill here for curr_span_status, because
    # we use gather to get the current probability of the span vertex
    # in res, and gather doesn't support negative indices.
    # this also means the corresponding probability might not be useful
    # because if it is not a span then we don't enforce that paths
    # must pass through the span's vertex. We account for this later
    curr_span_status = curr_span_status.masked_fill(~in_span, 0).unsqueeze(1)

    selection = res.gather(1, curr_span_status)

    masked = torch.full_like(res, -float("inf"))
    masked.scatter_(1, curr_span_status, selection)

    in_span = in_span.unsqueeze(-1)
    in_span = in_span.expand_as(res)

    res = torch.where(in_span, masked, res)
    return res

In [28]:
def dag_loss_raw_lynchpin(targets, transition_matrix, emission_probs, assignments):
    """
    Calculates the directed acyclic graph (DAG) loss given the targets, transition matrix, and emission probabilities.
    It returns the dynamic programming table of which one of the entries is the DAG loss.

    Args:
        targets (torch.Tensor): The target sequence of shape (batch_size, m).
        transition_matrix (torch.Tensor): The transition matrix of shape (batch_size, l, l).
        emission_probs (torch.Tensor): The emission probabilities of shape (batch_size, l, vocab_size).
        assignments (torch.Tensor): The assignments of shape (batch_size, m).

    Returns:
        torch.Tensor: The DAG loss of shape (batch_size, m, l).
        
    The assignments tensor is a tensor of shape (batch_size, m) where each element is an integer from 0 to l-1.
    If the element is -1, then the corresponding element in the target sequence is free to be any of the l vertices.
    Otherwise, at that position in the target sequence, the vertex must be the one specified by the assignment.
    """
    batch_size, m = targets.shape
    _, l, vocab_size = emission_probs.shape
    dp = torch.ones((batch_size, m, l), device=transition_matrix.device)
    dp[dp == 1] = -float("inf")
    initial_probs = torch.gather(
        emission_probs, dim=2, index=targets[:, 0].unsqueeze(1).unsqueeze(2)
    )
    dp[:, 0, 0] = initial_probs.squeeze(2).squeeze(1)
    # assumes that transition_matrix and emission_probs are already in log space
    # also we need to tranpose emission_probs so it is vocab_size x l
    # so the vector gather works
    emission_probs = emission_probs.transpose(1, 2)
    for i in range(1, m):
        # before proceeding, we need to check the previous iteration's dp values
        # to see if that it agrees with the assignments, if not, we need to adjust
        # the dp values accordingly
        prev_dp = dp[:, i - 1, :].unsqueeze(1).transpose(-1, -2)
        prev_dp = correct_dp_slice_for_spans(prev_dp, assignments, i)
        dp[:, i - 1, :] = prev_dp
        dp[:, i, :] = vector_gather(emission_probs, targets[:, i]) + (
            (
                logsumexp(
                    dp[:, i - 1, :].unsqueeze(1).transpose(1, 2) + transition_matrix,
                    dim=1,
                )
            ).squeeze(1)
        )
    return dp

In [29]:
assignments = assignments.unsqueeze(0)

In [30]:
dp2 = dag_loss_raw_lynchpin(target_seq, transition_matrix, emission_matrix, assignments)

In [31]:
dp2

tensor([[[ -1.6094,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
              -inf,     -inf,     -inf],
         [    -inf,     -inf,     -inf,     -inf,  -6.0323,     -inf,     -inf,
              -inf,     -inf,     -inf],
         [    -inf,     -inf,     -inf,     -inf,     -inf,  -8.3349,  -8.3349,
              -inf,     -inf,     -inf],
         [    -inf,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
           -9.2512,     -inf,     -inf],
         [    -inf,     -inf,     -inf,     -inf,     -inf,     -inf,     -inf,
              -inf,     -inf, -10.8606]]])

In [33]:
float(total_lynchpin_prob)

1.920000067912042e-05

In [34]:
float(torch.exp(dp2[0][-1][-1]))

1.920001523103565e-05