In [1]:
import torch.nn as nn

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(EncoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input).view(1, 1, -1)
        output = embedded
        output, hidden = self.gru(output, hidden)
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [2]:
class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size):
        super(DecoderRNN, self).__init__()
        self.hidden_size = hidden_size

        self.embedding = nn.Embedding(output_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, input, hidden):
        output = self.embedding(input).view(1, 1, -1)
        output = F.relu(output)
        output, hidden = self.gru(output, hidden)
        output = self.softmax(self.out(output[0]))
        return output, hidden

    def initHidden(self):
        return torch.zeros(1, 1, self.hidden_size, device=device)

In [4]:
class DiscriminatorEncoder(EncoderRNN):
    pass

class DiscriminatorDecoder(DecoderRNN):
    def __init__(self):
        super().__init__()
        out_embed_dim =  self.hidden_size
        self.fc_out = nn.Linear(out_embed_dim, 1)

    def forward(self, prev_output_tokens, encoder_out_dict):
        x, attn_scores = super().forward(prev_output_tokens, encoder_out_dict)
        return x, attn_scores
    
class Discriminator(nn.Module):
    def __init__(self, encoder, decoder):
        super.__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src_tokens, src_lengths, prev_output_tokens):
        encoder_out = self.encoder(src_tokens, src_lengths)
        decoder_out = self.decoder(prev_output_tokens, encoder_out)
        return decoder_out

In [5]:
from base import Model, Encoder, Decoder

In [6]:
class DiscriminatorEncoder(Encoder):
    pass

class DiscriminatorDecoder(Decoder):
    def __init__(self):
        super().__init__()
        out_embed_dim =  self.hidden_size
        self.fc_out = nn.Linear(out_embed_dim, 1)

    def forward(self, prev_output_tokens, encoder_out_dict):
        x, attn_scores = super().forward(prev_output_tokens, encoder_out_dict)
        return x, attn_scores
    
class Discriminator(nn.Module):
    def __init__(self, encoder, decoder):
        super.__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def forward(self, src_tokens, src_lengths, prev_output_tokens):
        encoder_out = self.encoder(src_tokens, src_lengths)
        decoder_out = self.decoder(prev_output_tokens, encoder_out)
        return decoder_out

In [2]:
import torch
import torch.functional as F

In [None]:

def discriminator_loss(predictions, labels, missing_tokens):
  """Discriminator loss based on predictions and labels.
  Args:
    predictions:  Discriminator linear predictions Tensor of shape [batch_size,
      sequence_length]
    labels: Labels for predictions, Tensor of shape [batch_size,
      sequence_length]
    missing_tokens:  Indicator for the missing tokens.  Evaluate the loss only
      on the tokens that were missing.
  Returns:
    loss:  Scalar tf.float32 loss.
  """
    loss = tf.losses.sigmoid_cross_entropy(labels, predictions, weights=missing_tokens)
    #loss = tf.Print(
    #    loss, [loss, labels, missing_tokens],
    #    message='loss, labels, missing_tokens',
    #    summarize=25,
    #    first_n=25)
    
    loss = f.cross_entropy_loss(labels, predictions, weight=missing_tokens)
    return loss



def cross_entropy_loss_matrix(gen_labels, gen_logits):
    """Computes the cross entropy loss for G.
    Args:
    gen_labels:  Labels for the correct token.
    gen_logits: Generator logits.
    Returns:
    loss_matrix:  Loss matrix of shape [batch_size, sequence_length].
    """
    loss = torch.sum(- target * F.log_softmax(logits, -1), -1)
    mean_loss = loss.mean()
    return mean_loss


def GAN_loss_matrix(dis_predictions):
  """Computes the cross entropy loss for G.
  Args:
    dis_predictions:  Discriminator predictions.
  Returns:
    loss_matrix: Loss matrix of shape [batch_size, sequence_length].
  """
  eps = tf.constant(1e-7, tf.float32)
  gan_loss_matrix = -tf.log(dis_predictions + eps)
  return gan_loss_matrix


def generator_GAN_loss(predictions):
    """Generator GAN loss based on Discriminator predictions."""
    return -torch.log(torch.mean(predictions))


