In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

from vqvaes.models import build_vqvae
from vqvaes.models.vq import VQ
from vqvaes.models.layers import Encoder, Decoder
import torch

In [3]:
torch.manual_seed(0)

<torch._C.Generator at 0x14d8963cb030>

In [4]:
vq = VQ(
    codebook_size=4,
    codebook_dim=12
)

vq

VQ(
  (codebook): Embedding(4, 12)
)

In [5]:
encoding = torch.randn((1, 3, 3, 12), requires_grad=True)
vq(encoding)



{'quantized': tensor([[[[ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959,  0.5667,
             0.7935,  0.5988, -1.5551, -0.3414,  1.8530],
           [ 0.4397,  0.1124,  0.6408,  0.4412, -0.1023,  0.7924, -0.2897,
             0.0525,  0.5229,  2.3022, -1.4689, -1.5867],
           [ 0.4397,  0.1124,  0.6408,  0.4412, -0.1023,  0.7924, -0.2897,
             0.0525,  0.5229,  2.3022, -1.4689, -1.5867]],
 
          [[ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463,
            -0.8437, -0.6136,  0.0316, -0.4927,  0.2484],
           [ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959,  0.5667,
             0.7935,  0.5988, -1.5551, -0.3414,  1.8530],
           [ 0.1198,  1.2377,  1.1168, -0.2473, -1.3527, -1.6959,  0.5667,
             0.7935,  0.5988, -1.5551, -0.3414,  1.8530]],
 
          [[ 0.7502, -0.5855, -0.1734,  0.1835,  1.3894,  1.5863,  0.9463,
            -0.8437, -0.6136,  0.0316, -0.4927,  0.2484],
           [ 0.7502, -0.5855, -0.1734,  0.1835,  1

In [6]:
a = torch.tensor([[1, 2]])
b = torch.tensor([[3, 4],
                  [5, 6]])
a.shape, b.shape

(torch.Size([1, 2]), torch.Size([2, 2]))

In [7]:
c = a - b
c

tensor([[-2, -2],
        [-4, -4]])

In [8]:
encoder = Encoder(
    in_channels=3,
    num_channels=16,
    num_residual_blocks=2,
    num_residual_channels=8
)

encoder

Encoder(
  (encoder): Sequential(
    (0): Conv2d(3, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (5): ResidualStack(
      (stack): Sequential(
        (0): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (1): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
    )
  )
)

In [9]:
input = torch.randn((1, 3, 128, 128))
encoder_output = encoder(input)
encoder_output.shape

torch.Size([1, 16, 32, 32])

In [10]:
decoder = Decoder(
    in_channels=16,
    out_channels=3,
    num_residual_blocks=2,
    num_residual_channels=8
)

decoder

Decoder(
  (decoder): Sequential(
    (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ResidualStack(
      (stack): Sequential(
        (0): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
        (1): ResidualBlock(
          (conv): Sequential(
            (0): ReLU()
            (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (2): ReLU()
            (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          )
        )
      )
    )
    (2): ConvTranspose2d(16, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU()
    (4): ConvTranspose2d(8, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)

In [11]:
decoder_output = decoder(encoder_output)
decoder_output.shape

torch.Size([1, 3, 128, 128])

In [21]:
vqvae = build_vqvae(
    codebook_size=4,
    codebook_dim=12,
    in_channels=3,
    num_channels=16,
    num_residual_blocks=2,
    num_residual_channels=8
)

vqvae

VQVAE(
  (encoder): Encoder(
    (encoder): Sequential(
      (0): Conv2d(3, 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(8, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): ResidualStack(
        (stack): Sequential(
          (0): ResidualBlock(
            (conv): Sequential(
              (0): ReLU()
              (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): ReLU()
              (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
          (1): ResidualBlock(
            (conv): Sequential(
              (0): ReLU()
              (1): Conv2d(16, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
              (2): ReLU()
              (3): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
            )
          )
        )
      )
    )
  )
  (pre_vq): C

In [24]:
inputs = torch.randn((1, 3, 128, 128))
vqvae(inputs)

(tensor([[[[ 0.0088,  0.0196,  0.0280,  ...,  0.0313,  0.0230,  0.0392],
           [-0.0635,  0.0367,  0.0474,  ...,  0.0488,  0.0540,  0.0407],
           [ 0.0390, -0.0199,  0.0262,  ..., -0.0291,  0.0483, -0.0090],
           ...,
           [-0.0139,  0.0566,  0.0363,  ...,  0.0970,  0.0497,  0.0377],
           [ 0.0442, -0.0829, -0.0013,  ..., -0.0636, -0.0416, -0.0318],
           [ 0.0586,  0.0840,  0.0377,  ...,  0.0921,  0.0769,  0.0546]],
 
          [[-0.0115, -0.0079,  0.0020,  ...,  0.0034,  0.0427, -0.0096],
           [-0.0412, -0.0266,  0.1051,  ...,  0.0210,  0.0308,  0.0538],
           [-0.0256,  0.0207,  0.0064,  ...,  0.0083,  0.1088, -0.0390],
           ...,
           [-0.0058, -0.0018,  0.1042,  ...,  0.0018,  0.0679,  0.0535],
           [ 0.0080,  0.0234,  0.0161,  ..., -0.0272,  0.0124, -0.0415],
           [ 0.0075, -0.0017,  0.0280,  ..., -0.0023,  0.0379,  0.0142]],
 
          [[-0.0679, -0.0811, -0.0203,  ..., -0.0677,  0.0083, -0.0826],
           [-