In [4]:
!pip install pytorch-lightning --quiet

In [1]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [39]:
import cv2
import glob
from torch.utils.data import DataLoader, Dataset

class ImageDataset(Dataset):
    def __init__(self, data_path):
        self.data_path = data_path
    
    def __len__(self):
        return len(self.data_path)

    def __getitem__(self, idx):
        image = cv2.imread(self.data_path[idx], cv2.IMREAD_COLOR)
        image = cv2.resize(image, (64, 64))
        image = torch.Tensor(image).moveaxis(-1, 0)

        return image

flist = glob.glob("./drive/MyDrive/img_align_celeba/*.jpg")
train_flist = flist[:800]
valid_flist = flist[800:]

train_image_dataset = ImageDataset(train_flist)
valid_image_dataset = ImageDataset(valid_flist)

train_data_loader = DataLoader(train_image_dataset, batch_size=16, 
                               num_workers=2, pin_memory=True, shuffle=True)

valid_data_loader = DataLoader(valid_image_dataset, batch_size=16, 
                               num_workers=2, pin_memory=True, shuffle=False)

In [None]:
import torch
from torch import nn
from torch.nn import functional as F
from abc import abstractmethod

# https://github.com/AntixK/PyTorch-VAE/blob/master/run.py
class VanillaVAE(nn.Module):
    def __init__(self,
                 in_channels,
                 latent_dim,
                 hidden_dims=None,
                 **kwargs):
        super().__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, out_channels=h_dim,
                              kernel_size= 3, stride= 2, padding  = 1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU())
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1]*4, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1]*4, latent_dim)


        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 4)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=1),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU())
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
                            nn.ConvTranspose2d(hidden_dims[-1],
                                               hidden_dims[-1],
                                               kernel_size=3,
                                               stride=2,
                                               padding=1,
                                               output_padding=1),
                            nn.BatchNorm2d(hidden_dims[-1]),
                            nn.LeakyReLU(),
                            nn.Conv2d(hidden_dims[-1], out_channels= 3,
                                      kernel_size= 3, padding= 1),
                            nn.Tanh())

    def encode(self, input):
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        result = torch.flatten(result, start_dim=1)

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z):
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu, log_var):

        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        
        return eps * std + mu

    def forward(self, input **kwargs):
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)

        return  z, self.decode(z), mu, log_var


model = VanillaVAE(in_channels=3, latent_dim=128, hidden_dims=None)
model

In [50]:
import pytorch_lightning as pl

def compute_loss(recon_image, orig_image, mu, log_var, kld_weight=0.00025):
    recon_loss = F.mse_loss(recon_image, orig_image)
    kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim = 1), dim = 0)

    total_loss = recon_loss + kld_loss * kld_weight

    return total_loss


class VAEXperiment(pl.LightningModule):
    def __init__(self, vae_model, params=None):
        super().__init__()

        self.model = vae_model


    def forward(self, input):
        return self.model(input)


    def training_step(self, batch, batch_idx):
        real_img = batch

        z, recon_image, mu, log_var = self.forward(real_img)
        train_loss = compute_loss(recon_image, real_img, mu, log_var)

        return train_loss


    def validation_step(self, batch, batch_idx):
        real_img = batch

        z, recon_image, mu, log_var = self.forward(real_img)
        valid_loss = compute_loss(recon_image, real_img, mu, log_var)

        return valid_loss


    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.model.parameters(), lr=0.005)
        
        return optimizer

In [51]:
experiment = VAEXperiment(model)
runner = pl.Trainer(max_epochs=10)
runner.fit(experiment, train_data_loader, valid_data_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: /content/lightning_logs

  | Name  | Type       | Params
-------------------------------------
0 | model | VanillaVAE | 3.9 M 
-------------------------------------
3.9 M     Trainable params
0         Non-trainable params
3.9 M     Total params
15.751    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
