In [1]:
# Source https://github.com/ml-postech/GM-VAE

In [3]:
import torch
from torch import nn
from math import log
from torch.nn import functional as F

## Encoder

In [168]:
# not from gm-vae, just want to uinderstand it better since every task has different encoder and decoder layers
class Encoder(nn.Module):
    def __init__(self, kernel:int=3, stride:int=1, padding:int=0) -> None:
        super(Encoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.BatchNorm2d(3),
            nn.Conv2d(3, 32, kernel, stride, padding), # input image will be bx3x512x512, output bx32x508x508
            nn.AvgPool2d(3, 2), # 2x2 avg pooling, TODO: decide if want to use max pooling output bx32x253x253
            nn.ReLU(), # output bx32x253x253
            nn.Conv2d(32, 64, kernel, stride, padding), # input bx32x253x253, output bx64x251x251
            nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx64x125x125
            nn.ReLU(), # output bx64x125x125
            nn.Conv2d(64, 128, kernel, stride, padding), # input bx64x125x125, output bx128x123x123
            nn.AvgPool2d(3, 2), # 2x2 avg pooling, output bx128x61x61
            nn.ReLU(), # output bx128x61x61
            nn.Flatten(), # output bx128*61*61
            nn.Linear(128*61*61, 2* 128*3*3) # output bx3*128*61*61 
        )


        """
        Calculating the output size of the encoder
        [(W-F+2P)/S + 1]. [(512-3+0)/1 + 1] = 510 conv1
        [(W-F)/S] + 1]. [(510-3)/2 + 1] = 254 pool1
        [(W-F+2P)/S + 1]. [(253-3+)/1 + 1] = 252 conv2
        [(W-F)/S] + 1]. [(251-3)/2 + 1] = 125 pool2
        [(W-F+2P)/S + 1]. [(125-3+0)/1 + 1] = 123 conv3
        [(W-F)/S] + 1]. [(123-3)/2 + 1] = 61 pool3
        [(W-F+2P)/S + 1]. [(61-3+0)/1 + 1] = 59 conv4
        [(W-F)/S] + 1]. [(59-3)/2 + 1] = 29 pool4
        """

    def resize(self, x: torch.Tensor) -> torch.Tensor:
        # resize image to 3x512x512
        x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=True)
        return x


    def encode(self, x: torch.Tensor) -> torch.Tensor:
        x = self.resize(x)
        x = self.encoder(x)

        # split into mean and logvar
        mean, logvar = torch.chunk(x, 2, dim=1)

        return mean, logvar, x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.encode(x)




In [169]:
# create false image
image = torch.randint(0, 255, (1, 3, 512, 512), dtype=torch.float32)

In [170]:
encoder = Encoder()

In [171]:
import numpy as np
np.argmax(image)

tensor(69)

In [172]:
image.shape

torch.Size([1, 3, 512, 512])

In [173]:
image[0][:, 0, 0]

tensor([93., 46., 31.])

In [174]:
output_image = encoder.encode(image)
# [(W-F+2P)/S + 1]. [(512-3+0)/1 + 1] = 510 conv1
# [(W-F)/S] + 1]. [(510-3)/2 + 1] = 254 pool1
# [(W-F+2P)/S + 1]. [(253-3+)/1 + 1] = 252 conv2
# [(W-F)/S] + 1]. [(251-3)/2 + 1] = 125 pool2
# [(W-F+2P)/S + 1]. [(125-3+0)/1 + 1] = 123 conv3
# [(W-F)/S] + 1]. [(123-3)/2 + 1] = 61 pool3
# [(W-F+2P)/S + 1]. [(61-3+0)/1 + 1] = 59 conv4
# [(W-F)/S] + 1]. [(59-3)/2 + 1] = 29 pool4


In [176]:
output_image[2].shape

torch.Size([1, 2304])

In [177]:
output_image[1].shape

torch.Size([1, 1152])

In [178]:
output_image[0].shape

torch.Size([1, 1152])

In [158]:
output_image

(tensor([[ 0.0070, -0.0009, -0.0333,  ..., -0.0028,  0.0082, -0.0074]],
        grad_fn=<SplitBackward0>),
 tensor([[-0.0104,  0.0211, -0.0100,  ..., -0.0046,  0.0060,  0.0032]],
        grad_fn=<SplitBackward0>))

In [96]:
np.argmax(output_image.detach().numpy())

30200

In [99]:
output_image[0, 30200]

tensor(0.0947, grad_fn=<SelectBackward0>)

In [100]:
output_image

tensor([[0.0538, 0.0477, 0.0459,  ..., 0.0050, 0.0070, 0.0060]],
       grad_fn=<ViewBackward0>)

## Decoder

In [179]:
class Decoder(nn.Module):
    def __init__(self, kernel:int=3, stride:int=1, padding:int=0) -> None:
        super(Decoder, self).__init__()

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel, stride, padding),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel, stride, padding),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, kernel, stride, padding),
            nn.ReLU()
        )
        
    
    def decode(self, x: torch.Tensor) -> torch.Tensor:
        x = self.decoder(x)
        return x
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.decode(x)


In [180]:
decoder = Decoder()

In [188]:
mean, logvar, _ = encoder.encode(image)
print("Mean shape: ", mean.shape)
print("Logvar shape: ", logvar.shape)

Mean shape:  torch.Size([1, 1152])
Logvar shape:  torch.Size([1, 1152])


In [None]:
z = variational.rsample(n_samples)

In [183]:
decoder.decode(output_image[0].view(1, 128, 3, 3)).shape

torch.Size([1, 3, 9, 9])