In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torch import Tensor
from torchvision.utils import save_image
from tqdm import tqdm
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.utils.data import random_split
from CustomDataset import CustomDataset
from PIL import Image
import wandb


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
INPUT_DIM = 3
#38804
Z_DIM = 1000
NUM_EPOCHS = 1000
BATCH_SIZE = 32
LR_RATE = 3e-4
KL_COEFF = 0.0000025
PATH = "model.pt"

In [22]:
wandb.init(
    # set the wandb project where this run will be logged
    project="dotreniranje",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": LR_RATE,
    "architecture": "VAE",
    "dataset": "CELEBA",
    "epochs": NUM_EPOCHS,
    }
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mivanjevtic501[0m ([33mracunarski-fakultet[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [23]:
class VanillaVAE(nn.Module):
    def __init__(self,
                    in_channels: int,
                    latent_dim: int,
                    hidden_dims: list = None,
                    **kwargs) -> None:
            super(VanillaVAE, self).__init__()

            self.latent_dim = latent_dim

            modules = []
            if hidden_dims is None:
                hidden_dims = [32, 64, 128, 256, 512]

            # Build Encoder
            for h_dim in hidden_dims:
                modules.append(
                    nn.Sequential(
                        nn.Conv2d(in_channels, out_channels=h_dim,
                                kernel_size= 3, stride= 2, padding  = 1),
                        nn.BatchNorm2d(h_dim),
                        nn.LeakyReLU())
                )
                in_channels = h_dim

            self.encoder = nn.Sequential(*modules)
            self.fc_mu = nn.Linear(hidden_dims[-1]*42, latent_dim)
            self.fc_var = nn.Linear(hidden_dims[-1]*42, latent_dim)


            # Build Decoder
            modules = []

            self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 42)

            hidden_dims.reverse()

            for i in range(len(hidden_dims) - 1):
                modules.append(
                    nn.Sequential(
                        nn.ConvTranspose2d(hidden_dims[i],
                                        hidden_dims[i + 1],
                                        kernel_size=3,
                                        stride = 2,
                                        padding=1,
                                        output_padding=1),
                        nn.BatchNorm2d(hidden_dims[i + 1]),
                        nn.LeakyReLU())
                )



            self.decoder = nn.Sequential(*modules)

            self.final_layer = nn.Sequential(
                                nn.ConvTranspose2d(hidden_dims[-1],
                                                hidden_dims[-1],
                                                kernel_size=3,
                                                stride=2,
                                                padding=1,
                                                output_padding=1),
                                nn.BatchNorm2d(hidden_dims[-1]),
                                nn.LeakyReLU(),
                                nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                        kernel_size= 3, padding= 1),
                                nn.Tanh())

    def encode(self, input: Tensor) -> list[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 7, 6)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> list[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return  [self.decode(z), input, mu, log_var]

    def loss_function(self,
                    *args) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]

        # kld_weight = kwargs['M_N'] # Account for the minibatch samples from the dataset
        # kld_weight = 0.00025
        kld_weight = KL_COEFF
        # kld_weight = 0
        # print(recons.shape, input.shape)
        recons_loss =F.mse_loss(recons, input)


        kld_loss = kld_weight * torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

        # loss = recons_loss + kld_loss
        return {'loss': (recons_loss, kld_loss), 'Reconstruction_Loss':recons_loss.detach(), 'KLD':-kld_loss.detach()}

    def sample(self,
            num_samples:int,
            current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples,
                        self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [24]:
batch_size = 32

data_length = 10#202599
dataset = CustomDataset("data/img_align_celeba", [(str(i).rjust(6, '0')+".jpg") for i in range(1,data_length+1)], transform=transforms.ToTensor())

# dataset_train, dataset_val = random_split(dataset, [int(data_length*0.8), data_length- int(data_length*0.8)])
dataset_train, dataset_val = dataset, dataset

train_loader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(dataset_val, batch_size=batch_size, shuffle=True)

0


In [25]:
def train(num_epochs, model, optimizer, loss_fn):
    # Start training
    for epoch in range(num_epochs):
        print(f"Epoch {epoch+1} of {num_epochs}")
        loop = tqdm(enumerate(train_loader))
        epoch_loss = 0
        epoch_reconst_loss = 0
        epoch_kl_div = 0
        for i, x in loop:
            # Forward pass
            x = x.to(device) #.view(-1, INPUT_DIM)
            x_reconst, _, mu, sigma = model(x)

            # loss, formulas from https://www.youtube.com/watch?v=igP03FXZqgo&t=2182s
            reconst_loss, kl_div = loss_fn(x_reconst, x, mu, sigma)['loss']

            # Backprop and optimize
            # vec je weightovan kl_div
            loss = reconst_loss + kl_div
            
            
            #wandb.log({"examples": images}
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            epoch_loss = epoch_loss + loss.item()
            epoch_reconst_loss = epoch_reconst_loss + reconst_loss.item()
            epoch_kl_div = epoch_kl_div + kl_div.item()
            loop.set_postfix(loss=loss.item())
        
        wandb.log({"total_loss": loss,
                       "reconst_loss": reconst_loss,
                       "kl_div": kl_div})


        if(epoch%2 == 0):
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
                }, PATH)


checkpoint = torch.load('model.pt')
model = VanillaVAE(INPUT_DIM, Z_DIM)


optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)


