In [1]:
import os
import numpy as np
import random
from sklearn.utils import shuffle
from torch.autograd import Variable
import pickle
import torchvision.models as models
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
# 2-d latent space, parameter count in same order of magnitude
# as in the original VAE paper (VAE paper has about 3x as many)
latent_dims = 512
capacity = 64
variational_beta = 1

In [3]:
def globalAvgPooling(x):
    return x.mean(dim=[-2, -1], keepdim=True)

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        c = capacity
        self.resnet152 = models.resnet152(pretrained=True, progress=False)
        for param in self.resnet152.parameters(): # запрещаем обучаться resnet 152
            param.requires_grad = False
        self.resnet152.fc = torch.nn.Sequential(nn.Linear(in_features=2048, out_features=1024, bias=True),
                            nn.ReLU())
        self.fc1 = nn.Linear(in_features=1024*6, out_features=1024, bias=True)   
        self.fc_mu = nn.Linear(in_features=1024, out_features=latent_dims)
        self.fc_logvar = nn.Linear(in_features=1024, out_features=latent_dims)
            
    def forward(self, x):
        x = torch.split(x, 3, dim=1)      # x.shape: batch_size x 3 x 280 x 280
        y = []                       
        for img in x:                     # img.shape: batch_size x 3 x 280 x 280
            img = self.resnet152(img)     # img.shape: batch_size x 1024
            y.append(img)
        
        x = torch.cat(y, dim=1)           # x.shape: batch_size x 1024*6
        x = self.fc1(x)
        x_mu = self.fc_mu(x)
        x_logvar = self.fc_logvar(x)
        return x_mu, x_logvar

In [5]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        c = capacity
        self.fc = nn.Linear(in_features=latent_dims, out_features=c*4*35*35)
        
        self.conv_d1 = nn.ConvTranspose2d(in_channels=c*4, out_channels=c*4, kernel_size=3, stride=1, padding=1)
        self.bn_d1 = nn.BatchNorm2d(num_features=c*4)
        self.conv3 = nn.ConvTranspose2d(in_channels=c*4, out_channels=c*2, kernel_size=4, stride=2, padding=1)
        
        self.conv_d2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c*2, kernel_size=3, stride=1, padding=1)
        self.bn_d2 = nn.BatchNorm2d(num_features=c*2)
        self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
        
        self.conv_d3 = nn.ConvTranspose2d(in_channels=c, out_channels=c, kernel_size=3, stride=1, padding=1)
        self.bn_d3 = nn.BatchNorm2d(num_features=c)
        self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=18, kernel_size=4, stride=2, padding=1)
            
    def forward(self, x):
        x = self.fc(x)
        x = x.view(x.size(0), capacity*4, 35, 35) # unflatten batch of feature vectors to a batch of multi-channel feature maps
        x = F.relu(self.bn_d1(self.conv_d1(x)))
        x = F.relu(self.conv3(x))
        x = F.relu(self.bn_d2(self.conv_d2(x)))
        x = F.relu(self.conv2(x))
        x = F.relu(self.bn_d3(self.conv_d3(x)))
        x = self.conv1(x)
        return x

In [6]:
class VariationalAutoencoder(nn.Module):
    def __init__(self):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
    
    def forward(self, x):
        latent_mu, latent_logvar = self.encoder(x)
        latent = self.latent_sample(latent_mu, latent_logvar)
        x_recon = self.decoder(latent)
        return x_recon, latent_mu, latent_logvar
    
    def latent_sample(self, mu, logvar):
        if self.training:
            # the reparameterization trick
            std = logvar.mul(0.5).exp_()
            eps = torch.empty_like(std).normal_()
            return eps.mul(std).add_(mu)
        else:
            return mu

In [7]:
def vae_loss(recon_x, x, mu, logvar):
    recon_loss = F.mse_loss(x, recon_x)
    
    kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss, variational_beta * kldivergence