# Models explorations

In [10]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

pd.set_option('max_colwidth', 400)
plt.style.use('fivethirtyeight')

## Langauge model

In [11]:
models_path = '../deep-latent-sequence-model/pretrained_lm'
yelp_path_0 = os.path.join(models_path, 'yelp_style0/model.pt')

In [12]:
class LSTM_LM(nn.Module):
  """LSTM language model"""
  def __init__(self, model_init, emb_init, hparams):
    super(LSTM_LM, self).__init__()
    self.nh = hparams.d_model
    # no padding when setting padding_idx to -1
    self.embed = nn.Embedding(hparams.src_vocab_size, 
      hparams.d_word_vec, padding_idx=hparams.pad_id)

    self.dropout_in = nn.Dropout(hparams.dropout_in)
    self.dropout_out = nn.Dropout(hparams.dropout_out)

    # concatenate z with input
    self.lstm = nn.LSTM(input_size=hparams.d_word_vec,
                 hidden_size=hparams.d_model,
                 num_layers=1,
                 batch_first=True)

    # prediction layer
    self.pred_linear = nn.Linear(self.nh, hparams.src_vocab_size, bias=True)

    if hparams.tie_weight:
        self.pred_linear.weight = self.embed.weight

    self.loss = nn.CrossEntropyLoss(ignore_index=hparams.pad_id, reduction="none")

    self.reset_parameters(model_init, emb_init)

  def reset_parameters(self, model_init, emb_init):
    for param in self.parameters():
      model_init(param)
    emb_init(self.embed.weight)

    self.pred_linear.bias.data.zero_()


  def decode(self, x, x_len, gumbel_softmax=False):
    """
    Args:
      x: (batch_size, seq_len)
      x_len: list of x lengths
    """

    # not predicting start symbol
    # sents_len -= 1

    if gumbel_softmax:
      batch_size, seq_len, _ = x.size()
      word_embed = x @ self.embed.weight
    else:
      batch_size, seq_len = x.size()

      # (batch_size, seq_len, ni)
      word_embed = self.embed(x)

    word_embed = self.dropout_in(word_embed)
    packed_embed = pack_padded_sequence(word_embed, x_len, batch_first=True)
    
    c_init = word_embed.new_zeros((1, batch_size, self.nh))
    h_init = word_embed.new_zeros((1, batch_size, self.nh))
    output, _ = self.lstm(packed_embed, (h_init, c_init))
    output, _ = pad_packed_sequence(output, batch_first=True)

    output = self.dropout_out(output)

    # (batch_size, seq_len, vocab_size)
    output_logits = self.pred_linear(output)

    return output_logits

  def reconstruct_error(self, x, x_len, gumbel_softmax=False, x_mask=None):
    """Cross Entropy in the language case
    Args:
      x: (batch_size, seq_len)
      x_len: list of x lengths
      x_mask: required if gumbel_softmax is True, 1 denotes mask,
              size (batch_size, seq_len)
    Returns:
      loss: (batch_size). Loss across different sentences
    """

    #remove end symbol
    src = x[:, :-1]

    # remove start symbol
    tgt = x[:, 1:]

    if gumbel_softmax:
      batch_size, seq_len, _ = src.size()
    else:
      batch_size, seq_len = src.size()

    x_len = [s - 1 for s in x_len]

    # (batch_size, seq_len, vocab_size)
    output_logits = self.decode(src, x_len, gumbel_softmax)

    if gumbel_softmax:
      log_p = F.log_softmax(output_logits, dim=2)
      x_mask = x_mask[:, 1:]
      loss = -((log_p * tgt).sum(dim=2) * (1. - x_mask)).sum(dim=1)
    else:
      tgt = tgt.contiguous().view(-1)
      # (batch_size * seq_len)
      loss = self.loss(output_logits.view(-1, output_logits.size(2)),
                 tgt)
      loss = loss.view(batch_size, -1).sum(-1)


    # (batch_size)
    return loss

  def compute_gumbel_logits(self, x, x_len):
    """Cross Entropy in the language case
    Args:
      x: (batch_size, seq_len)
      x_len: list of x lengths
      x_mask: required if gumbel_softmax is True, 1 denotes mask,
              size (batch_size, seq_len)
    Returns:
      loss: (batch_size). Loss across different sentences
    """

    #remove end symbol
    src = x[:, :-1]

    batch_size, seq_len, _ = src.size()

    x_len = [s - 1 for s in x_len]

    # (batch_size, seq_len, vocab_size)
    output_logits = self.decode(src, x_len, True)

    # (batch_size)
    return output_logits

  def log_probability(self, x, x_len, gumbel_softmax=False, x_mask=None):
    """Cross Entropy in the language case
    Args:
      x: (batch_size, seq_len)
    Returns:
      log_p: (batch_size).
    """

    return -self.reconstruct_error(x, x_len, gumbel_softmax, x_mask)

In [14]:
lm = torch.load(yelp_path_0)

In [15]:
lm

LSTM_LM(
  (embed): Embedding(9653, 128, padding_idx=0)
  (dropout_in): Dropout(p=0.3, inplace=False)
  (dropout_out): Dropout(p=0.3, inplace=False)
  (lstm): LSTM(128, 512, batch_first=True)
  (pred_linear): Linear(in_features=512, out_features=9653, bias=True)
  (loss): CrossEntropyLoss()
)

Language model used above is a simple LSTM network.