## Pyro, Bayesian analysis, and attempt on BCRF (Qi et al 2005)

In [None]:
import pyro
from pyro.distributions import Normal, Uniform
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO, EmpiricalMarginal, TraceEnum_ELBO, JitTraceEnum_ELBO
from pyro.infer.mcmc import MCMC, NUTS
from pyro.optim import Adam
from pyro.util import ignore_jit_warnings
from pyro.contrib.autoguide import AutoDelta

import torch
import torch.nn as nn

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
plt.style.use('ggplot')
import seaborn as sns

In [None]:
# DATA_URL = "https://d2fefpcigoriu7.cloudfront.net/datasets/rugged_data.csv"
# data = pd.read_csv(DATA_URL, encoding="ISO-8859-1")
# df = data[["cont_africa", "rugged", "rgdppc_2000"]]
# df = df[np.isfinite(df.rgdppc_2000)]
# df["rgdppc_2000"] = np.log(df["rgdppc_2000"])

In [None]:
# data = torch.tensor(df.values, dtype=torch.float)
# x_data, y_data = data[:, :-1], data[:, -1]

In [None]:
# def model(x_data, y_data):
#     n = len(x_data)

#     # w, b, sigma parameter is outside of plate, independent of N
#     weight = pyro.sample("w", dist.Normal(torch.zeros(1, 2), torch.ones(1, 2)))
#     bias = pyro.sample("b", dist.Normal(torch.tensor([[0.]]), torch.tensor([[100.]])))
#     sigma = pyro.sample("epsilon", Uniform(0., 10.))

#     with pyro.plate("map", n):
#         mu = (x_data[:, 0] * weight[0][0] + x_data[:, 1] * weight[0][1] + bias).squeeze(1)
#         yhat = pyro.sample("yhat", Normal(mu, sigma), obs=y_data)
#         return yhat

## Stochastic Variational Inference with spherical gaussian on linear regression

In [None]:
# from pyro.contrib.autoguide import AutoDiagonalNormal
# mean_field_guide = AutoDiagonalNormal(model)

# # inject callables into SVI instantiation
# svi = SVI(model, mean_field_guide, Adam({"lr": 0.03}), loss=Trace_ELBO(), num_samples=1000)

In [None]:
# pyro.clear_param_store()
# for j in range(2000):
#     loss = svi.step(x_data, y_data)
#     if j % 500 == 0:
#         print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(data)))

In [None]:
# for name, value in pyro.get_param_store().items():
#     print(name, pyro.param(name))

## No-U-turn Sampler

In [None]:
# pyro.clear_param_store()

# nuts = NUTS(model)
# sampler = MCMC(nuts,
#                num_samples=2000,
#                num_chains=1,
#                # burn-in
#                warmup_steps=100)
# traces = sampler.run(x_data, y_data)

In [None]:
# posteriors = traces.marginal(["w", "b", "epsilon"])

In [None]:
# a = posteriors.empirical["epsilon"]

In [None]:
# def get_marginal(traces, sites):
#     return EmpiricalMarginal(traces, sites)._get_samples_and_weights()[0].detach().cpu().numpy()

In [None]:
# posterior_weight = posteriors.empirical["w"]
# posterior_bias = posteriors.empirical["b"]
# posterior_epsilon = posteriors.empirical["epsilon"]

In [None]:
# sns.distplot(posterior_weight((10000, )).squeeze(1)[:, 0])
# sns.distplot(posterior_weight((10000, )).squeeze(1)[:, 1])
# sns.distplot(posterior_bias((10000, )))
# sns.distplot(posterior_epsilon((10000, )))

## vanila HMM before going BCRF (Qi et al. 2005)

In [None]:
# adopted from HMM tutorials at: https://pyro.ai/examples/hmm.html

In [None]:
# Pyro's poutine handles effects 
from pyro import poutine
import dmm.polyphonic_data_loader as poly

In [None]:
class MockArgs():
    num_steps=500
    hidden_dim=16
    nn_dim=48
    batch_size=32
    nn_channels=2
    learning_rate=0.05
    truncate=None
    print_shapes=False
    jit=True
    cuda=True
    raftery_parameterization=True
args = MockArgs()

In [None]:
if args.cuda:
    torch.set_default_tensor_type('torch.cuda.FloatTensor')

In [None]:
# JSB Chorales dataset, pianos keys pressed out of 88 keys.
# Modeling Temporal Dependencies in High-Dimensional Sequences: Application to Polyphonic Music Generation and Transcription
# http://www-etud.iro.umontreal.ca/~boulanni/icml2012
data = poly.load_data(poly.JSB_CHORALES)

