# Text VAE

## Variational Autoencoder (VAE)

VAEs are probabilistic and generative models that learn a latent representation of the input data.
By sampling from the learned distribution, we can generate new data.
In this notebook, we will build a VAE that can generate new text after being trained on a text corpus.

## Encoder Network

The encoder network takes the input data and encodes it into a latent space.
In a VAE, the encoder network does not encode the input into a fixed code,
but rather it produces parameters of a probability distribution (mean and variance).
The latent representation of the input is then sampled from this distribution.

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(Encoder, self).__init__()
        self.linear = nn.Linear(input_dim, hidden_dim)
        self.mu = nn.Linear(hidden_dim, z_dim)
        self.var = nn.Linear(hidden_dim, z_dim)

    def forward(self, x):
        hidden = torch.relu(self.linear(x))
        z_mu = self.mu(hidden)
        z_var = self.var(hidden)
        return z_mu, z_var

## Decoder Network

The decoder network takes a sample from the latent space and decodes it back into the original input space.
In a VAE, the decoder network is also a probabilistic model: it outputs parameters of a probability distribution from which the original input is sampled.

In [None]:
class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, output_dim):
        super(Decoder, self).__init__()
        self.linear = nn.Linear(z_dim, hidden_dim)
        self.out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        hidden = torch.relu(self.linear(x))
        predicted = torch.sigmoid(self.out(hidden))
        return predicted

## Variational Autoencoder (VAE) Network

The VAE network is composed of the encoder and the decoder.
The encoder encodes the input data into a latent distribution, a sample is drawn from this distribution, and then the decoder decodes the sample back into the original input space.

In [None]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim, z_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(input_dim, hidden_dim, z_dim)
        self.decoder = Decoder(z_dim, hidden_dim, input_dim)

    def forward(self, x):
        z_mu, z_var = self.encoder(x)
        std = torch.sqrt(z_var)
        eps = torch.randn_like(std)
        x_sample = eps.mul(std).add_(z_mu)
        predicted = self.decoder(x_sample)
        return predicted, z_mu, z_var

## Training

To train a VAE, we need to define a loss function that takes into account both the reconstruction loss and the KL divergence.
The reconstruction loss measures how well the VAE can reconstruct the original input from the latent sample,
and the KL divergence measures how closely the learned latent distribution matches the prior distribution (a standard normal distribution in this case).

In [None]:
def calculate_loss(x, reconstructed_x, mean, log_var):
    # Reconstruction loss
    RCL = F.binary_cross_entropy(reconstructed_x, x, reduction='sum')
    # KL divergence loss
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return RCL + KLD

In [None]:
def train(model, dataloader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for i, (x, _) in enumerate(dataloader):
            x = x.view(-1, 784)
            x_reconstructed, mean, log_var = model(x)
            loss = calculate_loss(x, x_reconstructed, mean, log_var)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1) % 100 == 0:
                print(f'Epoch[{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')

## Data Preparation

# Directories and Text Loading

In [None]:
import pandas as pd

# Load the Quora dataset
quora_data = pd.read_csv('test.csv')

# Combine the question1 and question2 columns into a single list of documents
documents = quora_data['question1'].tolist() + quora_data['question2'].tolist()

# Remove any NaN values
documents = [doc for doc in documents if str(doc) != 'nan']

# Text Preprocessing

In [None]:
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
import string

# Download the necessary NLTK corpora
nltk.download('punkt')
nltk.download('wordnet')
nltk.download('stopwords')

stop_words = set(stopwords.words('english'))
lemmatizer = WordNetLemmatizer()

def preprocess_text(text):
    # Tokenize the text
    words = word_tokenize(text)

    # Remove punctuation and convert to lowercase
    words = [word.lower() for word in words if word.isalpha()]

    # Remove stopwords
    words = [word for word in words if word not in stop_words]

    # Lemmatize the words
    words = [lemmatizer.lemmatize(word) for word in words]

    return words

# Preprocess the documents
documents = [preprocess_text(doc) for doc in documents]

# Word Embeddings

In [None]:
import torch
from torch.nn import Embedding
from collections import Counter

# Count the number of occurrences of each word
word_counts = Counter(word for doc in documents for word in doc)

# Create a dictionary that maps each word to a unique index
word_to_index = {word: i for i, (word, _) in enumerate(word_counts.most_common())}

# Add special tokens for unknown and padding
word_to_index['<unk>'] = len(word_to_index)
word_to_index['<pad>'] = len(word_to_index)

# Create the inverse mapping, from indices to words
index_to_word = {i: word for word, i in word_to_index.items()}

# Create the embedding layer
embedding = Embedding(len(word_to_index), 100)

# Convert the documents to sequences of word indices
documents = [[word_to_index.get(word, word_to_index['<unk>']) for word in doc] for doc in documents]

# Data Loader

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset

class TextDataset(Dataset):
    def __init__(self, documents):
        self.documents = documents

    def __len__(self):
        return len(self.documents)

    def __getitem__(self, idx):
        return torch.tensor(self.documents[idx])

# Create a function to pad sequences
def pad_collate(batch):
    return pad_sequence(batch, padding_value=word_to_index['<pad>'])

# Create the data loader
dataset = TextDataset(documents)
data_loader = DataLoader(dataset, batch_size=64, collate_fn=pad_collate)

# Variational Autoencoder

In [None]:
class VAE(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, latent_dim):
        super(VAE, self).__init__()

        # Encoder
        self.encoder_rnn = nn.GRU(embedding_dim, hidden_dim, batch_first=True)
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)

        # Decoder
        self.decoder_rnn = nn.GRU(latent_dim, hidden_dim, batch_first=True)
        self.decoder_out = nn.Linear(hidden_dim, embedding_dim)

    def encode(self, x):
        _, h = self.encoder_rnn(x)
        h = h.squeeze(0)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5*log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        z = z.unsqueeze(1)
        h = self.decoder_rnn(z)
        return self.decoder_out(h)

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

# Loss Function and Training Loop

In [None]:
def calculate_loss(x, reconstructed_x, mean, log_var):
    # Reconstruction loss
    RCL = F.mse_loss(reconstructed_x, x, reduction='sum')
    # KL divergence
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    return RCL + KLD

def train(model, dataloader, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        for i, x in enumerate(dataloader):
            x = x.float()
            x_reconstructed, mean, log_var = model(x)
            loss = calculate_loss(x, x_reconstructed, mean, log_var)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (i+1) % 100 == 0:
                print(f'Epoch[{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item()}')