In [None]:
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
%pip install torchinfo torch-summary matplotlib scipy

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary


class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=stride, bias=False
                )
            )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        identity = x

        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        identity = self.shortcut(identity)

        out += identity
        out = F.relu(out)
        return out


summary(
    BasicBlock(16, 32, 1),
    input_size=(37, 16, 128, 128),
    col_names=[
        "input_size",
        "kernel_size",
        "mult_adds",
        "num_params",
        "output_size",
        "trainable",
    ],
)


Layer (type:depth-idx)                   Input Shape               Kernel Shape              Mult-Adds                 Param #                   Output Shape              Trainable
BasicBlock                               [37, 16, 128, 128]        --                        --                        --                        [37, 32, 128, 128]        True
├─Conv2d: 1-1                            [37, 16, 128, 128]        [3, 3]                    2,793,406,464             4,608                     [37, 32, 128, 128]        True
├─BatchNorm2d: 1-2                       [37, 32, 128, 128]        --                        2,368                     64                        [37, 32, 128, 128]        True
├─Conv2d: 1-3                            [37, 32, 128, 128]        [3, 3]                    5,586,812,928             9,216                     [37, 32, 128, 128]        True
├─BatchNorm2d: 1-4                       [37, 32, 128, 128]        --                        2,368                 

In [19]:
from functools import reduce
import torchsummary


class Encoder(nn.Module):
    def __init__(self, input_size, latent_dim):
        super(Encoder, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(
                input_size[0], 64, kernel_size=7, stride=2, padding=3, bias=False
            ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        self.residual_blocks = nn.Sequential(
            BasicBlock(64, 64),
            BasicBlock(64, 64),
            BasicBlock(64, 128, stride=2),
            BasicBlock(128, 128),
            BasicBlock(128, 256, stride=2),
            BasicBlock(256, 256),
            BasicBlock(256, 512, stride=2),
            BasicBlock(512, 512),
        )

        self.input_size_to_fc = (
            torchsummary.summary(
                nn.Sequential(self.conv, self.residual_blocks), input_size, verbose=0
            )
            .summary_list[-1]
            .output_size[1:]
        )
        self.inputs_to_fc = reduce(lambda x, y: x * y, self.input_size_to_fc)

        self.fc_mu = nn.Linear(self.inputs_to_fc, latent_dim)
        self.fc_log_var = nn.Linear(self.inputs_to_fc, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        x = self.residual_blocks(x)
        x = x.view(x.size(0), -1)
        mu = self.fc_mu(x)
        log_var = self.fc_log_var(x)
        return mu, log_var


encoder = Encoder((1, 128, 128), 128)
summary(
    encoder,
    input_size=(37, 1, 128, 128),
    col_names=[
        "input_size",
        "kernel_size",
        "mult_adds",
        "num_params",
        "output_size",
        "trainable",
    ],
)


Layer (type:depth-idx)                   Input Shape               Kernel Shape              Mult-Adds                 Param #                   Output Shape              Trainable
Encoder                                  [37, 1, 128, 128]         --                        --                        --                        [37, 128]                 True
├─Sequential: 1-1                        [37, 1, 128, 128]         --                        --                        --                        [37, 64, 32, 32]          True
│    └─Conv2d: 2-1                       [37, 1, 128, 128]         [7, 7]                    475,267,072               3,136                     [37, 64, 64, 64]          True
│    └─BatchNorm2d: 2-2                  [37, 64, 64, 64]          --                        4,736                     128                       [37, 64, 64, 64]          True
│    └─ReLU: 2-3                         [37, 64, 64, 64]          --                        --                    

In [20]:
class ReverseBasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ReverseBasicBlock, self).__init__()

        self.conv1 = nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=3,
            stride=stride,
            padding=1,
            output_padding=stride - 1,
            bias=False,
        )
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.ConvTranspose2d(
            out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False
        )
        self.bn2 = nn.BatchNorm2d(out_channels)

        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.ConvTranspose2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=stride,
                    output_padding=stride - 1,
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels),
            )
        else:
            self.shortcut = nn.Sequential()

    def forward(self, x):
        identity = x
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))

        identity = self.shortcut(identity)

        out += identity
        out = F.relu(out)
        return out


summary(
    ReverseBasicBlock(16, 32, 1),
    input_size=(37, 16, 128, 128),
    col_names=[
        "input_size",
        "kernel_size",
        "mult_adds",
        "num_params",
        "output_size",
        "trainable",
    ],
)