def generator_blended_forward_loss(gen_logits, gen_labels, dis_predictions,
                                   is_real_input):
  """Computes the masked-loss for G.  This will be a blend of cross-entropy
  loss where the true label is known and GAN loss where the true label has been
  masked.
  Args:
    gen_logits: Generator logits.
    gen_labels:  Labels for the correct token.
    dis_predictions:  Discriminator predictions.
    is_real_input:  Tensor indicating whether the label is present.
  Returns:
    loss: Scalar tf.float32 total loss.
  """
  cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=gen_labels, logits=gen_logits)
  gan_loss = -tf.log(dis_predictions)
  loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss)
  return tf.reduce_mean(loss_matrix)


def wasserstein_generator_loss(gen_logits, gen_labels, dis_values,
                               is_real_input):
  """Computes the masked-loss for G.  This will be a blend of cross-entropy
  loss where the true label is known and GAN loss where the true label is
  missing.
  Args:
    gen_logits:  Generator logits.
    gen_labels:  Labels for the correct token.
    dis_values:  Discriminator values Tensor of shape [batch_size,
      sequence_length].
    is_real_input:  Tensor indicating whether the label is present.
  Returns:
    loss: Scalar tf.float32 total loss.
  """
  cross_entropy_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
      labels=gen_labels, logits=gen_logits)
  # Maximize the dis_values (minimize the negative)
  gan_loss = -dis_values
  loss_matrix = tf.where(is_real_input, cross_entropy_loss, gan_loss)
  loss = tf.reduce_mean(loss_matrix)
  return loss


def wasserstein_discriminator_loss(real_values, fake_values):
  """Wasserstein discriminator loss.
  Args:
    real_values: Value given by the Wasserstein Discriminator to real data.
    fake_values: Value given by the Wasserstein Discriminator to fake data.
  Returns:
    loss:  Scalar tf.float32 loss.
  """
  real_avg = tf.reduce_mean(real_values)
  fake_avg = tf.reduce_mean(fake_values)

  wasserstein_loss = real_avg - fake_avg
  return wasserstein_loss


def wasserstein_discriminator_loss_intrabatch(values, is_real_input):
  """Wasserstein discriminator loss.  This is an odd variant where the value
  difference is between the real tokens and the fake tokens within a single
  batch.
  Args:
    values: Value given by the Wasserstein Discriminator of shape [batch_size,
      sequence_length] to an imputed batch (real and fake).
    is_real_input: tf.bool Tensor of shape [batch_size, sequence_length]. If
      true, it indicates that the label is known.
  Returns:
    wasserstein_loss:  Scalar tf.float32 loss.
  """
  zero_tensor = tf.constant(0., dtype=tf.float32, shape=[])

  present = tf.cast(is_real_input, tf.float32)
  missing = tf.cast(1 - present, tf.float32)

  # Counts for real and fake tokens.
  real_count = tf.reduce_sum(present)
  fake_count = tf.reduce_sum(missing)

  # Averages for real and fake token values.
  real = tf.mul(values, present)
  fake = tf.mul(values, missing)
  real_avg = tf.reduce_sum(real) / real_count
  fake_avg = tf.reduce_sum(fake) / fake_count

  # If there are no real or fake entries in the batch, we assign an average
  # value of zero.
  real_avg = tf.where(tf.equal(real_count, 0), zero_tensor, real_avg)
  fake_avg = tf.where(tf.equal(fake_count, 0), zero_tensor, fake_avg)

  wasserstein_loss = real_avg - fake_avg
  return wasserstein_loss

In [3]:
import torchnlp.datasets as ds

In [7]:
from torchnlp.datasets import imdb_dataset
from torchnlp.datasets import penn_treebank_dataset

In [5]:
train = imdb_dataset(train=True)

aclImdb_v1.tar.gz: 84.1MB [00:56, 1.49MB/s]                            


In [6]:
train[0]

{'text': 'The story of the boy thief of Bagdad (as it was once spelled) has attracted filmmakers from Raoul Walsh in 1924, who starred Douglas Fairbanks in the first, silent, rendering of "Thief of Bagdad," to less imposing, more recent attempts. The best, however, remains 1940\'s version which for its time was a startling, magical panoply of top quality special effects. Those effects still work their charm.<br /><br />No less than six directors are listed for the technicolor movie which starred Sabu as the boy thief, Abu, John Justin as the dreamily in love deposed monarch, Ahmad and June Duprez as the lovely princess sought by Ahmad and pursued by the evil vizier, Jaffar, played by a sinister Conrad Veidt. The giant genie is ably acted by Rex Ingram.<br /><br />Ahmad is treacherously deposed by Jaffar and when later arrested by that traitorous serpent, he and the boy, Abu, suffer what are clearly incapacitating fates. Ahmad is rendered blind and Abu becomes a lovable mutt. Their adve