In [1]:
def generate_toy_data(num_symbols=5, num_segments=3, max_segment_len=5):
    """Generate toy data sample with repetition of symbols (EOS symbol: 0)."""
    seq = []
    symbols = np.random.choice(
        np.arange(1, num_symbols + 1), num_segments, replace=False)
    for seg_id in range(num_segments):
        segment_len = np.random.choice(np.arange(1, max_segment_len))
        seq += [symbols[seg_id]] * segment_len
    seq += [0]
    return torch.tensor(seq, dtype=torch.int64)

In [2]:
import torch
import torch.nn.functional as F
import numpy as np

EPS = 1e-17
NEG_INF = -1e30

In [3]:
data = []
for _ in range(512):
    data.append(generate_toy_data(
        num_symbols=5,
        num_segments=3))
lengths = torch.tensor(list(map(len, data)))
inputs = torch.nn.utils.rnn.pad_sequence(data, batch_first=True)
print(inputs)

tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 4, 4,  ..., 0, 0, 0],
        [3, 3, 3,  ..., 0, 0, 0],
        ...,
        [2, 2, 5,  ..., 0, 0, 0],
        [4, 4, 5,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])


In [4]:
"""Utility functions."""

import torch
import torch.nn.functional as F
import numpy as np

EPS = 1e-17
NEG_INF = -1e30


def to_one_hot(indices, max_index):
    """Get one-hot encoding of index tensors."""
    zeros = torch.zeros(
        indices.size()[0], max_index, dtype=torch.float32,
        device=indices.device)
    return zeros.scatter_(1, indices.unsqueeze(1), 1)


def gumbel_sample(shape):
    """Sample Gumbel noise."""
    uniform = torch.rand(shape).float()
    return - torch.log(EPS - torch.log(uniform + EPS))


def gumbel_softmax_sample(logits, temp=1.):
    """Sample from the Gumbel softmax / concrete distribution."""
    gumbel_noise = gumbel_sample(logits.size())
    if logits.is_cuda:
        gumbel_noise = gumbel_noise.cuda()
    return F.softmax((logits + gumbel_noise) / temp, dim=-1)


def gaussian_sample(mu, log_var):
    """Sample from Gaussian distribution."""
    gaussian_noise = torch.randn(mu.size())
    if mu.is_cuda:
        gaussian_noise = gaussian_noise.cuda()
    return mu + torch.exp(log_var * 0.5) * gaussian_noise


def kl_gaussian(mu, log_var):
    """KL divergence between Gaussian posterior and standard normal prior."""
    return -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp(), dim=1)


def kl_categorical_uniform(preds):
    """KL divergence between categorical distribution and uniform prior."""
    kl_div = preds * torch.log(preds + EPS)  # Constant term omitted.
    return kl_div.sum(1)


def kl_categorical(preds, log_prior):
    """KL divergence between two categorical distributions."""
    kl_div = preds * (torch.log(preds + EPS) - log_prior)
    return kl_div.sum(1)


def poisson_categorical_log_prior(length, rate, device):
    """Categorical prior populated with log probabilities of Poisson dist."""
    rate = torch.tensor(rate, dtype=torch.float32, device=device)
    values = torch.arange(
        1, length + 1, dtype=torch.float32, device=device).unsqueeze(0)
    log_prob_unnormalized = torch.log(
        rate) * values - rate - (values + 1).lgamma()
    # TODO(tkipf): Length-sensitive normalization.
    return F.log_softmax(log_prob_unnormalized, dim=1)  # Normalize.


def log_cumsum(probs, dim=1):
    """Calculate log of inclusive cumsum."""
    return torch.log(torch.cumsum(probs, dim=dim) + EPS)


def generate_toy_data(num_symbols=5, num_segments=3, max_segment_len=5):
    """Generate toy data sample with repetition of symbols (EOS symbol: 0)."""
    seq = []
    symbols = np.random.choice(
        np.arange(1, num_symbols + 1), num_segments, replace=False)
    for seg_id in range(num_segments):
        segment_len = np.random.choice(np.arange(1, max_segment_len))
        seq += [symbols[seg_id]] * segment_len
    seq += [0]
    return torch.tensor(seq, dtype=torch.int64)


