Skip to content

Commit

Permalink
Merge 10df75a into d828506
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Jan 28, 2020
2 parents d828506 + 10df75a commit 94445bb
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 0 deletions.
2 changes: 2 additions & 0 deletions torch_struct/__init__.py
Expand Up @@ -14,6 +14,7 @@
from .cky_crf import CKY_CRF
from .deptree import DepTree
from .linearchain import LinearChain
from .factorial_hmm import FactorialHMM
from .semimarkov import SemiMarkov
from .alignment import Alignment
from .rl import SelfCritical
Expand Down Expand Up @@ -43,6 +44,7 @@
CKY_CRF,
DepTree,
LinearChain,
FactorialHMM,
SemiMarkov,
LogSemiring,
StdSemiring,
Expand Down
56 changes: 56 additions & 0 deletions torch_struct/factorial_hmm.py
@@ -0,0 +1,56 @@
import torch
from .helpers import _Struct, Chart
import math


class FactorialHMM(_Struct):
def _dp(self, scores, lengths=None, force_grad=False):
transition, emission = scores
semiring = self.semiring
transition.requires_grad_(True)
emission.requires_grad_(True)
batch, L, K, K2 = transition.shape
batch, N, K, K, K = emission.shape
assert L == 3
assert K == K2

transition = semiring.convert(transition)
emission = semiring.convert(emission)


ssize = semiring.size()

state_out = Chart((batch, N, L, K), transition, semiring)
state_in = Chart((batch, N, L, K), transition, semiring)
emit = Chart((batch, N, K, K, K), transition, semiring)

emit[0, :] = emission[:, :, 0]

def make_out(val, i):
state_out[i, 0] = semiring.sum(semiring.sum(val, 4), 2)
state_out[i, 1] = semiring.sum(semiring.sum(val, 4), 3)
state_out[i, 2] = semiring.sum(semiring.sum(val, 3), 2)

make_out(emit[0, :], 0)

for i in range(1, N):
# print(transition.shape, state_out[i-1, :].unsqueeze(-2).shape)
state_in = semiring.dot(state_out[i-1, :].unsqueeze(-2), transition)
# print(state_in[..., None, :, None].shape, emission[:, :, i].shape)
emit[i, :] = semiring.times(state_in[..., 0, :, None, None],
state_in[..., 1, None, :, None],
state_in[..., 2, None, None, :],
emission[:, :, i])
make_out(emit[i, :], i)

log_Z = semiring.sum(emit[N-1, :])
return log_Z, [scores], None

@staticmethod
def _rand():
batch = torch.randint(2, 5, (1,))
K = torch.randint(2, 5, (1,))
N = torch.randint(2, 5, (1,))
transition = torch.rand(batch, 3, K, K)
emission = torch.rand(batch, N, K, K, K)
return (transition, emission), (batch.item(), N.item())
11 changes: 11 additions & 0 deletions torch_struct/test_algorithms.py
Expand Up @@ -2,6 +2,7 @@
from .cky_crf import CKY_CRF
from .deptree import DepTree, deptree_nonproj, deptree_part
from .linearchain import LinearChain
from .factorial_hmm import FactorialHMM
from .semimarkov import SemiMarkov
from .alignment import Alignment
from .semirings import (
Expand Down Expand Up @@ -325,6 +326,15 @@ def test_params(data, seed):
c = vals.grad.detach()
assert torch.isclose(b, c).all()

def test_factorial_hmm():
model = FactorialHMM
semiring = StdSemiring
struct = model(semiring)
vals, (batch, N) = model._rand()
alpha = struct.sum(vals)
print(alpha)
assert False


@given(data())
@settings(max_examples=50, deadline=None)
Expand Down Expand Up @@ -416,6 +426,7 @@ def test_hmm():
LinearChain().sum(out)



@given(data())
def test_sparse_max(data):
model = data.draw(sampled_from([LinearChain]))
Expand Down

0 comments on commit 94445bb

Please sign in to comment.