In [None]:
sequences = data['train']['sequences']
lengths = data['train']['sequence_lengths']

In [None]:
# only some keys of this piano were pressed, some not, so reduce dimension of the sequence
notes_pressed = ((sequences == 1).sum(0).sum(0) > 0)

In [None]:
sequences = sequences[:, :, notes_pressed]

In [None]:
if args.truncate:
    lengths.clamp_(max=args.truncate)
    sequences = sequences[:, :args.truncate]
num_observations = float(lengths.sum())

In [None]:
def model(sequences, lengths, args, batch_size=None, include_prior=True):
    # Sometimes it is safe to ignore jit warnings. Here we use the
    # pyro.util.ignore_jit_warnings context manager to silence warnings about
    # conversion to integer, since we know all three numbers will be the same
    # across all invocations to the model.
    with ignore_jit_warnings():
        num_sequences, max_length, data_dim = map(int, sequences.shape)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    with poutine.mask(mask=include_prior):
        # to_event seperates n right-most dimension as event dimension from batch dimension.
        # transition probability p(y_t | y_t-1)
        probs_x = pyro.sample("probs_x", dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        # emission probability p(x_t | y_t)
        probs_y = pyro.sample("probs_y", dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2))

    # nodes_plate in shape [DATA_DIM], in this case 51 dimension of each observation
    nodes_plate = pyro.plate("nodes", data_dim, dim=-1)
    # We subsample batch_size items out of num_sequences items. Note that since
    # we're using dim=-1 for the notes plate, we need to batch over a different
    # dimension, here dim=-2.

    with pyro.plate("sequences", size=num_sequences, subsample_size=batch_size, dim=-2) as batch:
        # here batch is indice of subsampled, in shape [BATCH_SIZE, 1].
        batch_lengths = lengths[batch]
        x = 0
        for t in pyro.markov(range(max_length if args.jit else batch_lengths.max())):
            with poutine.mask(mask=(t < batch_lengths).unsqueeze(-1)):
                hidden_states = probs_x[x]
                x = pyro.sample("x_{}".format(t), dist.Categorical(hidden_states), infer={"enumerate": "parallel"})
                # x is sampled from categorical distribution of [0, 1, 2, ..., hidden_dim], in batch
                # x is the hidden states in shape [BATCH_SIZE] at markov process at time t, of latent variable probs_x
                # nodes_plate is plate of N := data_dim
                with nodes_plate:
                    # y in size [BATCH_SIZE, DATA_DIM], generated by hidden_states at time t, which is x.
                    probs_y_given_hidden_state = probs_y[x.squeeze(-1)]
                    # bernoulli distribution because the music tones of data_dim 51 dimension is binary
                    y = pyro.sample("y_{}".format(t), dist.Bernoulli(probs_y_given_hidden_state),
                                    # observed these y at t
                                    obs=sequences[batch, t])


In [None]:
from pyro.contrib.autoguide import AutoDelta, AutoDiagonalNormal
# Delta distribution for constrained MAP inference
guide = AutoDelta(poutine.block(model, expose_fn=lambda msg: msg["name"].startswith("probs_")))

In [None]:
elbo = JitTraceEnum_ELBO(max_plate_nesting=2, strict_enumeration_warning=True)
optim = Adam({'lr': 1e-3})
svi = SVI(model, guide, optim, elbo)

In [None]:
# pyro.clear_param_store()
# for step in range(args.num_steps * 10):
#     loss = svi.step(sequences, lengths, args=args, batch_size=args.batch_size, include_prior=True)
#     if not step % 50:
#         print('{: >5d}\t{}'.format(step, loss / num_observations))

## Applying Turk data 20190716

In [None]:
import string
from pathlib import Path
from typing import Union, Tuple, Iterable
from smart_open import open

In [None]:
def simple_bioes_parser(path: Union[Path, str]) -> Tuple:
    """simple parser for BIOES to BIO and remove puntuations"""
    document_container = []
    sequence_container = []
    length_container = []

    with open(path, "r") as f:
        for line in f.readlines():
            length = len(sequence_container)
            line = line.rstrip()

            # naively detect sentence boundary
            if len(line) < 2:
                if length > 0:
                    length_container.append(length)
                    document_container.append(list(zip(*sequence_container)))
                    sequence_container = []
                    continue

            try:
                word, entity_type = line.split('\t')
            except ValueError:
                continue

            if word in string.punctuation:
                continue

            # skip lemmatization for later.
            word = word.lower()

            if "-" in entity_type:
                a, b = entity_type.split("-")
                a = a.translate(str.maketrans("ES", "IB"))
                entity_type = "-".join([a, b])

            sequence_container.append(tuple([word, entity_type]))

        for d, l in zip(document_container, length_container):
            tokens, ents = d
            assert len(tokens) == l
            assert len(ents) == l

    return document_container

