In [1]:
import numpy as np
import scipy as sp
import torch
from torch import nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
from pyro.infer.autoguide import AutoDelta, AutoNormal
from pyro.nn import PyroSample, PyroParam
import matplotlib.pyplot as plt

In [2]:
def logits_to_probs(logits):
    return torch.exp(logits - torch.logsumexp(logits, dim=-1, keepdim=True))

class IOHMM_model:
    def __init__(self, num_states, inputs, outputs, max_iter, tol,
                 initial_pi=None, theta_transition=None, theta_emission=None, sd=None):
    
        self.num_states = num_states
        self.inputs = inputs
        self.outputs = outputs
        self.T = inputs.shape[0]
        self.max_iter = max_iter
        self.tol = tol
        self.history = []

        # Initialize parameters
        self.initial_pi = torch.ones(num_states) / num_states if initial_pi is None else initial_pi
        self.theta_transition = torch.randn(num_states, num_states, inputs.shape[1] + 1) if theta_transition is None else theta_transition
        self.theta_emission = torch.randn(num_states, inputs.shape[1] + 1) if theta_emission is None else theta_emission
        self.sd = torch.ones(num_states, self.T) if sd is None else sd
        #self.theta_transition = pyro.param('theta_transition', torch.randn(num_states, num_states, inputs.shape[1] + 1))
        #self.theta_emission = pyro.param('theta_emission', torch.randn(num_states, inputs.shape[1] + 1))
        #self.sd = pyro.param('sd', torch.ones(num_states, self.T), constraint=dist.constraints.positive)
        #self.initial_pi = pyro.param('initial_pi', torch.ones(num_states) / num_states, constraint=dist.constraints.simplex)


    def model(self, inputs, outputs):
        
        # Initial state distribution
        z = pyro.sample('z_0', dist.Categorical(probs=self.initial_pi))

        for t in range(self.T):
            # Emission model
            emission_mean = torch.matmul(self.theta_emission[z], torch.cat((torch.ones(1), inputs[t, :])).unsqueeze(-1)).squeeze()
            sd = self.sd[z,t].exp()
            pyro.sample(f'obs_{t}', dist.Normal(emission_mean, sd), obs=outputs[t])

            if t < self.T - 1:
                # Transition model
                logits = torch.matmul(self.theta_transition[z], torch.cat((torch.ones(1), inputs[t, :])).unsqueeze(-1)).squeeze()
                transition_probs = logits_to_probs(logits)
                z = pyro.sample(f'z_{t+1}', dist.Categorical(probs=transition_probs))

    def guide(self, inputs, outputs):
        # Variational parameters for the initial state distribution
        q_initial_pi = pyro.param('q_initial_pi', self.initial_pi, constraint=dist.constraints.simplex)
        q_z = pyro.sample('z_0', dist.Categorical(probs=q_initial_pi))

        q_theta_transition = pyro.param('q_theta_transition', self.theta_transition)
        q_theta_emission = pyro.param('q_theta_emission', self.theta_emission)
        q_sd = pyro.param('q_sd', self.sd, constraint=dist.constraints.positive)


        for t in range(self.T - 1):
            logits = torch.matmul(q_theta_transition[q_z], torch.cat((torch.ones(1), inputs[t, :])))
            transition_probs = logits_to_probs(logits)
            q_z = pyro.sample(f'z_{t+1}', dist.Categorical(probs=transition_probs))

    def fit(self):
        optimizer = Adam({"lr": 0.01})
        svi = SVI(self.model, self.guide, optimizer, loss=Trace_ELBO())

        for step in range(self.max_iter):
            loss = svi.step(self.inputs, self.outputs)
            
            self.initial_pi = pyro.param('q_initial_pi')
            self.theta_emission = pyro.param('q_theta_emission')
            self.theta_transition = pyro.param('q_theta_transition')
            self.sd = pyro.param('q_sd')

            self.history.append(loss)
            if step % 10 == 0:
                print(f"Step {step}: loss = {loss}")
            if step > 0 and abs(self.history[-1] - self.history[-2]) < self.tol:
                break

