# VAE for text generation

Just a little experiment I'm running to see if I can do it. I have an agentic version of this repo too, but want to make sure I can code it from scratch as well.

My goal is to get a network that can generate synthetic tweets after being trained on an appropriately sized dataset... or at least could, in principle. I'll need to learn more about how to set up recurrent architectures than I currently know, so this is a good exercise for me!

In [None]:
# imports
import torch
import torch.nn as nn

In [None]:
class TextVAE(nn.Module):
    # A VAE that recreates text. Specifically, I'll use it for tweets.

    def __init__(
        self,
        vocab_size,
        embedding_dim,
        hidden_dim,        
        num_layers,
        dropout,
        latent_dim
    ):
        """
        Initialize the VAE.

        Inputs:
        ----
        vocab_size: Size of the model vocabulary. Recommend to set equal 
                    to the number of unique tokens in the training data.
        embedding_dim: Dimension of the space in which tokens are embedded.
        hidden_dim: Dimension of the hidden layer between the GRU and the
                    latent space.
        num_layers: Number of layers in each GRU.
        dropout: Dropout probability for the GRU.
        latent_dim: Dimension of the latent space
        """
        super().__init__()

        # Embedding layer: Convert tokens into vectors. Used in both encoder and decoder.
        # ----
        # Input: tensor (batch_size, seq_length)
        # Output: tensor (batch_size, seq_length, embedding_dim)
        self.embedding = nn.Embedding(
            vocab_size,
            embedding_dim,
            padding_idx=0,
        )

        # Encoder
        # ----

        # Encoder GRU: Reads tokens into GRU, a recurrent module which can handle the
        # sequence of tokens. Since we're using the hidden states, be mindful to transpose
        # things so that the tensor has the right shape afterwards!
        # Input: tensor (batch_size, seq_length, embedding_dim)
        # Output: output tensor (batch_size, seq_length, hidden_dim)
        #         hidden state tensor (num_layers, batch_size, hidden_dim)
        self.encoder_gru = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False,
        )

        # Latent space
        # ----

        # Fully connected layers: One each for mu and log variance. Note that 
        # this only uses the hidden state of the last token in the sequence! 
        # The relevant dimensional shuffling is done in the encoding function. 
        # In both cases:
        # Input: tensor (batch_size, hidden_dim)
        # Output: tensor (batch_size, latent_dim)
        self.fc_mu = nn.Linear(
            in_features=hidden_dim, 
            out_features=latent_dim,
        )
        self.fc_logvar = nn.Linear(
            in_features=hidden_dim, 
            out_features=latent_dim,
        )

        # Decoder
        # ----
        
        # Fully connected layer: Going from the latent space to the decoder's GRU.
        # Input: tensor (batch_size, seq_length, latent_dim)
        # Output: tensor (batch_size, seq_length, hidden_dim)
        self.fc_decoder_init = nn.Linear(
            in_features=latent_dim,
            out_features=hidden_dim
        )

        # # Decoder embedding layer: Convert max hidden state indices (i.e. tokens) into vectors
        # # Input: tensor (batch_size, seq_length)
        # # Output: tensor (batch_size, seq_length, embedding_dim)
        # self.decoder_embedding = nn.Embedding(
        #     vocab_size,
        #     embedding_dim,
        #     padding_idx=0,
        # )

        # Decoder GRU: The recurrent part of the decoder. Just like the encoder!
        # Input: tensor (batch_size, seq_length, embedding_dim)
        # Output: output tensor (batch_size, seq_length, hidden_dim)
        #         hidden state tensor (num_layers, batch_size, hidden_dim)
        self.decoder_gru = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False,
        )

        # Output layer: Convert GRU hidden states into token logits. Note that 
        # this only uses the hidden state of the last token in the sequence! 
        # The relevant dimensional shuffling is done in the decoding function.
        # Input: tensor (batch_size, hidden_dim)
        # Output: tensor (batch_size, vocab_size)
        self.fc_decoder_output = nn.Linear(
            in_features=hidden_dim,
            out_features=vocab_size
        )
    
    def encode(self, x, lengths):
        """
        Encode input sequences into a latent distribution, i.e. mu and log variance.

        Inputs:
        ----
        x: Input tensor of token indices. Shape: (batch_size, seq_length)
        lengths: Actual lengths of the sequences before padding. Shape: (batch_size)

        Outputs:
        ----
        Mu, logvar: Tensors representing the parameters of the latent distribution. Shape: (batch_size, latent_dim) 
        """

        # Embed the sequences.
        embedded = self.embedding(x)

        # Pack the sequences for GRU processing. This is a PyTorch data structure that
        # helps the GRU ignore padding tokens and process things faster.
        packed = nn.utils.rnn.pack_padded_sequence(
            input=embedded,
            lengths=lengths.cpu(),
            batch_first=True,
            enforce_sorted=False,
        )

        # Encode the packed sequence through the GRU.
        _, hidden = self.encoder_gru(packed)

        # Get the hidden state of the last layer.
        last_hidden = hidden[-1:].squeeze(0).contiguous()
        last_hidden = last_hidden.view(last_hidden.size(0), -1)

        # Get the mean and variance.
        mu = self.fc_mu(last_hidden)
        logvar = self.fc_logvar(last_hidden)

        return mu, logvar

    def latent_sample(self, mu, logvar):
        """
        Sample from the latent distribution using the reparameterization trick.

        Inputs:
        ----
        mu: Mean tensor of the latent distribution. Shape: (batch_size, latent_dim)
        logvar: Log variance tensor of the latent distribution. Shape: (batch_size, latent_dim)

        Outputs:
        ----
        z: Sampled latent tensor. Shape: (batch_size, latent_dim)
        """

        # Get the standard deviation.
        std = torch.exp(0.5 * logvar)

        # Add Gaussian noise.
        eps = torch.randn_like(std)

        # Now return the latent vector.
        z = mu + eps * std

        return z