Skip to content

Commit

Permalink
add hsmm helper
Browse files Browse the repository at this point in the history
  • Loading branch information
da03 committed Oct 11, 2021
1 parent e51fecc commit 08713b6
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
50 changes: 50 additions & 0 deletions tests/extensions.py
Expand Up @@ -165,6 +165,56 @@ def enumerate(semiring, edge):
ls = [s for (_, s) in chains[N]]
return semiring.unconvert(semiring.sum(torch.stack(ls, dim=1), dim=1)), ls

@staticmethod
def enumerate_hsmm(semiring, init_z_1, transition_z_to_z, transition_z_to_l, emission_n_l_z):
ssize = semiring.size()
batch, N, K, C = emission_n_l_z.shape

if init_z_1.dim() == 1:
init_z_1 = init_z_1.unsqueeze(0).expand(batch, C) # batch, C
transition_z_to_z = transition_z_to_z.unsqueeze(0).expand(batch, C, C)
transition_z_to_l = transition_z_to_l.unsqueeze(0).expand(batch, C, K)

init_z_1 = semiring.convert(init_z_1) # ssize, batch, C
transition_z_to_z = semiring.convert(transition_z_to_z) # ssize, batch, C, C
transition_z_to_l = semiring.convert(transition_z_to_l) # ssize, batch, C, K
emission_n_l_z = semiring.convert(emission_n_l_z) # ssize, batch, N, K, C

def score_chain(chain):
score = semiring.fill(torch.zeros(ssize, batch), torch.tensor(True), semiring.one)
state_0, _ = chain[0]
# P(z_{-1})
score = semiring.mul(score, init_z_1[:, :, state_0])
prev_state = state_0
n = 0
for t in range(len(chain) - 1):
state, k = chain[t + 1]
# P(z_t | z_{t-1})
score = semiring.mul(score, transition_z_to_z[:, :, prev_state, state])
# P(l_t | z_t)
score = semiring.mul(score, transition_z_to_l[:, :, state, k])
# P(x_{n:n+l_t} | z_t, l_t)
score = semiring.mul(score, emission_n_l_z[:, :, n, k, state])
prev_state = state
n += k
return score

chains = {}
chains[0] = [
[(c, 0)] for c in range(C)
]

for n in range(1, N + 1):
chains[n] = []
for k in range(1, K):
if n - k not in chains:
continue
for chain in chains[n - k]:
for c in range(C):
chains[n].append(chain + [(c, k)])
ls = [score_chain(chain) for chain in chains[N]]
return semiring.unconvert(semiring.sum(torch.stack(ls, dim=1), dim=1)), ls


### Tests

Expand Down
23 changes: 23 additions & 0 deletions tests/test_algorithms.py
Expand Up @@ -499,3 +499,26 @@ def ignore_alignment(data):
# assert torch.isclose(count, alpha).all()
struct = model(semiring, max_gap=1)
alpha = struct.sum(vals)


@pytest.mark.parametrize("model_test", ["SemiMarkov"])
@pytest.mark.parametrize("semiring", [LogSemiring, MaxSemiring])
def test_hsmm(model_test, semiring):
"Test HSMM helper function."
C, K, batch, N = 5, 3, 2, 5
init_z_1 = torch.rand(batch, C)
transition_z_to_z = torch.rand(C, C)
transition_z_to_l = torch.rand(C, K)
emission_n_l_z = torch.rand(batch, N, K, C)

# first way: enumerate using init/transitions/emission
partition1 = algorithms[model_test][1].enumerate_hsmm(semiring, init_z_1, transition_z_to_z,
transition_z_to_l, emission_n_l_z)[0]
# second way: enumerate using edge scores computed from init/transitions/emission
edge = SemiMarkov.hsmm(init_z_1, transition_z_to_z, transition_z_to_l, emission_n_l_z)
partition2 = algorithms[model_test][1].enumerate(semiring, edge)[0]
# third way: dp using edge scores computed from init/transitions/emission
partition3 = algorithms[model_test][0](semiring).logpartition(edge)[0]

assert torch.isclose(partition1, partition2).all()
assert torch.isclose(partition2, partition3).all()
41 changes: 41 additions & 0 deletions torch_struct/semimarkov.py
Expand Up @@ -173,3 +173,44 @@ def from_parts(edge):
labels[on[i][0], on[i][1] + on[i][2]] = on[i][3]
# print(edge.nonzero(), labels)
return labels, (C, K)

# Adapters
@staticmethod
def hsmm(init_z_1, transition_z_to_z, transition_z_to_l, emission_n_l_z):
"""
Convert HSMM log-probs to edge scores.
Parameters:
init_z_1: C or b x C (init_z[i] = log P(z_{-1}=i), note that z_{-1} is an
auxiliary state whose purpose is to induce a distribution over z_0.)
transition_z_to_z: C X C (transition_z_to_z[i][j] = log P(z_{n+1}=j | z_n=i),
note that the order of z_{n+1} and z_n is different
from `edges`.)
transition_z_to_l: C X K (transition_z_to_l[i][j] = P(l_n=j | z_n=i))
emission_n_l_z: b x N x K x C
Returns:
edges: b x (N-1) x K x C x C, where edges[b, n, k, c2, c1]
= log P(z_n=c2 | z_{n-1}=c1) + log P(l_n=k | z_n=c2)
+ log P(x_{n:n+l_n} | z_n=c2, l_n=k), if n>0
= log P(z_n=c2 | z_{n-1}=c1) + log P(l_n=k | z_n=c2)
+ log P(x_{n:n+l_n} | z_n=c2, l_n=k) + log P(z_{-1}), if n=0
"""
batch, N, K, C = emission_n_l_z.shape
edges = torch.zeros(batch, N, K, C, C).type_as(emission_n_l_z)

# initial state: log P(z_{-1})
if init_z_1.dim() == 1:
init_z_1 = init_z_1.unsqueeze(0).expand(batch, -1)
edges[:, 0, :, :, :] += init_z_1.view(batch, 1, 1, C)

# transitions: log P(z_n | z_{n-1})
edges += transition_z_to_z.transpose(-1, -2).view(1, 1, 1, C, C)

# l given z: log P(l_n | z_n)
edges += transition_z_to_l.transpose(-1, -2).view(1, 1, K, C, 1)

# emissions: log P(x_{n:n+l_n} | z_n, l_n)
edges += emission_n_l_z.view(batch, N, K, C, 1)

return edges

0 comments on commit 08713b6

Please sign in to comment.