In [None]:
import sys
import importlib
sys.path.insert(1, '../')

import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt


from utils.utils import generate_linear_data


In [None]:
# Pyro stuff

import pyro
import pyro.distributions as dist
from pyro import poutine 
from pyro.infer.autoguide import AutoDelta 
from pyro.infer import SVI, TraceEnum_ELBO, JitTraceEnum_ELBO, infer_discrete, config_enumerate
from pyro.ops.indexing import Vindex
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings


# Data generation
In the cells below we generate the data. 
It essentially consists of a ball going up and down in one dimension image. 

In [None]:

pixels = generate_linear_data(num_balls=1, num_pixels=15, time =60*40, sf=10, max_period =5, noise_level=0.0)
print(pixels)
plt.imshow(pixels[:,:100], aspect='auto')
plt.show()

Before we proceed we need to do some slight preprocessing. 
Essentially I will take the data that I have above and I will divide it into smaller sequences. 

In [None]:
num_sequences = 40
elem_per_sequence = int(pixels.shape[-1]/num_sequences)
sequences = []

i = 0 
#Repeat while we can take one more batch 
while (i + 1) * elem_per_sequence < pixels.shape[-1]: 
    start = i * elem_per_sequence
    end = (i + 1) * elem_per_sequence
    sequences.append(pixels[:,start:end])
    i += 1

# We just reshape or data to have [len,dim]
sequences = torch.tensor(np.array(sequences))
sequences = sequences.permute(0,2,1).float()
sequences.dtype



In [None]:
sequences.shape

# Model definition
We will define two models. First we will define a simple HMM with neural emission probabilities. Then we will define and autoregressive HMM with neural emission probabilities 


The first thing to do is to define the general parameters that our model will use, like the number of states of our markov chain. 

In [None]:
# Definition of parameters

states = 2
batch_size = 10
num_observations = pixels.shape[1]


In [None]:
# Simpale HMM model 

def model_0(sequences,num_states = 5, batch_size=None): 
    with ignore_jit_warnings():
        num_sequences, length, data_dim = map(int, sequences.shape)

    #Prior on the transition probabilities
    probs_x = pyro.sample(
        "probs_x", 
        dist.Dirichlet(0.9 * torch.eye(num_states) + 0.1).to_event(1)
    )


    probs_y = pyro.sample(
        "probs_y", 
        dist.Beta(0.1, 0.9).expand([num_states, data_dim]).to_event(2)
    )   

    pixels_plate = pyro.plate('pixels',data_dim, dim=-1)
    for i in pyro.plate('sequences', len(sequences), batch_size): 
        sequence = sequences[i]
        x = 0

        for t in pyro.markov(range(length)):
            x = pyro.sample(
                'x_{}_{}'.format(i,t), 
                dist.Categorical(probs_x[x]),
                infer={'enumerate': 'parallel'}
            )
           
            with pixels_plate:
                pyro.sample(
                    'y_{}_{}'.format(i, t),
                    dist.Bernoulli(probs_y[x.squeeze(-1)]),
                    obs = sequence[t]
                    )


In [None]:
#We will use the second faster model and see if it works
def model_2(sequences, num_states=2, batch_size=batch_size):
    
    num_sequences, length, data_dim = map(int, sequences.shape)


    #Prior on the transition probabilities
    probs_x = pyro.sample(
        "probs_x", 
        dist.Dirichlet(0.9 * torch.eye(num_states) + 0.1).to_event(1)
    )


    probs_y = pyro.sample(
        "probs_y", 
        dist.Beta(0.1, 0.9).expand([num_states, 2, data_dim]).to_event(3)
    )   

    pixels_plate = pyro.plate('pixels',data_dim, dim=-1)
    with pyro.plate('sequences', num_sequences, batch_size, dim=-2) as batch: 
        x = 0
        y = 0
        for t in pyro.markov(range(length)):
            x = pyro.sample(
                'x_{}'.format(t), 
                dist.Categorical(probs_x[x]),
                infer={'enumerate': 'parallel'}
            ).long()
           
            with pixels_plate as tones:
                y = pyro.sample(
                    'y_{}'.format(t),
                    dist.Bernoulli(probs_y[x, y, tones]),
                    obs = sequences[batch,t]
                    ).long()