def get_lstm_initial_state(batch_size, hidden_dim, device):
    """Get empty (zero) initial states for LSTM."""
    hidden_state = torch.zeros(batch_size, hidden_dim, device=device)
    cell_state = torch.zeros(batch_size, hidden_dim, device=device)
    return hidden_state, cell_state


def get_segment_probs(all_b_samples, all_masks, segment_id):
    """Get segment probabilities for a particular segment ID."""
    neg_cumsum = 1 - torch.cumsum(all_b_samples[segment_id], dim=1)
    if segment_id > 0:
        return neg_cumsum * all_masks[segment_id - 1]
    else:
        return neg_cumsum


def get_losses(inputs, outputs, args, beta_b=.1, beta_z=.1, prior_rate=3.,):
    """Get losses (NLL, KL divergences and neg. ELBO).

    Args:
        inputs: Padded input sequences.
        outputs: CompILE model output tuple.
        args: Argument dict from `ArgumentParser`.
        beta_b: Scaling factor for KL term of boundary variables (b).
        beta_z: Scaling factor for KL term of latents (z).
        prior_rate: Rate (lambda) for Poisson prior.
    """

    targets = inputs.view(-1)
    all_encs, all_recs, all_masks, all_b, all_z = outputs
    input_dim = args.num_symbols + 1

    nll = 0.
    kl_z = 0.
    for seg_id in range(args.num_segments):
        seg_prob = get_segment_probs(
            all_b['samples'], all_masks, seg_id)
        preds = all_recs[seg_id].view(-1, input_dim)
        seg_loss = F.cross_entropy(
            preds, targets, reduction='none').view(-1, inputs.size(1))

        # Ignore EOS token (last sequence element) in loss.
        nll += (seg_loss[:, :-1] * seg_prob[:, :-1]).sum(1).mean(0)

        # KL divergence on z.
        if args.latent_dist == 'gaussian':
            mu, log_var = torch.split(
                all_z['logits'][seg_id], args.latent_dim, dim=1)
            kl_z += kl_gaussian(mu, log_var).mean(0)
        elif args.latent_dist == 'concrete':
            kl_z += kl_categorical_uniform(
                F.softmax(all_z['logits'][seg_id], dim=-1)).mean(0)
        else:
            raise ValueError('Invalid argument for `latent_dist`.')

    # KL divergence on b (first segment only, ignore first time step).
    # TODO(tkipf): Implement alternative prior on soft segment length.
    probs_b = F.softmax(all_b['logits'][0], dim=-1)
    log_prior_b = poisson_categorical_log_prior(
        probs_b.size(1), prior_rate, device=inputs.device)
    kl_b = args.num_segments * kl_categorical(
        probs_b[:, 1:], log_prior_b[:, 1:]).mean(0)

    loss = nll + beta_z * kl_z + beta_b * kl_b
    return loss, nll, kl_z, kl_b


def get_reconstruction_accuracy(inputs, outputs, args):
    """Calculate reconstruction accuracy (averaged over sequence length)."""

    all_encs, all_recs, all_masks, all_b, all_z = outputs

    batch_size = inputs.size(0)

    rec_seq = []
    rec_acc = 0.
    for sample_idx in range(batch_size):
        prev_boundary_pos = 0
        rec_seq_parts = []
        for seg_id in range(args.num_segments):
            boundary_pos = torch.argmax(
                all_b['samples'][seg_id], dim=-1)[sample_idx]
            if prev_boundary_pos > boundary_pos:
                boundary_pos = prev_boundary_pos
            seg_rec_seq = torch.argmax(all_recs[seg_id], dim=-1)
            rec_seq_parts.append(
                seg_rec_seq[sample_idx, prev_boundary_pos:boundary_pos])
            prev_boundary_pos = boundary_pos
        rec_seq.append(torch.cat(rec_seq_parts))
        cur_length = rec_seq[sample_idx].size(0)
        matches = rec_seq[sample_idx] == inputs[sample_idx, :cur_length]
        rec_acc += matches.float().mean()
    rec_acc /= batch_size
    return rec_acc, rec_seq