model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

train(NUM_EPOCHS, model, optimizer, model.loss_function)

Epoch 1 of 200


1it [00:01,  1.77s/it, loss=0.00562]


Epoch 2 of 200


1it [00:01,  1.45s/it, loss=0.00568]


Epoch 3 of 200


1it [00:01,  1.08s/it, loss=0.00568]


Epoch 4 of 200


1it [00:01,  1.07s/it, loss=0.00544]


Epoch 5 of 200


1it [00:01,  1.12s/it, loss=0.00549]


Epoch 6 of 200


1it [00:01,  1.12s/it, loss=0.00557]


Epoch 7 of 200


1it [00:01,  1.07s/it, loss=0.00548]


Epoch 8 of 200


1it [00:01,  1.08s/it, loss=0.00556]


Epoch 9 of 200


1it [00:01,  1.14s/it, loss=0.00543]


Epoch 10 of 200


1it [00:01,  1.14s/it, loss=0.0054]


Epoch 11 of 200


1it [00:01,  1.10s/it, loss=0.00529]


Epoch 12 of 200


1it [00:01,  1.15s/it, loss=0.00538]


Epoch 13 of 200


1it [00:01,  1.16s/it, loss=0.00544]


Epoch 14 of 200


1it [00:01,  1.03s/it, loss=0.00531]


Epoch 15 of 200


1it [00:01,  1.06s/it, loss=0.00531]


Epoch 16 of 200


1it [00:01,  1.51s/it, loss=0.00524]


Epoch 17 of 200


1it [00:01,  1.14s/it, loss=0.00528]


Epoch 18 of 200


1it [00:01,  1.13s/it, loss=0.00531]


Epoch 19 of 200


1it [00:01,  1.51s/it, loss=0.00525]


Epoch 20 of 200


1it [00:01,  1.20s/it, loss=0.00525]


Epoch 21 of 200


1it [00:01,  1.53s/it, loss=0.00528]


Epoch 22 of 200


1it [00:01,  1.43s/it, loss=0.00511]


Epoch 23 of 200


1it [00:01,  1.27s/it, loss=0.00526]


Epoch 24 of 200


1it [00:01,  1.56s/it, loss=0.00517]


Epoch 25 of 200


1it [00:01,  1.40s/it, loss=0.00532]


Epoch 26 of 200


1it [00:01,  1.32s/it, loss=0.00504]


Epoch 27 of 200


1it [00:01,  1.45s/it, loss=0.00514]


Epoch 28 of 200


