In [1]:
from CRF import CRF
from utils import crf_train_loop
import numpy as np
import torch

In [2]:
# two dice one is fair, one is loaded
fair_dice = np.array([1/6]*6)
loaded_dice = np.array([0.04,0.04,0.04,0.04,0.04,0.8])

probabilities = {'fair': fair_dice,
                'loaded': loaded_dice}

In [5]:
# if dice is fair at time t, 0.6 chance we stay fair, 0.4 chance it is loaded at time 2
transition_mat = {'fair': np.array([0.8, 0.2, 0.0]),
                 'loaded': np.array([0.65, 0.35, 0.0]),
                 'start': np.array([0.5, 0.5, 0.0])}
states = ['fair', 'loaded', 'start']
state2ix = {'fair': 0,
           'loaded': 1,
           'start': 2}

log_likelihood = np.hstack([np.log(fair_dice).reshape(-1,1), 
                            np.log(loaded_dice).reshape(-1,1)])

In [6]:
def simulate_data(n_timesteps):
    data = np.zeros(n_timesteps)
    prev_state = 'start'
    state_list = np.zeros(n_timesteps)
    for n in range(n_timesteps):
        next_state = np.random.choice(states, p=transition_mat[prev_state])
        state_list[n] = state2ix[next_state]
        next_data = np.random.choice([0,1,2,3,4,5], p=probabilities[next_state])
        data[n] = next_data
    return data, state_list

In [10]:
n_obs = 15
rolls = np.zeros((5000, n_obs)).astype(int)
targets = np.zeros((5000, n_obs)).astype(int)

for i in range(5000):
    data, dices = simulate_data(n_obs)
    rolls[i] = data.reshape(1, -1).astype(int)
    targets[i] = dices.reshape(1, -1).astype(int)


In [7]:
model = CRF(2, log_likelihood)

In [8]:
model = crf_train_loop(model, rolls, targets, 1, 0.01)

Epoch 0: Batch 0/100 loss is 20.5139
Epoch 0: Batch 1/100 loss is 10.8498
Epoch 0: Batch 2/100 loss is 11.8594
Epoch 0: Batch 3/100 loss is 11.9312
Epoch 0: Batch 4/100 loss is 11.0403
Epoch 0: Batch 5/100 loss is 12.0504
Epoch 0: Batch 6/100 loss is 12.0493
Epoch 0: Batch 7/100 loss is 11.3213
Epoch 0: Batch 8/100 loss is 11.4374
Epoch 0: Batch 9/100 loss is 11.4760
Epoch 0: Batch 10/100 loss is 10.3746
Epoch 0: Batch 11/100 loss is 10.7225
Epoch 0: Batch 12/100 loss is 10.2565
Epoch 0: Batch 13/100 loss is 10.4401
Epoch 0: Batch 14/100 loss is 10.3484
Epoch 0: Batch 15/100 loss is 10.8308
Epoch 0: Batch 16/100 loss is 11.3905
Epoch 0: Batch 17/100 loss is 10.5387
Epoch 0: Batch 18/100 loss is 9.6910
Epoch 0: Batch 19/100 loss is 10.5773
Epoch 0: Batch 20/100 loss is 9.7850
Epoch 0: Batch 21/100 loss is 9.7704
Epoch 0: Batch 22/100 loss is 10.0050
Epoch 0: Batch 23/100 loss is 9.9496
Epoch 0: Batch 24/100 loss is 10.0994
Epoch 0: Batch 25/100 loss is 10.9641
Epoch 0: Batch 26/100 loss

In [11]:
torch.save(model.state_dict(), "./checkpoint.hdf5")

In [8]:
model.load_state_dict(torch.load("./checkpoint.hdf5"))

In [11]:
preds = model.forward(rolls[1])

In [35]:
rolls[0]

array([2, 3, 4, 5, 5, 5, 1, 5, 3, 2, 5, 5, 5, 3, 5])

In [36]:
model.forward(rolls[0])[0]

array([0, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0])

In [37]:
targets[0]

array([0, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 1])

In [26]:
list(model.parameters())[0].data.numpy()

array([[-0.86563134, -0.40748784, -0.54984874],
       [-1.3820231 , -0.59524935, -0.516026  ]], dtype=float32)