In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import nn
from torch.autograd import Variable

In [2]:
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        # Encoder
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1)
        self.bn1 = nn.BatchNorm2d(16)
        
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2)
        self.bn2 = nn.BatchNorm2d(32)
        
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1)
        self.bn3 = nn.BatchNorm2d(64)

        self.fc1 = nn.Linear(8 * 8 * 16, 512)
        self.fc_bn1 = nn.BatchNorm1d(512)
        self.fc21 = nn.Linear(512, 512)
        self.fc22 = nn.Linear(512, 512)

        # Decoder
        self.fc3 = nn.Linear(512, 512)
        self.fc_bn3 = nn.BatchNorm1d(512)
        self.fc4 = nn.Linear(512, 8 * 8 * 16)
        self.fc_bn4 = nn.BatchNorm1d(8 * 8 * 16)

        self.conv5 = nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2, padding=1, output_padding=1,)
        self.bn5 = nn.BatchNorm2d(32)
        
        self.conv6 = nn.ConvTranspose2d(32, 32, kernel_size=3, stride=1, padding=1)
        self.bn6 = nn.BatchNorm2d(32)
        
        self.conv7 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.bn7 = nn.BatchNorm2d(16)
        
        self.conv8 = nn.ConvTranspose2d(16, 3 * 256, kernel_size=3, stride=1, padding=1)

        self.relu = nn.ReLU()
        
    def encode(self, x):
        conv1 = self.relu(self.bn1(self.conv1(x)))
        conv2 = self.relu(self.bn2(self.conv2(conv1)))
        conv3 = self.relu(self.bn3(self.conv3(conv2))).view(-1, 8 * 8 * 16)

        fc1 = self.relu(self.fc_bn1(self.fc1(conv3)))
        return self.fc21(fc1), self.fc22(fc1)

    def reparameterize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
            return mu

    def decode(self, z):
        fc3 = self.relu(self.fc_bn3(self.fc3(z)))
        fc4 = self.relu(self.fc_bn4(self.fc4(fc3))).view(-1, 16, 8, 8)

        conv5 = self.relu(self.bn5(self.conv5(fc4)))
        conv6 = self.relu(self.bn6(self.conv6(conv5)))
        conv7 = self.relu(self.bn7(self.conv7(conv6)))
        return self.conv8(conv7).view(-1, 256, 3, 512, 512)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

In [3]:
transform_test = transforms.Compose([transforms.ToTensor()])

test_data = datasets.ImageFolder('../data/patches/data/test', transform = transform_test)

In [4]:
test_data[0]

(tensor([[[0.5686, 0.5765, 0.5647,  ..., 0.6902, 0.6745, 0.6706],
          [0.5569, 0.5608, 0.5647,  ..., 0.6902, 0.6902, 0.6941],
          [0.5490, 0.5647, 0.5725,  ..., 0.6941, 0.7098, 0.7333],
          ...,
          [0.8471, 0.8706, 0.8745,  ..., 0.5608, 0.5686, 0.5725],
          [0.8549, 0.8784, 0.8824,  ..., 0.5922, 0.6078, 0.5804],
          [0.8510, 0.8784, 0.8745,  ..., 0.5961, 0.6157, 0.5922]],
 
         [[0.4118, 0.4118, 0.4078,  ..., 0.5255, 0.5059, 0.5216],
          [0.4039, 0.4000, 0.4039,  ..., 0.5176, 0.5176, 0.5216],
          [0.3961, 0.3922, 0.4000,  ..., 0.5020, 0.5059, 0.5137],
          ...,
          [0.7176, 0.7176, 0.7059,  ..., 0.4039, 0.4157, 0.4196],
          [0.7059, 0.7020, 0.6941,  ..., 0.4157, 0.4157, 0.4235],
          [0.7098, 0.7098, 0.7059,  ..., 0.4039, 0.4078, 0.4157]],
 
         [[0.6471, 0.6588, 0.6510,  ..., 0.7216, 0.7137, 0.7059],
          [0.6275, 0.6353, 0.6392,  ..., 0.7098, 0.7098, 0.7137],
          [0.6196, 0.6314, 0.6392,  ...,

In [5]:
test_loader = DataLoader(test_data, batch_size=2)

In [10]:
model = VAE()
model.train()

VAE(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc_bn1): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc21): Linear(in_features=512, out_features=512, bias=True)
  (fc22): Linear(in_features=512, out_features=512, bias=True)
  (fc3): Linear(in_features=512, out_features=512, bias=True)
  (fc_bn3): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc4): Linear(in_features=512, out_features=1024, bias=True)
  (fc_bn4): BatchNorm1d(1024, eps=1e-05, momen

In [11]:
for img, label in test_loader:
    outputs = model.encode(img)
    break

In [12]:
outputs[0].shape

torch.Size([7938, 512])

In [13]:
z = model.reparameterize(outputs[0], outputs[1])

In [14]:
z.shape

torch.Size([7938, 512])

In [None]:
model.decode(z)