In [1]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import torch
import os
from skimage import io, transform
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
from skimage.io import imread
from torchsummary import summary
import pandas as pd
import time
from models import *
NAME = 'KL0005_64_model_novo'
KL_PAR = 0.0005

TRAIN_ROOT = '../data/data-celeba/Train'
VAL_ROOT = '../data/data-celeba/Validation'
VAL_ROOT_FIX = '../data/data-celeba/Fix_sample'
BATCH_SIZE = 64
EPOCHS = 10
LOG_INTERVAL = 50
BOTTLENECK_SIZE = 512
SAVE_MODEL = '../models/' + NAME
SAVE_RESULTS = '../results/' + NAME
%mkdir {SAVE_RESULTS}

In [2]:
no_cuda = False
seed = 1
cuda = not no_cuda and torch.cuda.is_available()
torch.manual_seed(seed)
device = torch.device("cuda" if cuda else "cpu")
kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}
torch.manual_seed(seed)


transform_seq = [ transforms.Resize((200,163)), transforms.Pad(( 19, 0, 18, 0)),
                  transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))]

train_loader_celeba = torch.utils.data.DataLoader(
    datasets.ImageFolder(TRAIN_ROOT, transform=transforms.Compose(transform_seq)),
    batch_size = BATCH_SIZE, shuffle=True, **kwargs)

val_loader_celeba = torch.utils.data.DataLoader(
    datasets.ImageFolder(VAL_ROOT, transform=transforms.Compose(transform_seq)),
    batch_size = BATCH_SIZE, shuffle=True, **kwargs)

val_loader_celeba_fix = torch.utils.data.DataLoader(
    datasets.ImageFolder(VAL_ROOT_FIX, transform=transforms.Compose(transform_seq)),
    batch_size = BATCH_SIZE, shuffle=False, **kwargs)

In [3]:
class customLoss(nn.Module):
    def __init__(self):
        super(customLoss, self).__init__()
        self.mse_loss = nn.MSELoss(reduction="mean")

    def forward(self, x_recon, x, mu, logvar):
        loss_MSE = self.mse_loss(x_recon, x)
        loss_KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim = -1)
        loss_KLD = torch.sum(loss_KLD*KL_PAR)/BATCH_SIZE
        return loss_MSE, loss_KLD

In [4]:
model = VAE_CNN(BOTTLENECK_SIZE).to(device)
optimizer = optim.RMSprop(model.parameters(), lr=0.001, alpha= 0.9)
loss_custom = customLoss()

In [5]:
#summary(model, (3,200,200))
model

VAE_CNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
  (mxp1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
  (mxp2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
  (mxp3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(256, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)
  (mxp4): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation

In [6]:
val_losses = []
train_losses = []

def train(epoch):
    epoch_start = 0
    model.train(True)
    train_loss = 0
    start = time.time()
    for batch_idx, (data, _) in enumerate(train_loader_celeba): 
        data = data.to(device)
        optimizer.zero_grad()
        recon_batch, mu, logvar = model(data)
        loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
        loss = loss_mse + loss_kl
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            interval = time.time() - start
            start = time.time()
            epoch_start = epoch_start + interval
            
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss MSE: {:.6f} \tLoss KL: {:.6f} \tTime Interv: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader_celeba.dataset),
                       100. * batch_idx / len(train_loader_celeba),
                       loss_mse.item(), loss_kl.item(), interval))

    print('====> Epoch: {} Average loss: {:.6f} Elapsed Time: {:.6f}'.format(
        epoch, train_loss * BATCH_SIZE / len(train_loader_celeba.dataset), epoch_start))
    train_losses.append(train_loss*BATCH_SIZE/ len(train_loader_celeba.dataset))


In [7]:
def test(epoch, sufix, train):
    model.train(train)
    test_loss_mse = 0
    test_loss_kl = 0
    with torch.no_grad():
        
        for i, (data, _) in enumerate(val_loader_celeba):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
            loss = loss_mse + loss_kl
            test_loss_mse += loss_mse.item()
            test_loss_kl += loss_kl.item()
            
        for i, (data, _) in enumerate(val_loader_celeba_fix):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n],
                                        recon_batch.view(7, 3, 200, 200)[:n]])
                save_image(comparison.cpu(),
                           SAVE_RESULTS + '/reconstruction_' + str(epoch) + sufix + '.png', nrow=n, normalize=True)

    test_loss_mse = test_loss_mse * BATCH_SIZE / len(val_loader_celeba.dataset)
    test_loss_kl = test_loss_kl * BATCH_SIZE / len(val_loader_celeba.dataset)
    print('====> Test set loss-mse: {:.6f}, loss-kl: {:.6f}'.format(test_loss_mse, test_loss_kl))
    val_losses.append((test_loss_mse, test_loss_kl))

In [8]:
len(val_loader_celeba.dataset)

40520

In [None]:
for epoch in range(1, EPOCHS + 1):
    train(epoch)
    test(epoch, '_on', True)
    test(epoch, '_off', False)
    model.train(False)
    with torch.no_grad():
        sample = torch.randn(BATCH_SIZE, BOTTLENECK_SIZE).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(BATCH_SIZE, 3, 200, 200),
                   SAVE_RESULTS + '/sample_off_' + str(epoch) + '.png', normalize=True)
    model.train(True)
    with torch.no_grad():
        sample = torch.randn(BATCH_SIZE, BOTTLENECK_SIZE).to(device)
        sample = model.decode(sample).cpu()
        save_image(sample.view(BATCH_SIZE, 3, 200, 200),
                   SAVE_RESULTS + '/sample_on' + str(epoch) + '.png', normalize=True)