1it [00:01,  1.17s/it, loss=0.00512]


Epoch 29 of 200


1it [00:01,  1.06s/it, loss=0.00518]


Epoch 30 of 200


1it [00:01,  1.02s/it, loss=0.00538]


Epoch 31 of 200


1it [00:01,  1.02s/it, loss=0.00512]


Epoch 32 of 200


1it [00:01,  1.06s/it, loss=0.00501]


Epoch 33 of 200


1it [00:01,  1.02s/it, loss=0.00511]


Epoch 34 of 200


1it [00:01,  1.07s/it, loss=0.00496]


Epoch 35 of 200


1it [00:01,  1.07s/it, loss=0.00504]


Epoch 36 of 200


1it [00:01,  1.16s/it, loss=0.00504]


Epoch 37 of 200


1it [00:01,  1.14s/it, loss=0.00502]


Epoch 38 of 200


1it [00:01,  1.13s/it, loss=0.00501]


Epoch 39 of 200


1it [00:01,  1.63s/it, loss=0.00495]


Epoch 40 of 200


1it [00:01,  1.13s/it, loss=0.00501]


Epoch 41 of 200


1it [00:01,  1.15s/it, loss=0.00496]


Epoch 42 of 200


1it [00:01,  1.34s/it, loss=0.005]


Epoch 43 of 200


1it [00:01,  1.37s/it, loss=0.00504]


Epoch 44 of 200


1it [00:01,  1.17s/it, loss=0.00489]


Epoch 45 of 200


1it [00:01,  1.11s/it, loss=0.0049]


Epoch 46 of 200


1it [00:01,  1.21s/it, loss=0.00482]


Epoch 47 of 200


1it [00:01,  1.52s/it, loss=0.00483]


Epoch 48 of 200


1it [00:01,  1.14s/it, loss=0.00497]


Epoch 49 of 200


1it [00:01,  1.06s/it, loss=0.00481]


Epoch 50 of 200


1it [00:01,  1.00s/it, loss=0.0048]


Epoch 51 of 200


1it [00:00,  1.02it/s, loss=0.00489]


Epoch 52 of 200


1it [00:00,  1.04it/s, loss=0.00474]


Epoch 53 of 200


1it [00:00,  1.04it/s, loss=0.0049]


Epoch 54 of 200


1it [00:00,  1.06it/s, loss=0.00477]


Epoch 55 of 200


1it [00:00,  1.03it/s, loss=0.00477]


Epoch 56 of 200


1it [00:00,  1.06it/s, loss=0.00487]


Epoch 57 of 200


1it [00:01,  1.17s/it, loss=0.0048]


Epoch 58 of 200


1it [00:00,  1.01it/s, loss=0.00478]


Epoch 59 of 200


1it [00:01,  1.00s/it, loss=0.00464]


Epoch 60 of 200


1it [00:00,  1.02it/s, loss=0.00478]


Epoch 61 of 200


1it [00:01,  1.02s/it, loss=0.00471]


Epoch 62 of 200


1it [00:01,  1.12s/it, loss=0.00463]


Epoch 63 of 200


1it [00:01,  1.02s/it, loss=0.00464]


Epoch 64 of 200


1it [00:01,  1.04s/it, loss=0.00479]


Epoch 65 of 200


1it [00:00,  1.03it/s, loss=0.00466]


Epoch 66 of 200


1it [00:01,  1.10s/it, loss=0.00473]


Epoch 67 of 200


1it [00:01,  1.00s/it, loss=0.00469]


Epoch 68 of 200


1it [00:00,  1.04it/s, loss=0.00467]


Epoch 69 of 200


1it [00:00,  1.03it/s, loss=0.00462]


Epoch 70 of 200


1it [00:00,  1.04it/s, loss=0.00472]


Epoch 71 of 200


1it [00:00,  1.05it/s, loss=0.0046]


Epoch 72 of 200


1it [00:00,  1.05it/s, loss=0.00463]


Epoch 73 of 200


