# VAE

> In this module, we train a variational autoencoder

In [1]:
#| default_exp fashion_unet

In [11]:
# |export
import torch.nn.functional as F
from torch import nn

from slowai.attention import ConditionalTAUnet, conditional_train
from slowai.cos_revisited import aesthetics
from slowai.super_rez import KaimingMixin
from slowai.tinyimagenet_a import denorm, get_imagenet_dls
from slowai.utils import show_images

In [5]:
aesthetics()

<function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>

In [15]:
class Lin(nn.Module):
    def __init__(self, c_in, c_out, bias=True, activation=True):
        super().__init__()
        self.lin = nn.Linear(c_in, c_out, bias=bias)
        self.activation = activation

    def forward(self, x):
        x = self.lin(x)
        if self.activation(x):
            x = F.relu(x)
        x = F.batch_norm(x)
        return x

In [16]:
class AE(nn.Module, KaimingMixin):
    def __init__(self, c_in, c_hidden, c_bottleneck):
        super().__init__()
        self.encoder = nn.Sequential(
            Lin(c_in, c_hidden),
            Lin(c_hidden, c_hidden),
            Lin(c_hidden, c_bottleneck),
        )
        self.decoder = nn.Sequential(
            Lin(c_bottleneck, c_hidden),
            Lin(c_hidden, c_hidden),
            Lin(c_hidden, c_in, activation=False),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x, self.decoder(x)

In [None]:
#| hide
import nbdev

nbdev.nbdev_export()