In [None]:
# This is an arHMM model. 
# I am ussing it as a way of having a threshold with which to compare our model
# If our model is just as good as an arHMM model then there is no point in having something more complicated
def model_1(sequences, num_states=2, batch_size=batch_size):
    with ignore_jit_warnings():
        num_sequences, length, data_dim = map(int, sequences.shape)


    #Prior on the transition probabilities
    probs_x = pyro.sample(
        "probs_x", 
        dist.Dirichlet(0.9 * torch.eye(num_states) + 0.1).to_event(1)
    )


    probs_y = pyro.sample(
        "probs_y", 
        dist.Beta(0.1, 0.9).expand([num_states, data_dim]).to_event(2)
    )   

    pixels_plate = pyro.plate('pixels',data_dim, dim=-1)
    with pyro.plate('sequences', num_sequences, batch_size, dim=-2) as batch: 
        x = int(np.random.rand * num_states)
        for t in pyro.markov(range(length)):
            x = pyro.sample(
                'x_{}_{}'.format(i,t), 
                dist.Categorical(probs_x[x]),
                infer={'enumerate': 'parallel'}
            )
           
            with pixels_plate:
                pyro.sample(
                    'y_{}_{}'.format(i, t),
                    dist.Bernoulli(probs_y[x.squeeze(-1)]),
                    obs = sequences[batch,t]
                    )

In [None]:
# Here we put the code for a neural arhmm
class PixelGenerator(nn.Module): 
    def __init__(self, args, data_dim):
        """
        I will make args a dictionary so that I can pass the data
        """
        self.args = args
        self.data_dim = data_dim 
        super().__init__()
        self.x_to_hidden = nn.Linear(args['hidden_dim'], args['nn_dim'])
        self.y_to_hidden = nn.Linear(args['nn_channels'] * data_dim, args['nn_dim'])
        self.conv = nn.Conv1d(1, args['nn_channels'], 3, padding=1)
        self.hidden_to_logits = nn.Linear(args['nn_dim'], data_dim)
        self.relu = nn.ReLU()
    def forward(self, x , y): 
        x_onehot = y.new_zeros(x.shape[:-1] + (self.args['hidden_dim'],)).scatter_(-1, x, 1)
        y_conv = self.relu(self.conv(y.reshape(-1,1, self.data_dim))).reshape(y.shape[:-1] + (-1,))
        h = self.relu(self.x_to_hidden(x_onehot) + self.y_to_hidden(y_conv))
        return self.hidden_to_logits(h)


pixel_generator = None
def model_5(sequences, num_states, args, batch_size=None):
    num_sequences, max_length, data_dim = map(int, sequences.shape)
    global pixel_generator
    if pixel_generator is None: 
        pixel_generator = PixelGenerator(args, data_dim)
    
    probs_x = pyro.sample('probs_x', dist.Dirichlet(0.9 * torch.eye(args['hidden_dim']) + 0.1).to_event(1))
    
    with pyro.plate('sequences', num_sequences, batch_size, dim=-2) as batch: 
        x = 0 
        y = torch.zeros(data_dim)
        for t in pyro.markov(range(max_length)):
            x = pyro.sample(
                "x_{}".format(t), dist.Categorical(probs_x[x]),
                infer={'enumerate':'parallel'})

            with pyro.plate('tones_{}'.format(t), data_dim, dim=-1):
                y = pyro.sample(
                        'y_{}'.format(t),
                        dist.Bernoulli(logits=pixel_generator(x,y)),
                        obs=sequences[batch,t])


In [None]:
#Arguments for our neural networks

args = {
    'hidden_dim':2,
    'nn_dim':48,
    'nn_channels':1
}

