Skip to content

Inference for the HMM model #70

@danoneata

Description

@danoneata

Hello! I was playing with the HMM distribution and I obtained some results that I don't really understand. More precisely, I've set the following parameters

t = torch.tensor([[0.99, 0.01], [0.01, 0.99]]).log()
e = torch.tensor([[0.50, 0.50], [0.50, 0.50]]).log()
i = torch.tensor(np.array([0.99, 0.01])).log()
x = torch.randint(0, 2, size=(1, 8))

and I was expecting the model to stay in the hidden state 0 regardless of the observed data x – it starts in state 0 and the transition matrix makes it very likely to maintain it. But when plotting the argmax, it appears that the model jumps from one state to the other:

def show_chain(chain):
    plt.imshow(chain.detach().sum(-1).transpose(0, 1))

dist = torch_struct.HMM(t, e, i, x)
show_chain(dist.argmax[0])

image

I must be missing something obvious; but shouldn't dist.argmax correspond to argmax_z p(z | x, Θ)? Thank you!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions