-
Notifications
You must be signed in to change notification settings - Fork 93
Closed
Description
Hello! I was playing with the HMM distribution and I obtained some results that I don't really understand. More precisely, I've set the following parameters
t = torch.tensor([[0.99, 0.01], [0.01, 0.99]]).log()
e = torch.tensor([[0.50, 0.50], [0.50, 0.50]]).log()
i = torch.tensor(np.array([0.99, 0.01])).log()
x = torch.randint(0, 2, size=(1, 8))
and I was expecting the model to stay in the hidden state 0 regardless of the observed data x
– it starts in state 0 and the transition matrix makes it very likely to maintain it. But when plotting the argmax
, it appears that the model jumps from one state to the other:
def show_chain(chain):
plt.imshow(chain.detach().sum(-1).transpose(0, 1))
dist = torch_struct.HMM(t, e, i, x)
show_chain(dist.argmax[0])
I must be missing something obvious; but shouldn't dist.argmax
correspond to argmax_z p(z | x, Θ)
? Thank you!
Metadata
Metadata
Assignees
Labels
No labels