In [1]:
import torch
import numpy as np
from models import encoder, decoder, quantizer, residual, vqvae

In [2]:
from torch import nn

In [4]:
enc = encoder.Encoder(in_dim=3, h_dim=64, n_res_layers=2, res_h_dim=64)

In [5]:
pre_quantization_conv = nn.Conv2d(64, 64, kernel_size=1, stride=1)

In [6]:
vector_quantization = quantizer.VectorQuantizer(64, 64, 0.25)

In [7]:
dec = decoder.Decoder(in_dim=64, h_dim=64, n_res_layers=2, res_h_dim=64)

In [8]:
enc.cuda()
pre_quantization_conv.cuda()
vector_quantization.cuda()
dec.cuda()

Decoder(
  (inverse_conv_stack): Sequential(
    (0): Upsample(size=3, mode='bilinear')
    (1): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (2): ReLU(inplace=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Upsample(size=8, mode='bilinear')
    (5): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (6): ReLU(inplace=True)
    (7): Dropout(p=0.2, inplace=False)
    (8): Upsample(size=15, mode='bilinear')
    (9): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (10): ReLU(inplace=True)
    (11): Dropout(p=0.2, inplace=False)
    (12): Upsample(size=32, mode='bilinear')
    (13): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (14): ReLU(inplace=True)
    (15): Dropout(p=0.2, inplace=False)
    (16): Upsample(size=63, mode='bilinear')
    (17): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
    (18): ReLU(inplace=True)
    (19): Dropout(p=0.2, inplace=False)
    (20): Upsampl

In [9]:
x = np.random.random_sample((1, 3, 256, 256))
x = torch.tensor(x).float().cuda()

In [10]:
x.shape

torch.Size([1, 3, 256, 256])

In [11]:
z_e = enc(x)
z_e.shape

torch.Size([1, 64, 16, 16])

In [12]:
z_e = pre_quantization_conv(z_e)
z_e.shape

torch.Size([1, 64, 16, 16])

In [13]:
_, z_q, _, _, _ = vector_quantization(z_e)
z_q.shape

torch.Size([1, 64, 16, 16])

In [14]:
x_hat = dec(z_q)

In [15]:
x_hat.shape

torch.Size([1, 384, 56, 56])

# effad ae

In [16]:
from common import get_autoencoder, get_pdn_small, get_pdn_medium

In [17]:
autoencoder = get_autoencoder()

In [18]:
autoencoder.cuda()

Sequential(
  (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (5): ReLU(inplace=True)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): Conv2d(64, 64, kernel_size=(8, 8), stride=(1, 1))
  (11): Upsample(size=3, mode='bilinear')
  (12): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
  (13): ReLU(inplace=True)
  (14): Dropout(p=0.2, inplace=False)
  (15): Upsample(size=8, mode='bilinear')
  (16): Conv2d(64, 64, kernel_size=(4, 4), stride=(1, 1), padding=(2, 2))
  (17): ReLU(inplace=True)
  (18): Dropout(p=0.2, inplace=False)
  (19): Upsample(size=15, mode='bilinear')
  (20): Conv2d(64, 64, kernel_s

In [19]:
autoencoder(x).shape

torch.Size([1, 384, 56, 56])

In [43]:
test_model = nn.Sequential(
        # encoder
        nn.Conv2d(in_channels=3, out_channels=32, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=32, out_channels=32, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=8))

In [80]:
test_model.cuda()

Sequential(
  (0): Conv2d(3, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(32, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (5): ReLU(inplace=True)
  (6): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(64, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (9): ReLU(inplace=True)
  (10): Conv2d(64, 64, kernel_size=(8, 8), stride=(1, 1))
)

In [20]:
pdn = get_pdn_medium()

In [21]:
pdn.cuda()

Sequential(
  (0): Conv2d(3, 256, kernel_size=(4, 4), stride=(1, 1))
  (1): ReLU(inplace=True)
  (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (3): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1))
  (4): ReLU(inplace=True)
  (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
  (6): Conv2d(512, 512, kernel_size=(1, 1), stride=(1, 1))
  (7): ReLU(inplace=True)
  (8): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1))
  (9): ReLU(inplace=True)
  (10): Conv2d(512, 384, kernel_size=(4, 4), stride=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(384, 384, kernel_size=(1, 1), stride=(1, 1))
)

In [22]:
pdn(x).shape

torch.Size([1, 384, 56, 56])

In [48]:
x.shape

torch.Size([1, 3, 256, 256])

In [96]:
a = torch.randn([1, 64, 16, 16]).cuda()

In [87]:
a.shape

torch.Size([1, 3, 16, 16])

In [88]:
test_net = nn.ConvTranspose2d(3, 64, kernel_size=3, stride=2, padding=1).cuda()

In [89]:
test_net(a).shape

torch.Size([1, 64, 31, 31])

In [92]:
aaaa = nn.Sequential(
        # decoder
        nn.Upsample(size=3, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Upsample(size=8, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Upsample(size=15, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Upsample(size=32, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Upsample(size=63, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Upsample(size=127, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=1,
                  padding=2),
        nn.ReLU(inplace=True),
        nn.Dropout(0.2),
        nn.Upsample(size=56, mode='bilinear'),
        nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1,
                  padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(in_channels=64, out_channels=384, kernel_size=3,
                  stride=1, padding=1)
    ).cuda()

In [97]:
aaaa(a).shape

torch.Size([1, 384, 56, 56])