In [14]:
import random
import string
import os
import itertools
import torch.nn as nn
import torch
import torch.nn.functional as F
import pyro
import numpy as np
import pyro.optim as optim
import pyro.distributions as dist
import pyro.infer
import pyro.optim
import time
from pyro.infer import SVI, Trace_ELBO, TraceGraph_ELBO
from PIL import Image
from torch.distributions import constraints
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils

In [15]:
pcfg = {
  "name": "astronomers",
  "terminals": ["astronomers", "ears", "saw", "stars", "telescopes", "with"],
  "non_terminals": ["S", "NP", "VP", "PP", "P", "V"],
  "productions": {
    "S": [["NP", "VP"]],
    "NP": [["NP", "PP"], ["astronomers"], ["ears"], ["saw"], ["stars"], ["telescopes"]],
    "VP": [["V", "NP"], ["VP", "PP"]],
    "PP": [["P", "NP"]],
    "P": [["with"]],
    "V": [["saw"]]
  },
  "start_symbol": "S",
}
true_production_probs = {
    "S": [1.0],
    "NP": [0.4, 0.1, 0.18, 0.04, 0.18, 0.1],
    "VP": [0.7, 0.3],
    "PP": [1.0],
    "P": [1.0],
    "V": [1.0]
}
max_depth_parse_tree = 30

In [43]:
# helper functions from code of paper Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow 
def word_to_index(word, terminals):
    """Convert word to int.
    Args:
        word: string
        terminals: set of terminal strings
    Returns: int; -1 if word is not in terminals
    """

    try:
        return sorted(terminals).index(word)
    except ValueError:
        return -1

def sentence_to_indices(sentence, terminals):
    """Convert sentence to list of indices.
    Args:
        sentence: list of strings
        terminals: set of terminal strings
    Returns: list of indices of length len(sentence); index is -1 if word is
        not in terminals
    """
    return [word_to_index(word, terminals) for word in sentence]


def _indices_to_string(indices):
    return ''.join([string.printable[index] for index in indices])


def _sentence_to_string(sentence, terminals):
    return _indices_to_string(sentence_to_indices(sentence, terminals))

def word_to_one_hot(word, terminals):
    """Convert word to its one-hot representation.
    Args:
        word: string
        terminals: set of terminal strings
    Returns: one hot tensor of shape [len(terminals)] or zeros if word is not
        in terminals
    """
    num_bins = len(terminals)
    try:
        i = sorted(terminals).index(word)
        return one_hot(torch.tensor([i]), num_bins)[0]
    except ValueError:
        return torch.zeros((num_bins,))

def sentence_to_one_hots(sentence, terminals):
    """Convert sentence to one-hots.
    Args:
        sentence: list of strings
        terminals: set of terminal strings
    Returns: matrix where ith row corresponds to a one-hot of ith word, shape
        [num_words, len(terminals)]
    """
    return torch.cat([word_to_one_hot(word, terminals).unsqueeze(0)
                      for word in sentence])

def one_hot(indices, num_bins):
    """Returns one hot vector given indices.
    Args:
        indices: tensors
        num_bins: number of bins
    Returns: matrix where ith row corresponds to a one
        hot version of indices[i].
    """
    return torch.zeros(len(indices), num_bins).scatter_(
        1, indices.long().unsqueeze(-1), 1)

def get_sample_address_embedding(non_terminal, non_terminals):
    """Returns an embedding of the sample address of a production.
    Args:
        non_terminal: string
        non_terminals: set of non_terminal symbols
    Returns: one-hot vector
    """
    num_bins = len(non_terminals)
    i = sorted(non_terminals).index(non_terminal)
    return one_hot(torch.tensor([i]), num_bins)[0]

In [28]:
# example:
# tree = ['S', ['NP', ['NP', 'saw'], ['PP', ['P', 'with'], ['NP', ['NP', 'telescopes'], ['PP', ['P', 'with'], ['NP', 'astronomers']]]]], ['VP', ['V', 'saw'], ['NP', 'astronomers']]]
# sentence = ['saw', 'with', 'telescopes', 'with', 'astronomers', 'saw', 'astronomers']
def model(observations=None):
    
    production_logits = {
        k: torch.randn((len(v),))
        for k, v in pcfg['productions'].items()
    }
    
    def get_leaves(tree):
        
        if isinstance(tree, list):
            return list(itertools.chain.from_iterable(
                [get_leaves(subtree) for subtree in tree[1:]]))
        else:
            return [tree]
    
    def sample_parse_tree(symbol = None, depth = 0, suffix = 0):
        if symbol is None:
            symbol = pcfg['start_symbol']
        if symbol in pcfg['terminals']:
            return symbol
        elif depth > max_depth_parse_tree:
            return symbol
        else:
            distribution = dist.Categorical(logits=production_logits[symbol])
            production_index = pyro.sample(f"production_index_{depth}_{suffix}", distribution)
            production = pcfg['productions'][symbol][production_index]
            return [symbol] + \
                [sample_parse_tree(s, depth=depth + 1, suffix = i) for i, s in enumerate(production)]
    tree = sample_parse_tree()
    sentence = get_leaves(tree)
    # how do we state observations, levenshtein distance?
    observation_type = 1