In [None]:
raw_documents = simple_bioes_parser("./data/turk_ner_20190716.txt")
raw_lengths = [len(d[0]) for d in raw_documents]

In [None]:
START_TAG = "<START>"
STOP_TAG = "<STOP>"

word_to_ix = {}
tag_to_ix = {}

# tag_to_ix[START_TAG] = len(tag_to_ix)

for sentence, tags in raw_documents:
    for word in sentence:
        if word not in word_to_ix:
            word_to_ix[word] = len(word_to_ix)
    for tag in tags:
        if tag not in tag_to_ix:
            tag_to_ix[tag] = len(tag_to_ix)

# tag_to_ix[STOP_TAG] = len(tag_to_ix)

In [None]:
tokenized_sents = [i[0] for i in raw_documents]

In [None]:
from gensim.corpora import Dictionary

In [None]:
dct = Dictionary(tokenized_sents)
dct.filter_extremes()
dct.compactify()
# unknown token last in the vocabulary
dct.token2id["UNK"] = len(dct)

In [None]:
data_dim = len(dct)

In [None]:
class MockArgs():
    data_dim = len(dct)
    num_steps=1000
    hidden_dim=len(tag_to_ix)
    nn_dim=48
    batch_size=32
    nn_channels=2
    learning_rate=0.05
    truncate=200
    print_shapes=False
    jit=True
    cuda=True
    raftery_parameterization=True

args = MockArgs()

In [None]:
def tensorize_entities(seq, dictionary):
    idxs = [tag_to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)

In [None]:
def tensorize_sentence(seq):
    idxs = dct.doc2idx(seq, unknown_word_index=0)
    return torch.tensor(idxs, dtype=torch.long)

In [None]:
tokenized_sequences, tokenized_entities = list(zip(*raw_documents))

In [None]:
idx_sents = list(map(tensorize_sentence, tokenized_sequences))

In [None]:
idx_ents = list(map(lambda x: tensorize_entities(x, tag_to_ix), tokenized_entities))

In [None]:
lengths = torch.tensor([len(d[0]) for d in raw_documents], dtype=torch.long)

In [None]:
def pad_sequence(data: Iterable):
    lengths = [d.shape[0] for d in data]
    max_length = max(lengths)

    template = torch.zeros(len(data), max_length, dtype=torch.long)
    for k, tensor in enumerate(data):
        template[k, :lengths[k]] = tensor

    return template

In [None]:
def BHMM(sequences, entities, lengths, args, include_prior=True):
    with ignore_jit_warnings():
        num_sequences = len(sequences)
        max_length = max(lengths)
        assert lengths.shape == (num_sequences, )
        assert lengths.max() <= max_length

    # to_event seperates n right-most dimension as event dimension from batch dimension.
    with poutine.mask(mask=include_prior):
        # transition probability p(y_t | y_t-1)
        probs_x = pyro.sample("probs_x", dist.Dirichlet(0.6 * torch.eye(args.hidden_dim) + 0.4).to_event(1))
        # emission probability p(x_t | y_t)
        probs_y = pyro.sample("probs_y", dist.Dirichlet(torch.rand([args.hidden_dim, args.data_dim]) + 0.1).to_event(1))

    with pyro.plate("sequences", size=num_sequences, subsample_size=args.batch_size, dim=-2) as batch:
        # here batch is indice of subsampled, in shape [BATCH_SIZE, 1].
        batch_lengths = lengths[batch]

        # start index of transition matrix for every sequence in the batch
        x = 0
        for t in pyro.markov(range(max_length)):
            with poutine.mask(mask=(t < batch_lengths).unsqueeze(-1)):
                hidden_states = probs_x[x]
                x = pyro.sample("x_{}".format(t),
                                dist.Categorical(hidden_states),
                                infer={"enumerate": "parallel"},
                                obs=entities[batch, t].unsqueeze(-1))
                # x is sampled from categorical distribution of [0, 1, 2, ..., tag_size], in batch
                # x is the hidden states in shape [BATCH_SIZE] at markov process at time t, of latent variable probs_x
                probs_y_given_st = probs_y[x]
                # y in size [BATCH_SIZE, DATA_DIM], generated by hidden_states at time t, which is x.
                # Categorical distribution for observable word generated from latent variable at z of certain value.
                y = pyro.sample("y_{}".format(t),
                                dist.Categorical(probs_y_given_st),
                                infer={"enumerate": "parallel"},
                                obs=sequences[batch, t].unsqueeze(-1))

