# Week 8 - Discrete Latent Variable Models and Hybrid Models Notebook

In this notebook, we will solve questions discrete latent variable models and hybrid generative models.

 - This notebook is prepared using PyTorch. However, you can use any Python package you want to implement the necessary functions in questions.
 - If the question asks you to implement a specific function, please do not use its readily available version from a package and implement it yourself.

## Question 1

Please answer the questions below:

1. Please give some examples to discrete data modalities.
1. Can we use GANs to generate discrete data points?
1. What is REINFORCE and why do we use it?
1. Please briefly explain Gumbel-Softmax by stating why do we need it and how do we use it in practice?
1. Please conceptually explain how PixelVAE works.
1. What is the novelty of $\beta$-VAE over the classical variational auto-encoder. Please briefly explain.

You can write your answer for each question in the markdown cell below:

**Please write your answer for each question here**

## Question 2

Implement the Gumbel-Softmax function. The function is characterized as below:

\begin{equation}
\hat{z} = \text{soft}\max_i \left(\frac{g_i + \log \pi}{\tau}\right)
\end{equation}

where $\pi$ are the class proabilities, $g_i$ are the i.i.d. samples from the gumbel distribution, and $\tau$ is the temperature parameter $\in (0, 1]$.

You can write additional function or functions to sample from the gumbel distribution.

In [None]:
import torch
torch.manual_seed(0)

batch_size = 16

# Let's assume four discrete outputs
num_classes = 4
logits = torch.randn(batch_size, num_classes)

In [None]:
# Alternative solution: https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f
def gumbel_softmax(logits, temperature):
  """Applies gumbel softmax operation to the provided logits

    Args:
        logits: (N x num_classes)
        temperature: A scalar constant that determines the bias-variance tradeoff
    Returns:
        the resulting tensor from the operation
  """
  #######################
  class_prob = -torch.log(-torch.log(torch.rand(logits.shape)))

  return torch.nn.functional.one_hot(torch.argmax(torch.softmax((logits+class_prob)/temperature, dim=1), dim=1))
  #######################

In [None]:
print(gumbel_softmax(logits, temperature=0.5))

**Expected Output:**

```
tensor([[0., 0., 0., 1.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.]])
```

**Bonus:** It is recommended for you to tinker with the temperature parameter and see how the results change.

## Question 3

Implement the loss function of VAE-GAN. You can refer to the [paper](https://arxiv.org/pdf/1512.09300.pdf) to see the motivation behind the loss function and the related equations.

The loss function of VAE-GAN consists of three parts, first one being the KL divergence loss:

\begin{equation}
\mathcal{L}_{prior} = D_{KL}(q(z|x)||p(z))
\end{equation}

where $z$ is the latent space vector from the latent distribution $p(z)$ and $x$ is the data point to be reconstructed. Typically, $z$ is sampled from $\mathcal N(0, 1)$. This term is considered as a regularizer and ensures that the distribution of the output of the encoder is similar to $\mathcal N(0, 1)$.

Second term is the reconstruction loss, but with a small twist:

\begin{equation}
\mathcal{L}^{\text{Dis}_l}_{\text{llike}} = -\mathbb{E}_{q(z|x)}[\log(p(\text{Dis}_l(x)|z)]
\end{equation}

Equation above is the log-likelihood based reconstruction loss of the original VAE, except for $x$ is replaced by $\text{Dis}_l(x)$. This is the intermediate represantation of the reconstructed version of $x \sim \text{Dec}(z)$ from the $l^{th}$ layer of the discriminator. This is to ensure that the image is not reconstructed on the pixel-level but more on a feature-level.

Finally, third part of the loss is our good old GAN loss:

\begin{equation}
\mathcal{L}_{\text{GAN}} = \log(\text{Dis}(x)) + \log(1 - \text{Dis}(\text{Gen}(z)))
\end{equation}

The final loss of the VAE-GAN is the sum of all these losses:

\begin{equation}
\mathcal{L} = \mathcal{L}_{prior} + \mathcal{L}^{\text{Dis}_l}_{\text{llike}} + \mathcal{L}_{\text{GAN}}
\end{equation}

Implement all three losses as different functions to the code cells below:


In [None]:
mean = torch.randn(batch_size, 20)
logvar = torch.randn(batch_size, 20)

In [None]:
def kl_loss(mean, logvar):
  """Calculates the KL loss based on the mean and logvar outputs of the Encoder network
  w.r.t to the Gaussian with zero mean and unit variance
  
    Args:
      mean: Tensor of mean values coming from the Encoder (N x D)
      logvar: Tensor of log-variance values coming from the Encoder (N x D)
    Returns:
      The resulting KL loss
  """
  #######################
  return (-0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()))/mean.shape[0]
  #######################


In [None]:
print(kl_loss(mean, logvar))

In [None]:
features_org = torch.randn(batch_size, 100)
features_recon = torch.randn(batch_size, 100)

# Uncomment the line below and run the function again to see a higher reconstruction error 
# features_recon = torch.normal(3, 20, batch_size, 100)

In [None]:
def reconstruction_loss(features_org, features_recon):
  """Calculates the reconstruction loss with mean squared error

    Args:
      features_org: Features of the original image obtained from an intermediate layer of the discriminator
      features_recon: Features of the reconstructed image obtained from an intermediate layer of the discriminator
    Returns:
      M.S.E based reconstruction error of the features
  """
  #######################
  # Write code here
  #######################
  return ((features_org - features_recon)**2).mean()

In [None]:
print(reconstruction_loss(faetures_org, features_recon))

In [None]:
outputs_real = torch.randn(batch_size, 32).clip(0, 1)
outputs_fake = torch.randn(batch_size, 32).clip(0, 1)

In [None]:
def gan_loss(d_real_outputs, d_fake_outputs):
  """Our good old GAN loss, doesn't need much of an explanation :)

    Args:
      d_real_outputs: Discriminator sigmoid outputs for the real data points
      d_fake_outputs: Discriminator sigmoid outputs for the fake data points
    Returns:
      The calculated GAN loss
  """
  #######################
  real = torch.log(d_real_outputs + 1e-7)
  fake = torch.log(1-d_fake_outputs + 1e-7)

  return -(real + fake).mean()
  #######################

In [None]:
print(gan_loss(outputs_real, outputs_fake))

tensor(11.0524)


## Bonus

My master's thesis was a hybrid generative model and it was published in Pattern Recognition. I would like to briefly talk about it during the notebook session.

For anyone who is interested, kindly read or skim through the paper before coming to the discussion session. I leave the link to the paper [here](https://faculty.ozyegin.edu.tr/ethemalpaydin/files/2021/01/Uras_bigan_PatRec.pdf).