Layer (type:depth-idx)                   Input Shape               Kernel Shape              Mult-Adds                 Param #                   Output Shape              Trainable
ReverseBasicBlock                        [37, 16, 128, 128]        --                        --                        --                        [37, 32, 128, 128]        True
├─ConvTranspose2d: 1-1                   [37, 16, 128, 128]        [3, 3]                    2,793,406,464             4,608                     [37, 32, 128, 128]        True
├─BatchNorm2d: 1-2                       [37, 32, 128, 128]        --                        2,368                     64                        [37, 32, 128, 128]        True
├─ConvTranspose2d: 1-3                   [37, 32, 128, 128]        [3, 3]                    5,586,812,928             9,216                     [37, 32, 128, 128]        True
├─BatchNorm2d: 1-4                       [37, 32, 128, 128]        --                        2,368                 

In [21]:
class Decoder(nn.Module):
    def __init__(self, latent_dim, size_from_fc, out_channels):
        super(Decoder, self).__init__()

        self.size_from_fc = size_from_fc

        self.fc = nn.Linear(latent_dim, reduce(lambda x, y: x * y, size_from_fc))

        self.residual_blocks = nn.Sequential(
            ReverseBasicBlock(512, 256, stride=2),
            ReverseBasicBlock(256, 256),
            ReverseBasicBlock(256, 128, stride=2),
            ReverseBasicBlock(128, 128),
            ReverseBasicBlock(128, 64, stride=2),
            ReverseBasicBlock(64, 64),
            ReverseBasicBlock(64, 64, stride=2),
            ReverseBasicBlock(64, 64),
        )

        self.conv_transpose = nn.Sequential(
            nn.ConvTranspose2d(
                64,
                out_channels,
                kernel_size=7,
                stride=2,
                padding=3,
                output_padding=1,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = self.fc(x)
        x = x.view(
            x.size(0), self.size_from_fc[0], self.size_from_fc[1], self.size_from_fc[2]
        )
        x = self.residual_blocks(x)
        x = self.conv_transpose(x)
        return x


decoder = Decoder(128, encoder.input_size_to_fc, 1)
# torchsummary.summary(decoder, input_size=(128,), verbose=0)
summary(
    decoder,
    input_size=(
        37,
        128,
    ),
    col_names=[
        "input_size",
        "kernel_size",
        "mult_adds",
        "num_params",
        "output_size",
        "trainable",
    ],
)


Layer (type:depth-idx)                        Input Shape               Kernel Shape              Mult-Adds                 Param #                   Output Shape              Trainable
Decoder                                       [37, 128]                 --                        --                        --                        [37, 1, 128, 128]         True
├─Linear: 1-1                                 [37, 128]                 --                        39,100,416                1,056,768                 [37, 8192]                True
├─Sequential: 1-2                             [37, 512, 4, 4]           --                        --                        --                        [37, 64, 64, 64]          True
│    └─ReverseBasicBlock: 2-1                 [37, 512, 4, 4]           --                        --                        --                        [37, 256, 8, 8]           True
│    │    └─ConvTranspose2d: 3-1              [37, 512, 4, 4]           [3, 3]            

In [22]:
class VAE(nn.Module):
    def __init__(self, input_size, latent_dim):
        super(VAE, self).__init__()

        self.input_size = input_size
        self.latent_dim = latent_dim

        self.encoder = Encoder(input_size, latent_dim)
        self.decoder = Decoder(latent_dim, self.encoder.input_size_to_fc, input_size[0])

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, log_var = self.encoder(x)
        z = self.reparameterize(mu, log_var)
        x_recon = self.decoder(z)
        return x_recon, mu, log_var


vae = VAE((1, 224, 224), 128)
# summary(vae, input_size=(37, 1, 128, 128))
summary(
    vae,
    input_size=(37, 1, 224, 224),
    depth=10,
    col_names=[
        "input_size",
        "kernel_size",
        "mult_adds",
        "num_params",
        "output_size",
        "trainable",
    ],
)


Layer (type:depth-idx)                             Input Shape               Kernel Shape              Mult-Adds                 Param #                   Output Shape              Trainable
VAE                                                [37, 1, 224, 224]         --                        --                        --                        [37, 1, 224, 224]         True
├─Encoder: 1-1                                     [37, 1, 224, 224]         --                        --                        --                        [37, 128]                 True
│    └─Sequential: 2-1                             [37, 1, 224, 224]         --                        --                        --                        [37, 64, 56, 56]          True
│    │    └─Conv2d: 3-1                            [37, 1, 224, 224]         [7, 7]                    1,455,505,408             3,136                     [37, 64, 112, 112]        True
│    │    └─BatchNorm2d: 3-2                       [37, 64, 112, 

In [23]:
def vae_loss(recon_x, x, mu, log_var, beta=1.0):
    recon_loss = F.binary_cross_entropy(recon_x, x, reduction="mean")
    recon_loss += F.mse_loss(recon_x, x)
    kld_loss = -0.5 * torch.mean(1 + log_var - mu.pow(2) - log_var.exp())
    return recon_loss, kld_loss, recon_loss + beta * kld_loss


In [24]:
from pathlib import Path
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.transforms.functional as TF

dataset_name = "cxr8"

base_data_path = Path("D:\\") / "data"

imagenet_data_path = (
    base_data_path
    / "imagenet-object-localization-challenge"
    / "ILSVRC"
    / "Data"
    / "CLS-LOC"
)

cxr8_data_path = base_data_path / "cxr8"


if dataset_name == "cxr8":
    train_dataset = datasets.ImageFolder(
        cxr8_data_path,
        transform=transforms.Compose(
            [
                transforms.Resize((448, 448)),
                transforms.Grayscale(),
                transforms.RandomEqualize(1.0),
                transforms.ToTensor(),
            ]
        ),
    )
elif dataset_name == "cifar10":
    train_dataset = datasets.CIFAR10(
        root="./data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize((224, 224)), transforms.ToTensor()]
        ),
    )
elif dataset_name == "imagenet":
    train_dataset = datasets.ImageFolder(
        imagenet_data_path,
        transform=transforms.Compose(
            [transforms.Resize((224, 224)), transforms.ToTensor()]
        ),
    )

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)


In [25]:
device = torch.device("cuda" if torch.cuda.is_available else "cpu")

latent_dim = 64
model = VAE((1 if dataset_name == "cxr8" else 3, 448, 448), latent_dim).to(device)
print(
    summary(
        model,
        input_size=(37, 1, 448, 448),
        depth=10,
        col_names=[
            "input_size",
            "kernel_size",
            "mult_adds",
            "num_params",
            "output_size",
            "trainable",
        ],
    )
)

import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=1e-4)