In [5]:
import torch
import torch.nn.functional as F
from torch import nn

import utils


class CompILE(nn.Module):
    """CompILE example implementation.

    Args:
        input_dim: Dictionary size of embeddings.
        hidden_dim: Number of hidden units.
        latent_dim: Dimensionality of latent variables (z).
        max_num_segments: Maximum number of segments to predict.
        temp_b: Gumbel softmax temperature for boundary variables (b).
        temp_z: Temperature for latents (z), only if latent_dist='concrete'.
        latent_dist: Whether to use Gaussian latents ('gaussian') or concrete /
            Gumbel softmax latents ('concrete').
    """
    def __init__(self, input_dim, hidden_dim, latent_dim, max_num_segments,
                 temp_b=1., temp_z=1., latent_dist='gaussian'):
        super(CompILE, self).__init__()

        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        self.max_num_segments = max_num_segments
        self.temp_b = temp_b
        self.temp_z = temp_z
        self.latent_dist = latent_dist
        print('input_dim:', input_dim)
        print('hidden_dim:', hidden_dim)
        print('latent_dim', latent_dim)
        print('max_num_segments', max_num_segments)
        self.embed = nn.Embedding(input_dim, hidden_dim)
        self.lstm_cell = nn.LSTMCell(hidden_dim, hidden_dim)

        # LSTM output heads.
        self.head_z_1 = nn.Linear(hidden_dim, hidden_dim)  # Latents (z).

        if latent_dist == 'gaussian':
            self.head_z_2 = nn.Linear(hidden_dim, latent_dim * 2)
        elif latent_dist == 'concrete':
            self.head_z_2 = nn.Linear(hidden_dim, latent_dim)
        else:
            raise ValueError('Invalid argument for `latent_dist`.')

        self.head_b_1 = nn.Linear(hidden_dim, hidden_dim)  # Boundaries (b).
        self.head_b_2 = nn.Linear(hidden_dim, 1)

        # Decoder MLP.
        self.decode_1 = nn.Linear(latent_dim, hidden_dim)
        self.decode_2 = nn.Linear(hidden_dim, input_dim)

    def masked_encode(self, inputs, mask):
        """Run masked RNN encoder on input sequence."""
        hidden = get_lstm_initial_state(
            inputs.size(0), self.hidden_dim, device=inputs.device)
        outputs = []
        for step in range(inputs.size(1)):
            hidden = self.lstm_cell(inputs[:, step], hidden)
            hidden = (mask[:, step, None] * hidden[0],
                      mask[:, step, None] * hidden[1])  # Apply mask.
            outputs.append(hidden[0])
        return torch.stack(outputs, dim=1)

    def get_boundaries(self, encodings, segment_id, lengths):
        """Get boundaries (b) for a single segment in batch."""
        if segment_id == self.max_num_segments - 1:
            # Last boundary is always placed on last sequence element.
            logits_b = None
            sample_b = torch.zeros_like(encodings[:, :, 0]).scatter_(
                1, lengths.unsqueeze(1) - 1, 1)
        else:
            hidden = F.relu(self.head_b_1(encodings))
            logits_b = self.head_b_2(hidden).squeeze(-1)
            # Mask out first position with large neg. value.
            neg_inf = torch.ones(
                encodings.size(0), 1, device=encodings.device) * NEG_INF
            # TODO(tkipf): Mask out padded positions with large neg. value.
            logits_b = torch.cat([neg_inf, logits_b[:, 1:]], dim=1)
            if self.training:
                sample_b = gumbel_softmax_sample(
                    logits_b, temp=self.temp_b)
            else:
                sample_b_idx = torch.argmax(logits_b, dim=1)
                sample_b = to_one_hot(sample_b_idx, logits_b.size(1))

        return logits_b, sample_b

    def get_latents(self, encodings, probs_b):
        """Read out latents (z) form input encodings for a single segment."""
        readout_mask = probs_b[:, 1:, None]  # Offset readout by 1 to left.
        readout = (encodings[:, :-1] * readout_mask).sum(1)
        hidden = F.relu(self.head_z_1(readout))
        logits_z = self.head_z_2(hidden)

        # Gaussian latents.
        if self.latent_dist == 'gaussian':
            if self.training:
                mu, log_var = torch.split(logits_z, self.latent_dim, dim=1)
                sample_z = gaussian_sample(mu, log_var)
            else:
                sample_z = logits_z[:, :self.latent_dim]

        # Concrete / Gumbel softmax latents.
        elif self.latent_dist == 'concrete':
            if self.training:
                sample_z = gumbel_softmax_sample(
                    logits_z, temp=self.temp_z)
            else:
                sample_z_idx = torch.argmax(logits_z, dim=1)
                sample_z = utils.to_one_hot(sample_z_idx, logits_z.size(1))
        else:
            raise ValueError('Invalid argument for `latent_dist`.')

        return logits_z, sample_z

    def decode(self, sample_z, length):
        """Decode single time step from latents and repeat over full seq."""
        hidden = F.relu(self.decode_1(sample_z))
        pred = self.decode_2(hidden)
        return pred.unsqueeze(1).repeat(1, length, 1)

    def get_next_masks(self, all_b_samples):
        """Get RNN hidden state masks for next segment."""
        if len(all_b_samples) < self.max_num_segments:
            # Product over cumsums (via log->sum->exp).
            log_cumsums = list(
                map(lambda x:log_cumsum(x, dim=1), all_b_samples))
            mask = torch.exp(sum(log_cumsums))
            return mask
        else:
            return None

    def forward(self, inputs, lengths):

        # Embed inputs.
        print(inputs.size())
        embeddings = self.embed(inputs)
        print(embeddings.size())
        # Create initial mask.
        mask = torch.ones(
            inputs.size(0), inputs.size(1), device=inputs.device)
        print(mask.size())
        all_b = {'logits': [], 'samples': []}
        all_z = {'logits': [], 'samples': []}
        all_encs = []
        all_recs = []
        all_masks = []
        for seg_id in range(self.max_num_segments):

            # Get masked LSTM encodings of inputs.
            #print('mask:', mask[:1])
            encodings = self.masked_encode(embeddings, mask)
            print(encodings.size())
            all_encs.append(encodings)

            # Get boundaries (b) for current segment.
            logits_b, sample_b = self.get_boundaries(
                encodings, seg_id, lengths)
            all_b['logits'].append(logits_b)
            all_b['samples'].append(sample_b)
            print(logits_b.size())
            #print(logits_b[:1])
            print(sample_b.size())
            #print(sample_b[:1])
            # Get latents (z) for current segment.
            logits_z, sample_z = self.get_latents(
                encodings, sample_b)
            all_z['logits'].append(logits_z)
            all_z['samples'].append(sample_z)

            # Get masks for next segment.
            mask = self.get_next_masks(all_b['samples'])
            print(mask.size())
            all_masks.append(mask)

            # Decode current segment from latents (z).
            reconstructions = self.decode(sample_z, length=inputs.size(1))
            all_recs.append(reconstructions)

        return all_encs, all_recs, all_masks, all_b, all_z


In [6]:
model = CompILE(
    input_dim=6,  # +1 for EOS/Padding symbol.
    hidden_dim=64,
    latent_dim=32,
    max_num_segments=3,
    latent_dist='gaussian')

input_dim: 6
hidden_dim: 64
latent_dim 32
max_num_segments 3


In [7]:
model.train()
outputs = model.forward(inputs, lengths)

torch.Size([512, 13])
torch.Size([512, 13, 64])
torch.Size([512, 13])
torch.Size([512, 13, 64])
torch.Size([512, 13])
torch.Size([512, 13])
torch.Size([512, 13])
torch.Size([512, 13, 64])
torch.Size([512, 13])
torch.Size([512, 13])
torch.Size([512, 13])
torch.Size([512, 13, 64])


AttributeError: 'NoneType' object has no attribute 'size'