#     if observation_type == 1:
#         # pad and delta distribution
#         gen_sentence = sentence_to_one_hots(sentence, pcfg["terminals"]
#         obs_sentence = sentence_to_one_hots(observations["obs_sentence"], pcfg["terminals"])
#         pyro.sample("obs", )
#     if observation_type == 2:
#         # pad and normal distribution but very low sigma
#     if observation_type == 3:
#         # compute levenshtein distance between gen and obs sentences and observe with 0

model()

['S', ['NP', ['NP', 'saw'], ['PP', ['P', 'with'], ['NP', ['NP', 'telescopes'], ['PP', ['P', 'with'], ['NP', 'astronomers']]]]], ['VP', ['V', 'saw'], ['NP', 'astronomers']]]
['saw', 'with', 'telescopes', 'with', 'astronomers', 'saw', 'astronomers']
['astronomers', 'ears', 'saw', 'stars', 'telescopes', 'with']
2545020
tensor([[0., 0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 1., 0.],
        [0., 0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0.]])


In [47]:
obs_embedding_dim = 100
inference_hidden_dim = 100
sample_embedding_dim = max([len(v) for _, v in pcfg['productions'].items()])
sample_address_embedding_dim = len(pcfg['non_terminals'])
sentence_embedder_gru = nn.GRU(
                input_size=len(pcfg['terminals']),
                hidden_size=obs_embedding_dim,
                num_layers=1)

inference_gru = nn.GRUCell(
            input_size=obs_embedding_dim + sample_embedding_dim
            + sample_address_embedding_dim,
            hidden_size=inference_hidden_dim)

proposal_layers = nn.ModuleDict({
            k: nn.Sequential(nn.Linear(inference_hidden_dim, 50),
                             nn.ReLU(),
                             nn.Linear(50, 25),
                             nn.ReLU(),
                             nn.Linear(25, len(v)))
            for k, v in pcfg['productions'].items()})

def get_inference_gru_output(obs_embedding,
                             previous_sample_embedding,
                             sample_address_embedding,
                             inference_hidden):
        input = torch.cat([obs_embedding,
                       previous_sample_embedding,
                       sample_address_embedding]).unsqueeze(0)
        return inference_gru(input, inference_hidden.unsqueeze(0)).squeeze(0)
    
def get_logits_from_inference_gru_output(inference_gru_output,
                                             non_terminal):
        """Args:
            inference_gru_output: tensor of shape [inference_hidden_dim]
            non_terminal: string
        Returns: logits for Categorical distribution
        """

        input_ = inference_gru_output.unsqueeze(0)
        return proposal_layers[non_terminal](input_).squeeze(0)
    
def guide(observations=None):
    obs_embedding, _ = sentence_embedder_gru(sentence_to_one_hots(observations["obs_sentence"], pcfg["terminals"]).unsqueeze(1))
    obs_embedding = obs_embedding[-1][0]
    def sample_parse_tree(symbol = None, previous_sample_embedding = None, inference_hidden = None, depth = 0, suffix = 0):
        if symbol is None:
            symbol = pcfg['start_symbol']
    
        if previous_sample_embedding is None:
            previous_sample_embedding = torch.zeros(
                (sample_embedding_dim,))
        
        if inference_hidden is None:
            inference_hidden = torch.zeros((inference_hidden_dim,))
            
        if symbol in pcfg['terminals']:
            return symbol
        elif depth > max_depth_parse_tree:
            return symbol
        else:
            # one-hot representation of non-terminal
            sample_address_embedding = get_sample_address_embedding(
                symbol, pcfg['non_terminals'])
            # get the output from inference gru
            inference_gru_output = get_inference_gru_output(
                obs_embedding, previous_sample_embedding,
                sample_address_embedding, inference_hidden)
            # compute the logits from proposal layers
            logits = get_logits_from_inference_gru_output(
                inference_gru_output, symbol)
            distribution = dist.Categorical(logits=logits)
            production_index = pyro.sample(f"production_index_{depth}_{suffix}", distribution)
            production = pcfg['productions'][symbol][production_index]
            sample_embedding = one_hot(production_index.unsqueeze(0), sample_embedding_dim)[0]
            return [symbol] + \
                [sample_parse_tree(s, sample_embedding, inference_gru_output, depth=depth + 1, suffix = i) for i, s in enumerate(production)]
    return sample_parse_tree()
print(guide({"obs_sentence" : ['saw', 'with', 'telescopes', 'with', 'astronomers', 'saw', 'astronomers']}))

['S', ['NP', 'ears'], ['VP', ['V', 'saw'], ['NP', 'stars']]]
