In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as autograd
import torch.optim as optim
import torch.utils.data as td
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
import cv2
import torchvision

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Date loader
path_image_bottom = '/root/userspace/project/image_bottom/'
path_image_top = '/root/userspace/project/image_top/'
img_bottom_list = os.listdir(path_image_bottom)
img_top_list = os.listdir(path_image_top)

# bottom image to nparray
max_size = 100000
idx = 0
img_bottom = []
for l in img_bottom_list:
    if idx > max_size:
        break
    p = path_image_bottom + l
    im = cv2.imread(p)
    img_bottom.append(im)
    idx += 1
    
# top image to nparray
idx = 0
img_top = []
for l in img_top_list:
    if idx > max_size:
        break
    p = path_image_top + l
    im = cv2.imread(p)
    img_top.append(im)
    idx += 1

# list to torch tensor
# axis change from (n_image, height, width, channel) to (n_image, channel, height, width)
# each element is divided by 255 to float
x_tensor = torch.from_numpy((np.array(img_bottom) / 255.).astype(np.float32)).to(device)
label_tensor = torch.from_numpy((np.array(img_top) / 255.).astype(np.float32)).to(device)
x_tensor = x_tensor.permute((0,3,1,2))
label_tensor = label_tensor.permute((0,3,1,2))
    
# Define Data Loader
batch_size = 100 # for test
trainset = td.TensorDataset(x_tensor, label_tensor) # train data and the label
dataloader = td.DataLoader(dataset=trainset, batch_size=batch_size, shuffle=True)

    Found GPU0 GRID K520 which is of cuda capability 3.0.
    PyTorch no longer supports this GPU because it is too old.
    


In [4]:
# CVAE Model components
# conv2d: kernel=1->filter size=[1,1], pad=0->padding size=0 (NONE padding),pad=1->ZERO padding for size 1 (for both edges)
# why no pooling ??
class EncoderModule(nn.Module):
    def __init__(self, input_channels, output_channels, stride, kernel, pad):
        super().__init__()
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=kernel, padding=pad, stride=stride)
        self.bn = nn.BatchNorm2d(output_channels)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        return self.relu(self.bn(self.conv(x)))
    
class Encoder(nn.Module):
    def __init__(self, color_channels, pooling_kernels, n_neurons_in_middle_layer):
        self.n_neurons_in_middle_layer = n_neurons_in_middle_layer
        super().__init__()
        # (N,3,50,100)->(N,32,50,100)
        self.bottle = EncoderModule(color_channels, 32, stride=1, kernel=1, pad=0)
        # (N,32,50,100)->(N,64,50,100)
        self.m1 = EncoderModule(32, 64, stride=1, kernel=3, pad=1)
        # (N,64,50,100)->(N,128,12,24) with stride size 4. TO BE REVIEWED: pad=0 or p
        self.m2 = EncoderModule(64, 128, stride=4, kernel=[3,5], pad=0)
        # (N,128,12,24)->(N,256,4,4) with stride size [2,4] TO BE REVIEWED: pad=0 or p
        self.m3 = EncoderModule(128, 256, stride=[2,4], kernel=[5,9], pad=0)
    def forward(self, x):
        out = self.m3(self.m2(self.m1(self.bottle(x))))
        return out.view(-1, self.n_neurons_in_middle_layer)

class DecoderModule(nn.Module):
    def __init__(self, input_channels, output_channels, stride, activation="relu", kernel=1):
        super().__init__()
        self.convt = nn.ConvTranspose2d(input_channels, output_channels, kernel_size=kernel, stride=stride)
        self.bn = nn.BatchNorm2d(output_channels)
        if activation == "relu":
            self.activation = nn.ReLU(inplace=True)
        elif activation == "sigmoid":
            self.activation = nn.Sigmoid()
    def forward(self, x):
        return self.activation(self.bn(self.convt(x)))

class Decoder(nn.Module):
    def __init__(self, in_channels, color_channels, pooling_kernels, decoder_input_size):
        self.decoder_input_size = decoder_input_size
        self.in_channels = in_channels
        super().__init__()
        # (N,128,12,24)<-(N,256,4,4) (in_channels=256 for VAE, 512 for CVAE)
        self.m1 = DecoderModule(self.in_channels, 128, stride=[3,6], kernel=[3,6])
        # (N,64,50,100)<-(N,128,12,24)
        self.m2 = DecoderModule(128, 64, stride=[4,4], kernel=[6,8])
        # (N,32,50,100)<-(N,64,50,100)
        self.m3 = DecoderModule(64, 32, stride=1, kernel=1)
        # (N,3,50,100)<-(N,32,50,100)
        self.bottle = DecoderModule(32, color_channels, stride=1, activation="sigmoid", kernel=1)
    def forward(self, x):
        # support 2 cases
        # VAE: input size (N,256*4*4) ->(N,256,4,4)
        # CVAE: input size (N,2*256*4*4) ->(N,512,4,4)
        out = x.view(-1, self.in_channels, self.decoder_input_size, self.decoder_input_size)
        out = self.m3(self.m2(self.m1(out)))
        return self.bottle(out)

