# VAE for text generation

Just a little experiment I'm running to see if I can do it. I have an agentic version of this repo too, but want to make sure I can code it from scratch as well.

My goal is to get a network that can generate synthetic tweets after being trained on an appropriately sized dataset... or at least could, in principle. I'll need to learn more about how to set up recurrent architectures than I currently know, so this is a good exercise for me!

In [None]:
# imports
import torch
import torch.nn as nn

In [None]:
class TextVAE(nn.Module):
    # A VAE that recreates text. Specifically, I'll use it for tweets.

    def __init__(
        self,
        vocab_size,
        embedding_dim,
        hidden_dim,
        num_layers,
        dropout,
        latent_dim
    ):
        """
        Initialize the VAE.

        Inputs:
        ----
        vocab_size: Size of the model vocabulary. Recommend to set equal 
                    to the number of unique tokens in the training data.
        embedding_dim: Dimension of the space in which words are embedded.
        hidden_dim: Dimension of the hidden layer between the GRU and the
                    latent space.
        num_layers: Number of layers in each GRU.
        dropout: Dropout probability for the GRU.
        latent_dim: The dimension of the latent space
        """
        super().__init__()

        # Encoder
        # ----

        # The embedding layer. Here, words are translated to vectors for use
        # in the model. padding_idx here just specifies what value counts as
        # padding. 
        # Input: tensor (seq_length)
        # Output: tensor (seq_length, embedding_dim)
        self.embedding = nn.Embedding(
            vocab_size,
            embedding_dim,
            padding_idx=0
        )

        # GRU, the recurrent part of the encoder. This allows us to use text 
        # data with the VAE, since recurrent architectures are well-suited to 
        # such things. 
        # Input: tensor (seq_length, embedding_dim)
        # Output: tensor (seq_length, hidden_dim)
        self.encoder_rnn = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False,
        )

        # Latent space
        # ----

        # Fully connected layers for mu and log variance. In both cases:
        # Input: tensor (seq_length, hidden_dim)
        # Output: tensor (seq_length, latent_dim)
        self.fc_mu = nn.Linear(
            in_features=hidden_dim, 
            out_features=latent_dim,
        )
        self.fc_logvar = nn.Linear(
            in_features=hidden_dim, 
            out_features=latent_dim,
        )

        # Decoder
        # ----
        
        # Fully connected layer to go from latent space to decoder GRU.
        # Input: tensor (seq_length, latent_dim)
        # Output: tensor (seq_length, hidden_dim)
        self.fc_decoder_init = nn.Linear(
            in_features=latent_dim,
            out_features=hidden_dim
        )

        # GRU, the recurrent part of the decoder. Just like before!
        # Input: tensor (seq_length, embedding_dim)
        # Output: tensor (seq_length, hidden_dim)
        self.decoder_rnn = nn.GRU(
            input_size=embedding_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            dropout=dropout,
            batch_first=True,
            bidirectional=False,
        )