Layer (type:depth-idx)                             Input Shape               Kernel Shape              Mult-Adds                 Param #                   Output Shape              Trainable
VAE                                                [37, 1, 448, 448]         --                        --                        --                        [37, 1, 448, 448]         True
├─Encoder: 1-1                                     [37, 1, 448, 448]         --                        --                        --                        [37, 64]                  True
│    └─Sequential: 2-1                             [37, 1, 448, 448]         --                        --                        --                        [37, 64, 112, 112]        True
│    │    └─Conv2d: 3-1                            [37, 1, 448, 448]         [7, 7]                    5,822,021,632             3,136                     [37, 64, 224, 224]        True
│    │    └─BatchNorm2d: 3-2                       [37, 64, 224, 

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
from IPython.display import clear_output, display

num_epochs = 10

fig, ax = plt.subplots(2, 10, figsize=(20,5))

for epoch in range(num_epochs):
    model.train()
    train_loss = 0
    train_count = 0
    for batch_idx, (data, _) in enumerate(train_loader):

        data = data.to(device)
        optimizer.zero_grad()

        recon_batch, mu, log_var = model(data)
        recon_loss, kldiv_loss, loss = vae_loss(
            recon_batch, data, mu, log_var, beta=0.01
        )

        if train_count % 10 == 0:
            orig_data = data.clone().cpu().detach().numpy()
            recon_data = recon_batch.cpu().detach().numpy()

            # print(v.shape)
            for n in range(10):
                ax[0][n].imshow(torch.movedim(torch.tensor(orig_data[n]), 0, -1), cmap='bone')
                ax[1][n].imshow(torch.movedim(torch.tensor(recon_data[n]), 0, -1), cmap='bone')

            clear_output(wait=True)

            display(plt.gcf())

        loss.backward()
        train_loss += loss.item()
        train_count += 1.0
        optimizer.step()

        print(
            f"Epoch [{epoch+1}/{num_epochs}], Batch: {batch_idx}, Loss: {train_loss / train_count:.6f} ({recon_loss:.4f}/{kldiv_loss:.4f})"
        )
        