In [1]:
import pytorch_lightning as pl
from matplotlib import pyplot
import os
import numpy as np
import random
import torchvision
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import urllib.parse
from torch.nn import functional as F
from torch import nn
from pytorch_lightning import LightningModule, Trainer, seed_everything
import torch
import matplotlib as plt
from pytorch_lightning.loggers import TensorBoardLogger



In [2]:
class View(nn.Module):
    def __init__(self, shape):
        super().__init__()
        self.shape = shape

    def forward(self, x):
        return x.view(*self.shape)

In [3]:
def create_dataset(path_to_data, batch_size=32, crop_size=224, num_of_channels=1):
    transform = transforms.Compose(
        [
            transforms.RandomRotation(degrees=(0, 360)),
            transforms.Resize(crop_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            transforms.Grayscale(num_output_channels=num_of_channels),
        ]
    )
    dataset = ImageFolder(root=path_to_data, transform=transform)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        pin_memory=True,
        num_workers=2,
    )
    return dataset, dataloader

In [4]:
class VAE(LightningModule):

    def __init__(
        self,
        input_height: int,
        enc_type: str = "default",
        first_conv: bool = False,
        maxpool1: bool = False,
        enc_out_dim: int = 5488,
        kl_coeff: float = 0.01,
        latent_dim: int = 5488,
        lr: float = 1e-3,
        **kwargs,
    ):
        super().__init__()

        self.save_hyperparameters()
        self.lr = lr
        self.kl_coeff = kl_coeff
        self.enc_out_dim = enc_out_dim
        self.latent_dim = latent_dim
        self.input_height = input_height

        if enc_type == "default":
            if input_height == 224:
                self.encoder = nn.Sequential(
                    nn.Conv2d(1, 224, (3,3), stride=(1,1), padding=(1,1)),  # 224 x 224 x 224
                    nn.ReLU(True),
                    nn.MaxPool2d(2),  # 224 x 112 x 112
        
                    nn.Conv2d(224, 112, (3,3), stride=(1,1), padding=(1,1)),  # 112 x 112 x 112
                    nn.ReLU(True),
                    nn.MaxPool2d(2),  # 112 x 56 x 56
        
                    nn.Conv2d(112, 56, (3,3), stride=(1,1), padding=(1,1)),  # 56 x 56 x 56
                    nn.ReLU(True),
                    nn.MaxPool2d(2),  # 56 x 28 x 28
        
                    nn.Conv2d(56, 7, (3, 3), stride=(1,1), padding=(1,1)),  # 7 x 28 x 28
                    nn.Flatten()  # 5488 x 1 x 1
                )
                 
                self.decoder = nn.Sequential(    
                    View([-1, 7, 28, 28]),
                    nn.ConvTranspose2d(7, 56, (3, 3), stride=(1,1), padding=(1,1)),  # 32 x 32 x 32
                    nn.ReLU(True),
                    nn.Upsample(scale_factor=2),  # 32 x 64 x 64
                
                    nn.ConvTranspose2d(56, 112, (3,3), stride=(1,1), padding=(1,1)),  # 64 x 64 x 64
                    nn.ReLU(True),
                    nn.Upsample(scale_factor=2),  # 64 x 128 x 128
        
                    nn.ConvTranspose2d(112, 224, (3,3), stride=(1,1), padding=(1,1)),  # 128 x 128 x 128
                    nn.ReLU(True),
                    nn.Upsample(scale_factor=2),  # 64 x 128 x 128
        
                    nn.ConvTranspose2d(224, 1, (3,3), stride=(1,1), padding=(1,1)),
                    nn.Sigmoid()
                )
        self.fc_mu = nn.Linear(self.enc_out_dim, self.latent_dim)
        self.fc_var = nn.Linear(self.enc_out_dim, self.latent_dim)

    @staticmethod
    def pretrained_weights_available():
        return list(VAE.pretrained_urls.keys())

    def from_pretrained(self, checkpoint_name):
        if checkpoint_name not in VAE.pretrained_urls:
            raise KeyError(str(checkpoint_name) + " not present in pretrained weights.")

        return self.load_from_checkpoint(VAE.pretrained_urls[checkpoint_name], strict=False)

    def forward(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        p, q, z = self.sample(mu, log_var)
        return self.decoder(z)

    def _run_step(self, x):
        x = self.encoder(x)
        mu = self.fc_mu(x)
        log_var = self.fc_var(x)
        p, q, z = self.sample(mu, log_var)
        return z, self.decoder(z), p, q
    
    def sample(self, mu, log_var):
        std = torch.exp(log_var / 2)
        p = torch.distributions.Normal(torch.zeros_like(mu), torch.ones_like(std))
        q = torch.distributions.Normal(mu, std)
        z = q.rsample()
        return p, q, z

    def step(self, batch, batch_idx):
        x, y = batch
        z, x_hat, p, q = self._run_step(x)
        ls = nn.BCELoss()
        recon_loss = ls(x_hat, x)
        kl = torch.distributions.kl_divergence(q, p)
        kl = kl.mean()
        kl *= self.kl_coeff

        loss = recon_loss + kl 

        logs = {
            "recon_loss": recon_loss,
            "kl": kl,
            "loss": loss,
        }
        return loss, logs

    def training_step(self, batch, batch_idx):
        loss, logs = self.step(batch, batch_idx)
        self.log_dict({f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False)
        return loss

    def validation_step(self, batch, batch_idx):
        loss, logs = self.step(batch, batch_idx)
        self.log_dict({f"val_{k}": v for k, v in logs.items()})
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.lr)

In [5]:
batch_size = 8
crop_size = 224
num_of_channels = 1
path_to_data = "images/"

In [6]:
dataset, dataloader = create_dataset(
    path_to_data, batch_size, crop_size, num_of_channels
)

In [7]:
class GenerateCallback(pl.Callback):
    def __init__(self, input_imgs, run_id="", every_n_epochs=1):
        super().__init__()
        self.input_imgs = input_imgs  # Images to reconstruct during training
        # Only save those images every N epochs (otherwise tensorboard gets quite large)
        self.every_n_epochs = every_n_epochs
        self.save_path = "Results_VAE\\{r}\\".format(r=run_id)

    def on_train_epoch_end(self, trainer, pl_module, save_to="Results_VAE/"):
        if trainer.current_epoch % self.every_n_epochs == 0:
            # Reconstruct images
            save_to = self.save_path + "epoch_{e}\\".format(e=trainer.current_epoch)
            input_imgs = self.input_imgs.to(pl_module.device)
            with torch.no_grad():
                pl_module.eval()
                reconst_imgs = pl_module(input_imgs)
                pl_module.train()
            # Plot and add to tensorboard
            imgs = torch.stack([input_imgs, reconst_imgs], dim=1).flatten(0, 1)
            grid = torchvision.utils.make_grid(
                imgs, nrow=2, normalize=True, range=(-1, 1)
            )
            trainer.logger.experiment.add_image(
                "Reconstructions", grid, global_step=trainer.global_step
            )
            for i in range(len(input_imgs)):
                pyplot.figure()
                pyplot.subplot(121)
                pyplot.imshow(input_imgs[i].cpu().detach().numpy()[0], cmap="gray")
                pyplot.title("original")
                pyplot.subplot(122)
                pyplot.imshow(reconst_imgs[i].cpu().detach().numpy()[0], cmap="gray")
                pyplot.title("reconstruction")

                if save_to is not None:
                    if not os.path.exists(save_to + "Images/"):
                        os.makedirs(save_to + "Images/")
                    pyplot.savefig(save_to + "Images/{}.pdf".format(i))
                else:
                    pyplot.show()
            pyplot.close("all")

In [8]:
def get_train_images(num, rand=True):
    if rand == True:
        return torch.stack(
            [dataset[i][0] for i in random.sample(range(0, len(dataset)), num)], dim=0
        )
    else:
        return torch.stack([dataset[i][0] for i in range(num)], dim=0)

In [9]:
def Train_AE(run_id):
    vae = VAE(crop_size, lr=1e-3, kl_coeff=0.05)
    logger = TensorBoardLogger("Results_VAE", name=run_id)
    trainer = pl.Trainer(
        logger=logger,
        max_epochs=300,
        callbacks=[
            GenerateCallback(get_train_images(8), every_n_epochs=1, run_id=run_id)
        ],
        log_every_n_steps=5,
    )
    return trainer, vae

In [10]:
trainer, ae = Train_AE("Results_VAE")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
trainer.fit(ae, dataloader)

  rank_zero_warn(
You are using a CUDA device ('NVIDIA GeForce RTX 3050 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 288 K 
1 | decoder | Sequential | 288 K 
2 | fc_mu   | Linear     | 30.1 M
3 | fc_var  | Linear     | 30.1 M
---------------------------------------
60.8 M    Trainable params
0         Non-trainable params
60.8 M    Total params
243.294   Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

: 

: 