1it [00:00,  1.02it/s, loss=0.00466]


Epoch 74 of 200


1it [00:00,  1.02it/s, loss=0.0046]


Epoch 75 of 200


1it [00:00,  1.02it/s, loss=0.00463]


Epoch 76 of 200


1it [00:00,  1.01it/s, loss=0.00448]


Epoch 77 of 200


1it [00:00,  1.02it/s, loss=0.00464]


Epoch 78 of 200


1it [00:00,  1.03it/s, loss=0.00457]


Epoch 79 of 200


1it [00:00,  1.02it/s, loss=0.00453]


Epoch 80 of 200


1it [00:01,  1.10s/it, loss=0.00452]


Epoch 81 of 200


1it [00:00,  1.05it/s, loss=0.00451]


Epoch 82 of 200


1it [00:00,  1.05it/s, loss=0.00444]


Epoch 83 of 200


1it [00:00,  1.03it/s, loss=0.00451]


Epoch 84 of 200


1it [00:00,  1.02it/s, loss=0.00444]


Epoch 85 of 200


1it [00:00,  1.02it/s, loss=0.00444]


Epoch 86 of 200


1it [00:00,  1.04it/s, loss=0.00436]


Epoch 87 of 200


1it [00:00,  1.05it/s, loss=0.00438]


Epoch 88 of 200


1it [00:00,  1.02it/s, loss=0.00438]


Epoch 89 of 200


1it [00:00,  1.04it/s, loss=0.00442]


Epoch 90 of 200


1it [00:00,  1.02it/s, loss=0.0044]


Epoch 91 of 200


1it [00:00,  1.02it/s, loss=0.0044]


Epoch 92 of 200


1it [00:00,  1.02it/s, loss=0.00431]


Epoch 93 of 200


1it [00:00,  1.02it/s, loss=0.00427]


Epoch 94 of 200


1it [00:00,  1.02it/s, loss=0.00444]


Epoch 95 of 200


1it [00:00,  1.03it/s, loss=0.00432]


Epoch 96 of 200


1it [00:00,  1.05it/s, loss=0.00444]


Epoch 97 of 200


1it [00:00,  1.06it/s, loss=0.00438]


Epoch 98 of 200


1it [00:00,  1.07it/s, loss=0.00441]


Epoch 99 of 200


1it [00:00,  1.04it/s, loss=0.0043]


Epoch 100 of 200


1it [00:01,  1.97s/it, loss=0.0043]


Epoch 101 of 200


1it [00:01,  1.36s/it, loss=0.00437]


Epoch 102 of 200


1it [00:01,  1.42s/it, loss=0.00426]


Epoch 103 of 200


1it [00:01,  1.21s/it, loss=0.00447]


Epoch 104 of 200


1it [00:01,  1.24s/it, loss=0.00432]


Epoch 105 of 200


1it [00:01,  1.13s/it, loss=0.00433]


Epoch 106 of 200


1it [00:01,  1.09s/it, loss=0.00424]


Epoch 107 of 200


1it [00:01,  1.07s/it, loss=0.0043]


Epoch 108 of 200


1it [00:01,  1.01s/it, loss=0.00435]


Epoch 109 of 200


1it [00:01,  1.04s/it, loss=0.00425]


Epoch 110 of 200


1it [00:01,  1.35s/it, loss=0.00424]


Epoch 111 of 200


1it [00:01,  1.30s/it, loss=0.00428]


Epoch 112 of 200


1it [00:01,  1.13s/it, loss=0.00431]


Epoch 113 of 200


1it [00:01,  1.25s/it, loss=0.00427]


Epoch 114 of 200


1it [00:01,  1.15s/it, loss=0.00425]


Epoch 115 of 200


1it [00:01,  1.09s/it, loss=0.00415]


Epoch 116 of 200


1it [00:01,  1.06s/it, loss=0.00416]


Epoch 117 of 200


1it [00:01,  1.05s/it, loss=0.00426]


