In [4]:
# Source https://github.com/ml-postech/GM-VAE

In [2]:
import torch
from torch import nn
from math import log
from torch.nn import functional as F

In [3]:
torch.cuda.is_available()

True

In [4]:
# create false image
image = torch.randint(0, 255, (3, 3, 512, 512), dtype=torch.float32)

## Encoder

In [9]:
# not from gm-vae, just want to uinderstand it better since every task has different encoder and decoder layers
class Encoder(nn.Module):
    def __init__(self, kernel:int=3, stride:int=1, padding:int=0) -> None:
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.BatchNorm2d(3),
            nn.Conv2d(3, 32, kernel, stride, padding), # input image will be bx3x512x512, output bx32x508x508
            nn.AvgPool2d(3, 2), # 2x2 avg pooling, TODO: decide if want to use max pooling output bx32x253x253
            nn.ReLU(), # output bx32x253x253
            nn.Conv2d(32, 64, kernel, stride, padding), # input bx32x253x253, output bx64x251x251
            nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx64x125x125
            nn.ReLU(), # output bx64x125x125
            nn.Conv2d(64, 64, kernel, stride, padding), # input bx64x125x125, output bx64x123x123
            nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx64x61x61
            nn.ReLU(), # output bx64x61x61
            nn.Flatten(), # output bx64*61*61 
        )

        self.layer_mean = nn.Linear(64*61*61, 2048) 
        self.layer_logvar = nn.Linear(64*61*61, 2048)   


        """
        Calculating the output size of the encoder
        [(W-F+2P)/S + 1]. [(512-3+0)/1 + 1] = 510 conv1
        [(W-F)/S] + 1]. [(510-3)/2 + 1] = 254 pool1
        [(W-F+2P)/S + 1]. [(253-3+)/1 + 1] = 252 conv2
        [(W-F)/S] + 1]. [(251-3)/2 + 1] = 125 pool2
        [(W-F+2P)/S + 1]. [(125-3+)/1 + 1] = 123 conv3
        [(W-F)/S] + 1]. [(123-3)/2 + 1] = 61 pool3
        """

    def resize(self, x: torch.Tensor) -> torch.Tensor:
        # resize image to 3x512x512
        x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=True)
        return x

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x = self.resize(x)
        x = self.encoder(x)
        mean = self.layer_mean(x)
        logvar = self.layer_logvar(x)

        return x, mean.unsqueeze(1), logvar.unsqueeze(1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encode(x)




In [5]:
device = torch.device('cuda')

In [6]:
device

device(type='cuda')

In [9]:
encoder = Encoder()

In [None]:
import numpy as np
np.argmax(image)

In [None]:
image.shape

In [None]:
image[0][:, 0, 0]

In [None]:
output_image = encoder.encode(image)
# [(W-F+2P)/S + 1]. [(512-3+0)/1 + 1] = 510 conv1
# [(W-F)/S] + 1]. [(510-3)/2 + 1] = 254 pool1
# [(W-F+2P)/S + 1]. [(253-3+)/1 + 1] = 252 conv2
# [(W-F)/S] + 1]. [(251-3)/2 + 1] = 125 pool2
# [(W-F+2P)/S + 1]. [(125-3+0)/1 + 1] = 123 conv3
# [(W-F)/S] + 1]. [(123-3)/2 + 1] = 61 pool3
# [(W-F+2P)/S + 1]. [(61-3+0)/1 + 1] = 59 conv4
# [(W-F)/S] + 1]. [(59-3)/2 + 1] = 29 pool4


In [11]:
encoder

Encoder(
  (encoder): Sequential(
    (0): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1))
    (2): AvgPool2d(kernel_size=3, stride=2, padding=0)
    (3): ReLU()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): AvgPool2d(kernel_size=3, stride=2, padding=0)
    (6): ReLU()
    (7): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (8): AvgPool2d(kernel_size=3, stride=2, padding=0)
    (9): ReLU()
    (10): Flatten(start_dim=1, end_dim=-1)
  )
  (layer_mean): Linear(in_features=968256, out_features=2048, bias=True)
  (layer_logvar): Linear(in_features=968256, out_features=2048, bias=True)
)

