# Gumbel Softmax
## How to use GAN for sequences?

In a gan notebook we explored how GAN's can be used to approximate some continuous distribution. At it's core there were a generator function $G$ (neural network) which took a random gaussian noise as an input and produced a vector of numbers (possibly of size one, as in our simple one-dimensional example) that resembled the ones coming from real data distribution. We can easily see how this can be applied to generate random images - we just let the network output the vector with the size of the image and train adversarially against a discriminator like usual.

But what if we were intrested in taking this approach to generate text? Well, let's think for a moment. To generate this sequences we would need a generator. It can also take some random noise, and produce a sequence of numbers. We can take some ideas here from machine translation. The random noise could be the input to the recurrent neural network as a first hidden state - it would capture (possibly after some transformation with a feed-forward neural network) the meaning that we would like to encode in a sentence. In that way it could produce a bunch of output words until the stop token being generated. We would have to think about a solution to stop this solution from being too long at the beginning of training, but in general this seems like a plausible approach.

What about the discriminator? It takes as an input what generator produces as an output. In this case it is a sequence of words, so the natural candidate for a discriminator is a recurrent neural network as well. The discriminator would then produce a number from 0 to 1, with 0 being certain that this is a good sample and 1 if certain that this is fake. We can also observe that for samples coming form real world distribution (real sentences) discriminator is simply a language model - it gives each sentence a probability.  So this is it, right?

Wrong. What we didn't take into account how we are going to train this network. Well, this is obvious - with gradient descent. But let's keep in mind, that the loss for training generator is a value of discriminator (or it's function like logarithm, which doesn't change the point we are making): $L(\theta_G) = D(G(z; \theta_G))$. So in order to calculate the derivative we have to calculate $\frac{\partial}{\partial \theta_G} D(G(z; \theta_G)) = \frac{\partial}{\partial G}D(G) \frac{\partial}{\partial \theta_G} G$. It seems a little bit weird since $G$ is now a sequence of discrete elements from a vocabulary, what would it mean to differentiate by it? In the case of continuous values of discriminator it wasn't a problem, we could diffentiate by a value of an input just as we would differentiate with respect to weights. Here it is completely differrent - we cannot differentiate a function with a discrete input. That's like differentiating a sequence with respect to some natural number, doesn't make sense by definition.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import numpy as np
import matplotlib.pyplot as plt; plt.style.use("fivethirtyeight")

In [2]:
def generate_sample_sequence(n: int, k: int, max_num: int):
    "Generates a sequence of random discrete numbers from some weird distribution."
    assert n > 2
    assert 2*k < max_num  # we want these numbers to be somehow dependent
    x = np.zeros([n], dtype=np.int)
    x[0] = np.random.randint(-k, k + 1) % max_num
    for i in range(1, n - 1):
        x[i] = np.random.randint(x[i - 1] - k, x[i - 1] + k + 1) % max_num
    x[n - 1] = np.sum(x) % max_num
    return x

In [3]:
n_sequences = 100
sequences = [generate_sample_sequence(10, 2, 10) for _ in range(n_sequences)]
sequences[:5]

[array([9, 1, 9, 8, 8, 0, 0, 8, 0, 3]),
 array([0, 0, 8, 6, 6, 5, 6, 7, 5, 3]),
 array([8, 0, 9, 0, 0, 9, 8, 0, 9, 3]),
 array([2, 4, 4, 4, 3, 4, 2, 0, 1, 4]),
 array([1, 2, 1, 2, 1, 1, 2, 0, 1, 1])]

Let's suppose that we want to build GAN to draw from the above distribution. This is a simple case of a distribuiton, when most elements in a sequence depend only on a previous one, but the last one depends on all of them. Obviously distribution of words in sentences are much more complicated, with more dependencies, but this example will hopefully serve as good simplification to reason about more complicated cases.

Now, let's build a generator and discriminator to detect fake examples from real ones.

In [None]:
class Generator(nn.Module):
    def __init__(self, input_size: int, hidden_size: int):
        pass
    
    def forward(x, h):
        pass