Epoch 118 of 200


1it [00:01,  1.04s/it, loss=0.00413]


Epoch 119 of 200


1it [00:01,  1.10s/it, loss=0.00412]


Epoch 120 of 200


1it [00:01,  1.11s/it, loss=0.00413]


Epoch 121 of 200


1it [00:01,  1.13s/it, loss=0.00419]


Epoch 122 of 200


1it [00:01,  1.08s/it, loss=0.00422]


Epoch 123 of 200


1it [00:01,  1.07s/it, loss=0.00409]


Epoch 124 of 200


1it [00:01,  1.01s/it, loss=0.0042]


Epoch 125 of 200


1it [00:01,  1.12s/it, loss=0.00413]


Epoch 126 of 200


1it [00:01,  1.10s/it, loss=0.00411]


Epoch 127 of 200


1it [00:01,  1.10s/it, loss=0.00411]


Epoch 128 of 200


1it [00:01,  1.03s/it, loss=0.00413]


Epoch 129 of 200


1it [00:01,  1.06s/it, loss=0.00413]


Epoch 130 of 200


1it [00:00,  1.02it/s, loss=0.00411]


Epoch 131 of 200


1it [00:00,  1.04it/s, loss=0.00435]


Epoch 132 of 200


1it [00:01,  1.20s/it, loss=0.0041]


Epoch 133 of 200


1it [00:01,  1.07s/it, loss=0.00419]


Epoch 134 of 200


1it [00:01,  1.03s/it, loss=0.00409]


Epoch 135 of 200


1it [00:01,  1.03s/it, loss=0.00412]


Epoch 136 of 200


1it [00:01,  1.26s/it, loss=0.00405]


Epoch 137 of 200


1it [00:01,  1.27s/it, loss=0.00405]


Epoch 138 of 200


1it [00:01,  1.18s/it, loss=0.00409]


Epoch 139 of 200


1it [00:01,  1.12s/it, loss=0.00397]


Epoch 140 of 200


1it [00:01,  1.12s/it, loss=0.00396]


Epoch 141 of 200


1it [00:01,  1.11s/it, loss=0.00404]


Epoch 142 of 200


1it [00:01,  1.11s/it, loss=0.00406]


Epoch 143 of 200


1it [00:01,  1.12s/it, loss=0.00395]


Epoch 144 of 200


1it [00:01,  1.07s/it, loss=0.00399]


Epoch 145 of 200


1it [00:01,  1.26s/it, loss=0.00405]


Epoch 146 of 200


1it [00:01,  1.05s/it, loss=0.00399]


Epoch 147 of 200


1it [00:01,  1.16s/it, loss=0.00388]


Epoch 148 of 200


1it [00:01,  1.09s/it, loss=0.00395]


Epoch 149 of 200


1it [00:01,  1.20s/it, loss=0.00388]


Epoch 150 of 200


1it [00:01,  1.10s/it, loss=0.00403]


Epoch 151 of 200


1it [00:01,  1.09s/it, loss=0.00402]


Epoch 152 of 200


1it [00:01,  1.10s/it, loss=0.00392]


Epoch 153 of 200


1it [00:01,  1.13s/it, loss=0.00393]


Epoch 154 of 200


1it [00:01,  1.05s/it, loss=0.00393]


Epoch 155 of 200


1it [00:01,  1.06s/it, loss=0.00388]


Epoch 156 of 200


1it [00:01,  1.07s/it, loss=0.00403]


Epoch 157 of 200


1it [00:01,  1.07s/it, loss=0.004]


Epoch 158 of 200


1it [00:01,  1.04s/it, loss=0.00388]


Epoch 159 of 200


1it [00:01,  1.07s/it, loss=0.00389]


Epoch 160 of 200


1it [00:01,  1.06s/it, loss=0.00401]


Epoch 161 of 200


1it [00:01,  1.08s/it, loss=0.00391]


Epoch 162 of 200


1it [00:01,  1.10s/it, loss=0.00403]