In [None]:
len(output_image)

In [None]:
output_image[1].shape

In [None]:
output_image[0].shape

In [None]:
output_image

In [None]:
np.argmax(output_image.detach().numpy())

In [None]:
output_image[0, 30200]

In [None]:
output_image

# Start of autoencoder class

In [11]:
class VAE(nn.Module):
    def __init__(self) -> None:
        super(VAE, self).__init__()
        self.encoder = Encoder()
        self.decoder = None

    def reparameterize(self, mean: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        # log var = (log var^2) / 2
        # e ^ log var = e^(log var^2 / 2)
        # e ^ log var = e^ (2 * (log var) / 2)
        # e ^ log var = e^ (log var)
        # e ^ log var = var
        # var = var (this is why we use logvar instead of variation, it makes it possible for back propagation to know if a number is negative or positive)

        eps = 0.5
        z = mean + eps * torch.exp(logvar)
        return z

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean, logvar = self.encoder(x)
        z = self.reparameterize(mean, logvar)
        return self.decoder(z), mean, logvar

In [13]:
vae = VAE()

In [15]:
x, mean, logvar = vae.encoder.encode(image)

In [16]:
print(f"x shape: {x.shape}, mean shape: {mean.shape}, logvar shape: {logvar.shape}")
print(f"type of x: {type(x)}, type of mean: {type(mean)}, type of logvar: {type(logvar)}")

x shape: torch.Size([3, 238144]), mean shape: torch.Size([3, 1, 2048]), logvar shape: torch.Size([3, 1, 2048])
type of x: <class 'torch.Tensor'>, type of mean: <class 'torch.Tensor'>, type of logvar: <class 'torch.Tensor'>


In [17]:
z = vae.reparameterize(mean, logvar)

In [18]:
print(f"z shape: {z.shape}")

z shape: torch.Size([3, 1, 2048])


## Decoder

In [31]:
class Decoder(nn.Module):
    def __init__(self, kernel:int=3, stride:int=1, padding:int=0) -> None:
        super(Decoder, self).__init__()
        self.kernel = kernel
        self.stride = stride
        self.padding = padding

        self.decoder = nn.Sequential(
            nn.Linear(2048, 64*61*61),
            nn.ReLU(),  # Activation function
            nn.Unflatten(dim=1, unflattened_size=(64, 61, 61)),  # Reshape to [batch_size, 128, 3, 3]
            nn.ConvTranspose2d(64, 64, kernel, stride, padding),  # Reverse the convolutions of the encoder
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel, stride, padding),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel, stride, padding),  # Adjusted to match the original image channels
            nn.Sigmoid() 
        )

            # nn.Conv2d(3, 32, kernel, stride, padding), # input image will be bx3x512x512, output bx32x508x508
            # nn.AvgPool2d(3, 2), # 2x2 avg pooling, TODO: decide if want to use max pooling output bx32x253x253
            # nn.ReLU(), # output bx32x253x253
            # nn.Conv2d(32, 64, kernel, stride, padding), # input bx32x253x253, output bx64x251x251
            # nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx64x125x125
            # nn.ReLU(), # output bx64x125x125
            # nn.Conv2d(64, 64, kernel, stride, padding), # input bx64x125x125, output bx64x123x123
            # nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx64x61x61
            # nn.ReLU(), # output bx64x61x61
            # nn.Flatten(), # output bx64*61*61 

    def decode(self, r_samples: torch.Tensor) -> torch.Tensor:
        linear = nn.Linear(2048, 64*61*61)
        relu1 = nn.ReLU()  # Activation function
        unflatten = nn.Unflatten(dim=1, unflattened_size=(64, 3, 3))  # Reshape to [batch_size, 128, 3, 3
        conv1 = nn.ConvTranspose2d(64, 64, self.kernel, self.stride, self.padding)  # Reverse the convolutions of the encode
        relu2 = nn.ReLU()
        conv2 = nn.ConvTranspose2d(64, 32, self.kernel, self.stride, self.padding)
        relu3 = nn.ReLU()
        conv3 = nn.ConvTranspose2d(32, 3, self.kernel, self.stride, self.padding)  # Adjusted to match the original image channel
        sigmoid = nn.Sigmoid()

        print(f"r_samples shape: {r_samples.shape}")
        x = linear(r_samples)
        print(f"x shape: {x.shape}")
        x = relu1(x)
        print(f"x shape: {x.shape}")
        x = unflatten(x)
        print(f"x shape: {x.shape}")
        x = conv1(x)
        print(f"x shape: {x.shape}")
        x = relu2(x)
        print(f"x shape: {x.shape}")
        x = conv2(x)
        print(f"x shape: {x.shape}")
        x = relu3(x)
        print(f"x shape: {x.shape}")
        x = conv3(x)
        print(f"x shape: {x.shape}")
        x = sigmoid(x)
        print(f"x shape: {x.shape}")
        
        reconstructed_images = []
        for sample in r_samples:
            x = self.decoder(sample.squeeze())  # Remove the singleton dimensions
            reconstructed_images.append(x.unsqueeze(0))  # Add batch dimension back
        return torch.cat(reconstructed_images, dim=0)
    
    def forward(self, r_samples: torch.Tensor) -> torch.Tensor:
        return self.decode(r_samples)