In [3]:
x=np.arange(np.pi/2 , 11*np.pi/2, 0.1)
x=np.sin(x)
x_shift = x[:-1]
x = x[1:]
diff = x - x_shift

hidden_states = [0 if diff[i] > 0 else 1 for i in range(len(diff))]
y = [x[i] + np.random.normal(0.5,0.1) if hidden_states[i]==0 else x[i] - np.random.normal(0.5,0.1) for i in range(len(x))]

input = torch.tensor([x, diff], dtype=torch.float32).T
output = torch.tensor(y, dtype=torch.float32) 

  input = torch.tensor([x, diff], dtype=torch.float32).T


In [4]:
num_states = 2
max_iter = 1000
tol = 1e-4

iohmm = IOHMM_model(num_states, input, output, max_iter, tol)
print("theta_transition")
print(iohmm.theta_transition)
print("theta_emission")
print(iohmm.theta_emission)
print("initial_pi")
print(iohmm.initial_pi)
#print(iohmm.sd)

theta_transition
tensor([[[ 0.0288, -0.5775, -0.0028],
         [ 1.1988,  1.3219, -0.0804]],

        [[-0.7637, -0.2766, -0.0684],
         [ 0.1273, -0.3001, -1.1393]]])
theta_emission
tensor([[ 1.3791,  0.2689,  1.8250],
        [-0.9837, -0.6596, -0.3291]])
initial_pi
tensor([0.5000, 0.5000])


In [5]:
print(iohmm.theta_emission.shape)
print(iohmm.inputs.shape)

torch.Size([2, 3])
torch.Size([157, 2])


In [6]:
iohmm.fit()

Step 0: loss = 331.61166620254517
Step 10: loss = 322.52592515945435
Step 20: loss = 311.47847509384155
Step 30: loss = 298.66763496398926
Step 40: loss = 288.28627943992615
Step 50: loss = 277.9204510450363
Step 60: loss = 268.7274281978607
Step 70: loss = 259.3067419528961
Step 80: loss = 250.61321687698364
Step 90: loss = 243.62178218364716
Step 100: loss = 235.9726848602295
Step 110: loss = 231.3232283592224
Step 120: loss = 221.70822954177856
Step 130: loss = 217.27923381328583
Step 140: loss = 215.9348382949829
Step 150: loss = 209.93115484714508
Step 160: loss = 206.23529875278473
Step 170: loss = 203.8552144765854
Step 180: loss = 200.47974574565887
Step 190: loss = 198.9700504541397
Step 200: loss = 195.9186553955078
Step 210: loss = 194.19640910625458
Step 220: loss = 192.92591607570648
Step 230: loss = 191.32984459400177
Step 240: loss = 188.78779554367065
Step 250: loss = 186.92579889297485
Step 260: loss = 185.91550481319427
Step 270: loss = 183.95980834960938
Step 280: lo

In [7]:
# inspecting the learned parameters
print("theta_transition")
print(iohmm.theta_transition)
print("theta_emission")
print(iohmm.theta_emission)
print("initial_pi")
print(iohmm.initial_pi)
#print(iohmm.sd)

theta_transition
tensor([[[-0.3876, -0.2650,  0.3232],
         [ 1.6152,  1.0093, -0.4064]],

        [[-0.5064, -0.7262, -0.0456],
         [-0.1300,  0.1495, -1.1621]]], requires_grad=True)
theta_emission
tensor([[ 4.2261e-03,  1.0215e+00,  6.3538e+00],
        [-9.4851e-04,  9.9401e-01,  6.2670e+00]], requires_grad=True)
initial_pi
tensor([0.6557, 0.3443], grad_fn=<DivBackward0>)


In [8]:
q_initial_pi_value = pyro.param('q_initial_pi').detach().cpu().numpy()
print(f"Final q_initial_pi: {q_initial_pi_value}")


Final q_initial_pi: [0.6556922 0.3443078]
