In [3]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F

In [4]:
class VAE(nn.Module):
    def __init__(self, input_dim, latent_dim, hidden_dim=200):
        """
        Input shape: (bs, 1, 28, 28)
        """
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, stride=2),  # 13
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2), # 5
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(64 * 6 * 6, 2 * latent_dim),
        )

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 7 * 7 * 32),
            nn.ReLU(),
            nn.Unflatten(1, (32, 7, 7)),
            nn.ConvTranspose2d(32, 64, kernel_size=3, stride=2, padding=1, output_padding=1), # 14
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), # 28
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=3, stride=1, padding=1), # 28
            nn.Sigmoid(),
        )

    def reparameterize(self, mu, log_var):
        std = th.exp(0.5 * log_var)
        eps = th.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        x = self.encoder(x)
        mu, log_var = th.chunk(x, 2, dim=-1)
        z = self.reparameterize(mu, log_var)
        x = self.decoder(z)
        return x, mu, log_var

In [5]:
device = th.device('cuda' if th.cuda.is_available() else 'cpu')
# device = th.device('cpu')
model = VAE(28 * 28, 2).to(device)


In [1]:
! pip install torchvision

Defaulting to user installation because normal site-packages is not writeable
Collecting torchvision
  Using cached torchvision-0.15.2-cp39-cp39-win_amd64.whl (1.2 MB)
Installing collected packages: torchvision
Successfully installed torchvision-0.15.2


In [6]:
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

dataset = datasets.MNIST(root='./dataset/', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

In [7]:
import matplotlib.pyplot as plt

def show_image(img):
    plt.imshow(img.squeeze().numpy(), cmap='gray')
    plt.colorbar()
    plt.show()

show_image(dataset[0][0])

: 

: 

In [7]:
from tqdm import tqdm

def train(model: nn.Module, optimizer: th.optim.Optimizer, reconst_loss: nn.Module, train_loader: DataLoader, epochs: int, device: th.device):
    model.train()
    train_loss = 0
    for epoch in range(epochs):
        desc = f'Epoch {epoch + 1}/{epochs}'
        prograss_bar = tqdm(enumerate(train_loader), desc=desc, leave=False)
        for batch_idx, (data, _) in prograss_bar:
            data = data.to(device)
            optimizer.zero_grad()
            output, mu, log_var = model(data)
            rec_loss = reconst_loss(output, data)
            kl_div = -0.5 * th.sum(1 + log_var - mu.pow(2) - log_var.exp())
            loss = rec_loss + kl_div
            loss.backward()
            train_loss += loss.item()
            optimizer.step()
            prograss_bar.set_postfix({'loss': loss.item() / len(data)})
    return train_loss / len(train_loader)

In [8]:
model = VAE(28 * 28, 2).to(device)
optimizer = th.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.BCELoss(reduction='sum')

In [9]:
final_loss = train(model, optimizer, loss_fn, train_loader, 10, device)

                                               

In [1]:
def plot_latent_images(model, n, digit_size=28):
    norm = th.distributions.Normal(0, 1)
    grid_x = norm.icdf(th.linspace(0.05, 0.95, n))
    grid_y = norm.icdf(th.linspace(0.05, 0.95, n))
    image_width = digit_size * n
    image_height = image_width
    image = th.zeros(image_height, image_width)
    for i, yi in enumerate(grid_x):
        for j, xi in enumerate(grid_y):
            z = th.tensor([[xi, yi]]).to(device)
            x_hat = model.decoder(z)
            image[
                i * digit_size : (i + 1) * digit_size,
                j * digit_size : (j + 1) * digit_size,
            ] = x_hat[0].reshape(digit_size, digit_size)
    plt.imshow(image.detach().cpu().numpy(), cmap='gray')
    plt.axis('off')
    plt.show()

plot_latent_images(model, 20)

NameError: name 'model' is not defined