In [None]:
model = model_5
guide = AutoDelta(poutine.block(
    model, 
    expose_fn=lambda msg: msg['name'].startswith('probs_')
))

#this depends on whether we are using model zero or not
first_available_dim = -3
guide_trace = poutine.trace(guide).get_trace(
    sequences, states, batch_size = batch_size, args= args
)
model_trace = poutine.trace(
    poutine.replay(poutine.enum(model, first_available_dim), guide_trace)).get_trace(
        sequences, batch_size=batch_size, num_states=states, args=args
    )
print(model_trace.format_shapes())

In [None]:
# Training args
lr = 0.05
num_steps = 500
max_plate_nesting = 2
batch_size = 10
report_freq = 2


In [None]:
pyro.clear_param_store()
optim = Adam({'lr': lr})
elbo = TraceEnum_ELBO(
    max_plate_nesting = max_plate_nesting)


svi = SVI(model, guide, optim, elbo)

loss = []


for step in range(num_steps):
    loss.append(svi.step(sequences, num_states=states, batch_size=batch_size, args=args)/num_observations)
    if step % report_freq == 0:
        print("step: svi - step {}, loss {}".format(step,loss[-1]))



In [None]:
model_title = 'arhmm '
plt.plot(-np.array(loss))
#plt.title(model_title + str(states))
#plt.savefig("elbo_"+model+"_"+str(states)+".png",dpi=300)


In [None]:
# This is a funciton to check that our values make sense
def ppc_vanilla_hmm(probs_x, probs_y, length_sample):
    x = [0]
    y=[]
    for t in range(length_sample):
        x.append(dist.Categorical(probs_x[x[-1]]).sample())
        y.append(dist.Bernoulli(probs_y[x[-1]]).sample())
    return x, y

def ppc_arhmm(probs_x, probs_y, length_sample, dim_data): 
    x = [int(np.random.rand()*states)]
    y = [torch.zeros(dim_data).long()]
    pixels = torch.tensor(list(range(dim_data)))
    print(x)
    for t in range(length_sample):
        x.append(dist.Categorical(probs_x[x[-1]]).sample())
        y.append(dist.Bernoulli(probs_y[x[-1], y[-1], pixels]).sample().long())
    
    return x, y


In [None]:
values = pyro.get_param_store()
probs_x = values['AutoDelta.probs_x']
probs_y = values['AutoDelta.probs_y']
x, y = ppc_arhmm(probs_x, probs_y, 1000, 15)
y_np = np.array([a.numpy() for a in y])
x_np = np.array([a.numpy() for a in x[1:]])
plt.imshow(y_np[:100].T, aspect='auto')
plt.show()
#plt.savefig('simulated_data_arhmm_10_states.png', dpi=300)



In [None]:
plt.imshow(np.expand_dims(x_np,-1).T, aspect='auto')
plt.show()


In [None]:
@infer_discrete(first_available_dim=-3, temperature=0)
@config_enumerate
def viterbi_decoder_arhmm(sequence, probs_x, probs_y, num_states=2):
    length, data_dim = map(int, sequence.shape)
    #Prior on the transition probabilities
    pixels_plate = pyro.plate('pixels',data_dim, dim=-1)
    x_list = [0]
    x= 0
    y = 0
    for t in pyro.markov(range(length)):
        x = pyro.sample(
            'x_{}'.format(t), 
            dist.Categorical(probs_x[x]),
            ).long()
        x_list.append(x)
           
        with pixels_plate as tones:
            y = pyro.sample(
                'y_{}'.format(t),
                dist.Bernoulli(probs_y[x, y, tones]),
                 obs = sequence[t]
                                  ).long()

    return x_list


In [None]:
infered_states_0 = [a.numpy() for a in viterbi_decoder_arhmm(sequences[0], probs_x, probs_y)[1:30]]
plt.imshow(np.expand_dims(infered_states_0,-1).T, aspect='auto')
plt.show()


In [None]:
plt.imshow(sequences[0][1:100].T, aspect='auto')

In [None]:
np.random.rand()*14