In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

from vqvaes.models import build_vqvae
from vqvaes.models.vq import VQ
from vqvaes.models.layers import Encoder, Decoder
import torch

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x150e397b2ad0>

In [4]:
vq = VQ(
    codebook_size=4,
    codebook_dim=12
)

vq

VQ(
  (codebook): Embedding(4, 12)
)

In [11]:
encoding = torch.randn((1, 3, 3, 12), requires_grad=True)
vq(encoding)

{'quantize': tensor([[[[ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463,
            -0.8437, -0.6136,  0.0316, -0.4927,  0.2484],
           [ 0.4397,  0.1124,  0.6408,  0.4412, -0.1023,  0.7924, -0.2897,
             0.0525,  0.5229,  2.3022, -1.4689, -1.5867],
           [ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463,
            -0.8437, -0.6136,  0.0316, -0.4927,  0.2484]],
 
          [[-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160,
            -2.1152,  0.3223, -1.2633,  0.3500,  0.3081],
           [-1.1258, -1.1524, -0.2506, -0.4339,  0.8487,  0.6920, -0.3160,
            -2.1152,  0.3223, -1.2633,  0.3500,  0.3081],
           [ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463,
            -0.8437, -0.6136,  0.0316, -0.4927,  0.2484]],
 
          [[ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959,  0.5667,
             0.7935,  0.5988, -1.5551, -0.3414,  1.8530],
           [ 0.7502, -0.5855, -0.1734,  0.1835,  1.

In [12]:
a = torch.tensor([[1, 2]])
b = torch.tensor([[3, 4],
                  [5, 6]])
a.shape, b.shape

(torch.Size([1, 2]), torch.Size([2, 2]))

In [13]:
c = a - b
c

tensor([[-2, -2],
        [-4, -4]])

In [14]:
encoder = Encoder(
    in_channels=3,
    num_channels=16,
    num_residual_blocks=2,
    num_residual_channels=8
)

encoder

Encoder(
  (encoder): Sequential(
    (0): Conv2d(3, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ResidualStack(
      (stack): Sequential(
        (0): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (1): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
    )
  )
)

In [15]:
input = torch.randn((1, 3, 128, 128))
encoder_output = encoder(input)
encoder_output.shape

torch.Size([1, 16, 32, 32])

In [16]:
decoder = Decoder(
    in_channels=16,
    out_channels=3,
    num_residual_blocks=2,
    num_residual_channels=8
)

decoder

Decoder(
  (decoder): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ResidualStack(
      (stack): Sequential(
        (0): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (1): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
    )
    (2): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(8, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)

In [17]:
decoder_output = decoder(encoder_output)
decoder_output.shape

torch.Size([1, 3, 128, 128])

In [18]:
vqvae = build_vqvae(
    codebook_size=4,
    codebook_dim=12,
    in_channels=3,
    num_channels=16,
    num_residual_blocks=2,
    num_residual_channels=8
)

vqvae

VQVAE(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(3, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ResidualStack(
        (stack): Sequential(
          (0): ResidualBlock(
            (conv): Sequential(
              (0): ReLU()
              (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): ReLU()
              (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (1): ResidualBlock(
            (conv): Sequential(
              (0): ReLU()
              (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): ReLU()
              (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
        )
      )
    )
  )
  (pre_vq): C

In [19]:
inputs = torch.randn((1, 3, 128, 128))
vqvae(inputs)

(tensor([[[[-0.0875, -0.0195, -0.1159,  ..., -0.0513, -0.0977, -0.0666],
           [-0.0410, -0.0677, -0.0594,  ..., -0.0722, -0.1221, -0.0596],
           [-0.0769, -0.1109, -0.1247,  ..., -0.0046, -0.0660, -0.0894],
           ...,
           [-0.0882, -0.0533,  0.0177,  ..., -0.1384, -0.0649, -0.0333],
           [-0.0996, -0.1146, -0.1462,  ..., -0.0664, -0.1311, -0.1002],
           [-0.0599, -0.0733, -0.0410,  ..., -0.0884, -0.0584, -0.0508]],
 
          [[-0.0022, -0.0190, -0.0107,  ...,  0.0084, -0.0137,  0.0125],
           [-0.0052,  0.0245, -0.0473,  ...,  0.0631,  0.0003,  0.0259],
           [-0.0060,  0.0080, -0.0332,  ..., -0.0217, -0.0351,  0.0238],
           ...,
           [ 0.0284,  0.0316, -0.0884,  ..., -0.0796,  0.0082,  0.0566],
           [ 0.0233,  0.0420, -0.0359,  ...,  0.0490, -0.0082,  0.0540],
           [ 0.0110,  0.0054,  0.0051,  ..., -0.0223,  0.0143,  0.0194]],
 
          [[ 0.0590,  0.0154,  0.0851,  ...,  0.0267,  0.0908,  0.0626],
           [ 