====> Epoch: 1 Average loss: 0.126012 Elapsed Time: 959.574267
====> Test set loss-mse: 0.075084, loss-kl: 0.026208
====> Test set loss-mse: 0.128408, loss-kl: 0.029993


In [None]:
torch.randn(BATCH_SIZE, BOTTLENECK_SIZE)

In [None]:
x = torch.Tensor([[1,2,2,4]])
x.exp().sum()

In [None]:
x_v = Variable(x.data.n)

In [None]:
x.new(x.size()).normal_()

In [None]:
x.normal_()

In [None]:
x_v

In [None]:
data = [train_loader_celeba.dataset[i][0].view(1,3,200,200) for i in range(100, 132)]
'''img = img.view(1,3,200,200)
data = img
data.shape'''

In [None]:
a = 34
a

In [None]:
data = torch.cat(data)

In [None]:
data.shape

In [None]:
model.train(True)
train_loss = 0

data = data.to(device)
optimizer.zero_grad()
recon_batch, mu, logvar = model(data)
loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
loss = loss_mse + loss_kl
loss.backward()
train_loss += loss.item()
optimizer.step()

In [None]:
len(data)

In [None]:
recon_batch.shape

In [None]:
loss_mse/32

In [None]:
dif = (recon_batch - data)
dif.pow(2).sum()/(200*200*3*32)

In [None]:
loss_kl/32

In [None]:
torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            }, SAVE_MODEL)

In [12]:
torch.cuda.empty_cache()

In [13]:
sample = 4
sample_out1 = 5
sample_out2 = 6
sample = torch.randn(4, BOTTLENECK_SIZE).to(device)

In [50]:
sample_out1 = model.decode(sample).cpu()


In [51]:
sample_out1.shape

torch.Size([4, 3, 200, 200])

In [52]:
model.train(False)

VAE_CNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mxp1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mxp2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mxp3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (mxp4): MaxPool2d(kernel_size=2, stride=2, padding=1, dilation=1, 

In [53]:
sample_out2 = model.decode(sample).cpu()


In [54]:
(sample_out1-sample_out2).detach().numpy().sum()

0.0

In [12]:

model.train(True)
sample = model.decode(sample).cpu()
save_image(sample.view(BATCH_SIZE, 3, 200, 200),
                   SAVE_RESULTS + '/sample_on_semnograd_' + str(epoch) + '.png', normalize=True)

In [63]:
for i in model.bn1.parameters():
    print(i)

Parameter containing:
tensor([0.9292, 0.8448, 0.9831, 1.0190, 0.9913, 0.9115, 0.9428, 1.0651, 0.8819,
        0.9167, 1.1364, 1.1123, 1.1135, 1.0974, 0.9329, 1.0686, 0.9492, 0.9299,
        1.1258, 1.1505, 1.1051, 0.9718, 1.0167, 0.9452, 1.1193, 1.0964, 0.8607,
        0.9682, 1.1332, 1.0189, 0.9768, 1.1475], device='cuda:0',
       requires_grad=True)
Parameter containing:
tensor([-0.0426, -0.1471, -0.0447,  0.0336,  0.0034, -0.1570,  0.0412,  0.1586,
        -0.0386, -0.1243,  0.0757,  0.0099,  0.0115,  0.2145, -0.0811,  0.0070,
        -0.0220, -0.0144,  0.2599,  0.2122,  0.0909, -0.0056,  0.0440,  0.0563,
         0.1553,  0.2559,  0.0633,  0.0504,  0.0123,  0.2103, -0.0083,  0.1816],
       device='cuda:0', requires_grad=True)


In [None]:
'''from torchvision.utils import make_grid
import matplotlib.pyplot as plt'''

In [None]:
'''%matplotlib inline
def show(img):
    npimg = img.detach().numpy()
    plt.figure(figsize=(30, 10))
    plt.imshow(np.transpose(npimg, (1,2,0)), interpolation='nearest')'''

In [None]:
'''model.train(True)
test_loss = 0
for i, (data, _) in enumerate(val_loader_celeba):
    data = data.to(device)
    recon_batch, mu, logvar = model(data)
    loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
    loss = loss_mse + loss_kl
    test_loss += loss.item()
    if i == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],
                                recon_batch.view(7, 3, 200, 200)[:n]])
        show(make_grid(comparison.cpu(), nrow=n, normalize=True))'''

In [None]:
'''checkpoint = torch.load('model_save')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']
'''


In [None]:
'''model.train(True)
test_loss = 0
for i, (data, _) in enumerate(val_loader_celeba):
    data = data.to(device)
    recon_batch, mu, logvar = model(data)
    loss_mse, loss_kl = loss_custom(recon_batch, data, mu, logvar)
    loss = loss_mse + loss_kl
    test_loss += loss.item()
    if i == 0:
        n = min(data.size(0), 8)
        comparison = torch.cat([data[:n],
                                recon_batch.view(7, 3, 200, 200)[:n]])
        show(make_grid(comparison.cpu(), nrow=n, normalize=True))'''

In [None]:
'''import random
x_test = []
plt.figure(figsize=(30,10))
num_figs = 1
plt.figure(figsize=(15,100))
plt.axis('off')
plt.imshow(make_grid(comparison.cpu(), nrow=n, normalize=True).detach().numpy())

for i in range(num_figs):
    figure_Decoded = vae_2.predict(np.array([x_test[i].astype('float32')/127.5 -1]), batch_size = b_size)
    plt.axis('off')
    plt.subplot(num_figs,2,1+i*2)
    plt.imshow(x_test[i])
    plt.axis('off')
    plt.subplot(num_figs,2,2 + i*2)
    plt.imshow((figure_Decoded[0]+1)/2)
    plt.axis('off')

plt.show()'''