In [32]:
decoder = Decoder()

In [35]:
decoder.decode(z)

r_samples shape: torch.Size([3, 1, 2048])
x shape: torch.Size([3, 1, 238144])
x shape: torch.Size([3, 1, 238144])


RuntimeError: unflatten: Provided sizes [64, 3, 3] don't multiply up to the size of dim 1 (1) in the input tensor

: 

In [None]:
        # self.decoder = nn.Sequential(
        #     nn.ConvTranspose2d(128, 64, kernel, stride, padding),
        #     nn.ReLU(),
        #     nn.ConvTranspose2d(64, 32, kernel, stride, padding),
        #     nn.ReLU(),
        #     nn.ConvTranspose2d(32, 3, kernel, stride, padding),
        #     nn.ReLU()
        # )

In [None]:
mean, dist = output_image

In [None]:
mean.shape

In [None]:
dist.shape

In [None]:
distribution_cal = Distribution(mean, dist)

In [None]:
r_samples = distribution_cal.rsample(1000)

In [None]:
r_samples.shape

In [None]:
decoder = Decoder()

In [None]:
mean, logvar, _ = encoder.encode(image)
print("Mean shape: ", mean.shape)
print("Logvar shape: ", logvar.shape)

In [None]:
print("RSample shape: ", r_samples.shape)

In [None]:
output_decoder = decoder.decode(r_samples)

In [None]:
output_decoder.shape

In [None]:
            # nn.BatchNorm2d(3),
            # nn.Conv2d(3, 32, kernel, stride, padding), # input image will be bx3x512x512, output bx32x508x508
            # nn.AvgPool2d(3, 2), # 2x2 avg pooling, TODO: decide if want to use max pooling output bx32x253x253
            # nn.ReLU(), # output bx32x253x253
            # nn.Conv2d(32, 64, kernel, stride, padding), # input bx32x253x253, output bx64x251x251
            # nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx64x125x125
            # nn.ReLU(), # output bx64x125x125
            # nn.Conv2d(64, 128, kernel, stride, padding), # input bx64x125x125, output bx128x123x123
            # nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx128x61x61
            # nn.ReLU(), # output bx128x61x61
            # nn.Flatten(), # output bx128*61*61
            # nn.Linear(128*61*61, 2* 128*3*3) # output bx3*128*61*61 

# pytorch transpose conv2d
# https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
# https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html



In [None]:
output_image.shape

In [None]:
(1, 3, 512, 512)