In [5]:
# CVAE model
class VAE(nn.Module):
    def __init__(self, conditional):
        self.device = device
        self.conditional = conditional # if true, CVAE

        super().__init__()
        # resolution
        # cifar : 32 -> 8 -> 4
        pooling_kernel = [1, 1]
        encoder_output_size = 4 # TO BE REVIEWED: would be not nice... (should be automatically linked to the Encoder architecture)
        color_channels = 3

        # Middle
        # output size [N,256,4,4]
        self.n_latent_features = 64 # dimension of z
        encoder_1d_output_size = 256 * encoder_output_size * encoder_output_size
        n_neurons_middle_layer = (2 if self.conditional else 1) * encoder_1d_output_size # we double if CVAE

        # Encoder for X (lower-half face)
        self.encoder_x = Encoder(color_channels, pooling_kernel, encoder_1d_output_size)        
        # Encoder for Y (Upper-half face, "label")
        if self.conditional:
            self.encoder_y = Encoder(color_channels, pooling_kernel, encoder_1d_output_size) # ALMOST 100% SURE ABOUT THIS SETTING, BUT TO BE REVIEWED
        
        self.fc1 = nn.Linear(n_neurons_middle_layer, self.n_latent_features) # for encoder X
        self.fc2 = nn.Linear(n_neurons_middle_layer, self.n_latent_features) # for encoder X
        self.fc3 = nn.Linear(self.n_latent_features, encoder_1d_output_size) # for decoder X

        if self.conditional:
            self.fc4 = nn.Linear(encoder_1d_output_size, self.n_latent_features) # for encoder Y
            self.fc5 = nn.Linear(encoder_1d_output_size, self.n_latent_features) # for encoder Y
            self.fc6 = nn.Linear(self.n_latent_features, encoder_1d_output_size) # for decoder Y

        # Decoder for X (Labels should be taken into account also for this decoder stage...)
        self.decoder_x = Decoder(256*(2 if self.conditional else 1), color_channels, pooling_kernel, encoder_output_size)
        # Decoder for Y (Upper-half face, "label")
        if self.conditional:
            self.decoder_y = Decoder(256, color_channels, pooling_kernel, encoder_output_size)

    def _reparameterize(self, mu, logvar):
        std = logvar.mul(0.5).exp_()
        esp = torch.randn(*mu.size()).to(self.device)
        z = mu + std * esp
        return z
    
    def _bottleneck_x(self, h):
        mu, logvar = self.fc1(h), self.fc2(h)
        z = self._reparameterize(mu, logvar)
        return z, mu, logvar

    def _bottleneck_y(self, h):
        mu, logvar = self.fc4(h), self.fc5(h)
        z = self._reparameterize(mu, logvar)
        return z, mu, logvar

    def sampling(self, y=None, r=None):
        # assume latent features space ~ N(0, 1)
        n = y.size(0) if self.conditional else 64 # number of faces
        if r==None:
            zx = torch.randn(n, self.n_latent_features).to(self.device)
        else:
            zx = torch.full((n,self.n_latent_features), fill_value=r).to(self.device)
        zx = self.fc3(zx) # 256 * encode_out_size(4)* encode_out_size(4)
        
        # concat
        if self.conditional:
            hy = self.encoder_y(y) # 256 * encode_out_size(4)* encode_out_size(4)
            z = torch.cat((zx, hy), dim=-1) # 2* 256 * encode_out_size(4)* encode_out_size(4)
        else:
            z = zx
        # decode
        dx = self.decoder_x(z)
        d = torch.cat((y,dx), dim=2)
        return d

    def forward(self, x, y=None):
        # Encoder for X and for Y
        # output size [N,256*4*4]
        hx = self.encoder_x(x)
        if self.conditional:
            hy = self.encoder_y(y)
        # concat if CVAE
        if self.conditional:
            h = torch.cat((hx, hy), dim=-1)
        else:
            h = hx
        # Bottle-neck
        # z,mu,logvar size [N,256*4*4]->[N,64]
        zx, mu_x, logvar_x = self._bottleneck_x(h)
        if self.conditional:
            zy, mu_y, logvar_y = self._bottleneck_y(hy)
            
        # decoder
        # output size [N,64]->[N,256*4*4]
        zx = self.fc3(zx)
        if self.conditional:
            zx = torch.cat((zx, hy), dim=-1) # 2* 256 * encode_out_size(4)* encode_out_size(4)
            zy = self.fc6(zy)
            
        dx = self.decoder_x(zx)
        if self.conditional:
            dy = self.decoder_y(zy)
            d = torch.cat((dy,dx), dim=2)
            mu = torch.sqrt(mu_x.pow(2) + mu_y.pow(2))
            logvar = torch.sqrt(logvar_x.pow(2) + logvar_y.pow(2))
        else:
            d = dx
            mu = mu_x
            logvar = logvar_x
        return d, mu, logvar

    # Model
    def loss_function(self, recon_x, x, mu, logvar):
        # https://arxiv.org/abs/1312.6114 (Appendix B)
        BCE = F.binary_cross_entropy(recon_x, x, size_average=False)        
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return BCE + KLD

    def init_model(self):
        self.optimizer = optim.Adam(self.parameters(), lr=1e-3)

        if self.device == "cuda":
            self = self.cuda()
            torch.backends.cudnn.benchmark=True
        self.to(self.device)

