# Variational Autoencoder (VAE)

Check out the sister of the GAN: VAE. In this lab, you'll explore the components of a basic VAE to understand how it works. 

The "AE" in VAE stands for autoencoder. As an autoencoder, the VAE has two parts: an encoder and a decoder. 

Instead of mapping each image to a single point in $z$-space, the encoder outputs the means and covariance matrices of a multivariate normal distribution where all of the dimensions are independent. 

You should have had a chance to read more about multivariate normal distributions in last week's assignment, but you can think of the output of the encoder of a VAE this way: the means and standard deviations of a set of independent normal distributions, with one normal distribution (one mean and standard deviation) for each latent dimension. 

*VAE Architecture Drawing: The encoding outputs a distribution in $z$-space, and to generate an image you sample from the distributon and pass the $z$-space sample to the decoder, which returns an image. VAE latent space visualization from [Hyperspherical Variational Auto-Encoders](https://arxiv.org/abs/1804.00891)*, by Davidson et al. in UAI 2018 

## Encoder and Decoder

For your encoder and decoder, you can use similar architectures that you've seen before, with some tweaks. For example, for the decoder, you can use the DCGAN generator architecture. For the encoder, you can use a classifier that you used before, and instead of having it produce 1 classification output of whether something is a cat or not, for example, you can have it produce 2 different outputs, one for mean and one for standard deviation. Each of those outputs will have dimensionality $z$ to model the $z$ dimensions in the multivariate normal distributions.

In [1]:
import torch
import torch.nn as nn

class Encoder(nn.Module):
    '''
    Encoder Class
    Values:
    im_chan: the number of channels of the output image, a scalar
            MNIST is black-and-white (1 channel), so that's our default.
    hidden_dim: the inner dimension, a scalar
    '''

    def __init__(self, im_chan=1, output_chan=32, hidden_dim=16):
        super(Encoder, self).__init__()
        self.z_dim = output_chan
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, output_chan * 2, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a encoder block of the VAE, 
        corresponding to a convolution, a batchnorm (except for in the last layer), and an activation
        Parameters:
        input_channels: how many channels the input feature representation has
        output_channels: how many channels the output feature representation should have
        kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
        stride: the stride of the convolution
        final_layer: whether we're on the final layer (affects activation and batchnorm)
        '''        
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the Encoder: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
        image: a flattened image tensor with dimension (im_dim)
        '''
        disc_pred = self.disc(image) #disc_pred.shape: torch.Size([1024, 64, 1, 1])
        encoding = disc_pred.view(len(disc_pred), -1) # torch.Size([1024, 64])
        
        # The stddev output is treated as the log of the variance of the normal 
        # distribution by convention and for numerical stability
        return encoding[:, :self.z_dim], encoding[:, self.z_dim:].exp()

In [2]:
class Decoder(nn.Module):
    '''
    Decoder Class
    Values:
    z_dim: the dimension of the noise vector, a scalar
    im_chan: the number of channels of the output image, a scalar
            MNIST is black-and-white, so that's our default
    hidden_dim: the inner dimension, a scalar
    '''
    
    def __init__(self, z_dim=32, im_chan=1, hidden_dim=64):
        super(Decoder, self).__init__()
        self.z_dim = z_dim
        self.gen = nn.Sequential(
            self.make_gen_block(z_dim, hidden_dim * 4),
            self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=4, stride=1),
            self.make_gen_block(hidden_dim * 2, hidden_dim),
            self.make_gen_block(hidden_dim, im_chan, kernel_size=4, final_layer=True),
        )

    def make_gen_block(self, input_channels, output_channels, kernel_size=3, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a Decoder block of the VAE, 
        corresponding to a transposed convolution, a batchnorm (except for in the last layer), and an activation
        Parameters:
        input_channels: how many channels the input feature representation has
        output_channels: how many channels the output feature representation should have
        kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
        stride: the stride of the convolution
        final_layer: whether we're on the final layer (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Sigmoid(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the Decoder: Given a noise vector, 
        returns a generated image.
        Parameters:
        noise: a noise tensor with dimensions (batch_size, z_dim)
        '''
        x = noise.view(len(noise), self.z_dim, 1, 1)
        return self.gen(x)

# VAE

You can define the VAE using the encoder and decoder as follows. In the forward pass, the VAE samples from the encoder's output distribution before passing a value to the decoder. 

A common mistake is to pass the mean to the decoder --- this leads to blurrier images and is not the way in which VAEs are designed to be used. So, the steps you'll take are:

1.   Real image input to encoder
2.   Encoder outputs mean and standard deviation
3.   Sample from distribution with the outputed mean and standard deviation
4.   Take sampled value (vector/latent) as the input to the decoder
5.   Get fake sample
6.   Use reconstruction loss between the fake output of the decoder and the original real input to the encoder (more about this later - keep reading!)
7.   Backpropagate through



![VAE Architecture](vae_sampling.png)

## Quick Note on Implementation Notation ("Reparameterization Trick")
Most machine learning frameworks will not backpropagate through a random sample (Step 3-4 above work in the forward pass, but its gradient is not readily implemented for the backward pass using the usual notation). In PyTorch, you can do this by sampling using the `rsample` method, such as in `Normal(mean, stddev).rsample()`. This is equivalent to `torch.randn(z_dim) * stddev + mean`, but **do not use** `torch.normal(mean, stddev)`, as the optimizer will not backpropagate through the expectation of that sample. This is also known as the reparameterization trick, since you're moving the parameters of the random sample outside of the the function to explicitly highlight that the gradient should be calculated through these parameters.

In [3]:
from torch.distributions.normal import Normal

class VAE(nn.Module):
    '''
    VAE Class
    Values:
    z_dim: the dimension of the noise vector, a scalar
    im_chan: the number of channels of the output image, a scalar
            MNIST is black-and-white, so that's our default
    hidden_dim: the inner dimension, a scalar
    '''
    
    def __init__(self, z_dim=32, im_chan=1, hidden_dim=64):
        super(VAE, self).__init__()
        self.z_dim = z_dim
        self.encode = Encoder(im_chan, z_dim)
        self.decode = Decoder(z_dim, im_chan)

    def forward(self, images):
        '''
        Function for completing a forward pass of the Decoder: Given a noise vector, 
        returns a generated image.
        Parameters:
        images: an image tensor with dimensions (batch_size, im_chan, im_height, im_width)
        Returns:
        decoding: the autoencoded image
        q_dist: the z-distribution of the encoding
        '''
        q_mean, q_stddev = self.encode(images)
        q_dist = Normal(q_mean, q_stddev)
        z_sample = q_dist.rsample() # Sample once from each distribution, using the `rsample` notation
        decoding = self.decode(z_sample)
        return decoding, q_dist

## Evidence Lower Bound (ELBO)

When training a VAE, you're trying to maximize the likelihood of the real images. What this means is that you'd like the learned probability distribution to think it's likely that a real image (and the features in that real image) occurs -- as opposed to, say, random noise or weird-looking things. And you want to maximize the likelihood of the real stuff occurring and appropriately associate it with a point in the latent space distribution prior $p(z)$ (more on this below), which is where your learned latent noise vectors will live. However, finding this likelihood explicitly is mathematically intractable. So, instead, you can get a good lower bound on the likelihood, meaning you can figure out what the worst-case scenario of the likelihood is (its lower bound which *is* mathematically tractable) and try to maximize that instead. Because if you maximize its lower bound, or worst-case, then you probably are making the likelihood better too. And this neat technique is known as maximizing the Evidence Lower Bound (ELBO).
#### ùëù(ùëß)  is the prior probability | q(z) refers to the posterior probability

Some notation before jumping into explaining ELBO: First, the prior latent space distribution $p(z)$ is the prior probability you have on the latent space $z$. This represents the likelihood of a given latent point in the latent space, and you know what this actually is because you set it in the beginning as a multivariate normal distribution. Additionally, $q(z)$ refers to the posterior probability, or the distribution of the encoded images. Keep in mind that when each image is passed through the encoder, its encoding is a probability distribution.

Knowing that notation, here's the mathematical notation for the ELBO of a VAE, which is the lower bound you want to maximize: $\mathbb{E}\left(\log p(x|z)\right) + \mathbb{E}\left(\log \frac{p(z)}{q(z)}\right)$, which is equivalent to $\mathbb{E}\left(\log p(x|z)\right) - \mathrm{D_{KL}}(q(z|x)\Vert p(z))$

ELBO can be broken down into two parts: the reconstruction loss $\mathbb{E}\left(\log p(x|z)\right)$ and the KL divergence term $\mathrm{D_{KL}}(q(z|x)\Vert p(z))$. You'll explore each of these two terms in the next code and text sections.

### Reconstruction Loss 

Reconstruction loss refers to the distance between the real input image (that you put into the encoder) and the generated image (that comes out of the decoder). Explicitly, the reconstruction loss term is $\mathbb{E}\left(\log p(x|z)\right)$, the log probability of the true image given the latent value. 

For MNIST, you can treat each grayscale prediction as a binary random variable (also known as a Bernoulli distribution) with the value between 0 and 1 of a pixel corresponding to the output brightness, so you can use the binary cross entropy loss between the real input image and the generated image in order to represent the reconstruction loss term.

####  different assumptions about the "distribution" of the pixel brightnesses in an image will lead to different loss functions

In general, different assumptions about the "distribution" of the pixel brightnesses in an image will lead to different loss functions. For example, if you assume that the brightnesses of the pixels actually follow a normal distribution instead of a binary random (Bernoulli) distribution, this corresponds to a mean squared error (MSE) reconstruction loss.

Why the mean squared error? Well, as a point moves away from the center, $\mu$, of a normal distribution, its negative log likelihood increases quadratically. You can also write this as $\mathrm{NLL}(x) \propto (x-\mu)^2$ for $x \sim \mathcal{N}(\mu,\sigma)$. As a result, assuming the pixel brightnesses are normally distributed implies an MSE reconstruction loss. 

In [4]:
reconstruction_loss = nn.BCELoss(reduction='sum')

### KL Divergence 

KL divergence, mentioned in a video (on Inception Score) last week, allows you to evaluate how different one probability distribution is from another. If you have two distributions and they are exactly the same, then KL divergence is equal to 0. KL divergence is close to the notion of distance between distributions, but notice that it's called a divergence, not a distance; this is because it is not symmetric, meaning that $\mathrm{KL}(X\Vert Y)$ is usually not equal to the terms flipped $\mathrm{KL}(Y\Vert X)$. In contrast, a true distance function, like the Euclidean distance where you would take the squared difference between two points, is symmetric where you compare $(A-B)^2$ and $(B-A)^2$. 

Now, you care about two distributions and finding how different they are: (1) the learned latent space $q(z|x)$ that your encoder is trying to model and (2) your prior on the latent space $p(z)$, which you want your learned latent space to be as close as possible to. If both of your distributions are normal distributions, you can calculate the KL divergence, or $\mathrm{D_{KL}}(q(z|x)\Vert p(z))$, based on a simple formula. This makes KL divergence an attractive measure to use and the normal distribution a simultaneously attractive distribution to assume on your model and data. 

Well, your encoder is learning $q(z|x)$, but what's your latent prior $p(z)$? It is actually a fairly simple distribution for the latent space with a mean of zero and a standard deviation of one in each dimension, or $\mathcal{N}(0, I)$. You might also come across this as the *spherical normal distribution*, where the $I$ in $\mathcal{N}(0, I)$ stands for the identity matrix, meaning its covariance is 1 along the entire diagonal of the matrix and if you like geometry, it forms a nice symmetric-looking hypersphere, or a sphere with many (here, $z$) dimensions.

In [5]:
from torch.distributions.kl import kl_divergence
def kl_divergence_loss(q_dist):
    return kl_divergence(
        q_dist, Normal(torch.zeros_like(q_dist.mean), torch.ones_like(q_dist.stddev))
    ).sum(-1)

### Further Resources

An accessible but complete discussion and derivation of the evidence lower bound (ELBO) and the theory behind it can be found [at this link](https://deepgenerativemodels.github.io/notes/vae/) and [this lecture](http://www.cs.toronto.edu/~rgrosse/courses/csc421_2019/slides/lec17.pdf).

## Training a VAE

Here you can train a VAE, once again using MNIST! First, define the dataloader:

In [6]:
from torch.utils.data.dataloader import DataLoader
from torchvision import datasets, transforms

transform=transforms.Compose([
    transforms.ToTensor(),
])
mnist_dataset = datasets.MNIST('.', train=True, transform=transform)
train_dataloader = DataLoader(mnist_dataset, shuffle=True, batch_size=1024)

In [14]:
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (16, 8)

from torchvision.utils import make_grid
from tqdm import tqdm
import time

def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.axis('off')
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())

device = 'cuda'
vae = VAE().to(device)
vae_opt = torch.optim.Adam(vae.parameters(), lr=0.002)
for epoch in range(10):
    print(f"Epoch {epoch}")
    time.sleep(0.5)
    for images, _ in tqdm(train_dataloader):
        images = images.to(device)
        vae_opt.zero_grad() # Clear out the gradients
        recon_images, encoding = vae(images) 
        # encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
        #encoding_mean_shape: torch.Size([1024, 32])
        loss = reconstruction_loss(recon_images, images) + kl_divergence_loss(encoding).sum()
        loss.backward()
        vae_opt.step()
    plt.subplot(1,2,1)
    show_tensor_images(images)
    plt.title("True")
    plt.subplot(1,2,2)
    show_tensor_images(recon_images)
    plt.title("Reconstructed")
    plt.show()

Epoch 0




  0%|          | 0/59 [00:00<?, ?it/s][A[A

  2%|‚ñè         | 1/59 [00:00<00:09,  6.16it/s][A[A

  3%|‚ñé         | 2/59 [00:00<00:08,  6.57it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(9740.4951, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 1.0537e-01, -6.2053e-02, -3.8046e-02,  ...,  5.1727e-02,
          6.0529e-03, -1.9940e-01],
        [ 7.4510e-02, -6.1861e-01, -2.6070e-01,  ..., -7.5769e-02,
         -8.7658e-02, -4.4476e-01],
        [ 9.0441e-01,  4.0716e-01,  7.3794e-01,  ..., -1.8436e-01,
         -3.1666e-02, -1.4817e-01],
        ...,
        [-2.8755e-01,  1.3172e-01, -3.2881e-02,  ..., -3.5239e-01,
          1.1317e-02, -5.1140e-01],
        [-2.7081e-01, -5.5970e-04, -2.3454e-01,  ..., -5.9953e-01,
          2.2336e-01,  7.3320e-03],
        [-5.0616e-01,  6.3261e-01, -5.1716e-03,  ...,  6.9145e-02,
          4.1362e-01,  4.2269e-01]], device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(8802.8135, device='cuda:0', grad_fn=<SumBackward0>)




  5%|‚ñå         | 3/59 [00:00<00:08,  6.91it/s][A[A

  7%|‚ñã         | 4/59 [00:00<00:07,  7.14it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(16399.8516, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.2391,  0.4102,  1.9182,  ..., -0.1269,  0.1084, -0.1398],
        [ 0.0530, -0.1412,  0.7938,  ...,  0.1780,  0.2096, -0.3030],
        [-0.3516,  0.3684,  2.2759,  ...,  0.4121, -0.5370, -0.1030],
        ...,
        [ 0.2176, -0.2201,  0.3222,  ..., -0.4032, -0.3183, -0.0409],
        [-0.0851, -0.0049,  1.0058,  ..., -0.4661,  0.4043, -1.0221],
        [ 0.3509,  0.7614,  1.2438,  ..., -0.0925,  0.3634, -0.3142]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(17674.9551, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.4108,  0.7241,  1.7403,  ...,  0.1382,  0.6086, -0.3461],
        [-0.3640, -0.1783, 



  8%|‚ñä         | 5/59 [00:00<00:07,  7.35it/s][A[A

 10%|‚ñà         | 6/59 [00:00<00:07,  7.52it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(11914.3779, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.5058,  0.0299,  0.3882,  ..., -0.3855,  0.2514,  0.1821],
        [-0.6193,  0.1978,  0.8125,  ..., -0.0825,  0.2824,  0.2355],
        [ 0.2818, -0.2746,  1.3907,  ..., -0.7239, -0.5744, -0.4651],
        ...,
        [-1.2195, -0.0797,  0.8950,  ..., -0.6570, -0.1882,  0.2438],
        [ 0.3316,  0.4767,  1.3667,  ...,  0.0414,  0.7330, -0.7528],
        [-0.5081,  0.5054,  1.1278,  ...,  0.4797, -0.1799, -1.4165]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(9737.9355, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.0125, -0.2859,  1.1343,  ...,  0.0103, -1.1282, -0.5985],
        [-0.8578, -0.5289,  



 12%|‚ñà‚ñè        | 7/59 [00:00<00:06,  7.59it/s][A[A

 14%|‚ñà‚ñé        | 8/59 [00:01<00:06,  7.60it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(8341.8633, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.3123, -0.0512,  0.5371,  ...,  0.1078, -0.3493, -0.3809],
        [-0.1240,  0.2573,  1.5117,  ...,  0.4588, -1.1350, -0.0673],
        [ 0.0367, -0.5306,  0.5833,  ...,  0.4706,  0.0927, -0.4540],
        ...,
        [-0.5714,  0.8660,  1.1418,  ..., -0.5245, -0.9181,  0.5313],
        [-0.2664,  0.0197,  0.4069,  ..., -0.4153, -0.3606,  0.1737],
        [ 0.0953, -0.0484,  0.3332,  ...,  0.3034, -0.0905,  0.1642]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(7381.4897, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.7367, -0.1112, -0.2265,  ..., -0.3291, -0.0865,  0.6639],
        [-0.8911, -0.0495,  0



 15%|‚ñà‚ñå        | 9/59 [00:01<00:06,  7.59it/s][A[A

 17%|‚ñà‚ñã        | 10/59 [00:01<00:06,  7.60it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(7050.1919, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.8392,  0.3633, -0.2512,  ...,  0.6060,  0.6146,  0.5856],
        [ 1.3620,  0.0278, -0.3572,  ...,  0.5012,  0.2423, -0.1228],
        [ 0.4719,  0.6443,  1.0073,  ...,  0.2388, -0.8823,  1.0267],
        ...,
        [ 0.0665, -0.1009, -0.1512,  ..., -0.1604,  0.2399,  0.1325],
        [-0.7616, -0.1427, -0.0287,  ...,  0.1121, -0.1089, -0.3193],
        [-0.0829, -0.2172,  0.2711,  ...,  0.4417,  0.0913,  0.5569]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(6711.5664, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.0126,  0.0555,  0.0295,  ...,  0.5240,  0.6974,  0.7408],
        [-0.0757,  0.1371, -0



 19%|‚ñà‚ñä        | 11/59 [00:01<00:06,  7.60it/s][A[A

 20%|‚ñà‚ñà        | 12/59 [00:01<00:06,  7.61it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(6646.2334, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 1.0430, -0.3334, -0.5299,  ...,  0.7076,  0.0566,  0.8246],
        [-0.5422,  0.4200, -0.0246,  ..., -0.2161, -0.0286,  0.3734],
        [ 0.9163, -0.1010, -0.4748,  ...,  0.2544,  0.0891,  0.2466],
        ...,
        [ 1.1223,  0.0070, -0.3439,  ...,  0.7716,  0.3711,  0.5727],
        [-0.4968,  0.0801,  0.1446,  ..., -0.4153,  0.3485, -0.0477],
        [-0.1143,  0.3948,  0.5837,  ..., -0.2025,  0.0406,  1.2966]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(6667.0684, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.7544, -0.4148,  0.2968,  ...,  0.6469,  0.3271,  0.0459],
        [ 0.6386, -0.0308, -0



 22%|‚ñà‚ñà‚ñè       | 13/59 [00:01<00:06,  7.61it/s][A[A

 24%|‚ñà‚ñà‚ñé       | 14/59 [00:01<00:05,  7.63it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(6707.5015, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.4792,  0.1844,  0.3117,  ...,  0.2304,  1.0386,  1.4619],
        [-0.5416, -0.0454, -0.4263,  ..., -0.2354,  0.5635, -0.5185],
        [-0.4145,  0.3262, -0.0487,  ...,  0.1212, -1.0814, -0.1921],
        ...,
        [-0.0134,  0.2962, -0.2858,  ..., -0.1303, -0.0279,  0.4443],
        [ 0.9392,  0.0652, -0.1067,  ...,  0.4964,  0.7472, -0.6334],
        [-0.6315,  0.0267, -0.0247,  ..., -0.3731, -0.1250, -0.0755]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(6892.0439, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.0653, -0.1064, -0.5343,  ...,  0.2385,  0.2390, -0.4247],
        [ 1.2206,  0.3036, -0



 25%|‚ñà‚ñà‚ñå       | 15/59 [00:01<00:05,  7.62it/s][A[A

 27%|‚ñà‚ñà‚ñã       | 16/59 [00:02<00:05,  7.64it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(6807.2197, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.1395,  0.3540,  0.2872,  ..., -0.3477, -0.4362, -0.3818],
        [-0.4441, -0.2624, -0.5384,  ..., -0.6329, -0.5113, -0.4003],
        [-0.0329,  0.2067, -0.1719,  ...,  0.0508,  0.2438, -0.5665],
        ...,
        [-0.5032, -0.0339, -0.3891,  ..., -0.5889,  0.6640,  0.7243],
        [-0.7882, -0.0948,  0.6145,  ..., -0.3053,  0.1240, -0.0609],
        [-0.4604,  0.2082,  0.1663,  ..., -0.2584,  0.5369, -0.1750]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(6971.5415, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.2475, -0.1603,  0.4860,  ..., -0.5349,  0.3926,  1.1552],
        [-0.2154,  0.3173, -0



 29%|‚ñà‚ñà‚ñâ       | 17/59 [00:02<00:05,  7.64it/s][A[A

 31%|‚ñà‚ñà‚ñà       | 18/59 [00:02<00:05,  7.66it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(6972.2266, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.4969, -0.0967,  0.3224,  ...,  0.5245,  0.3860,  0.1968],
        [-0.5246, -0.4611,  0.3279,  ..., -1.2782, -1.1904, -0.0353],
        [-0.5281, -0.4226, -0.1822,  ..., -1.0893, -0.1257,  0.0669],
        ...,
        [ 1.0920,  0.1451, -0.0067,  ...,  0.3689, -0.4759, -0.7632],
        [-0.8384, -0.2266, -0.2505,  ..., -0.8221,  0.3488,  0.3044],
        [ 1.3820,  0.3951,  0.3427,  ...,  0.6868, -0.4381, -0.1941]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(7104.9346, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-1.3521, -0.5304, -0.3217,  ..., -1.1231, -0.0595, -0.3628],
        [-0.1548,  0.2301,  1



 32%|‚ñà‚ñà‚ñà‚ñè      | 19/59 [00:02<00:05,  7.63it/s][A[A

 34%|‚ñà‚ñà‚ñà‚ñç      | 20/59 [00:02<00:05,  7.66it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(6988.3086, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-1.0362, -0.7728, -0.1412,  ..., -0.6913,  0.3928,  0.5401],
        [ 0.6756,  0.6960,  0.2180,  ...,  0.8040,  0.7191,  0.8359],
        [-0.5108,  0.1258, -0.0712,  ..., -0.3353,  0.2236, -0.2389],
        ...,
        [ 0.4870,  0.3805,  0.1918,  ...,  0.3234,  0.6004,  0.6228],
        [ 1.0785,  0.3024,  0.8787,  ...,  0.7519, -0.1226,  0.1272],
        [-0.5787,  0.1276, -0.4371,  ..., -0.0104,  0.3294, -0.6778]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(7369.7935, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-8.7322e-02,  2.1967e-01,  1.2648e+00,  ...,  2.0478e-01,
          5.7264e-01,  1.8271e+0



 36%|‚ñà‚ñà‚ñà‚ñå      | 21/59 [00:02<00:04,  7.67it/s][A[A

 37%|‚ñà‚ñà‚ñà‚ñã      | 22/59 [00:02<00:04,  7.66it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(7408.8398, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.1087, -0.3387,  0.7804,  ...,  0.4732,  0.5736,  0.0547],
        [ 0.9838,  0.9478,  1.4217,  ...,  1.0235, -0.1600,  1.7052],
        [-0.3205,  0.1020,  0.7245,  ..., -0.8574, -0.2798,  0.3582],
        ...,
        [-0.2124, -0.4315, -0.4474,  ..., -0.8450, -0.0449, -0.3473],
        [-1.5071, -1.1728,  0.4839,  ..., -0.7513, -0.0340,  0.2696],
        [-0.5506, -0.0820, -0.4197,  ..., -0.3559,  0.4073, -0.0410]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(7727.6289, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-1.0377, -0.7636, -0.5517,  ..., -1.3959,  0.6026,  0.4015],
        [-0.0579, -0.1430,  0



 39%|‚ñà‚ñà‚ñà‚ñâ      | 23/59 [00:03<00:04,  7.64it/s][A[A

 41%|‚ñà‚ñà‚ñà‚ñà      | 24/59 [00:03<00:04,  7.66it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(8174.8623, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.3419,  0.7454,  0.7221,  ...,  0.4040,  0.3057,  1.2601],
        [ 0.7894,  0.6543,  0.6578,  ...,  0.6786,  0.1475, -0.1475],
        [-0.8436,  0.2049,  0.1173,  ..., -0.1222,  0.2788, -0.2417],
        ...,
        [-0.0871,  0.2184,  0.1428,  ...,  0.4864, -0.3662, -0.5367],
        [ 1.4555, -0.1012, -0.0640,  ...,  0.4860, -0.2288, -0.4068],
        [-1.7063, -0.5364,  0.6941,  ..., -0.5675, -0.0949,  0.4604]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(8460.5410, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.9057, -0.5624, -0.2718,  ..., -1.7194, -0.3489,  0.0055],
        [ 0.2963, -0.3717, -0



 42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 25/59 [00:03<00:04,  7.69it/s][A[A

 44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 26/59 [00:03<00:04,  7.70it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(8627.8691, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.2550,  0.5106,  0.2657,  ...,  0.7276, -0.0743, -0.3011],
        [-0.7582, -0.9019, -0.7062,  ..., -1.1112,  0.5862,  0.5449],
        [-0.4386,  0.3182,  0.0378,  ...,  0.2897, -0.0767,  0.0995],
        ...,
        [-0.4027,  0.0969, -0.4223,  ..., -0.5665,  0.4139,  0.0475],
        [ 0.3868, -0.1317,  0.3810,  ...,  0.8598,  0.4339, -0.0756],
        [-0.4820,  0.0185, -0.3857,  ..., -0.5207,  0.3933,  0.2946]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(8628.0303, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 0.8766,  0.4724,  0.7501,  ...,  1.2382, -0.2103, -0.3719],
        [ 1.0866,  0.5288,  0



 46%|‚ñà‚ñà‚ñà‚ñà‚ñå     | 27/59 [00:03<00:04,  7.70it/s][A[A

 47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 28/59 [00:03<00:04,  7.73it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(8975.5059, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.7653,  0.6551,  0.6181,  ...,  0.6336,  0.1232,  0.0206],
        [ 0.5699,  0.7529,  1.4517,  ...,  0.4025, -0.6945,  0.8688],
        [ 1.0791,  0.4459,  1.2723,  ...,  1.0994, -0.5566,  1.2594],
        ...,
        [-0.3910, -0.0784, -0.5674,  ..., -0.5681,  0.2102,  0.1097],
        [ 0.7226,  0.3884, -0.0531,  ...,  0.3913,  0.6826,  0.1768],
        [ 0.2968, -0.1359, -0.4757,  ..., -0.4131, -0.1401, -0.1918]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])
kl_divergence_loss(encoding).sum(): tensor(8726.1123, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[ 8.4567e-03,  7.8482e-01,  9.1454e-01,  ...,  9.7488e-01,
          1.0625e+00,  6.7836e-0



 49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 29/59 [00:03<00:03,  7.73it/s][A[A

kl_divergence_loss(encoding).sum(): tensor(8694.6982, device='cuda:0', grad_fn=<SumBackward0>)
encoding: Normal(loc: torch.Size([1024, 32]), scale: torch.Size([1024, 32]))
encoding_mean: tensor([[-0.6338,  0.3196,  0.1368,  ..., -0.3323, -0.4886,  0.1502],
        [ 0.3005,  0.7829,  0.3749,  ...,  0.9603,  0.3654, -0.6251],
        [-1.2302,  0.3226,  0.4534,  ..., -0.0647, -0.1500, -0.2130],
        ...,
        [-1.3669, -0.9969, -1.0950,  ..., -1.5216,  0.2151, -0.1221],
        [-1.0914, -0.3263, -0.0948,  ..., -0.3634, -0.2332, -0.5286],
        [-0.8304,  0.0065, -0.2492,  ..., -0.1470,  0.3804, -0.7118]],
       device='cuda:0', grad_fn=<SliceBackward>)
encoding_mean_shape: torch.Size([1024, 32])


KeyboardInterrupt: 