In [1]:
import torch

In [2]:
target_seq = torch.tensor([1,3,2,4,0])

In [3]:
num_vertices = 10
vocab_size = 5

In [4]:
transition_matrix = torch.tensor(
    [
        [0, 0.9, 0.04, 0, 0.06, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0.5, 0.5, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
    ]
)
emission_matrix = torch.ones(num_vertices, vocab_size) / vocab_size

In [5]:
def dfs(target_seq, transition_matrix, emission_matrix, seq, prob, storage):
    if len(seq) == len(target_seq):
        storage.append((seq, prob))
        return
    if len(seq) == 0:
        start_prob = emission_matrix[0][target_seq[0]]
        dfs(target_seq, transition_matrix, emission_matrix, [0], start_prob, storage)
    else:
        next_candidates = transition_matrix[seq[-1]]
        for i, p in enumerate(next_candidates):
            if p > 0:
                new_prob = prob * p * emission_matrix[i][target_seq[len(seq)]]
                dfs(target_seq, transition_matrix, emission_matrix, seq + [i], new_prob, storage)
                
def max_prob_path(storage):
    max_prob = 0
    max_seq = []
    for seq, prob in storage:
        if prob > max_prob:
            max_prob = prob
            max_seq = seq
    return max_seq, max_prob

In [6]:
storage = []

In [7]:
dfs(target_seq, transition_matrix, emission_matrix, [], 1, storage)

In [8]:
max_prob_path(storage)

([0, 1, 6, 7, 9], tensor(0.0003))

In [9]:
def dfs_lynchpin(target_seq, transition_matrix, emission_matrix, seq, prob, storage, assignments=None):
    if len(seq) == len(target_seq):
        storage.append((seq, prob))
        return
    if len(seq) == 0:
        start_prob = emission_matrix[0][target_seq[0]]
        dfs_lynchpin(target_seq, transition_matrix, emission_matrix, [0], start_prob, storage, assignments)
    else:
        next_candidates = transition_matrix[seq[-1]]
        current_assignment = -1 if assignments is None else assignments[len(seq)]
        for i, p in enumerate(next_candidates):
            if p > 0 and (current_assignment == -1 or current_assignment == i):
                new_prob = prob * p * emission_matrix[i][target_seq[len(seq)]]
                dfs_lynchpin(
                    target_seq, 
                    transition_matrix, 
                    emission_matrix, 
                    seq + [i], 
                    new_prob, 
                    storage,
                    assignments)

In [10]:
storage = []

In [11]:
assignments = torch.tensor([-1, 4, -1, 7, -1])

In [12]:
dfs_lynchpin(target_seq, transition_matrix, emission_matrix, [], 1, storage, assignments)

In [13]:
storage

[([0, 4, 5, 7, 9], tensor(9.6000e-06)), ([0, 4, 6, 7, 9], tensor(9.6000e-06))]