In [3]:
# Init CVAE model
conditional = True
cvae = VAE(conditional)
cvae.init_model()

# Train
cvae.train()
nepoch = 20
nReportFreq = 20
for epoch in range(nepoch):
    train_loss = 0
    samples_cnt = 0
    for i, data in enumerate(dataloader):
        # Extract input (lower-half face) and label (upper-half face)
        inputs, labels = data
        # Define them as torch Variable (auto-gradient will be applied on)
        inputs, labels = Variable(inputs), Variable(labels)
        # init the optimiser (FORGOT why we initialise the grad at the beginning of each epoch)
        cvae.optimizer.zero_grad()
        # forward
        recon_batch, mu, logvar = cvae.forward(inputs, labels)
        # evaluate loss
        inputs_full = torch.cat((labels, inputs), dim=2)
        loss = cvae.loss_function(recon_batch, inputs_full, mu, logvar)
        # back prop and update the weights
        loss.backward()
        cvae.optimizer.step()

        train_loss += loss.item()
        samples_cnt += inputs.size(0)
        
        if i%nReportFreq == 0:
            print('#Epoch: {:}, #Batch: {:}, #Pics:{:} Loss: {:.2f}'.format(epoch, i, batch_size*i, train_loss / samples_cnt))
        
# save model
state_dict = cvae.state_dict()
torch.save(state_dict, '/root/userspace/project/model2.pth')

NameError: name 'VAE' is not defined

In [6]:
# Testing
# 学習済みパラメータの読み込み
conditional = True
cvae = VAE(conditional)
cvae.init_model()
state_dict = torch.load('/root/userspace/project/model2.pth')
cvae.load_state_dict(state_dict)

cvae.eval()
val_loss = 0
samples_cnt = 0
nReportFreq = 20
with torch.no_grad():
    for i, data in enumerate(dataloader):
        # Extract input (lower-half face) and label (upper-half face)
        inputs, labels = data
        # Define them as torch Variable (auto-gradient will be applied on)
        inputs, labels = Variable(inputs), Variable(labels)
        recon_batch, mu, logvar = cvae.forward(inputs, labels)
        inputs_full = torch.cat((labels, inputs), dim=2)
        val_loss += cvae.loss_function(recon_batch, inputs_full, mu, logvar).item()
        samples_cnt += inputs.size(0)

        # From BGR -> RGB
        perm = torch.LongTensor([2,1,0])
        recon_batch = recon_batch[:, perm, :, :]
        torchvision.utils.save_image(recon_batch, '/root/userspace/project/reconstruction2/test_'+str(i)+'.png', nrow=10)

        if i%nReportFreq == 0:
            print('#Batch: {:}, Loss: {:.2f}'.format(i, val_loss / samples_cnt))



#Batch: 0, Loss: 18919.40
#Batch: 20, Loss: 18858.75
#Batch: 40, Loss: 18862.76
#Batch: 60, Loss: 18865.66
#Batch: 80, Loss: 18872.17
#Batch: 100, Loss: 18867.90
#Batch: 120, Loss: 18865.80


In [6]:
# Random sampling
with torch.no_grad():
    for i, data in enumerate(dataloader):
        if i > 0:
            break
        _, labels = data
        gen_batch = cvae.sampling(y=labels, r=None)

        # BGR -> RGB
        perm = torch.LongTensor([2,1,0])
        full_face_batch = gen_batch[:, perm, :, :]
        torchvision.utils.save_image(full_face_batch, '/root/userspace/project/generation2/test_'+str(i)+'.png', nrow=10, normalize=False, range=(0,255))