In [None]:
guide = AutoDelta(poutine.block(BHMM, expose_fn=lambda msg: msg["name"].startswith("probs_")))
elbo = Trace_ELBO(max_plate_nesting=2, strict_enumeration_warning=True)
svi = SVI(BHMM, guide, Adam({'lr': 1e-4}), elbo)

In [None]:
sequences = pad_sequence(idx_sents)
entities = pad_sequence(idx_ents)

In [None]:
if args.truncate:
    lengths.clamp_(max=args.truncate)
    sequences = sequences[:, :args.truncate]
    entities = entities[:, :args.truncate]
num_observations = float(lengths.sum())

In [None]:
pyro.clear_param_store()

nuts = NUTS(BHMM)
sampler = MCMC(nuts,
               num_samples=5000,
               num_chains=1,
               # burn-in
               warmup_steps=200)
traces = sampler.run(sequences, entities, lengths, args=args, include_prior=True)

In [None]:
# pyro.clear_param_store()
# for step in range(args.num_steps * 10):
#     loss = svi.step(sequences, entities, lengths, args=args, include_prior=True)
#     if not step % 50:
#         print('{: >5d}\t{}'.format(step, loss / num_observations))

In [None]:
def plot_posterior(posterior):
    # taken from 
    # generate Marginal distribution for `transition_prob` from posterior
    marginal = posterior.marginal(["transition_prob"])
    # get support of the marginal distribution
    trace_transition_prob = marginal.support()["transition_prob"]  # shape: num_samples x 3 x 3

    plt.figure(figsize=(10, 6))
    for i in range(num_categories):
        for j in range(num_categories):
            sns.distplot(trace_transition_prob[:, i, j], hist=False, kde_kws={"lw": 2},
                         label="transition_prob[{}, {}], true value = {:.2f}".format(i, j, transition_prob[i, j]))
    plt.xlabel("Probability", fontsize=13)
    plt.ylabel("Frequency", fontsize=13)
    plt.title("Transition probability posterior", fontsize=15)

In [None]:
num_categories = 3
num_words = 10
num_supervised_data = 100
num_data = 600

transition_prior = torch.empty(num_categories).fill_(1.)
emission_prior = torch.empty(num_words).fill_(0.1)

transition_prob = dist.Dirichlet(transition_prior).sample(torch.Size([num_categories]))
emission_prob = dist.Dirichlet(emission_prior).sample(torch.Size([num_categories]))

In [None]:
def equilibrium(mc_matrix):
    n = mc_matrix.size(0)
    return (torch.eye(n) - mc_matrix.t() + 1).inverse().matmul(torch.ones(n))

start_prob = equilibrium(transition_prob)

# simulate data
categories, words = [], []
for t in range(num_data):
    if t == 0 or t == num_supervised_data:
        category = dist.Categorical(start_prob).sample()
    else:
        category = dist.Categorical(transition_prob[category]).sample()
    word = dist.Categorical(emission_prob[category]).sample()
    categories.append(category)
    words.append(word)
categories, words = torch.stack(categories), torch.stack(words)

# split into supervised data and unsupervised data
supervised_categories = categories[:num_supervised_data]
supervised_words = categories[:num_supervised_data]
unsupervised_words = categories[num_supervised_data:]

In [None]:
def supervised_hmm(categories, words):
    with pyro.plate("prob_plate", num_categories):
        transition_prob = pyro.sample("transition_prob", dist.Dirichlet(transition_prior))
        emission_prob = pyro.sample("emission_prob", dist.Dirichlet(emission_prior))

    category = categories[0]  # start with first category
    for t in range(len(words)):
        if t > 0:
            category = pyro.sample("category_{}".format(t), dist.Categorical(transition_prob[category]),
                                   obs=categories[t])
        pyro.sample("word_{}".format(t), dist.Categorical(emission_prob[category]), obs=words[t])

In [None]:
# enable jit_compile to improve the sampling speed
nuts_kernel = NUTS(supervised_hmm, jit_compile=True, ignore_jit_warnings=True)
mcmc = MCMC(nuts_kernel, num_samples=100)
# we run MCMC to get posterior
supervised_posterior = mcmc.run(supervised_categories, supervised_words)
# after that, we plot the posterior
plot_posterior(supervised_posterior)