Epoch 163 of 200


1it [00:01,  1.19s/it, loss=0.0039]


Epoch 164 of 200


1it [00:01,  1.11s/it, loss=0.00404]


Epoch 165 of 200


1it [00:01,  1.23s/it, loss=0.00386]


Epoch 166 of 200


1it [00:01,  1.12s/it, loss=0.00386]


Epoch 167 of 200


1it [00:01,  1.08s/it, loss=0.00379]


Epoch 168 of 200


1it [00:01,  1.06s/it, loss=0.00378]


Epoch 169 of 200


1it [00:01,  1.03s/it, loss=0.00379]


Epoch 170 of 200


1it [00:01,  1.14s/it, loss=0.00372]


Epoch 171 of 200


1it [00:01,  1.08s/it, loss=0.00369]


Epoch 172 of 200


1it [00:01,  1.06s/it, loss=0.0038]


Epoch 173 of 200


1it [00:00,  1.05it/s, loss=0.00372]


Epoch 174 of 200


1it [00:00,  1.01it/s, loss=0.00371]


Epoch 175 of 200


1it [00:00,  1.01it/s, loss=0.00365]


Epoch 176 of 200


1it [00:00,  1.07it/s, loss=0.0037]


Epoch 177 of 200


1it [00:00,  1.06it/s, loss=0.00378]


Epoch 178 of 200


1it [00:00,  1.01it/s, loss=0.00377]


Epoch 179 of 200


1it [00:00,  1.05it/s, loss=0.00378]


Epoch 180 of 200


1it [00:00,  1.04it/s, loss=0.00378]


Epoch 181 of 200


1it [00:00,  1.04it/s, loss=0.00363]


Epoch 182 of 200


1it [00:00,  1.04it/s, loss=0.00382]


Epoch 183 of 200


1it [00:00,  1.04it/s, loss=0.00367]


Epoch 184 of 200


1it [00:00,  1.03it/s, loss=0.00372]


Epoch 185 of 200


1it [00:00,  1.01it/s, loss=0.00375]


Epoch 186 of 200


1it [00:01,  1.61s/it, loss=0.00369]


Epoch 187 of 200


1it [00:01,  1.12s/it, loss=0.00373]


Epoch 188 of 200


1it [00:01,  1.08s/it, loss=0.00361]


Epoch 189 of 200


1it [00:01,  1.12s/it, loss=0.00377]


Epoch 190 of 200


1it [00:01,  1.01s/it, loss=0.00359]


Epoch 191 of 200


1it [00:01,  1.07s/it, loss=0.00371]


Epoch 192 of 200


1it [00:01,  1.13s/it, loss=0.00365]


Epoch 193 of 200


1it [00:01,  1.14s/it, loss=0.00368]


Epoch 194 of 200


1it [00:01,  1.03s/it, loss=0.00363]


Epoch 195 of 200


1it [00:01,  1.01s/it, loss=0.00365]


Epoch 196 of 200


1it [00:01,  1.05s/it, loss=0.00363]


Epoch 197 of 200


1it [00:01,  1.17s/it, loss=0.00352]


Epoch 198 of 200


1it [00:01,  1.22s/it, loss=0.00368]


Epoch 199 of 200


1it [00:01,  1.27s/it, loss=0.00368]


Epoch 200 of 200


1it [00:01,  1.30s/it, loss=0.00359]


In [26]:
def test_inference():

    image = Image.open("data/img_align_celeba/000001.jpg")
    transform=transforms.ToTensor()

    encodings = []
    with torch.no_grad():
        mu, sigma = model.encode(transform(image).unsqueeze(0))
        encodings.append((mu, sigma))

    mu, sigma = encodings[0]

    epsilon = torch.randn_like(sigma)
    z = mu + sigma * epsilon
    out = model.decode(z)
    out = out.view(-1, 3, 224, 192)
    save_image(out, f"generated_ex.png")

In [27]:
test_inference()